1
0
Timothy Jaeryang Baek 2 сар өмнө
parent
commit
652dcabd86

+ 24 - 24
backend/open_webui/models/memories.py

@@ -37,7 +37,7 @@ class MemoryModel(BaseModel):
 
 
 class MemoriesTable:
-    def insert_new_memory(
+    async def insert_new_memory(
         self,
         user_id: str,
         content: str,
@@ -55,15 +55,15 @@ class MemoriesTable:
                 }
             )
             result = Memory(**memory.model_dump())
-            db.add(result)
-            db.commit()
-            db.refresh(result)
+            await db.add(result)
+            await db.commit()
+            await db.refresh(result)
             if result:
                 return MemoryModel.model_validate(result)
             else:
                 return None
 
-    def update_memory_by_id_and_user_id(
+    async def update_memory_by_id_and_user_id(
         self,
         id: str,
         user_id: str,
@@ -71,73 +71,73 @@ class MemoriesTable:
     ) -> Optional[MemoryModel]:
         async with get_db() as db:
             try:
-                memory = db.get(Memory, id)
+                memory = await db.get(Memory, id)
                 if not memory or memory.user_id != user_id:
                     return None
 
                 memory.content = content
                 memory.updated_at = int(time.time())
 
-                db.commit()
-                return self.get_memory_by_id(id)
+                await db.commit()
+                return await self.get_memory_by_id(id)
             except Exception:
                 return None
 
-    def get_memories(self) -> list[MemoryModel]:
+    async def get_memories(self) -> list[MemoryModel]:
         async with get_db() as db:
             try:
-                memories = db.query(Memory).all()
+                memories = await db.query(Memory).all()
                 return [MemoryModel.model_validate(memory) for memory in memories]
             except Exception:
                 return None
 
-    def get_memories_by_user_id(self, user_id: str) -> list[MemoryModel]:
+    async def get_memories_by_user_id(self, user_id: str) -> list[MemoryModel]:
         async with get_db() as db:
             try:
-                memories = db.query(Memory).filter_by(user_id=user_id).all()
+                memories = await db.query(Memory).filter_by(user_id=user_id).all()
                 return [MemoryModel.model_validate(memory) for memory in memories]
             except Exception:
                 return None
 
-    def get_memory_by_id(self, id: str) -> Optional[MemoryModel]:
+    async def get_memory_by_id(self, id: str) -> Optional[MemoryModel]:
         async with get_db() as db:
             try:
-                memory = db.get(Memory, id)
+                memory = await db.get(Memory, id)
                 return MemoryModel.model_validate(memory)
             except Exception:
                 return None
 
-    def delete_memory_by_id(self, id: str) -> bool:
+    async def delete_memory_by_id(self, id: str) -> bool:
         async with get_db() as db:
             try:
-                db.query(Memory).filter_by(id=id).delete()
-                db.commit()
+                await db.query(Memory).filter_by(id=id).delete()
+                await db.commit()
 
                 return True
 
             except Exception:
                 return False
 
-    def delete_memories_by_user_id(self, user_id: str) -> bool:
+    async def delete_memories_by_user_id(self, user_id: str) -> bool:
         async with get_db() as db:
             try:
-                db.query(Memory).filter_by(user_id=user_id).delete()
-                db.commit()
+                await db.query(Memory).filter_by(user_id=user_id).delete()
+                await db.commit()
 
                 return True
             except Exception:
                 return False
 
-    def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
+    async def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
         async with get_db() as db:
             try:
-                memory = db.get(Memory, id)
+                memory = await db.get(Memory, id)
                 if not memory or memory.user_id != user_id:
                     return None
 
                 # Delete the memory
-                db.delete(memory)
-                db.commit()
+                await db.delete(memory)
+                await db.commit()
 
                 return True
             except Exception:

+ 7 - 7
backend/open_webui/routers/memories.py

@@ -27,7 +27,7 @@ async def get_embeddings(request: Request):
 
 @router.get("/", response_model=list[MemoryModel])
 async def get_memories(user=Depends(get_verified_user)):
-    return Memories.get_memories_by_user_id(user.id)
+    return await Memories.get_memories_by_user_id(user.id)
 
 
 ############################
@@ -49,7 +49,7 @@ async def add_memory(
     form_data: AddMemoryForm,
     user=Depends(get_verified_user),
 ):
-    memory = Memories.insert_new_memory(user.id, form_data.content)
+    memory = await Memories.insert_new_memory(user.id, form_data.content)
 
     VECTOR_DB_CLIENT.upsert(
         collection_name=f"user-memory-{user.id}",
@@ -82,7 +82,7 @@ class QueryMemoryForm(BaseModel):
 async def query_memory(
     request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
 ):
-    memories = Memories.get_memories_by_user_id(user.id)
+    memories = await Memories.get_memories_by_user_id(user.id)
     if not memories:
         raise HTTPException(status_code=404, detail="No memories found for user")
 
@@ -104,7 +104,7 @@ async def reset_memory_from_vector_db(
 ):
     VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
 
-    memories = Memories.get_memories_by_user_id(user.id)
+    memories = await Memories.get_memories_by_user_id(user.id)
     VECTOR_DB_CLIENT.upsert(
         collection_name=f"user-memory-{user.id}",
         items=[
@@ -133,7 +133,7 @@ async def reset_memory_from_vector_db(
 
 @router.delete("/delete/user", response_model=bool)
 async def delete_memory_by_user_id(user=Depends(get_verified_user)):
-    result = Memories.delete_memories_by_user_id(user.id)
+    result = await Memories.delete_memories_by_user_id(user.id)
 
     if result:
         try:
@@ -157,7 +157,7 @@ async def update_memory_by_id(
     form_data: MemoryUpdateModel,
     user=Depends(get_verified_user),
 ):
-    memory = Memories.update_memory_by_id_and_user_id(
+    memory = await Memories.update_memory_by_id_and_user_id(
         memory_id, user.id, form_data.content
     )
     if memory is None:
@@ -191,7 +191,7 @@ async def update_memory_by_id(
 
 @router.delete("/{memory_id}", response_model=bool)
 async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)):
-    result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
+    result = await Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
 
     if result:
         VECTOR_DB_CLIENT.delete(