Explorar o código

feat: fc integration

Timothy J. Baek hai 10 meses
pai
achega
a27175d672
Modificáronse 5 ficheiros con 213 adicións e 38 borrados
  1. 4 22
      backend/apps/webui/routers/tools.py
  2. 23 0
      backend/apps/webui/utils.py
  3. 19 1
      backend/config.py
  4. 162 15
      backend/main.py
  5. 5 0
      backend/utils/task.py

+ 4 - 22
backend/apps/webui/routers/tools.py

@@ -7,6 +7,7 @@ from pydantic import BaseModel
 import json
 
 from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
+from apps.webui.utils import load_toolkit_module_by_id
 
 from utils.utils import get_current_user, get_admin_user
 from utils.tools import get_tools_specs
@@ -17,32 +18,13 @@ import os
 
 from config import DATA_DIR
 
+
 TOOLS_DIR = f"{DATA_DIR}/tools"
 os.makedirs(TOOLS_DIR, exist_ok=True)
 
 
 router = APIRouter()
 
-
-def load_toolkit_module_from_path(tools_id, tools_path):
-    spec = util.spec_from_file_location(tools_id, tools_path)
-    module = util.module_from_spec(spec)
-
-    try:
-        spec.loader.exec_module(module)
-        print(f"Loaded module: {module.__name__}")
-        if hasattr(module, "Tools"):
-            return module.Tools()
-        else:
-            raise Exception("No Tools class found")
-    except Exception as e:
-        print(f"Error loading module: {tools_id}")
-
-        # Move the file to the error folder
-        os.rename(tools_path, f"{tools_path}.error")
-        raise e
-
-
 ############################
 # GetToolkits
 ############################
@@ -89,7 +71,7 @@ async def create_new_toolkit(
             with open(toolkit_path, "w") as tool_file:
                 tool_file.write(form_data.content)
 
-            toolkit_module = load_toolkit_module_from_path(form_data.id, toolkit_path)
+            toolkit_module = load_toolkit_module_by_id(form_data.id)
 
             TOOLS = request.app.state.TOOLS
             TOOLS[form_data.id] = toolkit_module
@@ -149,7 +131,7 @@ async def update_toolkit_by_id(
         with open(toolkit_path, "w") as tool_file:
             tool_file.write(form_data.content)
 
-        toolkit_module = load_toolkit_module_from_path(id, toolkit_path)
+        toolkit_module = load_toolkit_module_by_id(id)
 
         TOOLS = request.app.state.TOOLS
         TOOLS[id] = toolkit_module

+ 23 - 0
backend/apps/webui/utils.py

@@ -0,0 +1,23 @@
+from importlib import util
+import os
+
+from config import TOOLS_DIR
+
+
+def load_toolkit_module_by_id(toolkit_id):
+    toolkit_path = os.path.join(TOOLS_DIR, f"{toolkit_id}.py")
+    spec = util.spec_from_file_location(toolkit_id, toolkit_path)
+    module = util.module_from_spec(spec)
+
+    try:
+        spec.loader.exec_module(module)
+        print(f"Loaded module: {module.__name__}")
+        if hasattr(module, "Tools"):
+            return module.Tools()
+        else:
+            raise Exception("No Tools class found")
+    except Exception as e:
+        print(f"Error loading module: {toolkit_id}")
+        # Move the file to the error folder
+        os.rename(toolkit_path, f"{toolkit_path}.error")
+        raise e

+ 19 - 1
backend/config.py

@@ -368,6 +368,14 @@ DOCS_DIR = os.getenv("DOCS_DIR", f"{DATA_DIR}/docs")
 Path(DOCS_DIR).mkdir(parents=True, exist_ok=True)
 
 
+####################################
+# Tools DIR
+####################################
+
+TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools")
+Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
+
+
 ####################################
 # LITELLM_CONFIG
 ####################################
@@ -669,7 +677,6 @@ Question:
     ),
 )
 
-
 SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig(
     "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
     "task.search.prompt_length_threshold",
@@ -679,6 +686,17 @@ SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig(
     ),
 )
 
+TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
+    "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
+    "task.tools.prompt_template",
+    os.environ.get(
+        "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
+        """Tools: {{TOOLS}}
+If a function tool doesn't match the query, return an empty string. Else, pick a function tool, fill in the parameters from the function tool's schema, and return it in the format { "name": \"functionName\", "parameters": { "key": "value" } }. Only pick a function if the user asks.  Only return the object. Do not return any other text.""",
+    ),
+)
+
+
 ####################################
 # WEBUI_SECRET_KEY
 ####################################

+ 162 - 15
backend/main.py

@@ -47,15 +47,24 @@ from pydantic import BaseModel
 from typing import List, Optional
 
 from apps.webui.models.models import Models, ModelModel
+from apps.webui.models.tools import Tools
+from apps.webui.utils import load_toolkit_module_by_id
+
+
 from utils.utils import (
     get_admin_user,
     get_verified_user,
     get_current_user,
     get_http_authorization_cred,
 )
-from utils.task import title_generation_template, search_query_generation_template
+from utils.task import (
+    title_generation_template,
+    search_query_generation_template,
+    tools_function_calling_generation_template,
+)
+from utils.misc import get_last_user_message, add_or_update_system_message
 
-from apps.rag.utils import rag_messages
+from apps.rag.utils import rag_messages, rag_template
 
 from config import (
     CONFIG_DATA,
@@ -82,6 +91,7 @@ from config import (
     TITLE_GENERATION_PROMPT_TEMPLATE,
     SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
     SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
+    TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
     AppConfig,
 )
 from constants import ERROR_MESSAGES
@@ -148,24 +158,71 @@ app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
 app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
     SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
 )
+app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
+    TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
+)
 
 app.state.MODELS = {}
 
 origins = ["*"]
 
-# Custom middleware to add security headers
-# class SecurityHeadersMiddleware(BaseHTTPMiddleware):
-#     async def dispatch(self, request: Request, call_next):
-#         response: Response = await call_next(request)
-#         response.headers["Cross-Origin-Opener-Policy"] = "same-origin"
-#         response.headers["Cross-Origin-Embedder-Policy"] = "require-corp"
-#         return response
 
+async def get_function_call_response(prompt, tool_id, template, task_model_id, user):
+    tool = Tools.get_tool_by_id(tool_id)
+    tools_specs = json.dumps(tool.specs, indent=2)
+    content = tools_function_calling_generation_template(template, tools_specs)
+
+    payload = {
+        "model": task_model_id,
+        "messages": [
+            {"role": "system", "content": content},
+            {"role": "user", "content": f"Query: {prompt}"},
+        ],
+        "stream": False,
+    }
+
+    payload = filter_pipeline(payload, user)
+    model = app.state.MODELS[task_model_id]
+
+    response = None
+    if model["owned_by"] == "ollama":
+        response = await generate_ollama_chat_completion(
+            OpenAIChatCompletionForm(**payload), user=user
+        )
+    else:
+        response = await generate_openai_chat_completion(payload, user=user)
+
+    print(response)
+    content = response["choices"][0]["message"]["content"]
+
+    # Parse the function response
+    if content != "":
+        result = json.loads(content)
+        print(result)
 
-# app.add_middleware(SecurityHeadersMiddleware)
+        # Call the function
+        if "name" in result:
+            if tool_id in webui_app.state.TOOLS:
+                toolkit_module = webui_app.state.TOOLS[tool_id]
+            else:
+                toolkit_module = load_toolkit_module_by_id(tool_id)
+                webui_app.state.TOOLS[tool_id] = toolkit_module
+
+            function = getattr(toolkit_module, result["name"])
+            function_result = None
+            try:
+                function_result = function(**result["parameters"])
+            except Exception as e:
+                print(e)
+
+            # Add the function result to the system prompt
+            if function_result:
+                return function_result
+
+    return None
 
 
-class RAGMiddleware(BaseHTTPMiddleware):
+class ChatCompletionMiddleware(BaseHTTPMiddleware):
     async def dispatch(self, request: Request, call_next):
         return_citations = False
 
@@ -182,12 +239,65 @@ class RAGMiddleware(BaseHTTPMiddleware):
             # Parse string to JSON
             data = json.loads(body_str) if body_str else {}
 
+            # Remove the citations from the body
             return_citations = data.get("citations", False)
             if "citations" in data:
                 del data["citations"]
 
