Browse Source

sort and fix backend imports

Pascal Lim 11 tháng trước cách đây
mục cha
commit
c386d0b1a5
63 tập tin đã thay đổi với 548 bổ sung973 xóa
  1. 19 32
      backend/apps/audio/main.py
  2. 23 35
      backend/apps/images/main.py
  3. 7 9
      backend/apps/images/utils/comfyui.py
  4. 19 33
      backend/apps/ollama/main.py
  5. 21 29
      backend/apps/openai/main.py
  6. 87 116
      backend/apps/rag/main.py
  7. 2 2
      backend/apps/rag/search/brave.py
  8. 2 1
      backend/apps/rag/search/duckduckgo.py
  9. 2 3
      backend/apps/rag/search/google_pse.py
  10. 3 3
      backend/apps/rag/search/jina_search.py
  11. 1 0
      backend/apps/rag/search/main.py
  12. 2 3
      backend/apps/rag/search/searxng.py
  13. 2 2
      backend/apps/rag/search/serper.py
  14. 2 3
      backend/apps/rag/search/serply.py
  15. 2 3
      backend/apps/rag/search/serpstack.py
  16. 1 2
      backend/apps/rag/search/tavily.py
  17. 14 24
      backend/apps/rag/utils.py
  18. 1 2
      backend/apps/socket/main.py
  19. 9 14
      backend/apps/webui/internal/db.py
  20. 5 5
      backend/apps/webui/internal/wrappers.py
  21. 35 41
      backend/apps/webui/main.py
  22. 6 12
      backend/apps/webui/models/auths.py
  23. 6 30
      backend/apps/webui/models/chats.py
  24. 5 15
      backend/apps/webui/models/documents.py
  25. 5 17
      backend/apps/webui/models/files.py
  26. 5 23
      backend/apps/webui/models/functions.py
  27. 5 16
      backend/apps/webui/models/memories.py
  28. 4 7
      backend/apps/webui/models/models.py
  29. 4 13
      backend/apps/webui/models/prompts.py
  30. 7 17
      backend/apps/webui/models/tags.py
  31. 4 15
      backend/apps/webui/models/tools.py
  32. 6 7
      backend/apps/webui/models/users.py
  33. 16 28
      backend/apps/webui/routers/auths.py
  34. 8 31
      backend/apps/webui/routers/chats.py
  35. 3 21
      backend/apps/webui/routers/configs.py
  36. 6 11
      backend/apps/webui/routers/documents.py
  37. 11 36
      backend/apps/webui/routers/files.py
  38. 7 17
      backend/apps/webui/routers/functions.py
  39. 5 11
      backend/apps/webui/routers/memories.py
  40. 4 10
      backend/apps/webui/routers/models.py
  41. 4 10
      backend/apps/webui/routers/prompts.py
  42. 7 13
      backend/apps/webui/routers/tools.py
  43. 9 24
      backend/apps/webui/routers/users.py
  44. 10 17
      backend/apps/webui/routers/utils.py
  45. 4 5
      backend/apps/webui/utils.py
  46. 16 44
      backend/config.py
  47. 6 12
      backend/env.py
  48. 92 98
      backend/main.py
  49. 1 16
      backend/migrations/env.py
  50. 3 3
      backend/migrations/versions/7e5b5dc7342b_init.py
  51. 1 3
      backend/migrations/versions/ca81bd47c050_add_config_table.py
  52. 1 3
      backend/test/apps/webui/routers/test_auths.py
  53. 1 3
      backend/test/apps/webui/routers/test_chats.py
  54. 0 1
      backend/test/apps/webui/routers/test_documents.py
  55. 0 1
      backend/test/apps/webui/routers/test_models.py
  56. 0 1
      backend/test/apps/webui/routers/test_prompts.py
  57. 0 1
      backend/test/apps/webui/routers/test_users.py
  58. 4 4
      backend/utils/misc.py
  59. 1 1
      backend/utils/schemas.py
  60. 1 2
      backend/utils/task.py
  61. 0 1
      backend/utils/tools.py
  62. 8 9
      backend/utils/utils.py
  63. 3 2
      backend/utils/webhook.py

+ 19 - 32
backend/apps/audio/main.py

@@ -7,46 +7,33 @@ from functools import lru_cache
 from pathlib import Path
 
 import requests
-from fastapi import (
-    FastAPI,
-    Request,
-    Depends,
-    HTTPException,
-    status,
-    UploadFile,
-    File,
-)
-from fastapi.middleware.cors import CORSMiddleware
-from fastapi.responses import FileResponse
-from pydantic import BaseModel
-
 from config import (
-    SRC_LOG_LEVELS,
-    CACHE_DIR,
-    WHISPER_MODEL,
-    WHISPER_MODEL_DIR,
-    WHISPER_MODEL_AUTO_UPDATE,
-    DEVICE_TYPE,
+    AUDIO_STT_ENGINE,
+    AUDIO_STT_MODEL,
     AUDIO_STT_OPENAI_API_BASE_URL,
     AUDIO_STT_OPENAI_API_KEY,
-    AUDIO_TTS_OPENAI_API_BASE_URL,
-    AUDIO_TTS_OPENAI_API_KEY,
     AUDIO_TTS_API_KEY,
-    AUDIO_STT_ENGINE,
-    AUDIO_STT_MODEL,
     AUDIO_TTS_ENGINE,
     AUDIO_TTS_MODEL,
-    AUDIO_TTS_VOICE,
+    AUDIO_TTS_OPENAI_API_BASE_URL,
+    AUDIO_TTS_OPENAI_API_KEY,
     AUDIO_TTS_SPLIT_ON,
-    AppConfig,
+    AUDIO_TTS_VOICE,
+    CACHE_DIR,
     CORS_ALLOW_ORIGIN,
+    DEVICE_TYPE,
+    WHISPER_MODEL,
+    WHISPER_MODEL_AUTO_UPDATE,
+    WHISPER_MODEL_DIR,
+    AppConfig,
 )
 from constants import ERROR_MESSAGES
-from utils.utils import (
-    get_current_user,
-    get_verified_user,
-    get_admin_user,
-)
+from env import SRC_LOG_LEVELS
+from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile, status
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import FileResponse
+from pydantic import BaseModel
+from utils.utils import get_admin_user, get_current_user, get_verified_user
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["AUDIO"])
@@ -211,7 +198,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
             body = json.loads(body)
             body["model"] = app.state.config.TTS_MODEL
             body = json.dumps(body).encode("utf-8")
-        except Exception as e:
+        except Exception:
             pass
 
         r = None
@@ -488,7 +475,7 @@ def get_available_voices() -> dict:
     elif app.state.config.TTS_ENGINE == "elevenlabs":
         try:
             ret = get_elevenlabs_voices()
-        except Exception as e:
+        except Exception:
             # Avoided @lru_cache with exception
             pass
 

+ 23 - 35
backend/apps/images/main.py

@@ -1,52 +1,42 @@
-from fastapi import (
-    FastAPI,
-    Request,
-    Depends,
-    HTTPException,
-)
-from fastapi.middleware.cors import CORSMiddleware
-from typing import Optional
-from pydantic import BaseModel
-from pathlib import Path
-import mimetypes
-import uuid
+import asyncio
 import base64
 import json
 import logging
+import mimetypes
 import re
-import requests
-import asyncio
-
-from utils.utils import (
-    get_verified_user,
-    get_admin_user,
-)
+import uuid
+from pathlib import Path
+from typing import Optional
 
+import requests
 from apps.images.utils.comfyui import (
-    ComfyUIWorkflow,
     ComfyUIGenerateImageForm,
+    ComfyUIWorkflow,
     comfyui_generate_image,
 )
-
-from constants import ERROR_MESSAGES
 from config import (
-    SRC_LOG_LEVELS,
-    CACHE_DIR,
-    IMAGE_GENERATION_ENGINE,
-    ENABLE_IMAGE_GENERATION,
-    AUTOMATIC1111_BASE_URL,
     AUTOMATIC1111_API_AUTH,
+    AUTOMATIC1111_BASE_URL,
+    CACHE_DIR,
     COMFYUI_BASE_URL,
     COMFYUI_WORKFLOW,
     COMFYUI_WORKFLOW_NODES,
-    IMAGES_OPENAI_API_BASE_URL,
-    IMAGES_OPENAI_API_KEY,
+    CORS_ALLOW_ORIGIN,
+    ENABLE_IMAGE_GENERATION,
+    IMAGE_GENERATION_ENGINE,
     IMAGE_GENERATION_MODEL,
     IMAGE_SIZE,
     IMAGE_STEPS,
-    CORS_ALLOW_ORIGIN,
+    IMAGES_OPENAI_API_BASE_URL,
+    IMAGES_OPENAI_API_KEY,
     AppConfig,
 )
+from constants import ERROR_MESSAGES
+from env import SRC_LOG_LEVELS
+from fastapi import Depends, FastAPI, HTTPException, Request
+from fastapi.middleware.cors import CORSMiddleware
+from pydantic import BaseModel
+from utils.utils import get_admin_user, get_verified_user
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["IMAGES"])
@@ -186,7 +176,7 @@ async def verify_url(user=Depends(get_admin_user)):
             )
             r.raise_for_status()
             return True
-        except Exception as e:
+        except Exception:
             app.state.config.ENABLED = False
             raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
     elif app.state.config.ENGINE == "comfyui":
@@ -194,7 +184,7 @@ async def verify_url(user=Depends(get_admin_user)):
             r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
             r.raise_for_status()
             return True
-        except Exception as e:
+        except Exception:
             app.state.config.ENABLED = False
             raise HTTPException(status_code=400, detail=ERROR_MESSAGES.INVALID_URL)
     else:
@@ -397,7 +387,6 @@ def save_url_image(url):
         r = requests.get(url)
         r.raise_for_status()
         if r.headers["content-type"].split("/")[0] == "image":
-
             mime_type = r.headers["content-type"]
             image_format = mimetypes.guess_extension(mime_type)
 
@@ -412,7 +401,7 @@ def save_url_image(url):
                     image_file.write(chunk)
             return image_filename
         else:
-            log.error(f"Url does not point to an image.")
+            log.error("Url does not point to an image.")
             return None
 
     except Exception as e:
@@ -430,7 +419,6 @@ async def image_generations(
     r = None
     try:
         if app.state.config.ENGINE == "openai":
-
             headers = {}
             headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
             headers["Content-Type"] = "application/json"

+ 7 - 9
backend/apps/images/utils/comfyui.py

@@ -1,20 +1,18 @@
 import asyncio
-import websocket  # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
 import json
-import urllib.request
-import urllib.parse
-import random
 import logging
+import random
+import urllib.parse
+import urllib.request
+from typing import Optional
 
-from config import SRC_LOG_LEVELS
+import websocket  # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
+from env import SRC_LOG_LEVELS
+from pydantic import BaseModel
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["COMFYUI"])
 
-from pydantic import BaseModel
-
-from typing import Optional
-
 default_headers = {"User-Agent": "Mozilla/5.0"}
 
 

+ 19 - 33
backend/apps/ollama/main.py

