1
0
Timothy Jaeryang Baek 2 сар өмнө
parent
commit
c512bf3559

+ 42 - 40
backend/open_webui/models/messages.py

@@ -95,7 +95,7 @@ class MessageResponse(MessageModel):
 
 
 
 
 class MessageTable:
 class MessageTable:
-    def insert_new_message(
+    async def insert_new_message(
         self, form_data: MessageForm, channel_id: str, user_id: str
         self, form_data: MessageForm, channel_id: str, user_id: str
     ) -> Optional[MessageModel]:
     ) -> Optional[MessageModel]:
         async with get_db() as db:
         async with get_db() as db:
@@ -117,19 +117,19 @@ class MessageTable:
             )
             )
 
 
             result = Message(**message.model_dump())
             result = Message(**message.model_dump())
-            db.add(result)
-            db.commit()
-            db.refresh(result)
+            await db.add(result)
+            await db.commit()
+            await db.refresh(result)
             return MessageModel.model_validate(result) if result else None
             return MessageModel.model_validate(result) if result else None
 
 
-    def get_message_by_id(self, id: str) -> Optional[MessageResponse]:
+    async def get_message_by_id(self, id: str) -> Optional[MessageResponse]:
         async with get_db() as db:
         async with get_db() as db:
-            message = db.get(Message, id)
+            message = await db.get(Message, id)
             if not message:
             if not message:
                 return None
                 return None
 
 
-            reactions = self.get_reactions_by_message_id(id)
-            replies = self.get_replies_by_message_id(id)
+            reactions = await self.get_reactions_by_message_id(id)
+            replies = await self.get_replies_by_message_id(id)
 
 
             return MessageResponse(
             return MessageResponse(
                 **{
                 **{
@@ -140,29 +140,29 @@ class MessageTable:
                 }
                 }
             )
             )
 
 
-    def get_replies_by_message_id(self, id: str) -> list[MessageModel]:
+    async def get_replies_by_message_id(self, id: str) -> list[MessageModel]:
         async with get_db() as db:
         async with get_db() as db:
             all_messages = (
             all_messages = (
-                db.query(Message)
+                await db.query(Message)
                 .filter_by(parent_id=id)
                 .filter_by(parent_id=id)
                 .order_by(Message.created_at.desc())
                 .order_by(Message.created_at.desc())
                 .all()
                 .all()
             )
             )
             return [MessageModel.model_validate(message) for message in all_messages]
             return [MessageModel.model_validate(message) for message in all_messages]
 
 
-    def get_reply_user_ids_by_message_id(self, id: str) -> list[str]:
+    async def get_reply_user_ids_by_message_id(self, id: str) -> list[str]:
         async with get_db() as db:
         async with get_db() as db:
             return [
             return [
                 message.user_id
                 message.user_id
-                for message in db.query(Message).filter_by(parent_id=id).all()
+                for message in await db.query(Message).filter_by(parent_id=id).all()
             ]
             ]
 
 
-    def get_messages_by_channel_id(
+    async def get_messages_by_channel_id(
         self, channel_id: str, skip: int = 0, limit: int = 50
         self, channel_id: str, skip: int = 0, limit: int = 50
     ) -> list[MessageModel]:
     ) -> list[MessageModel]:
         async with get_db() as db:
         async with get_db() as db:
             all_messages = (
             all_messages = (
-                db.query(Message)
+                await db.query(Message)
                 .filter_by(channel_id=channel_id, parent_id=None)
                 .filter_by(channel_id=channel_id, parent_id=None)
                 .order_by(Message.created_at.desc())
                 .order_by(Message.created_at.desc())
                 .offset(skip)
                 .offset(skip)
@@ -171,17 +171,17 @@ class MessageTable:
             )
             )
             return [MessageModel.model_validate(message) for message in all_messages]
             return [MessageModel.model_validate(message) for message in all_messages]
 
 
-    def get_messages_by_parent_id(
+    async def get_messages_by_parent_id(
         self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50
         self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50
     ) -> list[MessageModel]:
     ) -> list[MessageModel]:
         async with get_db() as db:
         async with get_db() as db:
-            message = db.get(Message, parent_id)
+            message = await db.get(Message, parent_id)
 
 
             if not message:
             if not message:
                 return []
                 return []
 
 
             all_messages = (
             all_messages = (
-                db.query(Message)
+                await db.query(Message)
                 .filter_by(channel_id=channel_id, parent_id=parent_id)
                 .filter_by(channel_id=channel_id, parent_id=parent_id)
                 .order_by(Message.created_at.desc())
                 .order_by(Message.created_at.desc())
                 .offset(skip)
                 .offset(skip)
@@ -195,20 +195,20 @@ class MessageTable:
 
 
             return [MessageModel.model_validate(message) for message in all_messages]
             return [MessageModel.model_validate(message) for message in all_messages]
 
 
-    def update_message_by_id(
+    async def update_message_by_id(
         self, id: str, form_data: MessageForm
         self, id: str, form_data: MessageForm
     ) -> Optional[MessageModel]:
     ) -> Optional[MessageModel]:
         async with get_db() as db:
         async with get_db() as db:
-            message = db.get(Message, id)
+            message = await db.get(Message, id)
             message.content = form_data.content
             message.content = form_data.content
             message.data = form_data.data
             message.data = form_data.data
             message.meta = form_data.meta
             message.meta = form_data.meta
             message.updated_at = int(time.time_ns())
             message.updated_at = int(time.time_ns())
-            db.commit()
-            db.refresh(message)
+            await db.commit()
+            await db.refresh(message)
             return MessageModel.model_validate(message) if message else None
             return MessageModel.model_validate(message) if message else None
 
 
-    def add_reaction_to_message(
+    async def add_reaction_to_message(
         self, id: str, user_id: str, name: str
         self, id: str, user_id: str, name: str
     ) -> Optional[MessageReactionModel]:
     ) -> Optional[MessageReactionModel]:
         async with get_db() as db:
         async with get_db() as db:
@@ -221,14 +221,16 @@ class MessageTable:
                 created_at=int(time.time_ns()),
                 created_at=int(time.time_ns()),
             )
             )
             result = MessageReaction(**reaction.model_dump())
             result = MessageReaction(**reaction.model_dump())
-            db.add(result)
-            db.commit()
-            db.refresh(result)
+            await db.add(result)
+            await db.commit()
+            await db.refresh(result)
             return MessageReactionModel.model_validate(result) if result else None
             return MessageReactionModel.model_validate(result) if result else None
 
 
-    def get_reactions_by_message_id(self, id: str) -> list[Reactions]:
+    async def get_reactions_by_message_id(self, id: str) -> list[Reactions]:
         async with get_db() as db:
         async with get_db() as db:
-            all_reactions = db.query(MessageReaction).filter_by(message_id=id).all()
+            all_reactions = (
+                await db.query(MessageReaction).filter_by(message_id=id).all()
+            )
 
 
             reactions = {}
             reactions = {}
             for reaction in all_reactions:
             for reaction in all_reactions:
@@ -243,36 +245,36 @@ class MessageTable:
 
 
             return [Reactions(**reaction) for reaction in reactions.values()]
             return [Reactions(**reaction) for reaction in reactions.values()]
 
 
-    def remove_reaction_by_id_and_user_id_and_name(
+    async def remove_reaction_by_id_and_user_id_and_name(
         self, id: str, user_id: str, name: str
         self, id: str, user_id: str, name: str
     ) -> bool:
     ) -> bool:
         async with get_db() as db:
         async with get_db() as db:
-            db.query(MessageReaction).filter_by(
+            await db.query(MessageReaction).filter_by(
                 message_id=id, user_id=user_id, name=name
                 message_id=id, user_id=user_id, name=name
             ).delete()
             ).delete()
-            db.commit()
+            await db.commit()
             return True
             return True
 
 
-    def delete_reactions_by_id(self, id: str) -> bool:
+    async def delete_reactions_by_id(self, id: str) -> bool:
         async with get_db() as db:
         async with get_db() as db:
-            db.query(MessageReaction).filter_by(message_id=id).delete()
-            db.commit()
+            await db.query(MessageReaction).filter_by(message_id=id).delete()
+            await db.commit()
             return True
             return True
 
 
-    def delete_replies_by_id(self, id: str) -> bool:
+    async def delete_replies_by_id(self, id: str) -> bool:
         async with get_db() as db:
         async with get_db() as db:
