messages.py 10 KB


  1. import json
  2. import time
  3. import uuid
  4. from typing import Optional
  5. from open_webui.internal.db import Base, get_db
  6. from open_webui.models.tags import TagModel, Tag, Tags
  7. from open_webui.models.users import Users, UserNameResponse
  8. from pydantic import BaseModel, ConfigDict
  9. from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
  10. from sqlalchemy import or_, func, select, and_, text
  11. from sqlalchemy.sql import exists
  12. ####################
  13. # Message DB Schema
  14. ####################
  15. class MessageReaction(Base):
  16. __tablename__ = "message_reaction"
  17. id = Column(Text, primary_key=True)
  18. user_id = Column(Text)
  19. message_id = Column(Text)
  20. name = Column(Text)
  21. created_at = Column(BigInteger)
  22. class MessageReactionModel(BaseModel):
  23. model_config = ConfigDict(from_attributes=True)
  24. id: str
  25. user_id: str
  26. message_id: str
  27. name: str
  28. created_at: int # timestamp in epoch
  29. class Message(Base):
  30. __tablename__ = "message"
  31. id = Column(Text, primary_key=True)
  32. user_id = Column(Text)
  33. channel_id = Column(Text, nullable=True)
  34. reply_to_id = Column(Text, nullable=True)
  35. parent_id = Column(Text, nullable=True)
  36. content = Column(Text)
  37. data = Column(JSON, nullable=True)
  38. meta = Column(JSON, nullable=True)
  39. created_at = Column(BigInteger) # time_ns
  40. updated_at = Column(BigInteger) # time_ns
  41. class MessageModel(BaseModel):
  42. model_config = ConfigDict(from_attributes=True)
  43. id: str
  44. user_id: str
  45. channel_id: Optional[str] = None
  46. reply_to_id: Optional[str] = None
  47. parent_id: Optional[str] = None
  48. content: str
  49. data: Optional[dict] = None
  50. meta: Optional[dict] = None
  51. created_at: int # timestamp in epoch
  52. updated_at: int # timestamp in epoch
  53. ####################
  54. # Forms
  55. ####################
  56. class MessageForm(BaseModel):
  57. content: str
  58. reply_to_id: Optional[str] = None
  59. parent_id: Optional[str] = None
  60. data: Optional[dict] = None
  61. meta: Optional[dict] = None
  62. class Reactions(BaseModel):
  63. name: str
  64. user_ids: list[str]
  65. count: int
  66. class MessageUserResponse(MessageModel):
  67. user: Optional[UserNameResponse] = None
  68. class MessageReplyToResponse(MessageUserResponse):
  69. reply_to_message: Optional[MessageUserResponse] = None
  70. class MessageResponse(MessageReplyToResponse):
  71. latest_reply_at: Optional[int]
  72. reply_count: int
  73. reactions: list[Reactions]
  74. class MessageTable:
  75. def insert_new_message(
  76. self, form_data: MessageForm, channel_id: str, user_id: str
  77. ) -> Optional[MessageModel]:
  78. with get_db() as db:
  79. id = str(uuid.uuid4())
  80. ts = int(time.time_ns())
  81. message = MessageModel(
  82. **{
  83. "id": id,
  84. "user_id": user_id,
  85. "channel_id": channel_id,
  86. "reply_to_id": form_data.reply_to_id,
  87. "parent_id": form_data.parent_id,
  88. "content": form_data.content,
  89. "data": form_data.data,
  90. "meta": form_data.meta,
  91. "created_at": ts,
  92. "updated_at": ts,
  93. }
  94. )
  95. result = Message(**message.model_dump())
  96. db.add(result)
  97. db.commit()
  98. db.refresh(result)
  99. return MessageModel.model_validate(result) if result else None
  100. def get_message_by_id(self, id: str) -> Optional[MessageReplyToResponse]:
  101. with get_db() as db:
  102. message = db.get(Message, id)
  103. if not message:
  104. return None
  105. reply_to_message = (
  106. self.get_message_by_id(message.reply_to_id)
  107. if message.reply_to_id
  108. else None
  109. )
  110. reactions = self.get_reactions_by_message_id(id)
  111. replies = self.get_thread_replies_by_message_id(id)
  112. user = Users.get_user_by_id(message.user_id)
  113. return MessageReplyToResponse.model_validate(
  114. {
  115. **MessageModel.model_validate(message).model_dump(),
  116. "user": user.model_dump() if user else None,
  117. "reply_to_message": (
  118. reply_to_message.model_dump() if reply_to_message else None
  119. ),
  120. "latest_reply_at": replies[0].created_at if replies else None,
  121. "reply_count": len(replies),
  122. "reactions": reactions,
  123. }
  124. )
  125. def get_thread_replies_by_message_id(self, id: str) -> list[MessageReplyToResponse]:
  126. with get_db() as db:
  127. all_messages = (
  128. db.query(Message)
  129. .filter_by(parent_id=id)
  130. .order_by(Message.created_at.desc())
  131. .all()
  132. )
  133. return [
  134. MessageReplyToResponse.model_validate(
  135. {
  136. **MessageModel.model_validate(message).model_dump(),
  137. "reply_to_message": (
  138. self.get_message_by_id(message.reply_to_id).model_dump()
  139. if message.reply_to_id
  140. else None
  141. ),
  142. }
  143. )
  144. for message in all_messages
  145. ]
  146. def get_reply_user_ids_by_message_id(self, id: str) -> list[str]:
  147. with get_db() as db:
  148. return [
  149. message.user_id
  150. for message in db.query(Message).filter_by(parent_id=id).all()
  151. ]
  152. def get_messages_by_channel_id(
  153. self, channel_id: str, skip: int = 0, limit: int = 50
  154. ) -> list[MessageReplyToResponse]:
  155. with get_db() as db:
  156. all_messages = (
  157. db.query(Message)
  158. .filter_by(channel_id=channel_id, parent_id=None)
  159. .order_by(Message.created_at.desc())
  160. .offset(skip)
  161. .limit(limit)
  162. .all()
  163. )
  164. return [
  165. MessageReplyToResponse.model_validate(
  166. {
  167. **MessageModel.model_validate(message).model_dump(),
  168. "reply_to_message": (
  169. self.get_message_by_id(message.reply_to_id).model_dump()
  170. if message.reply_to_id
  171. else None
  172. ),
  173. }
  174. )
  175. for message in all_messages
  176. ]
  177. def get_messages_by_parent_id(
  178. self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50
  179. ) -> list[MessageModel]:
  180. with get_db() as db:
  181. message = db.get(Message, parent_id)
  182. if not message:
  183. return []
  184. all_messages = (
  185. db.query(Message)
  186. .filter_by(channel_id=channel_id, parent_id=parent_id)
  187. .order_by(Message.created_at.desc())
  188. .offset(skip)
  189. .limit(limit)
  190. .all()
  191. )
  192. # If length of all_messages is less than limit, then add the parent message
  193. if len(all_messages) < limit:
  194. all_messages.append(message)
  195. return [MessageModel.model_validate(message) for message in all_messages]
  196. def update_message_by_id(
  197. self, id: str, form_data: MessageForm
  198. ) -> Optional[MessageModel]:
  199. with get_db() as db:
  200. message = db.get(Message, id)
  201. message.content = form_data.content
  202. message.data = {
  203. **(message.data if message.data else {}),
  204. **(form_data.data if form_data.data else {}),
  205. }
  206. message.meta = {
  207. **(message.meta if message.meta else {}),
  208. **(form_data.meta if form_data.meta else {}),
  209. }
  210. message.updated_at = int(time.time_ns())
  211. db.commit()
  212. db.refresh(message)
  213. return MessageModel.model_validate(message) if message else None
  214. def add_reaction_to_message(
  215. self, id: str, user_id: str, name: str
  216. ) -> Optional[MessageReactionModel]:
  217. with get_db() as db:
  218. reaction_id = str(uuid.uuid4())
  219. reaction = MessageReactionModel(
  220. id=reaction_id,
  221. user_id=user_id,
  222. message_id=id,
  223. name=name,
  224. created_at=int(time.time_ns()),
  225. )
  226. result = MessageReaction(**reaction.model_dump())
  227. db.add(result)
  228. db.commit()
  229. db.refresh(result)
  230. return MessageReactionModel.model_validate(result) if result else None
  231. def get_reactions_by_message_id(self, id: str) -> list[Reactions]:
  232. with get_db() as db:
  233. all_reactions = db.query(MessageReaction).filter_by(message_id=id).all()
  234. reactions = {}
  235. for reaction in all_reactions:
  236. if reaction.name not in reactions:
  237. reactions[reaction.name] = {
  238. "name": reaction.name,
  239. "user_ids": [],
  240. "count": 0,
  241. }
  242. reactions[reaction.name]["user_ids"].append(reaction.user_id)
  243. reactions[reaction.name]["count"] += 1
  244. return [Reactions(**reaction) for reaction in reactions.values()]
  245. def remove_reaction_by_id_and_user_id_and_name(
  246. self, id: str, user_id: str, name: str
  247. ) -> bool:
  248. with get_db() as db:
  249. db.query(MessageReaction).filter_by(
  250. message_id=id, user_id=user_id, name=name
  251. ).delete()
  252. db.commit()
  253. return True
  254. def delete_reactions_by_id(self, id: str) -> bool:
  255. with get_db() as db:
  256. db.query(MessageReaction).filter_by(message_id=id).delete()
  257. db.commit()
  258. return True
  259. def delete_replies_by_id(self, id: str) -> bool:
  260. with get_db() as db:
  261. db.query(Message).filter_by(parent_id=id).delete()
  262. db.commit()
  263. return True
  264. def delete_message_by_id(self, id: str) -> bool:
  265. with get_db() as db:
  266. db.query(Message).filter_by(id=id).delete()
  267. # Delete all reactions to this message
  268. db.query(MessageReaction).filter_by(message_id=id).delete()
  269. db.commit()
  270. return True
  271. Messages = MessageTable()