@@ -1,54 +1,40 @@
-from fastapi import (
-    FastAPI,
-    Request,
-    HTTPException,
-    Depends,
-    UploadFile,
-    File,
-)
-from fastapi.middleware.cors import CORSMiddleware
-from fastapi.responses import StreamingResponse
-
-from pydantic import BaseModel, ConfigDict
-
-import os
-import re
-import random
-import requests
-import json
-import aiohttp
 import asyncio
+import json
 import logging
+import os
+import random
+import re
 import time
-from urllib.parse import urlparse
 from typing import Optional, Union
+from urllib.parse import urlparse
 
-from starlette.background import BackgroundTask
-
+import aiohttp
+import requests
 from apps.webui.models.models import Models
-from constants import ERROR_MESSAGES
-from utils.utils import (
-    get_verified_user,
-    get_admin_user,
-)
-
 from config import (
-    SRC_LOG_LEVELS,
-    OLLAMA_BASE_URLS,
-    ENABLE_OLLAMA_API,
     AIOHTTP_CLIENT_TIMEOUT,
+    CORS_ALLOW_ORIGIN,
     ENABLE_MODEL_FILTER,
+    ENABLE_OLLAMA_API,
     MODEL_FILTER_LIST,
+    OLLAMA_BASE_URLS,
     UPLOAD_DIR,
     AppConfig,
-    CORS_ALLOW_ORIGIN,
 )
+from constants import ERROR_MESSAGES
+from env import SRC_LOG_LEVELS
+from fastapi import Depends, FastAPI, File, HTTPException, Request, UploadFile
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import StreamingResponse
+from pydantic import BaseModel, ConfigDict
+from starlette.background import BackgroundTask
 from utils.misc import (
-    calculate_sha256,
     apply_model_params_to_body_ollama,
     apply_model_params_to_body_openai,
     apply_model_system_prompt_to_body,
+    calculate_sha256,
 )
+from utils.utils import get_admin_user, get_verified_user
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["OLLAMA"])

+ 21 - 29
backend/apps/openai/main.py

@@ -1,44 +1,36 @@
-from fastapi import FastAPI, Request, HTTPException, Depends
-from fastapi.middleware.cors import CORSMiddleware
-from fastapi.responses import StreamingResponse, FileResponse
-
-import requests
-import aiohttp
 import asyncio
+import hashlib
 import json
 import logging
+from pathlib import Path
+from typing import Literal, Optional, overload
 
-from pydantic import BaseModel
-from starlette.background import BackgroundTask
-
+import aiohttp
+import requests
 from apps.webui.models.models import Models
-from constants import ERROR_MESSAGES
-from utils.utils import (
-    get_verified_user,
-    get_admin_user,
-)
-from utils.misc import (
-    apply_model_params_to_body_openai,
-    apply_model_system_prompt_to_body,
-)
-
 from config import (
-    SRC_LOG_LEVELS,
-    ENABLE_OPENAI_API,
     AIOHTTP_CLIENT_TIMEOUT,
-    OPENAI_API_BASE_URLS,
-    OPENAI_API_KEYS,
     CACHE_DIR,
+    CORS_ALLOW_ORIGIN,
     ENABLE_MODEL_FILTER,
+    ENABLE_OPENAI_API,
     MODEL_FILTER_LIST,
+    OPENAI_API_BASE_URLS,
+    OPENAI_API_KEYS,
     AppConfig,
-    CORS_ALLOW_ORIGIN,
 )
-from typing import Optional, Literal, overload
-
-
-import hashlib
-from pathlib import Path
+from constants import ERROR_MESSAGES
+from env import SRC_LOG_LEVELS
+from fastapi import Depends, FastAPI, HTTPException, Request
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import FileResponse, StreamingResponse
+from pydantic import BaseModel
+from starlette.background import BackgroundTask
+from utils.misc import (
+    apply_model_params_to_body_openai,
+    apply_model_system_prompt_to_body,
+)
+from utils.utils import get_admin_user, get_verified_user
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["OPENAI"])

+ 87 - 116
backend/apps/rag/main.py

@@ -1,143 +1,118 @@
-from fastapi import (
-    FastAPI,
-    Depends,
-    HTTPException,
-    status,
-    UploadFile,
-    File,
-    Form,
-)
-from fastapi.middleware.cors import CORSMiddleware
-import requests
-import os, shutil, logging, re
+import json
+import logging
+import mimetypes
+import os
+import shutil
+import socket
+import urllib.parse
+import uuid
 from datetime import datetime
-
 from pathlib import Path
-from typing import Union, Sequence, Iterator, Any
-
-from chromadb.utils.batch_utils import create_batches
-from langchain_core.documents import Document
-
-from langchain_community.document_loaders import (
-    WebBaseLoader,
-    TextLoader,
-    PyPDFLoader,
-    CSVLoader,
-    BSHTMLLoader,
-    Docx2txtLoader,
-    UnstructuredEPubLoader,
-    UnstructuredWordDocumentLoader,
-    UnstructuredMarkdownLoader,
-    UnstructuredXMLLoader,
-    UnstructuredRSTLoader,
-    UnstructuredExcelLoader,
-    UnstructuredPowerPointLoader,
-    YoutubeLoader,
-    OutlookMessageLoader,
-)
-from langchain.text_splitter import RecursiveCharacterTextSplitter
+from typing import Iterator, Optional, Sequence, Union
 
+import requests
 import validators
-import urllib.parse
-import socket
-
-
-from pydantic import BaseModel
-from typing import Optional
-import mimetypes
-import uuid
-import json
-
-from apps.webui.models.documents import (
-    Documents,
-    DocumentForm,
-    DocumentResponse,
-)
-from apps.webui.models.files import (
-    Files,
-)
-
-from apps.rag.utils import (
-    get_model_path,
-    get_embedding_function,
-    query_doc,
-    query_doc_with_hybrid_search,
-    query_collection,
-    query_collection_with_hybrid_search,
-)
-
 from apps.rag.search.brave import search_brave
+from apps.rag.search.duckduckgo import search_duckduckgo
 from apps.rag.search.google_pse import search_google_pse
+from apps.rag.search.jina_search import search_jina
 from apps.rag.search.main import SearchResult
+from apps.rag.search.searchapi import search_searchapi
 from apps.rag.search.searxng import search_searxng
 from apps.rag.search.serper import search_serper
-from apps.rag.search.serpstack import search_serpstack
 from apps.rag.search.serply import search_serply
-from apps.rag.search.duckduckgo import search_duckduckgo
+from apps.rag.search.serpstack import search_serpstack
 from apps.rag.search.tavily import search_tavily
-from apps.rag.search.jina_search import search_jina
-from apps.rag.search.searchapi import search_searchapi
-
-from utils.misc import (
-    calculate_sha256,
-    calculate_sha256_string,
-    sanitize_filename,
-    extract_folders_after_data_docs,
+from apps.rag.utils import (
+    get_embedding_function,
+    get_model_path,
+    query_collection,
+    query_collection_with_hybrid_search,
+    query_doc,
+    query_doc_with_hybrid_search,
 )
-from utils.utils import get_verified_user, get_admin_user
-
+from apps.webui.models.documents import DocumentForm, Documents
+from apps.webui.models.files import Files
+from chromadb.utils.batch_utils import create_batches
 from config import (
-    AppConfig,
-    ENV,
-    SRC_LOG_LEVELS,
-    UPLOAD_DIR,
-    DOCS_DIR,
+    BRAVE_SEARCH_API_KEY,
+    CHROMA_CLIENT,
+    CHUNK_OVERLAP,
+    CHUNK_SIZE,
     CONTENT_EXTRACTION_ENGINE,
-    TIKA_SERVER_URL,
-    RAG_TOP_K,
-    RAG_RELEVANCE_THRESHOLD,
-    RAG_FILE_MAX_SIZE,
-    RAG_FILE_MAX_COUNT,
+    CORS_ALLOW_ORIGIN,
+    DEVICE_TYPE,
+    DOCS_DIR,
+    ENABLE_RAG_HYBRID_SEARCH,
+    ENABLE_RAG_LOCAL_WEB_FETCH,
+    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
+    ENABLE_RAG_WEB_SEARCH,
+    ENV,
+    GOOGLE_PSE_API_KEY,
+    GOOGLE_PSE_ENGINE_ID,
+    PDF_EXTRACT_IMAGES,
     RAG_EMBEDDING_ENGINE,
     RAG_EMBEDDING_MODEL,
     RAG_EMBEDDING_MODEL_AUTO_UPDATE,
     RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
-    ENABLE_RAG_HYBRID_SEARCH,
-    ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
+    RAG_EMBEDDING_OPENAI_BATCH_SIZE,
+    RAG_FILE_MAX_COUNT,
+    RAG_FILE_MAX_SIZE,
+    RAG_OPENAI_API_BASE_URL,
+    RAG_OPENAI_API_KEY,
+    RAG_RELEVANCE_THRESHOLD,
     RAG_RERANKING_MODEL,
-    PDF_EXTRACT_IMAGES,
     RAG_RERANKING_MODEL_AUTO_UPDATE,
     RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
-    RAG_OPENAI_API_BASE_URL,
-    RAG_OPENAI_API_KEY,
-    DEVICE_TYPE,
-    CHROMA_CLIENT,
-    CHUNK_SIZE,
-    CHUNK_OVERLAP,
     RAG_TEMPLATE,
-    ENABLE_RAG_LOCAL_WEB_FETCH,
-    YOUTUBE_LOADER_LANGUAGE,
-    ENABLE_RAG_WEB_SEARCH,
-    RAG_WEB_SEARCH_ENGINE,
+    RAG_TOP_K,
+    RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
     RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
+    RAG_WEB_SEARCH_ENGINE,
+    RAG_WEB_SEARCH_RESULT_COUNT,
+    SEARCHAPI_API_KEY,
+    SEARCHAPI_ENGINE,
     SEARXNG_QUERY_URL,
-    GOOGLE_PSE_API_KEY,
-    GOOGLE_PSE_ENGINE_ID,
-    BRAVE_SEARCH_API_KEY,
-    SERPSTACK_API_KEY,
-    SERPSTACK_HTTPS,
     SERPER_API_KEY,
     SERPLY_API_KEY,
+    SERPSTACK_API_KEY,
+    SERPSTACK_HTTPS,
     TAVILY_API_KEY,
-    SEARCHAPI_API_KEY,
-    SEARCHAPI_ENGINE,
-    RAG_WEB_SEARCH_RESULT_COUNT,
-    RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
-    RAG_EMBEDDING_OPENAI_BATCH_SIZE,
-    CORS_ALLOW_ORIGIN,
+    TIKA_SERVER_URL,
+    UPLOAD_DIR,
+    YOUTUBE_LOADER_LANGUAGE,
+    AppConfig,
 )
-
 from constants import ERROR_MESSAGES
+from env import SRC_LOG_LEVELS
+from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status
+from fastapi.middleware.cors import CORSMiddleware
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+from langchain_community.document_loaders import (
+    BSHTMLLoader,
+    CSVLoader,
+    Docx2txtLoader,
+    OutlookMessageLoader,
+    PyPDFLoader,
+    TextLoader,
+    UnstructuredEPubLoader,
+    UnstructuredExcelLoader,
+    UnstructuredMarkdownLoader,
+    UnstructuredPowerPointLoader,
+    UnstructuredRSTLoader,
+    UnstructuredXMLLoader,
+    WebBaseLoader,
+    YoutubeLoader,
+)
+from langchain_core.documents import Document
+from pydantic import BaseModel
+from utils.misc import (
+    calculate_sha256,
+    calculate_sha256_string,
+    extract_folders_after_data_docs,
+    sanitize_filename,
+)
+from utils.utils import get_admin_user, get_verified_user
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
@@ -539,9 +514,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
         app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key
         app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key
         app.state.config.SEARCHAPI_API_KEY = form_data.web.search.searchapi_api_key
-        app.state.config.SEARCHAPI_ENGINE = (
-            form_data.web.search.searchapi_engine
-        )
+        app.state.config.SEARCHAPI_ENGINE = form_data.web.search.searchapi_engine
         app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
         app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
             form_data.web.search.concurrent_requests
@@ -981,7 +954,6 @@ def store_web_search(form_data: SearchForm, user=Depends(get_verified_user)):
 def store_data_in_vector_db(
     data, collection_name, metadata: Optional[dict] = None, overwrite: bool = False
 ) -> bool:
-
     text_splitter = RecursiveCharacterTextSplitter(
         chunk_size=app.state.config.CHUNK_SIZE,
         chunk_overlap=app.state.config.CHUNK_OVERLAP,
@@ -1342,7 +1314,6 @@ def store_text(
     form_data: TextRAGForm,
     user=Depends(get_verified_user),
 ):
-
     collection_name = form_data.collection_name
     if collection_name is None:
         collection_name = calculate_sha256_string(form_data.content)

+ 2 - 2
backend/apps/rag/search/brave.py

@@ -1,9 +1,9 @@
 import logging
 from typing import Optional
-import requests
 
+import requests
 from apps.rag.search.main import SearchResult, get_filtered_results
-from config import SRC_LOG_LEVELS
+from env import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])

