memories.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. from fastapi import Response, Request
  2. from fastapi import Depends, FastAPI, HTTPException, status
  3. from datetime import datetime, timedelta
  4. from typing import List, Union, Optional
  5. from fastapi import APIRouter
  6. from pydantic import BaseModel
  7. import logging
  8. from apps.webui.models.memories import Memories, MemoryModel
  9. from utils.utils import get_verified_user
  10. from constants import ERROR_MESSAGES
  11. from config import SRC_LOG_LEVELS, CHROMA_CLIENT
  12. log = logging.getLogger(__name__)
  13. log.setLevel(SRC_LOG_LEVELS["MODELS"])
  14. router = APIRouter()
  15. @router.get("/ef")
  16. async def get_embeddings(request: Request):
  17. return {"result": request.app.state.EMBEDDING_FUNCTION("hello world")}
  18. ############################
  19. # GetMemories
  20. ############################
  21. @router.get("/", response_model=List[MemoryModel])
  22. async def get_memories(user=Depends(get_verified_user)):
  23. return Memories.get_memories_by_user_id(user.id)
  24. ############################
  25. # AddMemory
  26. ############################
  27. class AddMemoryForm(BaseModel):
  28. content: str
  29. class MemoryUpdateModel(BaseModel):
  30. content: Optional[str] = None
  31. @router.post("/add", response_model=Optional[MemoryModel])
  32. async def add_memory(
  33. request: Request, form_data: AddMemoryForm, user=Depends(get_verified_user)
  34. ):
  35. memory = Memories.insert_new_memory(user.id, form_data.content)
  36. memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
  37. collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
  38. collection.upsert(
  39. documents=[memory.content],
  40. ids=[memory.id],
  41. embeddings=[memory_embedding],
  42. metadatas=[{"created_at": memory.created_at}],
  43. )
  44. return memory
  45. @router.post("/{memory_id}/update", response_model=Optional[MemoryModel])
  46. async def update_memory_by_id(
  47. memory_id: str,
  48. request: Request,
  49. form_data: MemoryUpdateModel,
  50. user=Depends(get_verified_user),
  51. ):
  52. memory = Memories.update_memory_by_id(memory_id, form_data.content)
  53. if memory is None:
  54. raise HTTPException(status_code=404, detail="Memory not found")
  55. if form_data.content is not None:
  56. memory_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content)
  57. collection = CHROMA_CLIENT.get_or_create_collection(
  58. name=f"user-memory-{user.id}"
  59. )
  60. collection.upsert(
  61. documents=[form_data.content],
  62. ids=[memory.id],
  63. embeddings=[memory_embedding],
  64. metadatas=[
  65. {"created_at": memory.created_at, "updated_at": memory.updated_at}
  66. ],
  67. )
  68. return memory
  69. ############################
  70. # QueryMemory
  71. ############################
  72. class QueryMemoryForm(BaseModel):
  73. content: str
  74. k: Optional[int] = 1
  75. @router.post("/query")
  76. async def query_memory(
  77. request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
  78. ):
  79. query_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content)
  80. collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
  81. results = collection.query(
  82. query_embeddings=[query_embedding],
  83. n_results=form_data.k, # how many results to return
  84. )
  85. return results
  86. ############################
  87. # ResetMemoryFromVectorDB
  88. ############################
  89. @router.get("/reset", response_model=bool)
  90. async def reset_memory_from_vector_db(
  91. request: Request, user=Depends(get_verified_user)
  92. ):
  93. CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}")
  94. collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
  95. memories = Memories.get_memories_by_user_id(user.id)
  96. for memory in memories:
  97. memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
  98. collection.upsert(
  99. documents=[memory.content],
  100. ids=[memory.id],
  101. embeddings=[memory_embedding],
  102. )
  103. return True
  104. ############################
  105. # DeleteMemoriesByUserId
  106. ############################
  107. @router.delete("/user", response_model=bool)
  108. async def delete_memory_by_user_id(user=Depends(get_verified_user)):
  109. result = Memories.delete_memories_by_user_id(user.id)
  110. if result:
  111. try:
  112. CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}")
  113. except Exception as e:
  114. log.error(e)
  115. return True
  116. return False
  117. ############################
  118. # DeleteMemoryById
  119. ############################
  120. @router.delete("/{memory_id}", response_model=bool)
  121. async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)):
  122. result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
  123. if result:
  124. collection = CHROMA_CLIENT.get_or_create_collection(
  125. name=f"user-memory-{user.id}"
  126. )
  127. collection.delete(ids=[memory_id])
  128. return True
  129. return False