-            # Example: Add a new key-value pair or modify existing ones
-            # data["modified"] = True  # Example modification
+            # Set the task model
+            task_model_id = data["model"]
+            if task_model_id not in app.state.MODELS:
+                raise HTTPException(
+                    status_code=status.HTTP_404_NOT_FOUND,
+                    detail="Model not found",
+                )
+
+            # Check if the user has a custom task model
+            # If the user has a custom task model, use that model
+            if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
+                if (
+                    app.state.config.TASK_MODEL
+                    and app.state.config.TASK_MODEL in app.state.MODELS
+                ):
+                    task_model_id = app.state.config.TASK_MODEL
+            else:
+                if (
+                    app.state.config.TASK_MODEL_EXTERNAL
+                    and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
+                ):
+                    task_model_id = app.state.config.TASK_MODEL_EXTERNAL
+
+            if "tool_ids" in data:
+                user = get_current_user(
+                    get_http_authorization_cred(request.headers.get("Authorization"))
+                )
+                prompt = get_last_user_message(data["messages"])
+                context = ""
+
+                for tool_id in data["tool_ids"]:
+                    response = await get_function_call_response(
+                        prompt=prompt,
+                        tool_id=tool_id,
+                        template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
+                        task_model_id=task_model_id,
+                        user=user,
+                    )
+                    print(response)
+
+                    if response:
+                        context += f"\n{response}"
+
+                system_prompt = rag_template(
+                    rag_app.state.config.RAG_TEMPLATE, context, prompt
+                )
+
+                data["messages"] = add_or_update_system_message(
+                    system_prompt, data["messages"]
+                )
+
+                del data["tool_ids"]
+
+            # If docs field is present, generate RAG completions
             if "docs" in data:
                 data = {**data}
                 data["messages"], citations = rag_messages(
@@ -210,7 +320,6 @@ class RAGMiddleware(BaseHTTPMiddleware):
 
             # Replace the request body with the modified one
             request._body = modified_body_bytes
-
             # Set custom header to ensure content-length matches new body length
             request.headers.__dict__["_list"] = [
                 (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
@@ -253,7 +362,7 @@ class RAGMiddleware(BaseHTTPMiddleware):
             yield data
 
 
-app.add_middleware(RAGMiddleware)
+app.add_middleware(ChatCompletionMiddleware)
 
 
 def filter_pipeline(payload, user):
@@ -515,6 +624,7 @@ async def get_task_config(user=Depends(get_verified_user)):
         "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
         "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
         "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
+        "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
     }
 
 
@@ -524,6 +634,7 @@ class TaskConfigForm(BaseModel):
     TITLE_GENERATION_PROMPT_TEMPLATE: str
     SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
     SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD: int
+    TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
 
 
 @app.post("/api/task/config/update")
@@ -539,6 +650,9 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
     app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
         form_data.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
     )
+    app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
+        form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
+    )
 
     return {
         "TASK_MODEL": app.state.config.TASK_MODEL,
@@ -546,6 +660,7 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
         "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
         "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
         "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
+        "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
     }
 
 
@@ -659,6 +774,38 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
         return await generate_openai_chat_completion(payload, user=user)
 
 
+@app.post("/api/task/tools/completions")
+async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)):
+    print("get_tools_function_calling")
+
+    model_id = form_data["model"]
+    if model_id not in app.state.MODELS:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail="Model not found",
+        )
+
+    # Check if the user has a custom task model
+    # If the user has a custom task model, use that model
+    if app.state.MODELS[model_id]["owned_by"] == "ollama":
+        if app.state.config.TASK_MODEL:
+            task_model_id = app.state.config.TASK_MODEL
+            if task_model_id in app.state.MODELS:
+                model_id = task_model_id
+    else:
+        if app.state.config.TASK_MODEL_EXTERNAL:
+            task_model_id = app.state.config.TASK_MODEL_EXTERNAL
+            if task_model_id in app.state.MODELS:
+                model_id = task_model_id
+
+    print(model_id)
+    template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
+
+    return await get_function_call_response(
+        form_data["prompt"], form_data["tool_id"], template, model_id, user
+    )
+
+
 @app.post("/api/chat/completions")
 async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
     model_id = form_data["model"]

+ 5 - 0
backend/utils/task.py

@@ -110,3 +110,8 @@ def search_query_generation_template(
         ),
     )
     return template
+
+
+def tools_function_calling_generation_template(template: str, tools_specs: str) -> str:
+    template = template.replace("{{TOOLS}}", tools_specs)
+    return template