Timothy Jaeryang Baek 2 kuukautta sitten
vanhempi
commit
f4cd24d2ca
2 muutettua tiedostoa jossa 21 lisäystä ja 18 poistoa
  1. 15 13
      backend/open_webui/models/tags.py
  2. 6 5
      backend/open_webui/routers/chats.py

+ 15 - 13
backend/open_webui/models/tags.py

@@ -47,15 +47,15 @@ class TagChatIdForm(BaseModel):
 
 
 class TagTable:
-    def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
+    async def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
         async with get_db() as db:
             id = name.replace(" ", "_").lower()
             tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
             try:
                 result = Tag(**tag.model_dump())
-                db.add(result)
-                db.commit()
-                db.refresh(result)
+                await db.add(result)
+                await db.commit()
+                await db.refresh(result)
                 if result:
                     return TagModel.model_validate(result)
                 else:
@@ -64,42 +64,44 @@ class TagTable:
                 log.exception(f"Error inserting a new tag: {e}")
                 return None
 
-    def get_tag_by_name_and_user_id(
+    async def get_tag_by_name_and_user_id(
         self, name: str, user_id: str
     ) -> Optional[TagModel]:
         try:
             id = name.replace(" ", "_").lower()
             async with get_db() as db:
-                tag = db.query(Tag).filter_by(id=id, user_id=user_id).first()
+                tag = await db.query(Tag).filter_by(id=id, user_id=user_id).first()
                 return TagModel.model_validate(tag)
         except Exception:
             return None
 
-    def get_tags_by_user_id(self, user_id: str) -> list[TagModel]:
+    async def get_tags_by_user_id(self, user_id: str) -> list[TagModel]:
         async with get_db() as db:
             return [
                 TagModel.model_validate(tag)
-                for tag in (db.query(Tag).filter_by(user_id=user_id).all())
+                for tag in (await db.query(Tag).filter_by(user_id=user_id).all())
             ]
 
-    def get_tags_by_ids_and_user_id(
+    async def get_tags_by_ids_and_user_id(
         self, ids: list[str], user_id: str
     ) -> list[TagModel]:
         async with get_db() as db:
             return [
                 TagModel.model_validate(tag)
                 for tag in (
-                    db.query(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id).all()
+                    await db.query(Tag)
+                    .filter(Tag.id.in_(ids), Tag.user_id == user_id)
+                    .all()
                 )
             ]
 
-    def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool:
+    async def delete_tag_by_name_and_user_id(self, name: str, user_id: str) -> bool:
         try:
             async with get_db() as db:
                 id = name.replace(" ", "_").lower()
-                res = db.query(Tag).filter_by(id=id, user_id=user_id).delete()
+                res = await db.query(Tag).filter_by(id=id, user_id=user_id).delete()
                 log.debug(f"res: {res}")
-                db.commit()
+                await db.commit()
                 return True
         except Exception as e:
             log.error(f"delete_tag: {e}")

+ 6 - 5
backend/open_webui/routers/chats.py

@@ -148,9 +148,10 @@ async def import_chat(form_data: ChatImportForm, user=Depends(get_verified_user)
                 tag_name = " ".join([word.capitalize() for word in tag_id.split("_")])
                 if (
                     tag_id != "none"
-                    and Tags.get_tag_by_name_and_user_id(tag_name, user.id) is None
+                    and await Tags.get_tag_by_name_and_user_id(tag_name, user.id)
+                    is None
                 ):
-                    Tags.insert_new_tag(tag_name, user.id)
+                    await Tags.insert_new_tag(tag_name, user.id)
 
         return ChatResponse(**chat.model_dump())
     except Exception as e:
@@ -261,7 +262,7 @@ async def get_user_archived_chats(user=Depends(get_verified_user)):
 @router.get("/all/tags", response_model=list[TagModel])
 async def get_all_user_tags(user=Depends(get_verified_user)):
     try:
-        tags = Tags.get_tags_by_user_id(user.id)
+        tags = await Tags.get_tags_by_user_id(user.id)
         return tags
     except Exception as e:
         log.exception(e)
@@ -556,7 +557,7 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
         chat = await Chats.get_chat_by_id(id)
         for tag in chat.meta.get("tags", []):
             if await Chats.count_chats_by_tag_name_and_user_id(tag, user.id) == 1:
-                Tags.delete_tag_by_name_and_user_id(tag, user.id)
+                await Tags.delete_tag_by_name_and_user_id(tag, user.id)
 
         result = await Chats.delete_chat_by_id_and_user_id(id, user.id)
         return result
@@ -694,7 +695,7 @@ async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
                     == 0
                 ):
                     log.debug(f"deleting tag: {tag_id}")
-                    Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
+                    await Tags.delete_tag_by_name_and_user_id(tag_id, user.id)
         else:
             for tag_id in chat.meta.get("tags", []):
                 tag = await Tags.get_tag_by_name_and_user_id(tag_id, user.id)