|
@@ -24,9 +24,17 @@ from open_webui.constants import ERROR_MESSAGES
|
|
|
from open_webui.env import SRC_LOG_LEVELS
|
|
|
|
|
|
|
|
|
+from open_webui.utils.models import (
|
|
|
+ get_all_models,
|
|
|
+ get_filtered_models,
|
|
|
+)
|
|
|
+from open_webui.utils.chat import generate_chat_completion
|
|
|
+
|
|
|
+
|
|
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
|
|
from open_webui.utils.access_control import has_access, get_users_with_access
|
|
|
from open_webui.utils.webhook import post_webhook
|
|
|
+from open_webui.utils.channels import extract_mentions, replace_mentions
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
|
@@ -221,13 +229,131 @@ async def send_notification(name, webui_url, channel, message, active_user_ids):
|
|
|
return True
|
|
|
|
|
|
|
|
|
-@router.post("/{id}/messages/post", response_model=Optional[MessageModel])
|
|
|
-async def post_new_message(
|
|
|
- request: Request,
|
|
|
- id: str,
|
|
|
- form_data: MessageForm,
|
|
|
- background_tasks: BackgroundTasks,
|
|
|
- user=Depends(get_verified_user),
|
|
|
+async def model_response_handler(request, channel, message, user):
|
|
|
+ MODELS = {
|
|
|
+ model["id"]: model
|
|
|
+ for model in get_filtered_models(await get_all_models(request, user=user), user)
|
|
|
+ }
|
|
|
+
|
|
|
+ mentions = extract_mentions(message.content)
|
|
|
+ message_content = replace_mentions(message.content)
|
|
|
+
|
|
|
+ # check if any of the mentions are models
|
|
|
+ model_mentions = [mention for mention in mentions if mention["id_type"] == "M"]
|
|
|
+ if not model_mentions:
|
|
|
+ return False
|
|
|
+
|
|
|
+ for mention in model_mentions:
|
|
|
+ model_id = mention["id"]
|
|
|
+ model = MODELS.get(model_id, None)
|
|
|
+
|
|
|
+ if model:
|
|
|
+ try:
|
|
|
+ # reverse to get in chronological order
|
|
|
+ thread_messages = Messages.get_messages_by_parent_id(
|
|
|
+ channel.id,
|
|
|
+ message.parent_id if message.parent_id else message.id,
|
|
|
+ )[::-1]
|
|
|
+
|
|
|
+ response_message, channel = await new_message_handler(
|
|
|
+ request,
|
|
|
+ channel.id,
|
|
|
+ MessageForm(
|
|
|
+ **{
|
|
|
+ "parent_id": (
|
|
|
+ message.parent_id if message.parent_id else message.id
|
|
|
+ ),
|
|
|
+ "content": f"",
|
|
|
+ "data": {},
|
|
|
+ "meta": {
|
|
|
+ "model_id": model_id,
|
|
|
+ "model_name": model.get("name", model_id),
|
|
|
+ },
|
|
|
+ }
|
|
|
+ ),
|
|
|
+ user,
|
|
|
+ )
|
|
|
+
|
|
|
+ thread_history = []
|
|
|
+ message_users = {}
|
|
|
+
|
|
|
+ for thread_message in thread_messages:
|
|
|
+ message_user = None
|
|
|
+ if thread_message.user_id not in message_users:
|
|
|
+ message_user = Users.get_user_by_id(thread_message.user_id)
|
|
|
+ message_users[thread_message.user_id] = message_user
|
|
|
+ else:
|
|
|
+ message_user = message_users[thread_message.user_id]
|
|
|
+
|
|
|
+ if thread_message.meta and thread_message.meta.get(
|
|
|
+ "model_id", None
|
|
|
+ ):
|
|
|
+ # If the message was sent by a model, use the model name
|
|
|
+ message_model_id = thread_message.meta.get("model_id", None)
|
|
|
+ message_model = MODELS.get(message_model_id, None)
|
|
|
+ username = (
|
|
|
+ message_model.get("name", message_model_id)
|
|
|
+ if message_model
|
|
|
+ else message_model_id
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ username = message_user.name if message_user else "Unknown"
|
|
|
+
|
|
|
+ thread_history.append(
|
|
|
+ f"{username}: {replace_mentions(thread_message.content)}"
|
|
|
+ )
|
|
|
+
|
|
|
+ system_message = {
|
|
|
+ "role": "system",
|
|
|
+ "content": f"You are {model.get('name', model_id)}, an AI assistant participating in a threaded conversation. Be helpful, concise, and conversational."
|
|
|
+ + (
|
|
|
+ f"Here's the thread history:\n\n{''.join([f'{msg}' for msg in thread_history])}\n\nContinue the conversation naturally, addressing the most recent message while being aware of the full context."
|
|
|
+ if thread_history
|
|
|
+ else ""
|
|
|
+ ),
|
|
|
+ }
|
|
|
+
|
|
|
+ form_data = {
|
|
|
+ "model": model_id,
|
|
|
+ "messages": [
|
|
|
+ system_message,
|
|
|
+ {
|
|
|
+ "role": "user",
|
|
|
+ "content": f"{user.name if user else 'User'}: {message_content}",
|
|
|
+ },
|
|
|
+ ],
|
|
|
+ "stream": False,
|
|
|
+ }
|
|
|
+
|
|
|
+ res = await generate_chat_completion(
|
|
|
+ request,
|
|
|
+ form_data=form_data,
|
|
|
+ user=user,
|
|
|
+ )
|
|
|
+
|
|
|
+ if res:
|
|
|
+ await update_message_by_id(
|
|
|
+ channel.id,
|
|
|
+ response_message.id,
|
|
|
+ MessageForm(
|
|
|
+ **{
|
|
|
+ "content": res["choices"][0]["message"]["content"],
|
|
|
+ "meta": {
|
|
|
+ "done": True,
|
|
|
+ },
|
|
|
+ }
|
|
|
+ ),
|
|
|
+ user,
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ log.info(e)
|
|
|
+ pass
|
|
|
+
|
|
|
+ return True
|
|
|
+
|
|
|
+
|
|
|
+async def new_message_handler(
|
|
|
+ request: Request, id: str, form_data: MessageForm, user=Depends(get_verified_user)
|
|
|
):
|
|
|
channel = Channels.get_channel_by_id(id)
|
|
|
if not channel:
|
|
@@ -301,21 +427,43 @@ async def post_new_message(
|
|
|
},
|
|
|
to=f"channel:{channel.id}",
|
|
|
)
|
|
|
+ return MessageModel(**message.model_dump()), channel
|
|
|
+ except Exception as e:
|
|
|
+ log.exception(e)
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
|
|
+ )
|
|
|
|
|
|
- active_user_ids = get_user_ids_from_room(f"channel:{channel.id}")
|
|
|
|
|
|
- async def background_handler():
|
|
|
- await send_notification(
|
|
|
- request.app.state.WEBUI_NAME,
|
|
|
- request.app.state.config.WEBUI_URL,
|
|
|
- channel,
|
|
|
- message,
|
|
|
- active_user_ids,
|
|
|
- )
|
|
|
+@router.post("/{id}/messages/post", response_model=Optional[MessageModel])
|
|
|
+async def post_new_message(
|
|
|
+ request: Request,
|
|
|
+ id: str,
|
|
|
+ form_data: MessageForm,
|
|
|
+ background_tasks: BackgroundTasks,
|
|
|
+ user=Depends(get_verified_user),
|
|
|
+):
|
|
|
+
|
|
|
+ try:
|
|
|
+ message, channel = await new_message_handler(request, id, form_data, user)
|
|
|
+ active_user_ids = get_user_ids_from_room(f"channel:{channel.id}")
|
|
|
+
|
|
|
+ async def background_handler():
|
|
|
+ await model_response_handler(request, channel, message, user)
|
|
|
+ await send_notification(
|
|
|
+ request.app.state.WEBUI_NAME,
|
|
|
+ request.app.state.config.WEBUI_URL,
|
|
|
+ channel,
|
|
|
+ message,
|
|
|
+ active_user_ids,
|
|
|
+ )
|
|
|
|
|
|
- background_tasks.add_task(background_handler)
|
|
|
+ background_tasks.add_task(background_handler)
|
|
|
|
|
|
- return MessageModel(**message.model_dump())
|
|
|
+ return message
|
|
|
+
|
|
|
+ except HTTPException as e:
|
|
|
+ raise e
|
|
|
except Exception as e:
|
|
|
log.exception(e)
|
|
|
raise HTTPException(
|