memories.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. import time
  2. import uuid
  3. from typing import Optional
  4. from open_webui.internal.db import Base, get_db
  5. from pydantic import BaseModel, ConfigDict
  6. from sqlalchemy import BigInteger, Column, String, Text
  7. ####################
  8. # Memory DB Schema
  9. ####################
  10. class Memory(Base):
  11. __tablename__ = "memory"
  12. id = Column(String, primary_key=True)
  13. user_id = Column(String)
  14. content = Column(Text)
  15. updated_at = Column(BigInteger)
  16. created_at = Column(BigInteger)
  17. class MemoryModel(BaseModel):
  18. id: str
  19. user_id: str
  20. content: str
  21. updated_at: int # timestamp in epoch
  22. created_at: int # timestamp in epoch
  23. model_config = ConfigDict(from_attributes=True)
  24. ####################
  25. # Forms
  26. ####################
  27. class MemoriesTable:
  28. def insert_new_memory(
  29. self,
  30. user_id: str,
  31. content: str,
  32. ) -> Optional[MemoryModel]:
  33. with get_db() as db:
  34. id = str(uuid.uuid4())
  35. memory = MemoryModel(
  36. **{
  37. "id": id,
  38. "user_id": user_id,
  39. "content": content,
  40. "created_at": int(time.time()),
  41. "updated_at": int(time.time()),
  42. }
  43. )
  44. result = Memory(**memory.model_dump())
  45. db.add(result)
  46. db.commit()
  47. db.refresh(result)
  48. if result:
  49. return MemoryModel.model_validate(result)
  50. else:
  51. return None
  52. def update_memory_by_id_and_user_id(
  53. self,
  54. id: str,
  55. user_id: str,
  56. content: str,
  57. ) -> Optional[MemoryModel]:
  58. with get_db() as db:
  59. try:
  60. memory = db.get(Memory, id)
  61. if not memory or memory.user_id != user_id:
  62. return None
  63. memory.content = content
  64. memory.updated_at = int(time.time())
  65. db.commit()
  66. return self.get_memory_by_id(id)
  67. except Exception:
  68. return None
  69. def get_memories(self) -> list[MemoryModel]:
  70. with get_db() as db:
  71. try:
  72. memories = db.query(Memory).all()
  73. return [MemoryModel.model_validate(memory) for memory in memories]
  74. except Exception:
  75. return None
  76. def get_memories_by_user_id(self, user_id: str) -> list[MemoryModel]:
  77. with get_db() as db:
  78. try:
  79. memories = db.query(Memory).filter_by(user_id=user_id).all()
  80. return [MemoryModel.model_validate(memory) for memory in memories]
  81. except Exception:
  82. return None
  83. def get_memory_by_id(self, id: str) -> Optional[MemoryModel]:
  84. with get_db() as db:
  85. try:
  86. memory = db.get(Memory, id)
  87. return MemoryModel.model_validate(memory)
  88. except Exception:
  89. return None
  90. def delete_memory_by_id(self, id: str) -> bool:
  91. with get_db() as db:
  92. try:
  93. db.query(Memory).filter_by(id=id).delete()
  94. db.commit()
  95. return True
  96. except Exception:
  97. return False
  98. def delete_memories_by_user_id(self, user_id: str) -> bool:
  99. with get_db() as db:
  100. try:
  101. db.query(Memory).filter_by(user_id=user_id).delete()
  102. db.commit()
  103. return True
  104. except Exception:
  105. return False
  106. def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
  107. with get_db() as db:
  108. try:
  109. memory = db.get(Memory, id)
  110. if not memory or memory.user_id != user_id:
  111. return None
  112. # Delete the memory
  113. db.delete(memory)
  114. db.commit()
  115. return True
  116. except Exception:
  117. return False
  118. Memories = MemoriesTable()