소스 검색

feat: follow ups

Timothy Jaeryang Baek 4 달 전
부모
커밋
9e49fbc8bf

+ 28 - 0
backend/open_webui/config.py

@@ -1411,6 +1411,34 @@ Strictly return in JSON format:
 {{MESSAGES:END:6}}
 </chat_history>"""
 
+
+FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
+    "FOLLOW_UP_GENERATION_PROMPT_TEMPLATE",
+    "task.follow_up.prompt_template",
+    os.environ.get("FOLLOW_UP_GENERATION_PROMPT_TEMPLATE", ""),
+)
+
+DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = """### Task:
+Suggest 3-5 relevant follow-up questions or discussion prompts based on the chat history to help continue or deepen the conversation.
+### Guidelines:
+- Make questions concise, clear, and directly related to the discussed topic(s).
+- Only generate follow-ups that make sense given the chat content and do not repeat what was already covered.
+- If the conversation is very short or not specific, suggest more general follow-ups.
+- Use the chat's primary language; default to English if multilingual.
+- Response must be a JSON array of strings, no extra text or formatting.
+### Output:
+JSON format: { "follow_ups": ["Question 1?", "Question 2?", "Question 3?"] }
+### Chat History:
+<chat_history>
+{{MESSAGES:END:6}}
+</chat_history>"""
+
+ENABLE_FOLLOW_UP_GENERATION = PersistentConfig(
+    "ENABLE_FOLLOW_UP_GENERATION",
+    "task.follow_up.enable",
+    os.environ.get("ENABLE_FOLLOW_UP_GENERATION", "True").lower() == "true",
+)
+
 ENABLE_TAGS_GENERATION = PersistentConfig(
     "ENABLE_TAGS_GENERATION",
     "task.tags.enable",

+ 1 - 0
backend/open_webui/constants.py

@@ -111,6 +111,7 @@ class TASKS(str, Enum):
 
     DEFAULT = lambda task="": f"{task if task else 'generation'}"
     TITLE_GENERATION = "title_generation"
+    FOLLOW_UP_GENERATION = "follow_up_generation"
     TAGS_GENERATION = "tags_generation"
     EMOJI_GENERATION = "emoji_generation"
     QUERY_GENERATION = "query_generation"

+ 6 - 0
backend/open_webui/main.py

@@ -359,10 +359,12 @@ from open_webui.config import (
     TASK_MODEL_EXTERNAL,
     ENABLE_TAGS_GENERATION,
     ENABLE_TITLE_GENERATION,
+    ENABLE_FOLLOW_UP_GENERATION,
     ENABLE_SEARCH_QUERY_GENERATION,
     ENABLE_RETRIEVAL_QUERY_GENERATION,
     ENABLE_AUTOCOMPLETE_GENERATION,
     TITLE_GENERATION_PROMPT_TEMPLATE,
+    FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
     TAGS_GENERATION_PROMPT_TEMPLATE,
     IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
     TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
@@ -959,6 +961,7 @@ app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = ENABLE_RETRIEVAL_QUERY_GENE
 app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ENABLE_AUTOCOMPLETE_GENERATION
 app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION
 app.state.config.ENABLE_TITLE_GENERATION = ENABLE_TITLE_GENERATION
+app.state.config.ENABLE_FOLLOW_UP_GENERATION = ENABLE_FOLLOW_UP_GENERATION
 
 
 app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
@@ -966,6 +969,9 @@ app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLA
 app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
     IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
 )
+app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = (
+    FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
+)
 
 app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
     TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE

+ 95 - 0
backend/open_webui/routers/tasks.py

@@ -9,6 +9,7 @@ import re
 from open_webui.utils.chat import generate_chat_completion
 from open_webui.utils.task import (
     title_generation_template,
+    follow_up_generation_template,
     query_generation_template,
     image_prompt_generation_template,
     autocomplete_generation_template,
@@ -25,6 +26,7 @@ from open_webui.utils.task import get_task_model_id
 
 from open_webui.config import (
     DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE,
+    DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
     DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE,
     DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
     DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
@@ -58,6 +60,8 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)):
         "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
         "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
         "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
+        "FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
+        "ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
         "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
         "ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
         "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
@@ -76,6 +80,8 @@ class TaskConfigForm(BaseModel):
     ENABLE_AUTOCOMPLETE_GENERATION: bool
     AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int
     TAGS_GENERATION_PROMPT_TEMPLATE: str
+    FOLLOW_UP_GENERATION_PROMPT_TEMPLATE: str
+    ENABLE_FOLLOW_UP_GENERATION: bool
     ENABLE_TAGS_GENERATION: bool
     ENABLE_SEARCH_QUERY_GENERATION: bool
     ENABLE_RETRIEVAL_QUERY_GENERATION: bool
@@ -94,6 +100,13 @@ async def update_task_config(
         form_data.TITLE_GENERATION_PROMPT_TEMPLATE
     )
 
+    request.app.state.config.ENABLE_FOLLOW_UP_GENERATION = (
+        form_data.ENABLE_FOLLOW_UP_GENERATION
+    )
+    request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = (
+        form_data.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
+    )
+
     request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
         form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
     )
@@ -133,6 +146,8 @@ async def update_task_config(
         "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
         "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
         "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
+        "ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
+        "FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
         "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
         "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
         "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
@@ -231,6 +246,86 @@ async def generate_title(
         )
 
 
+@router.post("/follow_up/completions")
+async def generate_follow_ups(
+    request: Request, form_data: dict, user=Depends(get_verified_user)
+):
+
+    if not request.app.state.config.ENABLE_FOLLOW_UP_GENERATION:
+        return JSONResponse(
+            status_code=status.HTTP_200_OK,
+            content={"detail": "Follow-up generation is disabled"},
+        )
+
+    if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
+        models = {
+            request.state.model["id"]: request.state.model,
+        }
+    else:
+        models = request.app.state.MODELS
+
+    model_id = form_data["model"]
+    if model_id not in models:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail="Model not found",
+        )
+
+    # Check if the user has a custom task model
+    # If the user has a custom task model, use that model
+    task_model_id = get_task_model_id(
+        model_id,
+        request.app.state.config.TASK_MODEL,
+        request.app.state.config.TASK_MODEL_EXTERNAL,
+        models,
+    )
+
+    log.debug(
+        f"generating chat title using model {task_model_id} for user {user.email} "
+    )
+
+    if request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE != "":
+        template = request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
+    else:
+        template = DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
+
+    content = follow_up_generation_template(
+        template,
+        form_data["messages"],
+        {
+            "name": user.name,
+            "location": user.info.get("location") if user.info else None,
+        },
+    )
+
+    payload = {
+        "model": task_model_id,
+        "messages": [{"role": "user", "content": content}],
+        "stream": False,
+        "metadata": {
+            **(request.state.metadata if hasattr(request.state, "metadata") else {}),
+            "task": str(TASKS.FOLLOW_UP_GENERATION),
+            "task_body": form_data,
+            "chat_id": form_data.get("chat_id", None),
+        },
+    }
+
+    # Process the payload through the pipeline
+    try:
+        payload = await process_pipeline_inlet_filter(request, payload, user, models)
+    except Exception as e:
+        raise e
+
+    try:
+        return await generate_chat_completion(request, form_data=payload, user=user)
+    except Exception as e:
+        log.error("Exception occurred", exc_info=True)
+        return JSONResponse(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            content={"detail": "An internal error has occurred."},
+        )
+
+
 @router.post("/tags/completions")
 async def generate_chat_tags(
     request: Request, form_data: dict, user=Depends(get_verified_user)

+ 18 - 0
backend/open_webui/utils/task.py

@@ -207,6 +207,24 @@ def title_generation_template(
     return template
 
 
+def follow_up_generation_template(
+    template: str, messages: list[dict], user: Optional[dict] = None
+) -> str:
+    prompt = get_last_user_message(messages)
+    template = replace_prompt_variable(template, prompt)
+    template = replace_messages_variable(template, messages)
+
+    template = prompt_template(
+        template,
+        **(
+            {"user_name": user.get("name"), "user_location": user.get("location")}
+            if user
+            else {}
+        ),
+    )
+    return template
+
+
 def tags_generation_template(
     template: str, messages: list[dict], user: Optional[dict] = None
 ) -> str:

+ 28 - 0
src/lib/components/admin/Settings/Interface.svelte

@@ -31,6 +31,8 @@
 		TASK_MODEL_EXTERNAL: '',
 		ENABLE_TITLE_GENERATION: true,
 		TITLE_GENERATION_PROMPT_TEMPLATE: '',
+		ENABLE_FOLLOW_UP_GENERATION: true,
+		FOLLOW_UP_GENERATION_PROMPT_TEMPLATE: '',
 		IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: '',
 		ENABLE_AUTOCOMPLETE_GENERATION: true,
 		AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: -1,
@@ -235,6 +237,32 @@
 					</div>
 				{/if}
 
+				<div class="mb-2.5 flex w-full items-center justify-between">
+					<div class=" self-center text-xs font-medium">
+						{$i18n.t('Follow Up Generation')}
+					</div>
+
+					<Switch bind:state={taskConfig.ENABLE_FOLLOW_UP_GENERATION} />
+				</div>
+
+				{#if taskConfig.ENABLE_FOLLOW_UP_GENERATION}
+					<div class="mb-2.5">
+						<div class=" mb-1 text-xs font-medium">{$i18n.t('Follow Up Generation Prompt')}</div>
+
+						<Tooltip
+							content={$i18n.t('Leave empty to use the default prompt, or enter a custom prompt')}
+							placement="top-start"
+						>
+							<Textarea
+								bind:value={taskConfig.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE}
+								placeholder={$i18n.t(
+									'Leave empty to use the default prompt, or enter a custom prompt'
+								)}
+							/>
+						</Tooltip>
+					</div>
+				{/if}
+
 				<div class="mb-2.5 flex w-full items-center justify-between">
 					<div class=" self-center text-xs font-medium">
 						{$i18n.t('Tags Generation')}