messages.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. import json
  2. import time
  3. import uuid
  4. from typing import Optional
  5. from open_webui.internal.db import Base, get_db
  6. from open_webui.models.tags import TagModel, Tag, Tags
  7. from open_webui.models.users import Users, UserNameResponse
  8. from pydantic import BaseModel, ConfigDict
  9. from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
  10. from sqlalchemy import or_, func, select, and_, text
  11. from sqlalchemy.sql import exists
  12. ####################
  13. # Message DB Schema
  14. ####################
  15. class MessageReaction(Base):
  16. __tablename__ = "message_reaction"
  17. id = Column(Text, primary_key=True)
  18. user_id = Column(Text)
  19. message_id = Column(Text)
  20. name = Column(Text)
  21. created_at = Column(BigInteger)
  22. class MessageReactionModel(BaseModel):
  23. model_config = ConfigDict(from_attributes=True)
  24. id: str
  25. user_id: str
  26. message_id: str
  27. name: str
  28. created_at: int # timestamp in epoch
  29. class Message(Base):
  30. __tablename__ = "message"
  31. id = Column(Text, primary_key=True)
  32. user_id = Column(Text)
  33. channel_id = Column(Text, nullable=True)
  34. reply_to_id = Column(Text, nullable=True)
  35. parent_id = Column(Text, nullable=True)
  36. content = Column(Text)
  37. data = Column(JSON, nullable=True)
  38. meta = Column(JSON, nullable=True)
  39. created_at = Column(BigInteger) # time_ns
  40. updated_at = Column(BigInteger) # time_ns
  41. class MessageModel(BaseModel):
  42. model_config = ConfigDict(from_attributes=True)
  43. id: str
  44. user_id: str
  45. channel_id: Optional[str] = None
  46. reply_to_id: Optional[str] = None
  47. parent_id: Optional[str] = None
  48. content: str
  49. data: Optional[dict] = None
  50. meta: Optional[dict] = None
  51. created_at: int # timestamp in epoch
  52. updated_at: int # timestamp in epoch
  53. ####################
  54. # Forms
  55. ####################
  56. class MessageForm(BaseModel):
  57. content: str
  58. reply_to_id: Optional[str] = None
  59. parent_id: Optional[str] = None
  60. data: Optional[dict] = None
  61. meta: Optional[dict] = None
  62. class Reactions(BaseModel):
  63. name: str
  64. user_ids: list[str]
  65. count: int
  66. class MessageUserResponse(MessageModel):
  67. user: Optional[UserNameResponse] = None
  68. class MessageReplyToResponse(MessageUserResponse):
  69. reply_to_message: Optional[MessageUserResponse] = None
  70. class MessageResponse(MessageReplyToResponse):
  71. latest_reply_at: Optional[int]
  72. reply_count: int
  73. reactions: list[Reactions]
  74. class MessageTable:
  75. def insert_new_message(
  76. self, form_data: MessageForm, channel_id: str, user_id: str
  77. ) -> Optional[MessageModel]:
  78. with get_db() as db:
  79. id = str(uuid.uuid4())
  80. ts = int(time.time_ns())
  81. message = MessageModel(
  82. **{
  83. "id": id,
  84. "user_id": user_id,
  85. "channel_id": channel_id,
  86. "reply_to_id": form_data.reply_to_id,
  87. "parent_id": form_data.parent_id,
  88. "content": form_data.content,
  89. "data": form_data.data,
  90. "meta": form_data.meta,
  91. "created_at": ts,
  92. "updated_at": ts,
  93. }
  94. )
  95. result = Message(**message.model_dump())
  96. db.add(result)
  97. db.commit()
  98. db.refresh(result)
  99. return MessageModel.model_validate(result) if result else None
  100. def get_message_by_id(self, id: str) -> Optional[MessageResponse]:
  101. with get_db() as db:
  102. message = db.get(Message, id)
  103. if not message:
  104. return None
  105. reply_to_message = (
  106. self.get_message_by_id(message.reply_to_id)
  107. if message.reply_to_id
  108. else None
  109. )
  110. reactions = self.get_reactions_by_message_id(id)
  111. thread_replies = self.get_thread_replies_by_message_id(id)
  112. user = Users.get_user_by_id(message.user_id)
  113. return MessageResponse.model_validate(
  114. {
  115. **MessageModel.model_validate(message).model_dump(),
  116. "user": user.model_dump() if user else None,
  117. "reply_to_message": (
  118. reply_to_message.model_dump() if reply_to_message else None
  119. ),
  120. "latest_reply_at": (
  121. thread_replies[0].created_at if thread_replies else None
  122. ),
  123. "reply_count": len(thread_replies),
  124. "reactions": reactions,
  125. }
  126. )
  127. def get_thread_replies_by_message_id(self, id: str) -> list[MessageReplyToResponse]:
  128. with get_db() as db:
  129. all_messages = (
  130. db.query(Message)
  131. .filter_by(parent_id=id)
  132. .order_by(Message.created_at.desc())
  133. .all()
  134. )
  135. messages = []
  136. for message in all_messages:
  137. reply_to_message = (
  138. self.get_message_by_id(message.reply_to_id)
  139. if message.reply_to_id
  140. else None
  141. )
  142. messages.append(
  143. MessageReplyToResponse.model_validate(
  144. {
  145. **MessageModel.model_validate(message).model_dump(),
  146. "reply_to_message": (
  147. reply_to_message.model_dump()
  148. if reply_to_message
  149. else None
  150. ),
  151. }
  152. )
  153. )
  154. return messages
  155. def get_reply_user_ids_by_message_id(self, id: str) -> list[str]:
  156. with get_db() as db:
  157. return [
  158. message.user_id
  159. for message in db.query(Message).filter_by(parent_id=id).all()
  160. ]
  161. def get_messages_by_channel_id(
  162. self, channel_id: str, skip: int = 0, limit: int = 50
  163. ) -> list[MessageReplyToResponse]:
  164. with get_db() as db:
  165. all_messages = (
  166. db.query(Message)
  167. .filter_by(channel_id=channel_id, parent_id=None)
  168. .order_by(Message.created_at.desc())
  169. .offset(skip)
  170. .limit(limit)
  171. .all()
  172. )
  173. messages = []
  174. for message in all_messages:
  175. reply_to_message = (
  176. self.get_message_by_id(message.reply_to_id)
  177. if message.reply_to_id
  178. else None
  179. )
  180. messages.append(
  181. MessageReplyToResponse.model_validate(
  182. {
  183. **MessageModel.model_validate(message).model_dump(),
  184. "reply_to_message": (
  185. reply_to_message.model_dump()
  186. if reply_to_message
  187. else None
  188. ),
  189. }
  190. )
  191. )
  192. return messages
  193. def get_messages_by_parent_id(
  194. self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50
  195. ) -> list[MessageReplyToResponse]:
  196. with get_db() as db:
  197. message = db.get(Message, parent_id)
  198. if not message:
  199. return []
  200. all_messages = (
  201. db.query(Message)
  202. .filter_by(channel_id=channel_id, parent_id=parent_id)
  203. .order_by(Message.created_at.desc())
  204. .offset(skip)
  205. .limit(limit)
  206. .all()
  207. )
  208. # If length of all_messages is less than limit, then add the parent message
  209. if len(all_messages) < limit:
  210. all_messages.append(message)
  211. messages = []
  212. for message in all_messages:
  213. reply_to_message = (
  214. self.get_message_by_id(message.reply_to_id)
  215. if message.reply_to_id
  216. else None
  217. )
  218. messages.append(
  219. MessageReplyToResponse.model_validate(
  220. {
  221. **MessageModel.model_validate(message).model_dump(),
  222. "reply_to_message": (
  223. reply_to_message.model_dump()
  224. if reply_to_message
  225. else None
  226. ),
  227. }
  228. )
  229. )
  230. return messages
  231. def update_message_by_id(
  232. self, id: str, form_data: MessageForm
  233. ) -> Optional[MessageModel]:
  234. with get_db() as db:
  235. message = db.get(Message, id)
  236. message.content = form_data.content
  237. message.data = {
  238. **(message.data if message.data else {}),
  239. **(form_data.data if form_data.data else {}),
  240. }
  241. message.meta = {
  242. **(message.meta if message.meta else {}),
  243. **(form_data.meta if form_data.meta else {}),
  244. }
  245. message.updated_at = int(time.time_ns())
  246. db.commit()
  247. db.refresh(message)
  248. return MessageModel.model_validate(message) if message else None
  249. def add_reaction_to_message(
  250. self, id: str, user_id: str, name: str
  251. ) -> Optional[MessageReactionModel]:
  252. with get_db() as db:
  253. reaction_id = str(uuid.uuid4())
  254. reaction = MessageReactionModel(
  255. id=reaction_id,
  256. user_id=user_id,
  257. message_id=id,
  258. name=name,
  259. created_at=int(time.time_ns()),
  260. )
  261. result = MessageReaction(**reaction.model_dump())
  262. db.add(result)
  263. db.commit()
  264. db.refresh(result)
  265. return MessageReactionModel.model_validate(result) if result else None
  266. def get_reactions_by_message_id(self, id: str) -> list[Reactions]:
  267. with get_db() as db:
  268. all_reactions = db.query(MessageReaction).filter_by(message_id=id).all()
  269. reactions = {}
  270. for reaction in all_reactions:
  271. if reaction.name not in reactions:
  272. reactions[reaction.name] = {
  273. "name": reaction.name,
  274. "user_ids": [],
  275. "count": 0,
  276. }
  277. reactions[reaction.name]["user_ids"].append(reaction.user_id)
  278. reactions[reaction.name]["count"] += 1
  279. return [Reactions(**reaction) for reaction in reactions.values()]
  280. def remove_reaction_by_id_and_user_id_and_name(
  281. self, id: str, user_id: str, name: str
  282. ) -> bool:
  283. with get_db() as db:
  284. db.query(MessageReaction).filter_by(
  285. message_id=id, user_id=user_id, name=name
  286. ).delete()
  287. db.commit()
  288. return True
  289. def delete_reactions_by_id(self, id: str) -> bool:
  290. with get_db() as db:
  291. db.query(MessageReaction).filter_by(message_id=id).delete()
  292. db.commit()
  293. return True
  294. def delete_replies_by_id(self, id: str) -> bool:
  295. with get_db() as db:
  296. db.query(Message).filter_by(parent_id=id).delete()
  297. db.commit()
  298. return True
  299. def delete_message_by_id(self, id: str) -> bool:
  300. with get_db() as db:
  301. db.query(Message).filter_by(id=id).delete()
  302. # Delete all reactions to this message
  303. db.query(MessageReaction).filter_by(message_id=id).delete()
  304. db.commit()
  305. return True
  306. Messages = MessageTable()