-            db.query(Message).filter_by(parent_id=id).delete()
-            db.commit()
+            await db.query(Message).filter_by(parent_id=id).delete()
+            await db.commit()
             return True
             return True
 
 
-    def delete_message_by_id(self, id: str) -> bool:
+    async def delete_message_by_id(self, id: str) -> bool:
         async with get_db() as db:
         async with get_db() as db:
-            db.query(Message).filter_by(id=id).delete()
+            await db.query(Message).filter_by(id=id).delete()
 
 
             # Delete all reactions to this message
             # Delete all reactions to this message
-            db.query(MessageReaction).filter_by(message_id=id).delete()
+            await db.query(MessageReaction).filter_by(message_id=id).delete()
 
 
-            db.commit()
+            await db.commit()
             return True
             return True
 
 
 
 

+ 21 - 21
backend/open_webui/routers/channels.py

@@ -164,7 +164,7 @@ async def get_channel_messages(
             status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
             status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
         )
         )
 
 
-    message_list = Messages.get_messages_by_channel_id(id, skip, limit)
+    message_list = await Messages.get_messages_by_channel_id(id, skip, limit)
     users = {}
     users = {}
 
 
     messages = []
     messages = []
@@ -173,7 +173,7 @@ async def get_channel_messages(
             user = await Users.get_user_by_id(message.user_id)
             user = await Users.get_user_by_id(message.user_id)
             users[message.user_id] = user
             users[message.user_id] = user
 
 
-        replies = Messages.get_replies_by_message_id(message.id)
+        replies = await Messages.get_replies_by_message_id(message.id)
         latest_reply_at = replies[0].created_at if replies else None
         latest_reply_at = replies[0].created_at if replies else None
 
 
         messages.append(
         messages.append(
@@ -182,7 +182,7 @@ async def get_channel_messages(
                     **message.model_dump(),
                     **message.model_dump(),
                     "reply_count": len(replies),
                     "reply_count": len(replies),
                     "latest_reply_at": latest_reply_at,
                     "latest_reply_at": latest_reply_at,
-                    "reactions": Messages.get_reactions_by_message_id(message.id),
+                    "reactions": await Messages.get_reactions_by_message_id(message.id),
                     "user": UserNameResponse(**users[message.user_id].model_dump()),
                     "user": UserNameResponse(**users[message.user_id].model_dump()),
                 }
                 }
             )
             )
@@ -244,7 +244,7 @@ async def post_new_message(
         )
         )
 
 
     try:
     try:
-        message = Messages.insert_new_message(form_data, channel.id, user.id)
+        message = await Messages.insert_new_message(form_data, channel.id, user.id)
 
 
         if message:
         if message:
             event_data = {
             event_data = {
@@ -257,7 +257,7 @@ async def post_new_message(
                             **message.model_dump(),
                             **message.model_dump(),
                             "reply_count": 0,
                             "reply_count": 0,
                             "latest_reply_at": None,
                             "latest_reply_at": None,
-                            "reactions": Messages.get_reactions_by_message_id(
+                            "reactions": await Messages.get_reactions_by_message_id(
                                 message.id
                                 message.id
                             ),
                             ),
                             "user": UserNameResponse(**user.model_dump()),
                             "user": UserNameResponse(**user.model_dump()),
@@ -276,7 +276,7 @@ async def post_new_message(
 
 
             if message.parent_id:
             if message.parent_id:
                 # If this message is a reply, emit to the parent message as well
                 # If this message is a reply, emit to the parent message as well
-                parent_message = Messages.get_message_by_id(message.parent_id)
+                parent_message = await Messages.get_message_by_id(message.parent_id)
 
 
                 if parent_message:
                 if parent_message:
                     await sio.emit(
                     await sio.emit(
@@ -348,7 +348,7 @@ async def get_channel_message(
             status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
             status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
         )
         )
 
 
-    message = Messages.get_message_by_id(message_id)
+    message = await Messages.get_message_by_id(message_id)
     if not message:
     if not message:
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
             status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
@@ -397,7 +397,7 @@ async def get_channel_thread_messages(
             status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
             status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
         )
         )
 
 
-    message_list = Messages.get_messages_by_parent_id(id, message_id, skip, limit)
+    message_list = await Messages.get_messages_by_parent_id(id, message_id, skip, limit)
     users = {}
     users = {}
 
 
     messages = []
     messages = []
@@ -412,7 +412,7 @@ async def get_channel_thread_messages(
                     **message.model_dump(),
                     **message.model_dump(),
                     "reply_count": 0,
                     "reply_count": 0,
                     "latest_reply_at": None,
                     "latest_reply_at": None,
-                    "reactions": Messages.get_reactions_by_message_id(message.id),
+                    "reactions": await Messages.get_reactions_by_message_id(message.id),
                     "user": UserNameResponse(**users[message.user_id].model_dump()),
                     "user": UserNameResponse(**users[message.user_id].model_dump()),
                 }
                 }
             )
             )
@@ -438,7 +438,7 @@ async def update_message_by_id(
             status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
             status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
         )
         )
 
 