+ 2 - 1
backend/apps/rag/search/duckduckgo.py

@@ -1,8 +1,9 @@
 import logging
 from typing import Optional
+
 from apps.rag.search.main import SearchResult, get_filtered_results
 from duckduckgo_search import DDGS
-from config import SRC_LOG_LEVELS
+from env import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])

+ 2 - 3
backend/apps/rag/search/google_pse.py

@@ -1,10 +1,9 @@
-import json
 import logging
 from typing import Optional
-import requests
 
+import requests
 from apps.rag.search.main import SearchResult, get_filtered_results
-from config import SRC_LOG_LEVELS
+from env import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])

+ 3 - 3
backend/apps/rag/search/jina_search.py

@@ -1,9 +1,9 @@
 import logging
-import requests
-from yarl import URL
 
+import requests
 from apps.rag.search.main import SearchResult
-from config import SRC_LOG_LEVELS
+from env import SRC_LOG_LEVELS
+from yarl import URL
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])

+ 1 - 0
backend/apps/rag/search/main.py

@@ -1,5 +1,6 @@
 from typing import Optional
 from urllib.parse import urlparse
+
 from pydantic import BaseModel
 
 

+ 2 - 3
backend/apps/rag/search/searxng.py

@@ -1,10 +1,9 @@
 import logging
-import requests
-
 from typing import Optional
 
+import requests
 from apps.rag.search.main import SearchResult, get_filtered_results
-from config import SRC_LOG_LEVELS
+from env import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])

+ 2 - 2
backend/apps/rag/search/serper.py

@@ -1,10 +1,10 @@
 import json
 import logging
 from typing import Optional
-import requests
 
+import requests
 from apps.rag.search.main import SearchResult, get_filtered_results
-from config import SRC_LOG_LEVELS
+from env import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])

+ 2 - 3
backend/apps/rag/search/serply.py

@@ -1,11 +1,10 @@
-import json
 import logging
 from typing import Optional
-import requests
 from urllib.parse import urlencode
 
+import requests
 from apps.rag.search.main import SearchResult, get_filtered_results
-from config import SRC_LOG_LEVELS
+from env import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])

+ 2 - 3
backend/apps/rag/search/serpstack.py

@@ -1,10 +1,9 @@
-import json
 import logging
 from typing import Optional
-import requests
 
+import requests
 from apps.rag.search.main import SearchResult, get_filtered_results
-from config import SRC_LOG_LEVELS
+from env import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])

+ 1 - 2
backend/apps/rag/search/tavily.py

@@ -1,9 +1,8 @@
 import logging
 
 import requests
-
 from apps.rag.search.main import SearchResult
-from config import SRC_LOG_LEVELS
+from env import SRC_LOG_LEVELS
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])

+ 14 - 24
backend/apps/rag/utils.py

@@ -1,27 +1,16 @@
-import os
 import logging
-import requests
-
-from typing import Union
-
-from apps.ollama.main import (
-    generate_ollama_embeddings,
-    GenerateEmbeddingsForm,
-)
+import os
+from typing import Optional, Union
 
+import requests
+from apps.ollama.main import GenerateEmbeddingsForm, generate_ollama_embeddings
+from config import CHROMA_CLIENT
+from env import SRC_LOG_LEVELS
 from huggingface_hub import snapshot_download
-
-from langchain_core.documents import Document
+from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
 from langchain_community.retrievers import BM25Retriever
-from langchain.retrievers import (
-    ContextualCompressionRetriever,
-    EnsembleRetriever,
-)
-
-from typing import Optional
-
-from utils.misc import get_last_user_message, add_or_update_system_message
-from config import SRC_LOG_LEVELS, CHROMA_CLIENT
+from langchain_core.documents import Document
+from utils.misc import get_last_user_message
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
@@ -261,7 +250,9 @@ def get_rag_context(
         collection_names = (
             file["collection_names"]
             if file["type"] == "collection"
-            else [file["collection_name"]] if file["collection_name"] else []
+            else [file["collection_name"]]
+            if file["collection_name"]
+            else []
         )
 
         collection_names = set(collection_names).difference(extracted_collections)
@@ -401,8 +392,8 @@ def generate_openai_batch_embeddings(
 
 from typing import Any
 
-from langchain_core.retrievers import BaseRetriever
 from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from langchain_core.retrievers import BaseRetriever
 
 
 class ChromaRetriever(BaseRetriever):
@@ -439,11 +430,10 @@ class ChromaRetriever(BaseRetriever):
 
 
 import operator
-
 from typing import Optional, Sequence
 
-from langchain_core.documents import BaseDocumentCompressor, Document
 from langchain_core.callbacks import Callbacks
+from langchain_core.documents import BaseDocumentCompressor, Document
 from langchain_core.pydantic_v1 import Extra
 
 

+ 1 - 2
backend/apps/socket/main.py

@@ -1,7 +1,6 @@
-import socketio
 import asyncio
 
-
+import socketio
 from apps.webui.models.users import Users
 from utils.utils import decode_token
 

+ 9 - 14
backend/apps/webui/internal/db.py

@@ -1,21 +1,16 @@
-import os
-import logging
 import json
+import logging
 from contextlib import contextmanager
+from typing import Any, Optional
 
-
-from typing import Optional, Any
-from typing_extensions import Self
-
-from sqlalchemy import create_engine, types, Dialect
-from sqlalchemy.sql.type_api import _T
-from sqlalchemy.ext.declarative import declarative_base
-from sqlalchemy.orm import sessionmaker, scoped_session
-
-
-from peewee_migrate import Router
 from apps.webui.internal.wrappers import register_connection
-from env import SRC_LOG_LEVELS, BACKEND_DIR, DATABASE_URL
+from env import BACKEND_DIR, DATABASE_URL, SRC_LOG_LEVELS
+from peewee_migrate import Router
+from sqlalchemy import Dialect, create_engine, types
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import scoped_session, sessionmaker
+from sqlalchemy.sql.type_api import _T
+from typing_extensions import Self
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["DB"])

+ 5 - 5
backend/apps/webui/internal/wrappers.py

@@ -1,13 +1,13 @@
+import logging
 from contextvars import ContextVar
-from peewee import *
-from peewee import PostgresqlDatabase, InterfaceError as PeeWeeInterfaceError
 
-import logging
+from env import SRC_LOG_LEVELS
+from peewee import *
+from peewee import InterfaceError as PeeWeeInterfaceError
+from peewee import PostgresqlDatabase
 from playhouse.db_url import connect, parse
 from playhouse.shortcuts import ReconnectMixin
 
-from env import SRC_LOG_LEVELS
-
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["DB"])
 

+ 35 - 41
backend/apps/webui/main.py

@@ -1,65 +1,59 @@
-from fastapi import FastAPI
-from fastapi.responses import StreamingResponse
-from fastapi.middleware.cors import CORSMiddleware
+import inspect
+import json
+import logging
+from typing import AsyncGenerator, Generator, Iterator
+
+from apps.socket.main import get_event_call, get_event_emitter
+from apps.webui.models.functions import Functions
+from apps.webui.models.models import Models
 from apps.webui.routers import (
     auths,
-    users,
     chats,
+    configs,
     documents,
-    tools,
+    files,
+    functions,
+    memories,
     models,
     prompts,
-    configs,
-    memories,
+    tools,
+    users,
     utils,
-    files,
-    functions,
 )
-from apps.webui.models.functions import Functions
-from apps.webui.models.models import Models
 from apps.webui.utils import load_function_module_by_id
-
-from utils.misc import (
-    openai_chat_chunk_message_template,
-    openai_chat_completion_message_template,
-    apply_model_params_to_body_openai,
-    apply_model_system_prompt_to_body,
-)
-
-from utils.tools import get_tools
-
 from config import (
-    SHOW_ADMIN_DETAILS,
     ADMIN_EMAIL,
-    WEBUI_AUTH,
+    CORS_ALLOW_ORIGIN,
     DEFAULT_MODELS,
     DEFAULT_PROMPT_SUGGESTIONS,
     DEFAULT_USER_ROLE,
-    ENABLE_SIGNUP,
+    ENABLE_COMMUNITY_SHARING,
     ENABLE_LOGIN_FORM,
+    ENABLE_MESSAGE_RATING,
+    ENABLE_SIGNUP,
+    JWT_EXPIRES_IN,
+    OAUTH_EMAIL_CLAIM,
+    OAUTH_PICTURE_CLAIM,
+    OAUTH_USERNAME_CLAIM,
+    SHOW_ADMIN_DETAILS,
     USER_PERMISSIONS,
     WEBHOOK_URL,
-    WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
-    WEBUI_AUTH_TRUSTED_NAME_HEADER,
-    JWT_EXPIRES_IN,
+    WEBUI_AUTH,
     WEBUI_BANNERS,
-    ENABLE_COMMUNITY_SHARING,
-    ENABLE_MESSAGE_RATING,
     AppConfig,
-    OAUTH_USERNAME_CLAIM,
-    OAUTH_PICTURE_CLAIM,
-    OAUTH_EMAIL_CLAIM,
-    CORS_ALLOW_ORIGIN,
 )
-
-from apps.socket.main import get_event_call, get_event_emitter
-
-import inspect
-import json
-import logging
-
-from typing import Iterator, Generator, AsyncGenerator
+from env import WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_NAME_HEADER
+from fastapi import FastAPI
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import StreamingResponse
 from pydantic import BaseModel
