messages.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. import json
  2. import time
  3. import uuid
  4. from typing import Optional
  5. from open_webui.internal.db import Base, get_db
  6. from open_webui.models.tags import TagModel, Tag, Tags
  7. from pydantic import BaseModel, ConfigDict
  8. from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
  9. from sqlalchemy import or_, func, select, and_, text
  10. from sqlalchemy.sql import exists
  11. ####################
  12. # Message DB Schema
  13. ####################
  14. class Message(Base):
  15. __tablename__ = "message"
  16. id = Column(Text, primary_key=True)
  17. user_id = Column(Text)
  18. channel_id = Column(Text, nullable=True)
  19. content = Column(Text)
  20. data = Column(JSON, nullable=True)
  21. meta = Column(JSON, nullable=True)
  22. created_at = Column(BigInteger) # time_ns
  23. updated_at = Column(BigInteger) # time_ns
  24. class MessageModel(BaseModel):
  25. model_config = ConfigDict(from_attributes=True)
  26. id: str
  27. user_id: str
  28. channel_id: Optional[str] = None
  29. content: str
  30. data: Optional[dict] = None
  31. meta: Optional[dict] = None
  32. created_at: int # timestamp in epoch
  33. updated_at: int # timestamp in epoch
  34. ####################
  35. # Forms
  36. ####################
  37. class MessageForm(BaseModel):
  38. content: str
  39. data: Optional[dict] = None
  40. meta: Optional[dict] = None
  41. class MessageTable:
  42. def insert_new_message(
  43. self, form_data: MessageForm, channel_id: str, user_id: str
  44. ) -> Optional[MessageModel]:
  45. with get_db() as db:
  46. id = str(uuid.uuid4())
  47. message = MessageModel(
  48. **{
  49. "id": id,
  50. "user_id": user_id,
  51. "channel_id": channel_id,
  52. "content": form_data.content,
  53. "data": form_data.data,
  54. "meta": form_data.meta,
  55. "created_at": int(time.time_ns()),
  56. "updated_at": int(time.time_ns()),
  57. }
  58. )
  59. result = Message(**message.model_dump())
  60. db.add(result)
  61. db.commit()
  62. db.refresh(result)
  63. return MessageModel.model_validate(result) if result else None
  64. def get_message_by_id(self, id: str) -> Optional[MessageModel]:
  65. with get_db() as db:
  66. message = db.get(Message, id)
  67. return MessageModel.model_validate(message) if message else None
  68. def get_messages_by_channel_id(
  69. self, channel_id: str, skip: int = 0, limit: int = 50
  70. ) -> list[MessageModel]:
  71. with get_db() as db:
  72. all_messages = (
  73. db.query(Message)
  74. .filter_by(channel_id=channel_id)
  75. .order_by(Message.created_at.desc())
  76. .offset(skip)
  77. .limit(limit)
  78. .all()
  79. )
  80. return [MessageModel.model_validate(message) for message in all_messages]
  81. def get_messages_by_user_id(
  82. self, user_id: str, skip: int = 0, limit: int = 50
  83. ) -> list[MessageModel]:
  84. with get_db() as db:
  85. all_messages = (
  86. db.query(Message)
  87. .filter_by(user_id=user_id)
  88. .order_by(Message.created_at.desc())
  89. .offset(skip)
  90. .limit(limit)
  91. .all()
  92. )
  93. return [MessageModel.model_validate(message) for message in all_messages]
  94. def update_message_by_id(
  95. self, id: str, form_data: MessageForm
  96. ) -> Optional[MessageModel]:
  97. with get_db() as db:
  98. message = db.get(Message, id)
  99. message.content = form_data.content
  100. message.data = form_data.data
  101. message.meta = form_data.meta
  102. message.updated_at = int(time.time_ns())
  103. db.commit()
  104. db.refresh(message)
  105. return MessageModel.model_validate(message) if message else None
  106. def delete_message_by_id(self, id: str) -> bool:
  107. with get_db() as db:
  108. db.query(Message).filter_by(id=id).delete()
  109. db.commit()
  110. return True
  111. Messages = MessageTable()