-    message = Messages.get_message_by_id(message_id)
+    message = await Messages.get_message_by_id(message_id)
     if not message:
     if not message:
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
             status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
@@ -461,8 +461,8 @@ async def update_message_by_id(
         )
         )
 
 
     try:
     try:
-        message = Messages.update_message_by_id(message_id, form_data)
-        message = Messages.get_message_by_id(message_id)
+        message = await Messages.update_message_by_id(message_id, form_data)
+        message = await Messages.get_message_by_id(message_id)
 
 
         if message:
         if message:
             await sio.emit(
             await sio.emit(
@@ -521,7 +521,7 @@ async def add_reaction_to_message(
             status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
             status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
         )
         )
 
 
-    message = Messages.get_message_by_id(message_id)
+    message = await Messages.get_message_by_id(message_id)
     if not message:
     if not message:
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
             status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
@@ -533,8 +533,8 @@ async def add_reaction_to_message(
         )
         )
 
 
     try:
     try:
-        Messages.add_reaction_to_message(message_id, user.id, form_data.name)
-        message = Messages.get_message_by_id(message_id)
+        await Messages.add_reaction_to_message(message_id, user.id, form_data.name)
+        message = await Messages.get_message_by_id(message_id)
 
 
         await sio.emit(
         await sio.emit(
             "channel-events",
             "channel-events",
@@ -587,7 +587,7 @@ async def remove_reaction_by_id_and_user_id_and_name(
             status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
             status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT()
         )
         )
 
 
-    message = Messages.get_message_by_id(message_id)
+    message = await Messages.get_message_by_id(message_id)
     if not message:
     if not message:
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
             status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
@@ -599,11 +599,11 @@ async def remove_reaction_by_id_and_user_id_and_name(
         )
         )
 
 
     try:
     try:
-        Messages.remove_reaction_by_id_and_user_id_and_name(
+        await Messages.remove_reaction_by_id_and_user_id_and_name(
             message_id, user.id, form_data.name
             message_id, user.id, form_data.name
         )
         )
 
 
-        message = Messages.get_message_by_id(message_id)
+        message = await Messages.get_message_by_id(message_id)
 
 
         await sio.emit(
         await sio.emit(
             "channel-events",
             "channel-events",
@@ -649,7 +649,7 @@ async def delete_message_by_id(
             status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
             status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
         )
         )
 
 
-    message = Messages.get_message_by_id(message_id)
+    message = await Messages.get_message_by_id(message_id)
     if not message:
     if not message:
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
             status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND
@@ -672,7 +672,7 @@ async def delete_message_by_id(
         )
         )
 
 
     try:
     try:
-        Messages.delete_message_by_id(message_id)
+        await Messages.delete_message_by_id(message_id)
         await sio.emit(
         await sio.emit(
             "channel-events",
             "channel-events",
             {
             {
@@ -693,7 +693,7 @@ async def delete_message_by_id(
 
 
         if message.parent_id:
         if message.parent_id:
             # If this message is a reply, emit to the parent message as well
             # If this message is a reply, emit to the parent message as well
-            parent_message = Messages.get_message_by_id(message.parent_id)
+            parent_message = await Messages.get_message_by_id(message.parent_id)
 
 
             if parent_message:
             if parent_message:
                 await sio.emit(
                 await sio.emit(