+from utils.misc import (
+    apply_model_params_to_body_openai,
+    apply_model_system_prompt_to_body,
+    openai_chat_chunk_message_template,
+    openai_chat_completion_message_template,
+)
+from utils.tools import get_tools
 
 app = FastAPI()
 

+ 6 - 12
backend/apps/webui/models/auths.py

@@ -1,15 +1,13 @@
-from pydantic import BaseModel
-from typing import Optional
-import uuid
 import logging
-from sqlalchemy import String, Column, Boolean, Text
-
-from utils.utils import verify_password
+import uuid
+from typing import Optional
 
-from apps.webui.models.users import UserModel, Users
 from apps.webui.internal.db import Base, get_db
-
+from apps.webui.models.users import UserModel, Users
 from env import SRC_LOG_LEVELS
+from pydantic import BaseModel
+from sqlalchemy import Boolean, Column, String, Text
+from utils.utils import verify_password
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -92,7 +90,6 @@ class AddUserForm(SignupForm):
 
 
 class AuthsTable:
-
     def insert_new_auth(
         self,
         email: str,
@@ -103,7 +100,6 @@ class AuthsTable:
         oauth_sub: Optional[str] = None,
     ) -> Optional[UserModel]:
         with get_db() as db:
-
             log.info("insert_new_auth")
 
             id = str(uuid.uuid4())
@@ -130,7 +126,6 @@ class AuthsTable:
         log.info(f"authenticate_user: {email}")
         try:
             with get_db() as db:
-
                 auth = db.query(Auth).filter_by(email=email, active=True).first()
                 if auth:
                     if verify_password(password, auth.password):
@@ -189,7 +184,6 @@ class AuthsTable:
     def delete_auth_by_id(self, id: str) -> bool:
         try:
             with get_db() as db:
-
                 # Delete User
                 result = Users.delete_user_by_id(id)
 

+ 6 - 30
backend/apps/webui/models/chats.py

@@ -1,14 +1,11 @@
-from pydantic import BaseModel, ConfigDict
-from typing import Union, Optional
-
 import json
-import uuid
 import time
-
-from sqlalchemy import Column, String, BigInteger, Boolean, Text
+import uuid
+from typing import Optional
 
 from apps.webui.internal.db import Base, get_db
-
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import BigInteger, Boolean, Column, String, Text
 
 ####################
 # Chat DB Schema
@@ -77,10 +74,8 @@ class ChatTitleIdResponse(BaseModel):
 
 
 class ChatTable:
-
     def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
         with get_db() as db:
-
             id = str(uuid.uuid4())
             chat = ChatModel(
                 **{
@@ -106,7 +101,6 @@ class ChatTable:
     def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
         try:
             with get_db() as db:
-
                 chat_obj = db.get(Chat, id)
                 chat_obj.chat = json.dumps(chat)
                 chat_obj.title = chat["title"] if "title" in chat else "New Chat"
@@ -115,12 +109,11 @@ class ChatTable:
                 db.refresh(chat_obj)
 
                 return ChatModel.model_validate(chat_obj)
-        except Exception as e:
+        except Exception:
             return None
 
     def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
         with get_db() as db:
-
             # Get the existing chat to share
             chat = db.get(Chat, chat_id)
             # Check if the chat is already shared
@@ -154,7 +147,6 @@ class ChatTable:
     def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
         try:
             with get_db() as db:
-
                 print("update_shared_chat_by_id")
                 chat = db.get(Chat, chat_id)
                 print(chat)
@@ -170,7 +162,6 @@ class ChatTable:
     def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
         try:
             with get_db() as db:
-
                 db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
                 db.commit()
 
@@ -183,7 +174,6 @@ class ChatTable:
     ) -> Optional[ChatModel]:
         try:
             with get_db() as db:
-
                 chat = db.get(Chat, id)
                 chat.share_id = share_id
                 db.commit()
@@ -195,7 +185,6 @@ class ChatTable:
     def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
         try:
             with get_db() as db:
-
                 chat = db.get(Chat, id)
                 chat.archived = not chat.archived
                 db.commit()
@@ -217,7 +206,6 @@ class ChatTable:
         self, user_id: str, skip: int = 0, limit: int = 50
     ) -> list[ChatModel]:
         with get_db() as db:
-
             all_chats = (
                 db.query(Chat)
                 .filter_by(user_id=user_id, archived=True)
@@ -297,7 +285,6 @@ class ChatTable:
     def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
         try:
             with get_db() as db:
-
                 chat = db.get(Chat, id)
                 return ChatModel.model_validate(chat)
         except Exception:
@@ -306,20 +293,18 @@ class ChatTable:
     def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
         try:
             with get_db() as db:
-
                 chat = db.query(Chat).filter_by(share_id=id).first()
 
                 if chat:
                     return self.get_chat_by_id(id)
                 else:
                     return None
-        except Exception as e:
+        except Exception:
             return None
 
     def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
         try:
             with get_db() as db:
-
                 chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
                 return ChatModel.model_validate(chat)
         except Exception:
@@ -327,7 +312,6 @@ class ChatTable:
 
     def get_chats(self, skip: int = 0, limit: int = 50) -> list[ChatModel]:
         with get_db() as db:
-
             all_chats = (
                 db.query(Chat)
                 # .limit(limit).offset(skip)
@@ -337,7 +321,6 @@ class ChatTable:
 
     def get_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
         with get_db() as db:
-
             all_chats = (
                 db.query(Chat)
                 .filter_by(user_id=user_id)
@@ -347,7 +330,6 @@ class ChatTable:
 
     def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
         with get_db() as db:
-
             all_chats = (
                 db.query(Chat)
                 .filter_by(user_id=user_id, archived=True)
@@ -358,7 +340,6 @@ class ChatTable:
     def delete_chat_by_id(self, id: str) -> bool:
         try:
             with get_db() as db:
-
                 db.query(Chat).filter_by(id=id).delete()
                 db.commit()
 
@@ -369,7 +350,6 @@ class ChatTable:
     def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
         try:
             with get_db() as db:
-
                 db.query(Chat).filter_by(id=id, user_id=user_id).delete()
                 db.commit()
 
@@ -379,9 +359,7 @@ class ChatTable:
 
     def delete_chats_by_user_id(self, user_id: str) -> bool:
         try:
-
             with get_db() as db:
-
                 self.delete_shared_chats_by_user_id(user_id)
 
                 db.query(Chat).filter_by(user_id=user_id).delete()
@@ -393,9 +371,7 @@ class ChatTable:
 
     def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
         try:
-
             with get_db() as db:
-
                 chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
                 shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
 

+ 5 - 15
backend/apps/webui/models/documents.py

@@ -1,15 +1,12 @@
-from pydantic import BaseModel, ConfigDict
-from typing import Optional
-import time
+import json
 import logging
-
-from sqlalchemy import String, Column, BigInteger, Text
+import time
+from typing import Optional
 
 from apps.webui.internal.db import Base, get_db
-
-import json
-
 from env import SRC_LOG_LEVELS
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import BigInteger, Column, String, Text
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -70,12 +67,10 @@ class DocumentForm(DocumentUpdateForm):
 
 
 class DocumentsTable:
-
     def insert_new_doc(
         self, user_id: str, form_data: DocumentForm
     ) -> Optional[DocumentModel]:
         with get_db() as db:
-
             document = DocumentModel(
                 **{
                     **form_data.model_dump(),
@@ -99,7 +94,6 @@ class DocumentsTable:
     def get_doc_by_name(self, name: str) -> Optional[DocumentModel]:
         try:
             with get_db() as db:
-
                 document = db.query(Document).filter_by(name=name).first()
                 return DocumentModel.model_validate(document) if document else None
         except Exception:
@@ -107,7 +101,6 @@ class DocumentsTable:
 
     def get_docs(self) -> list[DocumentModel]:
         with get_db() as db:
-
             return [
                 DocumentModel.model_validate(doc) for doc in db.query(Document).all()
             ]
@@ -117,7 +110,6 @@ class DocumentsTable:
     ) -> Optional[DocumentModel]:
         try:
             with get_db() as db:
-
                 db.query(Document).filter_by(name=name).update(
                     {
                         "title": form_data.title,
@@ -140,7 +132,6 @@ class DocumentsTable:
             doc_content = {**doc_content, **updated}
 
             with get_db() as db:
-
                 db.query(Document).filter_by(name=name).update(
                     {
                         "content": json.dumps(doc_content),
@@ -156,7 +147,6 @@ class DocumentsTable:
     def delete_doc_by_name(self, name: str) -> bool:
         try:
             with get_db() as db:
-
                 db.query(Document).filter_by(name=name).delete()
                 db.commit()
                 return True

+ 5 - 17
backend/apps/webui/models/files.py

@@ -1,15 +1,11 @@
-from pydantic import BaseModel, ConfigDict
-from typing import Union, Optional
-import time
 import logging
+import time
+from typing import Optional
 
-from sqlalchemy import Column, String, BigInteger, Text
-
-from apps.webui.internal.db import JSONField, Base, get_db
-
-import json
-
+from apps.webui.internal.db import Base, JSONField, get_db
 from env import SRC_LOG_LEVELS
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import BigInteger, Column, String, Text
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -59,10 +55,8 @@ class FileForm(BaseModel):
 
 
 class FilesTable:
-
     def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
         with get_db() as db:
-
             file = FileModel(
                 **{
                     **form_data.model_dump(),
@@ -86,7 +80,6 @@ class FilesTable:
 
     def get_file_by_id(self, id: str) -> Optional[FileModel]:
         with get_db() as db:
-
             try:
                 file = db.get(File, id)
                 return FileModel.model_validate(file)
@@ -95,7 +88,6 @@ class FilesTable:
 
     def get_files(self) -> list[FileModel]:
         with get_db() as db:
-
             return [FileModel.model_validate(file) for file in db.query(File).all()]
 
     def get_files_by_user_id(self, user_id: str) -> list[FileModel]:
@@ -106,9 +98,7 @@ class FilesTable:
             ]
 
     def delete_file_by_id(self, id: str) -> bool:
-
         with get_db() as db:
-
             try:
                 db.query(File).filter_by(id=id).delete()
                 db.commit()
@@ -118,9 +108,7 @@ class FilesTable:
                 return False
 
     def delete_all_files(self) -> bool:
-
         with get_db() as db:
-
             try:
                 db.query(File).delete()
                 db.commit()

+ 5 - 23
backend/apps/webui/models/functions.py

@@ -1,18 +1,12 @@
-from pydantic import BaseModel, ConfigDict
-from typing import Union, Optional
-import time
 import logging
+import time
+from typing import Optional
 
-from sqlalchemy import Column, String, Text, BigInteger, Boolean
-
-from apps.webui.internal.db import JSONField, Base, get_db
+from apps.webui.internal.db import Base, JSONField, get_db
 from apps.webui.models.users import Users
-
-import json
-import copy
-
-
 from env import SRC_LOG_LEVELS
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import BigInteger, Boolean, Column, String, Text
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -87,11 +81,9 @@ class FunctionValves(BaseModel):
 
 
 class FunctionsTable:
-
     def insert_new_function(
         self, user_id: str, type: str, form_data: FunctionForm
     ) -> Optional[FunctionModel]:
-
         function = FunctionModel(
             **{
                 **form_data.model_dump(),
@@ -119,7 +111,6 @@ class FunctionsTable:
     def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
         try:
             with get_db() as db:
-
                 function = db.get(Function, id)
                 return FunctionModel.model_validate(function)
         except Exception:
@@ -127,7 +118,6 @@ class FunctionsTable:
 
     def get_functions(self, active_only=False) -> list[FunctionModel]:
         with get_db() as db:
-
             if active_only:
                 return [
                     FunctionModel.model_validate(function)
@@ -143,7 +133,6 @@ class FunctionsTable:
         self, type: str, active_only=False
     ) -> list[FunctionModel]:
         with get_db() as db:
-
             if active_only:
                 return [
                     FunctionModel.model_validate(function)
@@ -159,7 +148,6 @@ class FunctionsTable:
 
     def get_global_filter_functions(self) -> list[FunctionModel]:
         with get_db() as db:
-
             return [
                 FunctionModel.model_validate(function)
                 for function in db.query(Function)
@@ -178,7 +166,6 @@ class FunctionsTable:
 
     def get_function_valves_by_id(self, id: str) -> Optional[dict]:
         with get_db() as db:
-
             try:
                 function = db.get(Function, id)
                 return function.valves if function.valves else {}
@@ -190,7 +177,6 @@ class FunctionsTable:
         self, id: str, valves: dict
     ) -> Optional[FunctionValves]:
         with get_db() as db:
-
             try:
                 function = db.get(Function, id)
                 function.valves = valves
@@ -204,7 +190,6 @@ class FunctionsTable:
     def get_user_valves_by_id_and_user_id(
         self, id: str, user_id: str
     ) -> Optional[dict]:
-
         try:
             user = Users.get_user_by_id(user_id)
             user_settings = user.settings.model_dump() if user.settings else {}
@@ -223,7 +208,6 @@ class FunctionsTable:
     def update_user_valves_by_id_and_user_id(
         self, id: str, user_id: str, valves: dict
     ) -> Optional[dict]:
-
         try:
             user = Users.get_user_by_id(user_id)
             user_settings = user.settings.model_dump() if user.settings else {}
@@ -246,7 +230,6 @@ class FunctionsTable:
 
     def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
         with get_db() as db:
-
             try:
                 db.query(Function).filter_by(id=id).update(
                     {
@@ -261,7 +244,6 @@ class FunctionsTable:
 
     def deactivate_all_functions(self) -> Optional[bool]:
         with get_db() as db:
-
             try:
                 db.query(Function).update(
                     {

+ 5 - 16
backend/apps/webui/models/memories.py

@@ -1,12 +1,10 @@
-from pydantic import BaseModel, ConfigDict
-from typing import Union, Optional
-
-from sqlalchemy import Column, String, BigInteger, Text
-
-from apps.webui.internal.db import Base, get_db
-
 import time
 import uuid
+from typing import Optional
+
+from apps.webui.internal.db import Base, get_db
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import BigInteger, Column, String, Text
 
 ####################
 # Memory DB Schema
@@ -39,13 +37,11 @@ class MemoryModel(BaseModel):
 
 
 class MemoriesTable:
-
     def insert_new_memory(
         self,
         user_id: str,
         content: str,
     ) -> Optional[MemoryModel]:
-
         with get_db() as db:
             id = str(uuid.uuid4())
 
@@ -73,7 +69,6 @@ class MemoriesTable:
         content: str,
     ) -> Optional[MemoryModel]:
         with get_db() as db:
-
             try:
                 db.query(Memory).filter_by(id=id).update(
                     {"content": content, "updated_at": int(time.time())}
@@ -85,7 +80,6 @@ class MemoriesTable:
 
     def get_memories(self) -> list[MemoryModel]:
         with get_db() as db:
-
             try:
                 memories = db.query(Memory).all()
                 return [MemoryModel.model_validate(memory) for memory in memories]
@@ -94,7 +88,6 @@ class MemoriesTable:
 
     def get_memories_by_user_id(self, user_id: str) -> list[MemoryModel]:
         with get_db() as db:
-
             try:
                 memories = db.query(Memory).filter_by(user_id=user_id).all()
                 return [MemoryModel.model_validate(memory) for memory in memories]
@@ -103,7 +96,6 @@ class MemoriesTable:
 
     def get_memory_by_id(self, id: str) -> Optional[MemoryModel]:
         with get_db() as db:
-
             try:
                 memory = db.get(Memory, id)
                 return MemoryModel.model_validate(memory)
@@ -112,7 +104,6 @@ class MemoriesTable:
 
     def delete_memory_by_id(self, id: str) -> bool:
         with get_db() as db:
-
             try:
                 db.query(Memory).filter_by(id=id).delete()
                 db.commit()
@@ -124,7 +115,6 @@ class MemoriesTable:
 
     def delete_memories_by_user_id(self, user_id: str) -> bool:
         with get_db() as db:
-
             try:
                 db.query(Memory).filter_by(user_id=user_id).delete()
                 db.commit()
@@ -135,7 +125,6 @@ class MemoriesTable:
 
     def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
         with get_db() as db:
-
             try:
                 db.query(Memory).filter_by(id=id, user_id=user_id).delete()
                 db.commit()

+ 4 - 7
backend/apps/webui/models/models.py

@@ -1,14 +1,11 @@
 import logging
-from typing import Optional, List
-
-from pydantic import BaseModel, ConfigDict
-from sqlalchemy import Column, BigInteger, Text
+import time
+from typing import Optional
 
 from apps.webui.internal.db import Base, JSONField, get_db
-
 from env import SRC_LOG_LEVELS
-
-import time
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import BigInteger, Column, Text
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])

+ 4 - 13
backend/apps/webui/models/prompts.py

@@ -1,12 +1,9 @@
-from pydantic import BaseModel, ConfigDict
-from typing import Optional
 import time
-
-from sqlalchemy import String, Column, BigInteger, Text
+from typing import Optional
 
 from apps.webui.internal.db import Base, get_db
-
-import json
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import BigInteger, Column, String, Text
 
 ####################
 # Prompts DB Schema
@@ -45,7 +42,6 @@ class PromptForm(BaseModel):
 
 
 class PromptsTable:
-
     def insert_new_prompt(
         self, user_id: str, form_data: PromptForm
     ) -> Optional[PromptModel]:
@@ -61,7 +57,6 @@ class PromptsTable:
 
         try:
             with get_db() as db:
-
                 result = Prompt(**prompt.dict())
                 db.add(result)
                 db.commit()
@@ -70,13 +65,12 @@ class PromptsTable:
                     return PromptModel.model_validate(result)
                 else:
                     return None
-        except Exception as e:
+        except Exception:
             return None
 
     def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
         try:
             with get_db() as db:
-
                 prompt = db.query(Prompt).filter_by(command=command).first()
                 return PromptModel.model_validate(prompt)
         except Exception:
@@ -84,7 +78,6 @@ class PromptsTable:
 
     def get_prompts(self) -> list[PromptModel]:
         with get_db() as db:
-
             return [
                 PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()
             ]
@@ -94,7 +87,6 @@ class PromptsTable:
     ) -> Optional[PromptModel]:
         try:
             with get_db() as db:
-
                 prompt = db.query(Prompt).filter_by(command=command).first()
                 prompt.title = form_data.title
                 prompt.content = form_data.content
@@ -107,7 +99,6 @@ class PromptsTable:
     def delete_prompt_by_command(self, command: str) -> bool:
         try:
             with get_db() as db:
-
                 db.query(Prompt).filter_by(command=command).delete()
                 db.commit()
 

+ 7 - 17
backend/apps/webui/models/tags.py

@@ -1,16 +1,12 @@
-from pydantic import BaseModel, ConfigDict
-from typing import Optional
-
-import json
-import uuid
-import time
 import logging
-
-from sqlalchemy import String, Column, BigInteger, Text
+import time
+import uuid
+from typing import Optional
 
 from apps.webui.internal.db import Base, get_db
-
 from env import SRC_LOG_LEVELS
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import BigInteger, Column, String, Text
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -77,10 +73,8 @@ class ChatTagsResponse(BaseModel):
 
 
 class TagTable:
-
     def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
         with get_db() as db:
-
             id = str(uuid.uuid4())
             tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
             try:
@@ -92,7 +86,7 @@ class TagTable:
                     return TagModel.model_validate(result)
                 else:
                     return None
-            except Exception as e:
+            except Exception:
                 return None
 
     def get_tag_by_name_and_user_id(
@@ -102,7 +96,7 @@ class TagTable:
             with get_db() as db:
                 tag = db.query(Tag).filter_by(name=name, user_id=user_id).first()
                 return TagModel.model_validate(tag)
-        except Exception as e:
+        except Exception:
             return None
 
     def add_tag_to_chat(
@@ -161,7 +155,6 @@ class TagTable:
         self, chat_id: str, user_id: str
     ) -> list[TagModel]:
         with get_db() as db:
-
             tag_names = [
                 chat_id_tag.tag_name
                 for chat_id_tag in (
@@ -186,7 +179,6 @@ class TagTable:
         self, tag_name: str, user_id: str
     ) -> list[ChatIdTagModel]:
         with get_db() as db:
-
             return [
                 ChatIdTagModel.model_validate(chat_id_tag)
                 for chat_id_tag in (
@@ -201,7 +193,6 @@ class TagTable:
         self, tag_name: str, user_id: str
     ) -> int:
         with get_db() as db:
-
             return (
                 db.query(ChatIdTag)
                 .filter_by(tag_name=tag_name, user_id=user_id)
@@ -236,7 +227,6 @@ class TagTable:
     ) -> bool:
         try:
             with get_db() as db:
-
                 res = (
                     db.query(ChatIdTag)
                     .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id)

+ 4 - 15
backend/apps/webui/models/tools.py

@@ -1,17 +1,12 @@
-from pydantic import BaseModel, ConfigDict
-from typing import Optional
-import time
 import logging
-from sqlalchemy import String, Column, BigInteger, Text
+import time
+from typing import Optional
 
 from apps.webui.internal.db import Base, JSONField, get_db
 from apps.webui.models.users import Users
-
-import json
-import copy
-
-
 from env import SRC_LOG_LEVELS
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import BigInteger, Column, String, Text
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -79,13 +74,10 @@ class ToolValves(BaseModel):
 
 
 class ToolsTable:
-
     def insert_new_tool(
         self, user_id: str, form_data: ToolForm, specs: list[dict]
     ) -> Optional[ToolModel]:
-
         with get_db() as db:
-
             tool = ToolModel(
                 **{
                     **form_data.model_dump(),
@@ -112,7 +104,6 @@ class ToolsTable:
     def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
         try:
             with get_db() as db:
-
                 tool = db.get(Tool, id)
                 return ToolModel.model_validate(tool)
         except Exception:
@@ -125,7 +116,6 @@ class ToolsTable:
     def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
         try:
             with get_db() as db:
-
                 tool = db.get(Tool, id)
                 return tool.valves if tool.valves else {}
         except Exception as e:
@@ -135,7 +125,6 @@ class ToolsTable:
     def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
         try:
             with get_db() as db:
-
                 db.query(Tool).filter_by(id=id).update(
                     {"valves": valves, "updated_at": int(time.time())}
                 )

+ 6 - 7
backend/apps/webui/models/users.py

@@ -1,11 +1,10 @@
-from pydantic import BaseModel, ConfigDict
-from typing import Optional
 import time
-
-from sqlalchemy import String, Column, BigInteger, Text
+from typing import Optional
 
 from apps.webui.internal.db import Base, JSONField, get_db
 from apps.webui.models.chats import Chats
+from pydantic import BaseModel, ConfigDict
+from sqlalchemy import BigInteger, Column, String, Text
 
 ####################
 # User DB Schema
@@ -113,7 +112,7 @@ class UsersTable:
             with get_db() as db:
                 user = db.query(User).filter_by(id=id).first()
                 return UserModel.model_validate(user)
-        except Exception as e:
+        except Exception:
             return None
 
     def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
@@ -221,7 +220,7 @@ class UsersTable:
                 user = db.query(User).filter_by(id=id).first()
                 return UserModel.model_validate(user)
                 # return UserModel(**user.dict())
-        except Exception as e:
+        except Exception:
             return None
 
     def delete_user_by_id(self, id: str) -> bool:
@@ -255,7 +254,7 @@ class UsersTable:
             with get_db() as db:
                 user = db.query(User).filter_by(id=id).first()
                 return user.api_key
-        except Exception as e:
+        except Exception:
             return None
 
 

+ 16 - 28
backend/apps/webui/routers/auths.py

@@ -1,43 +1,33 @@
-import logging
-
-from fastapi import Request, UploadFile, File
-from fastapi import Depends, HTTPException, status
-from fastapi.responses import Response
-
-from fastapi import APIRouter
-from pydantic import BaseModel
 import re
 import uuid
-import csv
 
 from apps.webui.models.auths import (
+    AddUserForm,
+    ApiKey,
+    Auths,
     SigninForm,
+    SigninResponse,
     SignupForm,
-    AddUserForm,
-    UpdateProfileForm,
     UpdatePasswordForm,
+    UpdateProfileForm,
     UserResponse,
-    SigninResponse,
-    Auths,
-    ApiKey,
 )
 from apps.webui.models.users import Users
-
+from config import WEBUI_AUTH
+from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
+from env import WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_NAME_HEADER
+from fastapi import APIRouter, Depends, HTTPException, Request, status
+from fastapi.responses import Response
+from pydantic import BaseModel
+from utils.misc import parse_duration, validate_email_format
 from utils.utils import (
-    get_password_hash,
-    get_current_user,
-    get_admin_user,
-    create_token,
     create_api_key,
+    create_token,
+    get_admin_user,
+    get_current_user,
+    get_password_hash,
 )
-from utils.misc import parse_duration, validate_email_format
 from utils.webhook import post_webhook
-from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
-from config import (
-    WEBUI_AUTH,
-    WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
-    WEBUI_AUTH_TRUSTED_NAME_HEADER,
-)
 
 router = APIRouter()
 
@@ -273,7 +263,6 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
 
 @router.post("/add", response_model=SigninResponse)
 async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
-
     if not validate_email_format(form_data.email.lower()):
         raise HTTPException(
             status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
@@ -283,7 +272,6 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
         raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
 
     try:
-
         print(form_data)
         hashed = get_password_hash(form_data.password)
         user = Auths.insert_new_auth(

+ 8 - 31
backend/apps/webui/routers/chats.py

@@ -1,34 +1,15 @@
-from fastapi import Depends, Request, HTTPException, status
-from datetime import datetime, timedelta
-from typing import Union, Optional
-from utils.utils import get_verified_user, get_admin_user
-from fastapi import APIRouter
-from pydantic import BaseModel
 import json
 import logging
+from typing import Optional
 
-from apps.webui.models.users import Users
-from apps.webui.models.chats import (
-    ChatModel,
-    ChatResponse,
-    ChatTitleForm,
-    ChatForm,
-    ChatTitleIdResponse,
-    Chats,
-)
-
-
-from apps.webui.models.tags import (
-    TagModel,
-    ChatIdTagModel,
-    ChatIdTagForm,
-    ChatTagsResponse,
-    Tags,
-)
-
+from apps.webui.models.chats import ChatForm, ChatResponse, Chats, ChatTitleIdResponse
+from apps.webui.models.tags import ChatIdTagForm, ChatIdTagModel, TagModel, Tags
+from config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT
 from constants import ERROR_MESSAGES
-
-from config import SRC_LOG_LEVELS, ENABLE_ADMIN_EXPORT, ENABLE_ADMIN_CHAT_ACCESS
+from env import SRC_LOG_LEVELS
+from fastapi import APIRouter, Depends, HTTPException, Request, status
+from pydantic import BaseModel
+from utils.utils import get_admin_user, get_verified_user
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -61,7 +42,6 @@ async def get_session_user_chat_list(
 
 @router.delete("/", response_model=bool)
 async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)):
-
     if (
         user.role == "user"
         and not request.app.state.config.USER_PERMISSIONS["chat"]["deletion"]
@@ -220,7 +200,6 @@ class TagNameForm(BaseModel):
 async def get_user_chat_list_by_tag_name(
     form_data: TagNameForm, user=Depends(get_verified_user)
 ):
-
     chat_ids = [
         chat_id_tag.chat_id
         for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(
@@ -299,7 +278,6 @@ async def update_chat_by_id(
 
 @router.delete("/{id}", response_model=bool)
 async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
-
     if user.role == "admin":
         result = Chats.delete_chat_by_id(id)
         return result
@@ -323,7 +301,6 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified
 async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
     chat = Chats.get_chat_by_id_and_user_id(id, user.id)
     if chat:
-
         chat_body = json.loads(chat.chat)
         updated_chat = {
             **chat_body,

+ 3 - 21
backend/apps/webui/routers/configs.py

@@ -1,25 +1,7 @@
-from fastapi import Response, Request
-from fastapi import Depends, FastAPI, HTTPException, status
-from datetime import datetime, timedelta
-from typing import Union
-
-from fastapi import APIRouter
-from pydantic import BaseModel
-import time
-import uuid
-
 from config import BannerModel
-
-from apps.webui.models.users import Users
-
-from utils.utils import (
-    get_password_hash,
-    get_verified_user,
-    get_admin_user,
-    create_token,
-)
-from utils.misc import get_gravatar_url, validate_email_format
-from constants import ERROR_MESSAGES
+from fastapi import APIRouter, Depends, Request
+from pydantic import BaseModel
+from utils.utils import get_admin_user, get_verified_user
 
 router = APIRouter()
 

+ 6 - 11
backend/apps/webui/routers/documents.py

@@ -1,21 +1,16 @@
-from fastapi import Depends, FastAPI, HTTPException, status
-from datetime import datetime, timedelta
-from typing import Union, Optional
-
-from fastapi import APIRouter
-from pydantic import BaseModel
 import json
+from typing import Optional
 
 from apps.webui.models.documents import (
-    Documents,
     DocumentForm,
-    DocumentUpdateForm,
-    DocumentModel,
     DocumentResponse,
+    Documents,
+    DocumentUpdateForm,
 )
-
-from utils.utils import get_verified_user, get_admin_user
 from constants import ERROR_MESSAGES
+from fastapi import APIRouter, Depends, HTTPException, status
+from pydantic import BaseModel
+from utils.utils import get_admin_user, get_verified_user
 
 router = APIRouter()
 

+ 11 - 36
backend/apps/webui/routers/files.py

@@ -1,42 +1,17 @@
-from fastapi import (
-    Depends,
-    FastAPI,
-    HTTPException,
-    status,
-    Request,
-    UploadFile,
-    File,
-    Form,
-)
-
-
-from datetime import datetime, timedelta
-from typing import Union, Optional
-from pathlib import Path
-
-from fastapi import APIRouter
-from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
-
-from pydantic import BaseModel
-import json
-
-from apps.webui.models.files import (
-    Files,
-    FileForm,
-    FileModel,
-    FileModelResponse,
-)
-from utils.utils import get_verified_user, get_admin_user
-from constants import ERROR_MESSAGES
-
-from importlib import util
+import logging
 import os
+import shutil
 import uuid
-import os, shutil, logging, re
-
-
-from config import SRC_LOG_LEVELS, UPLOAD_DIR
+from pathlib import Path
+from typing import Optional
 
+from apps.webui.models.files import FileForm, FileModel, Files
+from config import UPLOAD_DIR
+from constants import ERROR_MESSAGES
+from env import SRC_LOG_LEVELS
+from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
+from fastapi.responses import FileResponse
+from utils.utils import get_admin_user, get_verified_user
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])

+ 7 - 17
backend/apps/webui/routers/functions.py

@@ -1,27 +1,18 @@
-from fastapi import Depends, FastAPI, HTTPException, status, Request
-from datetime import datetime, timedelta
-from typing import Union, Optional
-
-from fastapi import APIRouter
-from pydantic import BaseModel
-import json
+import os
+from pathlib import Path
+from typing import Optional
 
 from apps.webui.models.functions import (
-    Functions,
     FunctionForm,
     FunctionModel,
     FunctionResponse,
+    Functions,
 )
 from apps.webui.utils import load_function_module_by_id
-from utils.utils import get_verified_user, get_admin_user
+from config import CACHE_DIR, FUNCTIONS_DIR
 from constants import ERROR_MESSAGES
-
-from importlib import util
-import os
-from pathlib import Path
-
-from config import DATA_DIR, CACHE_DIR, FUNCTIONS_DIR
-
+from fastapi import APIRouter, Depends, HTTPException, Request, status
+from utils.utils import get_admin_user, get_verified_user
 
 router = APIRouter()
 
@@ -304,7 +295,6 @@ async def update_function_valves_by_id(
 ):
     function = Functions.get_function_by_id(id)
     if function:
-
         if id in request.app.state.FUNCTIONS:
             function_module = request.app.state.FUNCTIONS[id]
         else:

+ 5 - 11
backend/apps/webui/routers/memories.py

@@ -1,18 +1,12 @@
-from fastapi import Response, Request
-from fastapi import Depends, FastAPI, HTTPException, status
-from datetime import datetime, timedelta
-from typing import Union, Optional
-
-from fastapi import APIRouter
-from pydantic import BaseModel
 import logging
+from typing import Optional
 
 from apps.webui.models.memories import Memories, MemoryModel
-
+from config import CHROMA_CLIENT
+from env import SRC_LOG_LEVELS
+from fastapi import APIRouter, Depends, HTTPException, Request
+from pydantic import BaseModel
 from utils.utils import get_verified_user
-from constants import ERROR_MESSAGES
-
-from config import SRC_LOG_LEVELS, CHROMA_CLIENT
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])

+ 4 - 10
backend/apps/webui/routers/models.py

@@ -1,15 +1,9 @@
-from fastapi import Depends, FastAPI, HTTPException, status, Request
-from datetime import datetime, timedelta
-from typing import Union, Optional
+from typing import Optional
 
-from fastapi import APIRouter
-from pydantic import BaseModel
-import json
-
-from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse
-
-from utils.utils import get_verified_user, get_admin_user
+from apps.webui.models.models import ModelForm, ModelModel, ModelResponse, Models
 from constants import ERROR_MESSAGES
+from fastapi import APIRouter, Depends, HTTPException, Request, status
+from utils.utils import get_admin_user, get_verified_user
 
 router = APIRouter()
 

+ 4 - 10
backend/apps/webui/routers/prompts.py

@@ -1,15 +1,9 @@
-from fastapi import Depends, FastAPI, HTTPException, status
-from datetime import datetime, timedelta
-from typing import Union, Optional
+from typing import Optional
 
-from fastapi import APIRouter
-from pydantic import BaseModel
-import json
-
-from apps.webui.models.prompts import Prompts, PromptForm, PromptModel
-
-from utils.utils import get_verified_user, get_admin_user
+from apps.webui.models.prompts import PromptForm, PromptModel, Prompts
 from constants import ERROR_MESSAGES
+from fastapi import APIRouter, Depends, HTTPException, status
+from utils.utils import get_admin_user, get_verified_user
 
 router = APIRouter()
 

+ 7 - 13
backend/apps/webui/routers/tools.py

@@ -1,20 +1,14 @@
-from fastapi import Depends, HTTPException, status, Request
+import os
+from pathlib import Path
 from typing import Optional
 
-from fastapi import APIRouter
-
-from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
+from apps.webui.models.tools import ToolForm, ToolModel, ToolResponse, Tools
 from apps.webui.utils import load_toolkit_module_by_id
-
-from utils.utils import get_admin_user, get_verified_user
-from utils.tools import get_tools_specs
+from config import CACHE_DIR, DATA_DIR
 from constants import ERROR_MESSAGES
-
-import os
-from pathlib import Path
-
-from config import DATA_DIR, CACHE_DIR
-
+from fastapi import APIRouter, Depends, HTTPException, Request, status
+from utils.tools import get_tools_specs
+from utils.utils import get_admin_user, get_verified_user
 
 TOOLS_DIR = f"{DATA_DIR}/tools"
 os.makedirs(TOOLS_DIR, exist_ok=True)

+ 9 - 24
backend/apps/webui/routers/users.py

@@ -1,33 +1,20 @@
-from fastapi import Response, Request
-from fastapi import Depends, FastAPI, HTTPException, status
-from datetime import datetime, timedelta
-from typing import Union, Optional
-
-from fastapi import APIRouter
-from pydantic import BaseModel
-import time
-import uuid
 import logging
+from typing import Optional
 
+from apps.webui.models.auths import Auths
+from apps.webui.models.chats import Chats
 from apps.webui.models.users import (
     UserModel,
-    UserUpdateForm,
     UserRoleUpdateForm,
-    UserSettings,
     Users,
-)
-from apps.webui.models.auths import Auths
-from apps.webui.models.chats import Chats
-
-from utils.utils import (
-    get_verified_user,
-    get_password_hash,
-    get_current_user,
-    get_admin_user,
+    UserSettings,
+    UserUpdateForm,
 )
 from constants import ERROR_MESSAGES
-
-from config import SRC_LOG_LEVELS
+from env import SRC_LOG_LEVELS
+from fastapi import APIRouter, Depends, HTTPException, Request, status
+from pydantic import BaseModel
+from utils.utils import get_admin_user, get_password_hash, get_verified_user
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MODELS"])
@@ -69,7 +56,6 @@ async def update_user_permissions(
 
 @router.post("/update/role", response_model=Optional[UserModel])
 async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin_user)):
-
     if user.id != form_data.id and form_data.id != Users.get_first_user().id:
         return Users.update_user_role_by_id(form_data.id, form_data.role)
 
@@ -173,7 +159,6 @@ class UserResponse(BaseModel):
 
 @router.get("/{user_id}", response_model=UserResponse)
 async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
-
     # Check if user_id is a shared chat
     # If it is, get the user_id from the chat
     if user_id.startswith("shared-"):

+ 10 - 17
backend/apps/webui/routers/utils.py

@@ -1,23 +1,16 @@
-from pathlib import Path
 import site
+from pathlib import Path
 
-from fastapi import APIRouter, UploadFile, File, Response
-from fastapi import Depends, HTTPException, status
-from starlette.responses import StreamingResponse, FileResponse
-from pydantic import BaseModel
-
-
-from fpdf import FPDF
-import markdown
 import black
-
-
-from utils.utils import get_admin_user
-from utils.misc import calculate_sha256, get_gravatar_url
-
-from config import OLLAMA_BASE_URLS, DATA_DIR, UPLOAD_DIR, ENABLE_ADMIN_EXPORT
+import markdown
+from config import DATA_DIR, ENABLE_ADMIN_EXPORT
 from constants import ERROR_MESSAGES
-
+from fastapi import APIRouter, Depends, HTTPException, Response, status
+from fpdf import FPDF
+from pydantic import BaseModel
+from starlette.responses import FileResponse
+from utils.misc import get_gravatar_url
+from utils.utils import get_admin_user
 
 router = APIRouter()
 
@@ -115,7 +108,7 @@ async def download_chat_as_pdf(
     return Response(
         content=bytes(pdf_bytes),
         media_type="application/pdf",
-        headers={"Content-Disposition": f"attachment;filename=chat.pdf"},
+        headers={"Content-Disposition": "attachment;filename=chat.pdf"},
     )
 
 

+ 4 - 5
backend/apps/webui/utils.py

@@ -1,13 +1,12 @@
-from importlib import util
 import os
 import re
-import sys
 import subprocess
+import sys
+from importlib import util
 
-
-from apps.webui.models.tools import Tools
 from apps.webui.models.functions import Functions
-from config import TOOLS_DIR, FUNCTIONS_DIR
+from apps.webui.models.tools import Tools
+from config import FUNCTIONS_DIR, TOOLS_DIR
 
 
 def extract_frontmatter(file_path):

+ 16 - 44
backend/config.py

@@ -1,58 +1,30 @@
-from sqlalchemy import create_engine, Column, Integer, DateTime, JSON, func
-from contextlib import contextmanager
-
-
-import os
-import sys
+import json
 import logging
-import importlib.metadata
-import pkgutil
-from urllib.parse import urlparse
+import os
+import shutil
 from datetime import datetime
-
-import chromadb
-from chromadb import Settings
-from typing import TypeVar, Generic
-from pydantic import BaseModel
-from typing import Optional
-
 from pathlib import Path
-import json
-import yaml
+from typing import Generic, Optional, TypeVar
+from urllib.parse import urlparse
 
+import chromadb
 import requests
-import shutil
-
-
+import yaml
 from apps.webui.internal.db import Base, get_db
-
-from constants import ERROR_MESSAGES
-
+from chromadb import Settings
 from env import (
-    ENV,
-    VERSION,
-    SAFE_MODE,
-    GLOBAL_LOG_LEVEL,
-    SRC_LOG_LEVELS,
-    BASE_DIR,
-    DATA_DIR,
     BACKEND_DIR,
-    FRONTEND_BUILD_DIR,
-    WEBUI_NAME,
-    WEBUI_URL,
-    WEBUI_FAVICON_URL,
-    WEBUI_BUILD_HASH,
     CONFIG_DATA,
-    DATABASE_URL,
-    CHANGELOG,
+    DATA_DIR,
+    ENV,
+    FRONTEND_BUILD_DIR,
     WEBUI_AUTH,
-    WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
-    WEBUI_AUTH_TRUSTED_NAME_HEADER,
-    WEBUI_SECRET_KEY,
-    WEBUI_SESSION_COOKIE_SAME_SITE,
-    WEBUI_SESSION_COOKIE_SECURE,
+    WEBUI_FAVICON_URL,
+    WEBUI_NAME,
     log,
 )
+from pydantic import BaseModel
+from sqlalchemy import JSON, Column, DateTime, Integer, func
 
 
 class EndpointFilter(logging.Filter):
@@ -72,8 +44,8 @@ logging.getLogger("uvicorn.access").addFilter(EndpointFilter())
 def run_migrations():
     print("Running migrations")
     try:
-        from alembic.config import Config
         from alembic import command
+        from alembic.config import Config
 
         alembic_cfg = Config(BACKEND_DIR / "alembic.ini")
         command.upgrade(alembic_cfg, "head")

+ 6 - 12
backend/env.py

@@ -1,19 +1,13 @@
-from pathlib import Path
-import os
-import logging
-import sys
-import json
-
-
 import importlib.metadata
+import json
+import logging
+import os
 import pkgutil
-from urllib.parse import urlparse
-from datetime import datetime
-
+import sys
+from pathlib import Path
 
 import markdown
 from bs4 import BeautifulSoup
-
 from constants import ERROR_MESSAGES
 
 ####################################
@@ -26,7 +20,7 @@ BASE_DIR = BACKEND_DIR.parent  # the path containing the backend/
 print(BASE_DIR)
 
 try:
-    from dotenv import load_dotenv, find_dotenv
+    from dotenv import find_dotenv, load_dotenv
 
     load_dotenv(find_dotenv(str(BASE_DIR / ".env")))
 except ImportError:

+ 92 - 98
backend/main.py

@@ -1,130 +1,124 @@
 import base64
-import uuid
-from contextlib import asynccontextmanager
-from authlib.integrations.starlette_client import OAuth
-from authlib.oidc.core import UserInfo
+import inspect
 import json
-import time
-import os
-import sys
 import logging
-import aiohttp
-import requests
 import mimetypes
+import os
 import shutil
-import inspect
+import sys
+import time
+import uuid
+from contextlib import asynccontextmanager
 from typing import Optional
 
-from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
-from fastapi.staticfiles import StaticFiles
-from fastapi.responses import JSONResponse
-from fastapi import HTTPException
-from fastapi.middleware.cors import CORSMiddleware
-from sqlalchemy import text
-from starlette.exceptions import HTTPException as StarletteHTTPException
-from starlette.middleware.base import BaseHTTPMiddleware
-from starlette.middleware.sessions import SessionMiddleware
-from starlette.responses import StreamingResponse, Response, RedirectResponse
-
-
-from apps.socket.main import app as socket_app, get_event_emitter, get_event_call
+import aiohttp
+import requests
+from apps.audio.main import app as audio_app
+from apps.images.main import app as images_app
+from apps.ollama.main import app as ollama_app
 from apps.ollama.main import (
-    app as ollama_app,
-    get_all_models as get_ollama_models,
     generate_openai_chat_completion as generate_ollama_chat_completion,
 )
-from apps.openai.main import (
-    app as openai_app,
-    get_all_models as get_openai_models,
-    generate_chat_completion as generate_openai_chat_completion,
-)
-
-from apps.audio.main import app as audio_app
-from apps.images.main import app as images_app
+from apps.ollama.main import get_all_models as get_ollama_models
+from apps.openai.main import app as openai_app
+from apps.openai.main import generate_chat_completion as generate_openai_chat_completion
+from apps.openai.main import get_all_models as get_openai_models
 from apps.rag.main import app as rag_app
-from apps.webui.main import (
-    app as webui_app,
-    get_pipe_models,
-    generate_function_chat_completion,
-)
+from apps.rag.utils import get_rag_context, rag_template
+from apps.socket.main import app as socket_app
+from apps.socket.main import get_event_call, get_event_emitter
 from apps.webui.internal.db import Session
-
-
-from pydantic import BaseModel
-
+from apps.webui.main import app as webui_app
+from apps.webui.main import generate_function_chat_completion, get_pipe_models
 from apps.webui.models.auths import Auths
-from apps.webui.models.models import Models
 from apps.webui.models.functions import Functions
-from apps.webui.models.users import Users, UserModel
-
+from apps.webui.models.models import Models
+from apps.webui.models.users import UserModel, Users
 from apps.webui.utils import load_function_module_by_id
-
-from utils.utils import (
-    get_admin_user,
-    get_verified_user,
-    get_current_user,
-    get_http_authorization_cred,
-    get_password_hash,
-    create_token,
-    decode_token,
-)
-from utils.task import (
-    title_generation_template,
-    search_query_generation_template,
-    tools_function_calling_generation_template,
-    moa_response_generation_template,
-)
-
-from utils.tools import get_tools
-from utils.misc import (
-    get_last_user_message,
-    add_or_update_system_message,
-    prepend_to_first_user_message_content,
-    parse_duration,
-)
-
-from apps.rag.utils import get_rag_context, rag_template
-
+from authlib.integrations.starlette_client import OAuth
+from authlib.oidc.core import UserInfo
 from config import (
-    run_migrations,
-    WEBUI_NAME,
-    WEBUI_URL,
-    WEBUI_AUTH,
-    ENV,
-    VERSION,
-    CHANGELOG,
-    FRONTEND_BUILD_DIR,
     CACHE_DIR,
-    STATIC_DIR,
+    CORS_ALLOW_ORIGIN,
     DEFAULT_LOCALE,
-    ENABLE_OPENAI_API,
-    ENABLE_OLLAMA_API,
+    ENABLE_ADMIN_CHAT_ACCESS,
+    ENABLE_ADMIN_EXPORT,
     ENABLE_MODEL_FILTER,
+    ENABLE_OAUTH_SIGNUP,
+    ENABLE_OLLAMA_API,
+    ENABLE_OPENAI_API,
+    ENV,
+    FRONTEND_BUILD_DIR,
     MODEL_FILTER_LIST,
-    GLOBAL_LOG_LEVEL,
-    SRC_LOG_LEVELS,
-    WEBHOOK_URL,
-    ENABLE_ADMIN_EXPORT,
-    WEBUI_BUILD_HASH,
+    OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
+    OAUTH_PROVIDERS,
+    SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
+    SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
+    STATIC_DIR,
     TASK_MODEL,
     TASK_MODEL_EXTERNAL,
     TITLE_GENERATION_PROMPT_TEMPLATE,
-    SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
-    SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
     TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
+    WEBHOOK_URL,
+    WEBUI_AUTH,
+    WEBUI_NAME,
+    AppConfig,
+    run_migrations,
+)
+from constants import ERROR_MESSAGES, TASKS, WEBHOOK_MESSAGES
+from env import (
+    CHANGELOG,
+    GLOBAL_LOG_LEVEL,
     SAFE_MODE,
-    OAUTH_PROVIDERS,
-    ENABLE_OAUTH_SIGNUP,
-    OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
+    SRC_LOG_LEVELS,
+    VERSION,
+    WEBUI_BUILD_HASH,
     WEBUI_SECRET_KEY,
     WEBUI_SESSION_COOKIE_SAME_SITE,
     WEBUI_SESSION_COOKIE_SECURE,
-    ENABLE_ADMIN_CHAT_ACCESS,
-    AppConfig,
-    CORS_ALLOW_ORIGIN,
+    WEBUI_URL,
+)
+from fastapi import (
+    Depends,
+    FastAPI,
+    File,
+    Form,
+    HTTPException,
+    Request,
+    UploadFile,
+    status,
+)
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import JSONResponse
+from fastapi.staticfiles import StaticFiles
+from pydantic import BaseModel
+from sqlalchemy import text
+from starlette.exceptions import HTTPException as StarletteHTTPException
+from starlette.middleware.base import BaseHTTPMiddleware
+from starlette.middleware.sessions import SessionMiddleware
+from starlette.responses import RedirectResponse, Response, StreamingResponse
+from utils.misc import (
+    add_or_update_system_message,
+    get_last_user_message,
+    parse_duration,
+    prepend_to_first_user_message_content,
+)
+from utils.task import (
+    moa_response_generation_template,
+    search_query_generation_template,
+    title_generation_template,
+    tools_function_calling_generation_template,
+)
+from utils.tools import get_tools
+from utils.utils import (
+    create_token,
+    decode_token,
+    get_admin_user,
+    get_current_user,
+    get_http_authorization_cred,
+    get_password_hash,
+    get_verified_user,
 )
-
-from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS
 from utils.webhook import post_webhook
 
 if SAFE_MODE:

+ 1 - 16
backend/migrations/env.py

@@ -1,24 +1,9 @@
-import os
 from logging.config import fileConfig
 
-from sqlalchemy import engine_from_config
-from sqlalchemy import pool
-
 from alembic import context
-
 from apps.webui.models.auths import Auth
-from apps.webui.models.chats import Chat
-from apps.webui.models.documents import Document
-from apps.webui.models.memories import Memory
-from apps.webui.models.models import Model
-from apps.webui.models.prompts import Prompt
-from apps.webui.models.tags import Tag, ChatIdTag
-from apps.webui.models.tools import Tool
-from apps.webui.models.users import User
-from apps.webui.models.files import File
-from apps.webui.models.functions import Function
-
 from env import DATABASE_URL
+from sqlalchemy import engine_from_config, pool
 
 # this is the Alembic Config object, which provides
 # access to the values within the .ini file in use.

+ 3 - 3
backend/migrations/versions/7e5b5dc7342b_init.py

@@ -1,16 +1,16 @@
 """init
 
 Revision ID: 7e5b5dc7342b
-Revises: 
+Revises:
 Create Date: 2024-06-24 13:15:33.808998
 
 """
 
 from typing import Sequence, Union
 
-from alembic import op
-import sqlalchemy as sa
 import apps.webui.internal.db
+import sqlalchemy as sa
+from alembic import op
 from migrations.util import get_existing_tables
 
 # revision identifiers, used by Alembic.

+ 1 - 3
backend/migrations/versions/ca81bd47c050_add_config_table.py

@@ -8,10 +8,8 @@ Create Date: 2024-08-25 15:26:35.241684
 
 from typing import Sequence, Union
 
-from alembic import op
 import sqlalchemy as sa
-import apps.webui.internal.db
-
+from alembic import op
 
 # revision identifiers, used by Alembic.
 revision: str = "ca81bd47c050"

+ 1 - 3
backend/test/apps/webui/routers/test_auths.py

@@ -1,5 +1,3 @@
-import pytest
-
 from test.util.abstract_integration_test import AbstractPostgresTest
 from test.util.mock_user import mock_webui_user
 
@@ -9,8 +7,8 @@ class TestAuths(AbstractPostgresTest):
 
     def setup_class(cls):
         super().setup_class()
-        from apps.webui.models.users import Users
         from apps.webui.models.auths import Auths
+        from apps.webui.models.users import Users
 
         cls.users = Users
         cls.auths = Auths

+ 1 - 3
backend/test/apps/webui/routers/test_chats.py

@@ -5,7 +5,6 @@ from test.util.mock_user import mock_webui_user
 
 
 class TestChats(AbstractPostgresTest):
-
     BASE_PATH = "/api/v1/chats"
 
     def setup_class(cls):
@@ -13,8 +12,7 @@ class TestChats(AbstractPostgresTest):
 
     def setup_method(self):
         super().setup_method()
-        from apps.webui.models.chats import ChatForm
-        from apps.webui.models.chats import Chats
+        from apps.webui.models.chats import ChatForm, Chats
 
         self.chats = Chats
         self.chats.insert_new_chat(

+ 0 - 1
backend/test/apps/webui/routers/test_documents.py

@@ -3,7 +3,6 @@ from test.util.mock_user import mock_webui_user
 
 
 class TestDocuments(AbstractPostgresTest):
-
     BASE_PATH = "/api/v1/documents"
 
     def setup_class(cls):

+ 0 - 1
backend/test/apps/webui/routers/test_models.py

@@ -3,7 +3,6 @@ from test.util.mock_user import mock_webui_user
 
 
 class TestModels(AbstractPostgresTest):
-
     BASE_PATH = "/api/v1/models"
 
     def setup_class(cls):

+ 0 - 1
backend/test/apps/webui/routers/test_prompts.py

@@ -3,7 +3,6 @@ from test.util.mock_user import mock_webui_user
 
 
 class TestPrompts(AbstractPostgresTest):
-
     BASE_PATH = "/api/v1/prompts"
 
     def test_prompts(self):

+ 0 - 1
backend/test/apps/webui/routers/test_users.py

@@ -21,7 +21,6 @@ def _assert_user(data, id, **kwargs):
 
 
 class TestUsers(AbstractPostgresTest):
-
     BASE_PATH = "/api/v1/users"
 
     def setup_class(cls):

+ 4 - 4
backend/utils/misc.py

@@ -1,10 +1,10 @@
-from pathlib import Path
 import hashlib
 import re
-from datetime import timedelta
-from typing import Optional, Callable
-import uuid
 import time
+import uuid
+from datetime import timedelta
+from pathlib import Path
+from typing import Callable, Optional
 
 from utils.task import prompt_template
 

+ 1 - 1
backend/utils/schemas.py

@@ -1,7 +1,7 @@
 from ast import literal_eval
+from typing import Any, Literal, Optional, Type
 
 from pydantic import BaseModel, Field, create_model
-from typing import Any, Optional, Type, Literal
 
 
 def json_schema_to_model(tool_dict: dict[str, Any]) -> Type[BaseModel]:

+ 1 - 2
backend/utils/task.py

@@ -1,6 +1,5 @@
-import re
 import math
-
+import re
 from datetime import datetime
 from typing import Optional
 

+ 0 - 1
backend/utils/tools.py

@@ -5,7 +5,6 @@ from typing import Awaitable, Callable, get_type_hints
 from apps.webui.models.tools import Tools
 from apps.webui.models.users import UserModel
 from apps.webui.utils import load_toolkit_module_by_id
-
 from utils.schemas import json_schema_to_model
 
 log = logging.getLogger(__name__)

+ 8 - 9
backend/utils/utils.py

@@ -1,16 +1,15 @@
-from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
-from fastapi import HTTPException, status, Depends, Request
+import logging
+import uuid
+from datetime import UTC, datetime, timedelta
+from typing import Optional, Union
 
+import jwt
 from apps.webui.models.users import Users
-
-from typing import Union, Optional
 from constants import ERROR_MESSAGES
-from passlib.context import CryptContext
-from datetime import datetime, timedelta, UTC
-import jwt
-import uuid
-import logging
 from env import WEBUI_SECRET_KEY
+from fastapi import Depends, HTTPException, Request, status
+from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
+from passlib.context import CryptContext
 
 logging.getLogger("passlib").setLevel(logging.ERROR)
 

+ 3 - 2
backend/utils/webhook.py

@@ -1,8 +1,9 @@
 import json
-import requests
 import logging
 
-from config import SRC_LOG_LEVELS, VERSION, WEBUI_FAVICON_URL, WEBUI_NAME
+import requests
+from config import WEBUI_FAVICON_URL, WEBUI_NAME
+from env import SRC_LOG_LEVELS, VERSION
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["WEBHOOK"])