|
@@ -44,6 +44,10 @@ class ChatForm(BaseModel):
|
|
|
chat: dict
|
|
|
|
|
|
|
|
|
+class ChatTitleForm(BaseModel):
|
|
|
+ title: str
|
|
|
+
|
|
|
+
|
|
|
class ChatResponse(BaseModel):
|
|
|
id: str
|
|
|
user_id: str
|
|
@@ -93,6 +97,20 @@ class ChatTable:
|
|
|
except:
|
|
|
return None
|
|
|
|
|
|
+ def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
|
|
|
+ try:
|
|
|
+ query = Chat.update(
|
|
|
+ chat=json.dumps(chat),
|
|
|
+ title=chat["title"] if "title" in chat else "New Chat",
|
|
|
+ timestamp=int(time.time()),
|
|
|
+ ).where(Chat.id == id)
|
|
|
+ query.execute()
|
|
|
+
|
|
|
+ chat = Chat.get(Chat.id == id)
|
|
|
+ return ChatModel(**model_to_dict(chat))
|
|
|
+ except:
|
|
|
+ return None
|
|
|
+
|
|
|
def get_chat_lists_by_user_id(
|
|
|
self, user_id: str, skip: int = 0, limit: int = 50
|
|
|
) -> List[ChatModel]:
|