memories.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. from fastapi import APIRouter, Depends, HTTPException, Request
  2. from pydantic import BaseModel
  3. import logging
  4. from typing import Optional
  5. from open_webui.models.memories import Memories, MemoryModel
  6. from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
  7. from open_webui.utils.auth import get_verified_user
  8. from open_webui.env import SRC_LOG_LEVELS
  9. log = logging.getLogger(__name__)
  10. log.setLevel(SRC_LOG_LEVELS["MODELS"])
  11. router = APIRouter()
  12. @router.get("/ef")
  13. async def get_embeddings(request: Request):
  14. return {"result": request.app.state.EMBEDDING_FUNCTION("hello world")}
  15. ############################
  16. # GetMemories
  17. ############################
  18. @router.get("/", response_model=list[MemoryModel])
  19. async def get_memories(user=Depends(get_verified_user)):
  20. return Memories.get_memories_by_user_id(user.id)
  21. ############################
  22. # AddMemory
  23. ############################
  24. class AddMemoryForm(BaseModel):
  25. content: str
  26. class MemoryUpdateModel(BaseModel):
  27. content: Optional[str] = None
  28. @router.post("/add", response_model=Optional[MemoryModel])
  29. async def add_memory(
  30. request: Request,
  31. form_data: AddMemoryForm,
  32. user=Depends(get_verified_user),
  33. ):
  34. memory = Memories.insert_new_memory(user.id, form_data.content)
  35. VECTOR_DB_CLIENT.upsert(
  36. collection_name=f"user-memory-{user.id}",
  37. items=[
  38. {
  39. "id": memory.id,
  40. "text": memory.content,
  41. "vector": request.app.state.EMBEDDING_FUNCTION(
  42. memory.content, user=user
  43. ),
  44. "metadata": {"created_at": memory.created_at},
  45. }
  46. ],
  47. )
  48. return memory
  49. ############################
  50. # QueryMemory
  51. ############################
  52. class QueryMemoryForm(BaseModel):
  53. content: str
  54. k: Optional[int] = 1
  55. @router.post("/query")
  56. async def query_memory(
  57. request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
  58. ):
  59. memories = Memories.get_memories_by_user_id(user.id)
  60. if not memories:
  61. raise HTTPException(status_code=404, detail="No memories found for user")
  62. results = VECTOR_DB_CLIENT.search(
  63. collection_name=f"user-memory-{user.id}",
  64. vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user=user)],
  65. limit=form_data.k,
  66. )
  67. return results
  68. ############################
  69. # ResetMemoryFromVectorDB
  70. ############################
  71. @router.post("/reset", response_model=bool)
  72. async def reset_memory_from_vector_db(
  73. request: Request, user=Depends(get_verified_user)
  74. ):
  75. VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
  76. memories = Memories.get_memories_by_user_id(user.id)
  77. VECTOR_DB_CLIENT.upsert(
  78. collection_name=f"user-memory-{user.id}",
  79. items=[
  80. {
  81. "id": memory.id,
  82. "text": memory.content,
  83. "vector": request.app.state.EMBEDDING_FUNCTION(
  84. memory.content, user=user
  85. ),
  86. "metadata": {
  87. "created_at": memory.created_at,
  88. "updated_at": memory.updated_at,
  89. },
  90. }
  91. for memory in memories
  92. ],
  93. )
  94. return True
  95. ############################
  96. # DeleteMemoriesByUserId
  97. ############################
  98. @router.delete("/delete/user", response_model=bool)
  99. async def delete_memory_by_user_id(user=Depends(get_verified_user)):
  100. result = Memories.delete_memories_by_user_id(user.id)
  101. if result:
  102. try:
  103. VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
  104. except Exception as e:
  105. log.error(e)
  106. return True
  107. return False
  108. ############################
  109. # UpdateMemoryById
  110. ############################
  111. @router.post("/{memory_id}/update", response_model=Optional[MemoryModel])
  112. async def update_memory_by_id(
  113. memory_id: str,
  114. request: Request,
  115. form_data: MemoryUpdateModel,
  116. user=Depends(get_verified_user),
  117. ):
  118. memory = Memories.update_memory_by_id_and_user_id(
  119. memory_id, user.id, form_data.content
  120. )
  121. if memory is None:
  122. raise HTTPException(status_code=404, detail="Memory not found")
  123. if form_data.content is not None:
  124. VECTOR_DB_CLIENT.upsert(
  125. collection_name=f"user-memory-{user.id}",
  126. items=[
  127. {
  128. "id": memory.id,
  129. "text": memory.content,
  130. "vector": request.app.state.EMBEDDING_FUNCTION(
  131. memory.content, user=user
  132. ),
  133. "metadata": {
  134. "created_at": memory.created_at,
  135. "updated_at": memory.updated_at,
  136. },
  137. }
  138. ],
  139. )
  140. return memory
  141. ############################
  142. # DeleteMemoryById
  143. ############################
  144. @router.delete("/{memory_id}", response_model=bool)
  145. async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)):
  146. result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
  147. if result:
  148. VECTOR_DB_CLIENT.delete(
  149. collection_name=f"user-memory-{user.id}", ids=[memory_id]
  150. )
  151. return True
  152. return False