|
@@ -9,6 +9,7 @@ import re
|
|
|
from open_webui.utils.chat import generate_chat_completion
|
|
|
from open_webui.utils.task import (
|
|
|
title_generation_template,
|
|
|
+ follow_up_generation_template,
|
|
|
query_generation_template,
|
|
|
image_prompt_generation_template,
|
|
|
autocomplete_generation_template,
|
|
@@ -25,6 +26,7 @@ from open_webui.utils.task import get_task_model_id
|
|
|
|
|
|
from open_webui.config import (
|
|
|
DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE,
|
|
|
+ DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
|
|
|
DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE,
|
|
|
DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
|
|
|
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
@@ -58,6 +60,8 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)):
|
|
|
"ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
|
|
|
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
|
|
|
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
|
|
+ "FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
|
|
|
+ "ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
|
|
|
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
|
|
|
"ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
|
|
|
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
|
|
@@ -76,6 +80,8 @@ class TaskConfigForm(BaseModel):
|
|
|
ENABLE_AUTOCOMPLETE_GENERATION: bool
|
|
|
AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int
|
|
|
TAGS_GENERATION_PROMPT_TEMPLATE: str
|
|
|
+ FOLLOW_UP_GENERATION_PROMPT_TEMPLATE: str
|
|
|
+ ENABLE_FOLLOW_UP_GENERATION: bool
|
|
|
ENABLE_TAGS_GENERATION: bool
|
|
|
ENABLE_SEARCH_QUERY_GENERATION: bool
|
|
|
ENABLE_RETRIEVAL_QUERY_GENERATION: bool
|
|
@@ -94,6 +100,13 @@ async def update_task_config(
|
|
|
form_data.TITLE_GENERATION_PROMPT_TEMPLATE
|
|
|
)
|
|
|
|
|
|
+ request.app.state.config.ENABLE_FOLLOW_UP_GENERATION = (
|
|
|
+ form_data.ENABLE_FOLLOW_UP_GENERATION
|
|
|
+ )
|
|
|
+ request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = (
|
|
|
+ form_data.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
|
|
|
+ )
|
|
|
+
|
|
|
request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
|
|
|
form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
|
|
|
)
|
|
@@ -133,6 +146,8 @@ async def update_task_config(
|
|
|
"AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
|
|
|
"TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
|
|
|
"ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
|
|
|
+ "ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
|
|
|
+ "FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
|
|
|
"ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
|
|
|
"ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
|
|
|
"QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
@@ -231,6 +246,86 @@ async def generate_title(
|
|
|
)
|
|
|
|
|
|
|
|
|
+@router.post("/follow_up/completions")
|
|
|
+async def generate_follow_ups(
|
|
|
+ request: Request, form_data: dict, user=Depends(get_verified_user)
|
|
|
+):
|
|
|
+
|
|
|
+ if not request.app.state.config.ENABLE_FOLLOW_UP_GENERATION:
|
|
|
+ return JSONResponse(
|
|
|
+ status_code=status.HTTP_200_OK,
|
|
|
+ content={"detail": "Follow-up generation is disabled"},
|
|
|
+ )
|
|
|
+
|
|
|
+ if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
|
|
+ models = {
|
|
|
+ request.state.model["id"]: request.state.model,
|
|
|
+ }
|
|
|
+ else:
|
|
|
+ models = request.app.state.MODELS
|
|
|
+
|
|
|
+ model_id = form_data["model"]
|
|
|
+ if model_id not in models:
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=status.HTTP_404_NOT_FOUND,
|
|
|
+ detail="Model not found",
|
|
|
+ )
|
|
|
+
|
|
|
+ # Check if the user has a custom task model
|
|
|
+ # If the user has a custom task model, use that model
|
|
|
+ task_model_id = get_task_model_id(
|
|
|
+ model_id,
|
|
|
+ request.app.state.config.TASK_MODEL,
|
|
|
+ request.app.state.config.TASK_MODEL_EXTERNAL,
|
|
|
+ models,
|
|
|
+ )
|
|
|
+
|
|
|
+ log.debug(
|
|
|
+ f"generating chat title using model {task_model_id} for user {user.email} "
|
|
|
+ )
|
|
|
+
|
|
|
+ if request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE != "":
|
|
|
+ template = request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
|
|
|
+ else:
|
|
|
+ template = DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
|
|
|
+
|
|
|
+ content = follow_up_generation_template(
|
|
|
+ template,
|
|
|
+ form_data["messages"],
|
|
|
+ {
|
|
|
+ "name": user.name,
|
|
|
+ "location": user.info.get("location") if user.info else None,
|
|
|
+ },
|
|
|
+ )
|
|
|
+
|
|
|
+ payload = {
|
|
|
+ "model": task_model_id,
|
|
|
+ "messages": [{"role": "user", "content": content}],
|
|
|
+ "stream": False,
|
|
|
+ "metadata": {
|
|
|
+ **(request.state.metadata if hasattr(request.state, "metadata") else {}),
|
|
|
+ "task": str(TASKS.FOLLOW_UP_GENERATION),
|
|
|
+ "task_body": form_data,
|
|
|
+ "chat_id": form_data.get("chat_id", None),
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ # Process the payload through the pipeline
|
|
|
+ try:
|
|
|
+ payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
|
|
+ except Exception as e:
|
|
|
+ raise e
|
|
|
+
|
|
|
+ try:
|
|
|
+ return await generate_chat_completion(request, form_data=payload, user=user)
|
|
|
+ except Exception as e:
|
|
|
+ log.error("Exception occurred", exc_info=True)
|
|
|
+ return JSONResponse(
|
|
|
+ status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
+ content={"detail": "An internal error has occurred."},
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
@router.post("/tags/completions")
|
|
|
async def generate_chat_tags(
|
|
|
request: Request, form_data: dict, user=Depends(get_verified_user)
|