memories.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. import time
  2. import uuid
  3. from typing import Optional
  4. from open_webui.internal.db import Base, get_db
  5. from pydantic import BaseModel, ConfigDict
  6. from sqlalchemy import BigInteger, Column, String, Text
  7. ####################
  8. # Memory DB Schema
  9. ####################
  10. class Memory(Base):
  11. __tablename__ = "memory"
  12. id = Column(String, primary_key=True)
  13. user_id = Column(String)
  14. content = Column(Text)
  15. updated_at = Column(BigInteger)
  16. created_at = Column(BigInteger)
  17. class MemoryModel(BaseModel):
  18. id: str
  19. user_id: str
  20. content: str
  21. updated_at: int # timestamp in epoch
  22. created_at: int # timestamp in epoch
  23. model_config = ConfigDict(from_attributes=True)
  24. ####################
  25. # Forms
  26. ####################
  27. class MemoriesTable:
  28. def insert_new_memory(
  29. self,
  30. user_id: str,
  31. content: str,
  32. ) -> Optional[MemoryModel]:
  33. with get_db() as db:
  34. id = str(uuid.uuid4())
  35. memory = MemoryModel(
  36. **{
  37. "id": id,
  38. "user_id": user_id,
  39. "content": content,
  40. "created_at": int(time.time()),
  41. "updated_at": int(time.time()),
  42. }
  43. )
  44. result = Memory(**memory.model_dump())
  45. db.add(result)
  46. db.commit()
  47. db.refresh(result)
  48. if result:
  49. return MemoryModel.model_validate(result)
  50. else:
  51. return None
  52. def update_memory_by_id_and_user_id(
  53. self,
  54. id: str,
  55. user_id: str,
  56. content: str,
  57. ) -> Optional[MemoryModel]:
  58. with get_db() as db:
  59. try:
  60. db.query(Memory).filter_by(id=id, user_id=user_id).update(
  61. {"content": content, "updated_at": int(time.time())}
  62. )
  63. db.commit()
  64. return self.get_memory_by_id(id)
  65. except Exception:
  66. return None
  67. def get_memories(self) -> list[MemoryModel]:
  68. with get_db() as db:
  69. try:
  70. memories = db.query(Memory).all()
  71. return [MemoryModel.model_validate(memory) for memory in memories]
  72. except Exception:
  73. return None
  74. def get_memories_by_user_id(self, user_id: str) -> list[MemoryModel]:
  75. with get_db() as db:
  76. try:
  77. memories = db.query(Memory).filter_by(user_id=user_id).all()
  78. return [MemoryModel.model_validate(memory) for memory in memories]
  79. except Exception:
  80. return None
  81. def get_memory_by_id(self, id: str) -> Optional[MemoryModel]:
  82. with get_db() as db:
  83. try:
  84. memory = db.get(Memory, id)
  85. return MemoryModel.model_validate(memory)
  86. except Exception:
  87. return None
  88. def delete_memory_by_id(self, id: str) -> bool:
  89. with get_db() as db:
  90. try:
  91. db.query(Memory).filter_by(id=id).delete()
  92. db.commit()
  93. return True
  94. except Exception:
  95. return False
  96. def delete_memories_by_user_id(self, user_id: str) -> bool:
  97. with get_db() as db:
  98. try:
  99. db.query(Memory).filter_by(user_id=user_id).delete()
  100. db.commit()
  101. return True
  102. except Exception:
  103. return False
  104. def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
  105. with get_db() as db:
  106. try:
  107. db.query(Memory).filter_by(id=id, user_id=user_id).delete()
  108. db.commit()
  109. return True
  110. except Exception:
  111. return False
  112. Memories = MemoriesTable()