messages.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. import json
  2. import time
  3. import uuid
  4. from typing import Optional
  5. from open_webui.internal.db import Base, get_db
  6. from open_webui.models.tags import TagModel, Tag, Tags
  7. from open_webui.models.users import Users, UserNameResponse
  8. from pydantic import BaseModel, ConfigDict
  9. from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
  10. from sqlalchemy import or_, func, select, and_, text
  11. from sqlalchemy.sql import exists
  12. ####################
  13. # Message DB Schema
  14. ####################
  15. class MessageReaction(Base):
  16. __tablename__ = "message_reaction"
  17. id = Column(Text, primary_key=True)
  18. user_id = Column(Text)
  19. message_id = Column(Text)
  20. name = Column(Text)
  21. created_at = Column(BigInteger)
  22. class MessageReactionModel(BaseModel):
  23. model_config = ConfigDict(from_attributes=True)
  24. id: str
  25. user_id: str
  26. message_id: str
  27. name: str
  28. created_at: int # timestamp in epoch
  29. class Message(Base):
  30. __tablename__ = "message"
  31. id = Column(Text, primary_key=True)
  32. user_id = Column(Text)
  33. channel_id = Column(Text, nullable=True)
  34. reply_to_id = Column(Text, nullable=True)
  35. parent_id = Column(Text, nullable=True)
  36. content = Column(Text)
  37. data = Column(JSON, nullable=True)
  38. meta = Column(JSON, nullable=True)
  39. created_at = Column(BigInteger) # time_ns
  40. updated_at = Column(BigInteger) # time_ns
  41. class MessageModel(BaseModel):
  42. model_config = ConfigDict(from_attributes=True)
  43. id: str
  44. user_id: str
  45. channel_id: Optional[str] = None
  46. reply_to_id: Optional[str] = None
  47. parent_id: Optional[str] = None
  48. content: str
  49. data: Optional[dict] = None
  50. meta: Optional[dict] = None
  51. created_at: int # timestamp in epoch
  52. updated_at: int # timestamp in epoch
  53. ####################
  54. # Forms
  55. ####################
  56. class MessageForm(BaseModel):
  57. content: str
  58. reply_to_id: Optional[str] = None
  59. parent_id: Optional[str] = None
  60. data: Optional[dict] = None
  61. meta: Optional[dict] = None
  62. class Reactions(BaseModel):
  63. name: str
  64. user_ids: list[str]
  65. count: int
  66. class MessageUserResponse(MessageModel):
  67. user: Optional[UserNameResponse] = None
  68. class MessageReplyToResponse(MessageUserResponse):
  69. reply_to_message: Optional[MessageUserResponse] = None
  70. class MessageResponse(MessageReplyToResponse):
  71. latest_reply_at: Optional[int]
  72. reply_count: int
  73. reactions: list[Reactions]
  74. class MessageTable:
  75. def insert_new_message(
  76. self, form_data: MessageForm, channel_id: str, user_id: str
  77. ) -> Optional[MessageModel]:
  78. with get_db() as db:
  79. id = str(uuid.uuid4())
  80. ts = int(time.time_ns())
  81. message = MessageModel(
  82. **{
  83. "id": id,
  84. "user_id": user_id,
  85. "channel_id": channel_id,
  86. "reply_to_id": form_data.reply_to_id,
  87. "parent_id": form_data.parent_id,
  88. "content": form_data.content,
  89. "data": form_data.data,
  90. "meta": form_data.meta,
  91. "created_at": ts,
  92. "updated_at": ts,
  93. }
  94. )
  95. result = Message(**message.model_dump())
  96. db.add(result)
  97. db.commit()
  98. db.refresh(result)
  99. return MessageModel.model_validate(result) if result else None
  100. def get_message_by_id(self, id: str) -> Optional[MessageResponse]:
  101. with get_db() as db:
  102. message = db.get(Message, id)
  103. if not message:
  104. return None
  105. reply_to_message = (
  106. self.get_message_by_id(message.reply_to_id)
  107. if message.reply_to_id
  108. else None
  109. )
  110. reactions = self.get_reactions_by_message_id(id)
  111. thread_replies = self.get_thread_replies_by_message_id(id)
  112. user = Users.get_user_by_id(message.user_id)
  113. return MessageResponse.model_validate(
  114. {
  115. **MessageModel.model_validate(message).model_dump(),
  116. "user": user.model_dump() if user else None,
  117. "reply_to_message": (
  118. reply_to_message.model_dump() if reply_to_message else None
  119. ),
  120. "latest_reply_at": (
  121. thread_replies[0].created_at if thread_replies else None
  122. ),
  123. "reply_count": len(thread_replies),
  124. "reactions": reactions,
  125. }
  126. )
  127. def get_thread_replies_by_message_id(self, id: str) -> list[MessageReplyToResponse]:
  128. with get_db() as db:
  129. all_messages = (
  130. db.query(Message)
  131. .filter_by(parent_id=id)
  132. .order_by(Message.created_at.desc())
  133. .all()
  134. )
  135. return [
  136. MessageReplyToResponse.model_validate(
  137. {
  138. **MessageModel.model_validate(message).model_dump(),
  139. "reply_to_message": (
  140. self.get_message_by_id(message.reply_to_id).model_dump()
  141. if message.reply_to_id
  142. else None
  143. ),
  144. }
  145. )
  146. for message in all_messages
  147. ]
  148. def get_reply_user_ids_by_message_id(self, id: str) -> list[str]:
  149. with get_db() as db:
  150. return [
  151. message.user_id
  152. for message in db.query(Message).filter_by(parent_id=id).all()
  153. ]
  154. def get_messages_by_channel_id(
  155. self, channel_id: str, skip: int = 0, limit: int = 50
  156. ) -> list[MessageReplyToResponse]:
  157. with get_db() as db:
  158. all_messages = (
  159. db.query(Message)
  160. .filter_by(channel_id=channel_id, parent_id=None)
  161. .order_by(Message.created_at.desc())
  162. .offset(skip)
  163. .limit(limit)
  164. .all()
  165. )
  166. return [
  167. MessageReplyToResponse.model_validate(
  168. {
  169. **MessageModel.model_validate(message).model_dump(),
  170. "reply_to_message": (
  171. self.get_message_by_id(message.reply_to_id).model_dump()
  172. if message.reply_to_id
  173. else None
  174. ),
  175. }
  176. )
  177. for message in all_messages
  178. ]
  179. def get_messages_by_parent_id(
  180. self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50
  181. ) -> list[MessageReplyToResponse]:
  182. with get_db() as db:
  183. message = db.get(Message, parent_id)
  184. if not message:
  185. return []
  186. all_messages = (
  187. db.query(Message)
  188. .filter_by(channel_id=channel_id, parent_id=parent_id)
  189. .order_by(Message.created_at.desc())
  190. .offset(skip)
  191. .limit(limit)
  192. .all()
  193. )
  194. # If length of all_messages is less than limit, then add the parent message
  195. if len(all_messages) < limit:
  196. all_messages.append(message)
  197. return [
  198. MessageReplyToResponse.model_validate(
  199. {
  200. **MessageModel.model_validate(message).model_dump(),
  201. "reply_to_message": (
  202. self.get_message_by_id(message.reply_to_id).model_dump()
  203. if message.reply_to_id
  204. else None
  205. ),
  206. }
  207. )
  208. for message in all_messages
  209. ]
  210. def update_message_by_id(
  211. self, id: str, form_data: MessageForm
  212. ) -> Optional[MessageModel]:
  213. with get_db() as db:
  214. message = db.get(Message, id)
  215. message.content = form_data.content
  216. message.data = {
  217. **(message.data if message.data else {}),
  218. **(form_data.data if form_data.data else {}),
  219. }
  220. message.meta = {
  221. **(message.meta if message.meta else {}),
  222. **(form_data.meta if form_data.meta else {}),
  223. }
  224. message.updated_at = int(time.time_ns())
  225. db.commit()
  226. db.refresh(message)
  227. return MessageModel.model_validate(message) if message else None
  228. def add_reaction_to_message(
  229. self, id: str, user_id: str, name: str
  230. ) -> Optional[MessageReactionModel]:
  231. with get_db() as db:
  232. reaction_id = str(uuid.uuid4())
  233. reaction = MessageReactionModel(
  234. id=reaction_id,
  235. user_id=user_id,
  236. message_id=id,
  237. name=name,
  238. created_at=int(time.time_ns()),
  239. )
  240. result = MessageReaction(**reaction.model_dump())
  241. db.add(result)
  242. db.commit()
  243. db.refresh(result)
  244. return MessageReactionModel.model_validate(result) if result else None
  245. def get_reactions_by_message_id(self, id: str) -> list[Reactions]:
  246. with get_db() as db:
  247. all_reactions = db.query(MessageReaction).filter_by(message_id=id).all()
  248. reactions = {}
  249. for reaction in all_reactions:
  250. if reaction.name not in reactions:
  251. reactions[reaction.name] = {
  252. "name": reaction.name,
  253. "user_ids": [],
  254. "count": 0,
  255. }
  256. reactions[reaction.name]["user_ids"].append(reaction.user_id)
  257. reactions[reaction.name]["count"] += 1
  258. return [Reactions(**reaction) for reaction in reactions.values()]
  259. def remove_reaction_by_id_and_user_id_and_name(
  260. self, id: str, user_id: str, name: str
  261. ) -> bool:
  262. with get_db() as db:
  263. db.query(MessageReaction).filter_by(
  264. message_id=id, user_id=user_id, name=name
  265. ).delete()
  266. db.commit()
  267. return True
  268. def delete_reactions_by_id(self, id: str) -> bool:
  269. with get_db() as db:
  270. db.query(MessageReaction).filter_by(message_id=id).delete()
  271. db.commit()
  272. return True
  273. def delete_replies_by_id(self, id: str) -> bool:
  274. with get_db() as db:
  275. db.query(Message).filter_by(parent_id=id).delete()
  276. db.commit()
  277. return True
  278. def delete_message_by_id(self, id: str) -> bool:
  279. with get_db() as db:
  280. db.query(Message).filter_by(id=id).delete()
  281. # Delete all reactions to this message
  282. db.query(MessageReaction).filter_by(message_id=id).delete()
  283. db.commit()
  284. return True
  285. Messages = MessageTable()