|
@@ -47,15 +47,24 @@ from pydantic import BaseModel
|
|
|
from typing import List, Optional
|
|
|
|
|
|
from apps.webui.models.models import Models, ModelModel
|
|
|
+from apps.webui.models.tools import Tools
|
|
|
+from apps.webui.utils import load_toolkit_module_by_id
|
|
|
+
|
|
|
+
|
|
|
from utils.utils import (
|
|
|
get_admin_user,
|
|
|
get_verified_user,
|
|
|
get_current_user,
|
|
|
get_http_authorization_cred,
|
|
|
)
|
|
|
-from utils.task import title_generation_template, search_query_generation_template
|
|
|
+from utils.task import (
|
|
|
+ title_generation_template,
|
|
|
+ search_query_generation_template,
|
|
|
+ tools_function_calling_generation_template,
|
|
|
+)
|
|
|
+from utils.misc import get_last_user_message, add_or_update_system_message
|
|
|
|
|
|
-from apps.rag.utils import rag_messages
|
|
|
+from apps.rag.utils import rag_messages, rag_template
|
|
|
|
|
|
from config import (
|
|
|
CONFIG_DATA,
|
|
@@ -82,6 +91,7 @@ from config import (
|
|
|
TITLE_GENERATION_PROMPT_TEMPLATE,
|
|
|
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
|
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
|
|
|
+ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
|
AppConfig,
|
|
|
)
|
|
|
from constants import ERROR_MESSAGES
|
|
@@ -148,24 +158,71 @@ app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
|
|
|
app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
|
|
|
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
|
|
|
)
|
|
|
+app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
|
|
|
+ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
|
|
+)
|
|
|
|
|
|
app.state.MODELS = {}
|
|
|
|
|
|
origins = ["*"]
|
|
|
|
|
|
-# Custom middleware to add security headers
|
|
|
-# class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
|
|
-# async def dispatch(self, request: Request, call_next):
|
|
|
-# response: Response = await call_next(request)
|
|
|
-# response.headers["Cross-Origin-Opener-Policy"] = "same-origin"
|
|
|
-# response.headers["Cross-Origin-Embedder-Policy"] = "require-corp"
|
|
|
-# return response
|
|
|
|
|
|
+async def get_function_call_response(prompt, tool_id, template, task_model_id, user):
|
|
|
+ tool = Tools.get_tool_by_id(tool_id)
|
|
|
+ tools_specs = json.dumps(tool.specs, indent=2)
|
|
|
+ content = tools_function_calling_generation_template(template, tools_specs)
|
|
|
+
|
|
|
+ payload = {
|
|
|
+ "model": task_model_id,
|
|
|
+ "messages": [
|
|
|
+ {"role": "system", "content": content},
|
|
|
+ {"role": "user", "content": f"Query: {prompt}"},
|
|
|
+ ],
|
|
|
+ "stream": False,
|
|
|
+ }
|
|
|
+
|
|
|
+ payload = filter_pipeline(payload, user)
|
|
|
+ model = app.state.MODELS[task_model_id]
|
|
|
+
|
|
|
+ response = None
|
|
|
+ if model["owned_by"] == "ollama":
|
|
|
+ response = await generate_ollama_chat_completion(
|
|
|
+ OpenAIChatCompletionForm(**payload), user=user
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ response = await generate_openai_chat_completion(payload, user=user)
|
|
|
+
|
|
|
+ print(response)
|
|
|
+ content = response["choices"][0]["message"]["content"]
|
|
|
+
|
|
|
+ # Parse the function response
|
|
|
+ if content != "":
|
|
|
+ result = json.loads(content)
|
|
|
+ print(result)
|
|
|
|
|
|
-# app.add_middleware(SecurityHeadersMiddleware)
|
|
|
+ # Call the function
|
|
|
+ if "name" in result:
|
|
|
+ if tool_id in webui_app.state.TOOLS:
|
|
|
+ toolkit_module = webui_app.state.TOOLS[tool_id]
|
|
|
+ else:
|
|
|
+ toolkit_module = load_toolkit_module_by_id(tool_id)
|
|
|
+ webui_app.state.TOOLS[tool_id] = toolkit_module
|
|
|
+
|
|
|
+ function = getattr(toolkit_module, result["name"])
|
|
|
+ function_result = None
|
|
|
+ try:
|
|
|
+ function_result = function(**result["parameters"])
|
|
|
+ except Exception as e:
|
|
|
+ print(e)
|
|
|
+
|
|
|
+ # Add the function result to the system prompt
|
|
|
+ if function_result:
|
|
|
+ return function_result
|
|
|
+
|
|
|
+ return None
|
|
|
|
|
|
|
|
|
-class RAGMiddleware(BaseHTTPMiddleware):
|
|
|
+class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
|
return_citations = False
|
|
|
|
|
@@ -182,12 +239,65 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
|
|
# Parse string to JSON
|
|
|
data = json.loads(body_str) if body_str else {}
|
|
|
|
|
|
+ # Remove the citations from the body
|
|
|
return_citations = data.get("citations", False)
|
|
|
if "citations" in data:
|
|
|
del data["citations"]
|
|
|
|
|
|
- # Example: Add a new key-value pair or modify existing ones
|
|
|
- # data["modified"] = True # Example modification
|
|
|
+ # Set the task model
|
|
|
+ task_model_id = data["model"]
|
|
|
+ if task_model_id not in app.state.MODELS:
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=status.HTTP_404_NOT_FOUND,
|
|
|
+ detail="Model not found",
|
|
|
+ )
|
|
|
+
|
|
|
+ # Check if the user has a custom task model
|
|
|
+ # If the user has a custom task model, use that model
|
|
|
+ if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
|
|
|
+ if (
|
|
|
+ app.state.config.TASK_MODEL
|
|
|
+ and app.state.config.TASK_MODEL in app.state.MODELS
|
|
|
+ ):
|
|
|
+ task_model_id = app.state.config.TASK_MODEL
|
|
|
+ else:
|
|
|
+ if (
|
|
|
+ app.state.config.TASK_MODEL_EXTERNAL
|
|
|
+ and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
|
|
|
+ ):
|
|
|
+ task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
|
|
+
|
|
|
+ if "tool_ids" in data:
|
|
|
+ user = get_current_user(
|
|
|
+ get_http_authorization_cred(request.headers.get("Authorization"))
|
|
|
+ )
|
|
|
+ prompt = get_last_user_message(data["messages"])
|
|
|
+ context = ""
|
|
|
+
|
|
|
+ for tool_id in data["tool_ids"]:
|
|
|
+ response = await get_function_call_response(
|
|
|
+ prompt=prompt,
|
|
|
+ tool_id=tool_id,
|
|
|
+ template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
|
+ task_model_id=task_model_id,
|
|
|
+ user=user,
|
|
|
+ )
|
|
|
+ print(response)
|
|
|
+
|
|
|
+ if response:
|
|
|
+ context += f"\n{response}"
|
|
|
+
|
|
|
+ system_prompt = rag_template(
|
|
|
+ rag_app.state.config.RAG_TEMPLATE, context, prompt
|
|
|
+ )
|
|
|
+
|
|
|
+ data["messages"] = add_or_update_system_message(
|
|
|
+ system_prompt, data["messages"]
|
|
|
+ )
|
|
|
+
|
|
|
+ del data["tool_ids"]
|
|
|
+
|
|
|
+ # If docs field is present, generate RAG completions
|
|
|
if "docs" in data:
|
|
|
data = {**data}
|
|
|
data["messages"], citations = rag_messages(
|
|
@@ -210,7 +320,6 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
|
|
|
|
|
# Replace the request body with the modified one
|
|
|
request._body = modified_body_bytes
|
|
|
-
|
|
|
# Set custom header to ensure content-length matches new body length
|
|
|
request.headers.__dict__["_list"] = [
|
|
|
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
|
|
@@ -253,7 +362,7 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
|
|
yield data
|
|
|
|
|
|
|
|
|
-app.add_middleware(RAGMiddleware)
|
|
|
+app.add_middleware(ChatCompletionMiddleware)
|
|
|
|
|
|
|
|
|
def filter_pipeline(payload, user):
|
|
@@ -515,6 +624,7 @@ async def get_task_config(user=Depends(get_verified_user)):
|
|
|
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
|
|
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
|
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
|
|
|
+ "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
|
}
|
|
|
|
|
|
|
|
@@ -524,6 +634,7 @@ class TaskConfigForm(BaseModel):
|
|
|
TITLE_GENERATION_PROMPT_TEMPLATE: str
|
|
|
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
|
|
|
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD: int
|
|
|
+ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
|
|
|
|
|
|
|
|
|
@app.post("/api/task/config/update")
|
|
@@ -539,6 +650,9 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
|
|
|
app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
|
|
|
form_data.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
|
|
|
)
|
|
|
+ app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
|
|
|
+ form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
|
|
+ )
|
|
|
|
|
|
return {
|
|
|
"TASK_MODEL": app.state.config.TASK_MODEL,
|
|
@@ -546,6 +660,7 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
|
|
|
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
|
|
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
|
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
|
|
|
+ "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
|
}
|
|
|
|
|
|
|
|
@@ -659,6 +774,38 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
|
|
|
return await generate_openai_chat_completion(payload, user=user)
|
|
|
|
|
|
|
|
|
+@app.post("/api/task/tools/completions")
|
|
|
+async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)):
|
|
|
+ print("get_tools_function_calling")
|
|
|
+
|
|
|
+ model_id = form_data["model"]
|
|
|
+ if model_id not in app.state.MODELS:
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=status.HTTP_404_NOT_FOUND,
|
|
|
+ detail="Model not found",
|
|
|
+ )
|
|
|
+
|
|
|
+ # Check if the user has a custom task model
|
|
|
+ # If the user has a custom task model, use that model
|
|
|
+ if app.state.MODELS[model_id]["owned_by"] == "ollama":
|
|
|
+ if app.state.config.TASK_MODEL:
|
|
|
+ task_model_id = app.state.config.TASK_MODEL
|
|
|
+ if task_model_id in app.state.MODELS:
|
|
|
+ model_id = task_model_id
|
|
|
+ else:
|
|
|
+ if app.state.config.TASK_MODEL_EXTERNAL:
|
|
|
+ task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
|
|
+ if task_model_id in app.state.MODELS:
|
|
|
+ model_id = task_model_id
|
|
|
+
|
|
|
+ print(model_id)
|
|
|
+ template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
|
|
+
|
|
|
+ return await get_function_call_response(
|
|
|
+ form_data["prompt"], form_data["tool_id"], template, model_id, user
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
@app.post("/api/chat/completions")
|
|
|
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
|
|
|
model_id = form_data["model"]
|