memories.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. from fastapi import APIRouter, Depends, HTTPException, Request
  2. from pydantic import BaseModel
  3. import logging
  4. from typing import Optional
  5. from open_webui.models.memories import Memories, MemoryModel
  6. from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
  7. from open_webui.utils.auth import get_verified_user
  8. from open_webui.env import SRC_LOG_LEVELS
  9. log = logging.getLogger(__name__)
  10. log.setLevel(SRC_LOG_LEVELS["MODELS"])
  11. router = APIRouter()
  12. @router.get("/ef")
  13. async def get_embeddings(request: Request):
  14. return {"result": request.app.state.EMBEDDING_FUNCTION("hello world")}
  15. ############################
  16. # GetMemories
  17. ############################
  18. @router.get("/", response_model=list[MemoryModel])
  19. async def get_memories(user=Depends(get_verified_user)):
  20. return Memories.get_memories_by_user_id(user.id)
  21. ############################
  22. # AddMemory
  23. ############################
  24. class AddMemoryForm(BaseModel):
  25. content: str
  26. class MemoryUpdateModel(BaseModel):
  27. content: Optional[str] = None
  28. @router.post("/add", response_model=Optional[MemoryModel])
  29. async def add_memory(
  30. request: Request,
  31. form_data: AddMemoryForm,
  32. user=Depends(get_verified_user),
  33. ):
  34. memory = Memories.insert_new_memory(user.id, form_data.content)
  35. VECTOR_DB_CLIENT.upsert(
  36. collection_name=f"user-memory-{user.id}",
  37. items=[
  38. {
  39. "id": memory.id,
  40. "text": memory.content,
  41. "vector": request.app.state.EMBEDDING_FUNCTION(
  42. memory.content, user=user
  43. ),
  44. "metadata": {"created_at": memory.created_at},
  45. }
  46. ],
  47. )
  48. return memory
  49. ############################
  50. # QueryMemory
  51. ############################
  52. class QueryMemoryForm(BaseModel):
  53. content: str
  54. k: Optional[int] = 1
  55. @router.post("/query")
  56. async def query_memory(
  57. request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
  58. ):
  59. results = VECTOR_DB_CLIENT.search(
  60. collection_name=f"user-memory-{user.id}",
  61. vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user=user)],
  62. limit=form_data.k,
  63. )
  64. return results
  65. ############################
  66. # ResetMemoryFromVectorDB
  67. ############################
  68. @router.post("/reset", response_model=bool)
  69. async def reset_memory_from_vector_db(
  70. request: Request, user=Depends(get_verified_user)
  71. ):
  72. VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
  73. memories = Memories.get_memories_by_user_id(user.id)
  74. VECTOR_DB_CLIENT.upsert(
  75. collection_name=f"user-memory-{user.id}",
  76. items=[
  77. {
  78. "id": memory.id,
  79. "text": memory.content,
  80. "vector": request.app.state.EMBEDDING_FUNCTION(
  81. memory.content, user=user
  82. ),
  83. "metadata": {
  84. "created_at": memory.created_at,
  85. "updated_at": memory.updated_at,
  86. },
  87. }
  88. for memory in memories
  89. ],
  90. )
  91. return True
  92. ############################
  93. # DeleteMemoriesByUserId
  94. ############################
  95. @router.delete("/delete/user", response_model=bool)
  96. async def delete_memory_by_user_id(user=Depends(get_verified_user)):
  97. result = Memories.delete_memories_by_user_id(user.id)
  98. if result:
  99. try:
  100. VECTOR_DB_CLIENT.delete_collection(f"user-memory-{user.id}")
  101. except Exception as e:
  102. log.error(e)
  103. return True
  104. return False
  105. ############################
  106. # UpdateMemoryById
  107. ############################
  108. @router.post("/{memory_id}/update", response_model=Optional[MemoryModel])
  109. async def update_memory_by_id(
  110. memory_id: str,
  111. request: Request,
  112. form_data: MemoryUpdateModel,
  113. user=Depends(get_verified_user),
  114. ):
  115. memory = Memories.update_memory_by_id_and_user_id(
  116. memory_id, user.id, form_data.content
  117. )
  118. if memory is None:
  119. raise HTTPException(status_code=404, detail="Memory not found")
  120. if form_data.content is not None:
  121. VECTOR_DB_CLIENT.upsert(
  122. collection_name=f"user-memory-{user.id}",
  123. items=[
  124. {
  125. "id": memory.id,
  126. "text": memory.content,
  127. "vector": request.app.state.EMBEDDING_FUNCTION(
  128. memory.content, user=user
  129. ),
  130. "metadata": {
  131. "created_at": memory.created_at,
  132. "updated_at": memory.updated_at,
  133. },
  134. }
  135. ],
  136. )
  137. return memory
  138. ############################
  139. # DeleteMemoryById
  140. ############################
  141. @router.delete("/{memory_id}", response_model=bool)
  142. async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)):
  143. result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
  144. if result:
  145. VECTOR_DB_CLIENT.delete(
  146. collection_name=f"user-memory-{user.id}", ids=[memory_id]
  147. )
  148. return True
  149. return False