瀏覽代碼

feat(trace): opentelemetry instrument

orenzhang 1 月之前
父節點
當前提交
c761e4fd08

+ 3 - 1
backend/open_webui/config.py

@@ -1580,7 +1580,9 @@ QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", None)
 # OpenSearch
 OPENSEARCH_URI = os.environ.get("OPENSEARCH_URI", "https://localhost:9200")
 OPENSEARCH_SSL = os.environ.get("OPENSEARCH_SSL", "true").lower() == "true"
-OPENSEARCH_CERT_VERIFY = os.environ.get("OPENSEARCH_CERT_VERIFY", "false").lower() == "true"
+OPENSEARCH_CERT_VERIFY = (
+    os.environ.get("OPENSEARCH_CERT_VERIFY", "false").lower() == "true"
+)
 OPENSEARCH_USERNAME = os.environ.get("OPENSEARCH_USERNAME", None)
 OPENSEARCH_PASSWORD = os.environ.get("OPENSEARCH_PASSWORD", None)
 

+ 9 - 0
backend/open_webui/env.py

@@ -442,3 +442,12 @@ AUDIT_EXCLUDED_PATHS = os.getenv("AUDIT_EXCLUDED_PATHS", "/chats,/chat,/folders"
 )
 AUDIT_EXCLUDED_PATHS = [path.strip() for path in AUDIT_EXCLUDED_PATHS]
 AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS]
+
+####################################
+# OPENTELEMETRY
+####################################
+
+OT_ENABLED = os.environ.get("OT_ENABLED", "false").lower() == "true"
+OT_SERVICE_NAME = os.environ.get("OT_SERVICE_NAME", "open-webui")
+OT_HOST = os.environ.get("OT_HOST", "http://localhost:4317")
+OT_TOKEN = os.environ.get("OT_TOKEN", "")

+ 13 - 1
backend/open_webui/main.py

@@ -330,6 +330,7 @@ from open_webui.env import (
     BYPASS_MODEL_ACCESS_CONTROL,
     RESET_CONFIG_ON_START,
     OFFLINE_MODE,
+    OT_ENABLED,
 )
 
 
@@ -356,7 +357,7 @@ from open_webui.utils.oauth import OAuthManager
 from open_webui.utils.security_headers import SecurityHeadersMiddleware
 
 from open_webui.tasks import stop_task, list_tasks  # Import from tasks.py
-
+from open_webui.utils.trace.setup import setup
 
 if SAFE_MODE:
     print("SAFE MODE ENABLED")
@@ -426,6 +427,17 @@ app.state.config = AppConfig(redis_url=REDIS_URL)
 app.state.WEBUI_NAME = WEBUI_NAME
 app.state.LICENSE_METADATA = None
 
+
+########################################
+#
+# OPENTELEMETRY
+#
+########################################
+
+if OT_ENABLED:
+    setup(app)
+
+
 ########################################
 #
 # OLLAMA

+ 13 - 18
backend/open_webui/retrieval/loaders/tavily.py

@@ -9,18 +9,20 @@ from open_webui.env import SRC_LOG_LEVELS
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 
+
 class TavilyLoader(BaseLoader):
     """Extract web page content from URLs using Tavily Extract API.
-    
+
     This is a LangChain document loader that uses Tavily's Extract API to
     retrieve content from web pages and return it as Document objects.
-    
+
     Args:
         urls: URL or list of URLs to extract content from.
         api_key: The Tavily API key.
         extract_depth: Depth of extraction, either "basic" or "advanced".
         continue_on_failure: Whether to continue if extraction of a URL fails.
     """
+
     def __init__(
         self,
         urls: Union[str, List[str]],
@@ -29,13 +31,13 @@ class TavilyLoader(BaseLoader):
         continue_on_failure: bool = True,
     ) -> None:
         """Initialize Tavily Extract client.
-        
+
         Args:
             urls: URL or list of URLs to extract content from.
             api_key: The Tavily API key.
             include_images: Whether to include images in the extraction.
             extract_depth: Depth of extraction, either "basic" or "advanced".
-                advanced extraction retrieves more data, including tables and 
+                advanced extraction retrieves more data, including tables and
                 embedded content, with higher success but may increase latency.
                 basic costs 1 credit per 5 successful URL extractions,
                 advanced costs 2 credits per 5 successful URL extractions.
@@ -43,35 +45,28 @@ class TavilyLoader(BaseLoader):
         """
         if not urls:
             raise ValueError("At least one URL must be provided.")
-            
+
         self.api_key = api_key
         self.urls = urls if isinstance(urls, list) else [urls]
         self.extract_depth = extract_depth
         self.continue_on_failure = continue_on_failure
         self.api_url = "https://api.tavily.com/extract"
-        
+
     def lazy_load(self) -> Iterator[Document]:
         """Extract and yield documents from the URLs using Tavily Extract API."""
         batch_size = 20
         for i in range(0, len(self.urls), batch_size):
-            batch_urls = self.urls[i:i + batch_size]
+            batch_urls = self.urls[i : i + batch_size]
             try:
                 headers = {
                     "Content-Type": "application/json",
-                    "Authorization": f"Bearer {self.api_key}"
+                    "Authorization": f"Bearer {self.api_key}",
                 }
                 # Use string for single URL, array for multiple URLs
                 urls_param = batch_urls[0] if len(batch_urls) == 1 else batch_urls
-                payload = {
-                    "urls": urls_param,
-                    "extract_depth": self.extract_depth
-                }
+                payload = {"urls": urls_param, "extract_depth": self.extract_depth}
                 # Make the API call
-                response = requests.post(
-                    self.api_url,
-                    headers=headers,
-                    json=payload
-                )
+                response = requests.post(self.api_url, headers=headers, json=payload)
                 response.raise_for_status()
                 response_data = response.json()
                 # Process successful results
@@ -95,4 +90,4 @@ class TavilyLoader(BaseLoader):
                 if self.continue_on_failure:
                     log.error(f"Error extracting content from batch {batch_urls}: {e}")
                 else:
-                    raise e
+                    raise e

+ 40 - 51
backend/open_webui/retrieval/vector/dbs/opensearch.py

@@ -21,14 +21,14 @@ class OpenSearchClient:
             verify_certs=OPENSEARCH_CERT_VERIFY,
             http_auth=(OPENSEARCH_USERNAME, OPENSEARCH_PASSWORD),
         )
-    
+
     def _get_index_name(self, collection_name: str) -> str:
         return f"{self.index_prefix}_{collection_name}"
 
     def _result_to_get_result(self, result) -> GetResult:
         if not result["hits"]["hits"]:
             return None
-        
+
         ids = []
         documents = []
         metadatas = []
@@ -43,7 +43,7 @@ class OpenSearchClient:
     def _result_to_search_result(self, result) -> SearchResult:
         if not result["hits"]["hits"]:
             return None
-        
+
         ids = []
         distances = []
         documents = []
@@ -56,16 +56,15 @@ class OpenSearchClient:
             metadatas.append(hit["_source"].get("metadata"))
 
         return SearchResult(
-            ids=[ids], distances=[distances], documents=[documents], metadatas=[metadatas]
+            ids=[ids],
+            distances=[distances],
+            documents=[documents],
+            metadatas=[metadatas],
         )
 
     def _create_index(self, collection_name: str, dimension: int):
         body = {
-            "settings": {
-                "index": {
-                "knn": True
-                }
-            },
+            "settings": {"index": {"knn": True}},
             "mappings": {
                 "properties": {
                     "id": {"type": "keyword"},
@@ -81,13 +80,13 @@ class OpenSearchClient:
                             "parameters": {
                                 "ef_construction": 128,
                                 "m": 16,
-                            }
+                            },
                         },
                     },
                     "text": {"type": "text"},
                     "metadata": {"type": "object"},
                 }
-            }
+            },
         }
         self.client.indices.create(
             index=self._get_index_name(collection_name), body=body
@@ -100,9 +99,7 @@ class OpenSearchClient:
     def has_collection(self, collection_name: str) -> bool:
         # has_collection here means has index.
         # We are simply adapting to the norms of the other DBs.
-        return self.client.indices.exists(
-            index=self._get_index_name(collection_name)
-        )
+        return self.client.indices.exists(index=self._get_index_name(collection_name))
 
     def delete_collection(self, collection_name: str):
         # delete_collection here means delete index.
@@ -115,33 +112,30 @@ class OpenSearchClient:
         try:
             if not self.has_collection(collection_name):
                 return None
-            
+
             query = {
                 "size": limit,
                 "_source": ["text", "metadata"],
                 "query": {
                     "script_score": {
-                        "query": {
-                            "match_all": {}
-                        },
+                        "query": {"match_all": {}},
                         "script": {
                             "source": "cosineSimilarity(params.query_value, doc[params.field]) + 1.0",
                             "params": {
-                            "field": "vector",
-                            "query_value": vectors[0]
+                                "field": "vector",
+                                "query_value": vectors[0],
                             },  # Assuming single query vector
                         },
                     }
                 },
             }
-            
+
             result = self.client.search(
-                index=self._get_index_name(collection_name),
-                body=query
+                index=self._get_index_name(collection_name), body=query
             )
 
             return self._result_to_search_result(result)
-        
+
         except Exception as e:
             return None
 
@@ -152,20 +146,14 @@ class OpenSearchClient:
             return None
 
         query_body = {
-            "query": {
-                "bool": {
-                    "filter": []
-                }
-            },
+            "query": {"bool": {"filter": []}},
             "_source": ["text", "metadata"],
         }
 
         for field, value in filter.items():
-            query_body["query"]["bool"]["filter"].append({
-                "match": {
-                    "metadata." + str(field): value
-                }
-            })
+            query_body["query"]["bool"]["filter"].append(
+                {"match": {"metadata." + str(field): value}}
+            )
 
         size = limit if limit else 10
 
@@ -201,9 +189,9 @@ class OpenSearchClient:
         for batch in self._create_batches(items):
             actions = [
                 {
-                    "_op_type": "index", 
+                    "_op_type": "index",
                     "_index": self._get_index_name(collection_name),
-                    "_id": item["id"], 
+                    "_id": item["id"],
                     "_source": {
                         "vector": item["vector"],
                         "text": item["text"],
@@ -222,9 +210,9 @@ class OpenSearchClient:
         for batch in self._create_batches(items):
             actions = [
                 {
-                    "_op_type": "update", 
+                    "_op_type": "update",
                     "_index": self._get_index_name(collection_name),
-                    "_id": item["id"], 
+                    "_id": item["id"],
                     "doc": {
                         "vector": item["vector"],
                         "text": item["text"],
@@ -236,7 +224,12 @@ class OpenSearchClient:
             ]
             bulk(self.client, actions)
 
-    def delete(self, collection_name: str, ids: Optional[list[str]] = None, filter: Optional[dict] = None):
+    def delete(
+        self,
+        collection_name: str,
+        ids: Optional[list[str]] = None,
+        filter: Optional[dict] = None,
+    ):
         if ids:
             actions = [
                 {
@@ -249,20 +242,16 @@ class OpenSearchClient:
             bulk(self.client, actions)
         elif filter:
             query_body = {
-                "query": {
-                    "bool": {
-                        "filter": []
-                    }
-                },
+                "query": {"bool": {"filter": []}},
             }
             for field, value in filter.items():
-                query_body["query"]["bool"]["filter"].append({
-                    "match": {
-                        "metadata." + str(field): value
-                    }
-                })
-            self.client.delete_by_query(index=self._get_index_name(collection_name), body=query_body)
-                
+                query_body["query"]["bool"]["filter"].append(
+                    {"match": {"metadata." + str(field): value}}
+                )
+            self.client.delete_by_query(
+                index=self._get_index_name(collection_name), body=query_body
+            )
+
     def reset(self):
         indices = self.client.indices.get(index=f"{self.index_prefix}_*")
         for index in indices:

+ 8 - 9
backend/open_webui/retrieval/web/utils.py

@@ -136,18 +136,18 @@ class RateLimitMixin:
         self.last_request_time = datetime.now()
 
 
-class URLProcessingMixin:  
+class URLProcessingMixin:
     def _verify_ssl_cert(self, url: str) -> bool:
         """Verify SSL certificate for a URL."""
         return verify_ssl_cert(url)
-        
+
     async def _safe_process_url(self, url: str) -> bool:
         """Perform safety checks before processing a URL."""
         if self.verify_ssl and not self._verify_ssl_cert(url):
             raise ValueError(f"SSL certificate verification failed for {url}")
         await self._wait_for_rate_limit()
         return True
-    
+
     def _safe_process_url_sync(self, url: str) -> bool:
         """Synchronous version of safety checks."""
         if self.verify_ssl and not self._verify_ssl_cert(url):
@@ -286,7 +286,7 @@ class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
                     proxy["server"] = env_proxy_server
                 else:
                     proxy = {"server": env_proxy_server}
-                    
+
         # Store parameters for creating TavilyLoader instances
         self.web_paths = web_paths if isinstance(web_paths, list) else [web_paths]
         self.api_key = api_key
@@ -295,7 +295,7 @@ class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
         self.verify_ssl = verify_ssl
         self.trust_env = trust_env
         self.proxy = proxy
-        
+
         # Add rate limiting
         self.requests_per_second = requests_per_second
         self.last_request_time = None
@@ -329,7 +329,7 @@ class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
                 log.exception(e, "Error extracting content from URLs")
             else:
                 raise e
-    
+
     async def alazy_load(self) -> AsyncIterator[Document]:
         """Async version with rate limiting and SSL verification."""
         valid_urls = []
@@ -341,13 +341,13 @@ class SafeTavilyLoader(BaseLoader, RateLimitMixin, URLProcessingMixin):
                 log.warning(f"SSL verification failed for {url}: {str(e)}")
                 if not self.continue_on_failure:
                     raise e
-        
+
         if not valid_urls:
             if self.continue_on_failure:
                 log.warning("No valid URLs to process after SSL verification")
                 return
             raise ValueError("No valid URLs to process after SSL verification")
-        
+
         try:
             loader = TavilyLoader(
                 urls=valid_urls,
@@ -477,7 +477,6 @@ class SafePlaywrightURLLoader(PlaywrightURLLoader, RateLimitMixin, URLProcessing
             await browser.close()
 
 
-
 class SafeWebBaseLoader(WebBaseLoader):
     """WebBaseLoader with enhanced error handling for URLs."""
 

+ 0 - 0
backend/open_webui/utils/trace/__init__.py


+ 26 - 0
backend/open_webui/utils/trace/constants.py

@@ -0,0 +1,26 @@
+from opentelemetry.semconv.trace import SpanAttributes as _SpanAttributes
+
+# Span Tags
+SPAN_DB_TYPE = "mysql"
+SPAN_REDIS_TYPE = "redis"
+SPAN_DURATION = "duration"
+SPAN_SQL_STR = "sql"
+SPAN_SQL_EXPLAIN = "explain"
+SPAN_ERROR_TYPE = "error"
+
+
+class SpanAttributes(_SpanAttributes):
+    """
+    Span Attributes
+    """
+
+    DB_INSTANCE = "db.instance"
+    DB_TYPE = "db.type"
+    DB_IP = "db.ip"
+    DB_PORT = "db.port"
+    ERROR_KIND = "error.kind"
+    ERROR_OBJECT = "error.object"
+    ERROR_MESSAGE = "error.message"
+    RESULT_CODE = "result.code"
+    RESULT_MESSAGE = "result.message"
+    RESULT_ERRORS = "result.errors"

+ 31 - 0
backend/open_webui/utils/trace/exporters.py

@@ -0,0 +1,31 @@
+import threading
+
+from opentelemetry.sdk.trace import ReadableSpan
+from opentelemetry.sdk.trace.export import BatchSpanProcessor
+
+
+class LazyBatchSpanProcessor(BatchSpanProcessor):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.done = True
+        with self.condition:
+            self.condition.notify_all()
+        self.worker_thread.join()
+        self.done = False
+        self.worker_thread = None
+
+    def on_end(self, span: ReadableSpan) -> None:
+        if self.worker_thread is None:
+            self.worker_thread = threading.Thread(
+                name=self.__class__.__name__, target=self.worker, daemon=True
+            )
+            self.worker_thread.start()
+        super().on_end(span)
+
+    def shutdown(self) -> None:
+        self.done = True
+        with self.condition:
+            self.condition.notify_all()
+        if self.worker_thread:
+            self.worker_thread.join()
+        self.span_exporter.shutdown()

+ 155 - 0
backend/open_webui/utils/trace/instrumentors.py

@@ -0,0 +1,155 @@
+import logging
+import traceback
+from typing import Collection
+
+from chromadb.telemetry.opentelemetry.fastapi import instrument_fastapi
+from opentelemetry.instrumentation.httpx import (
+    HTTPXClientInstrumentor,
+    RequestInfo,
+    ResponseInfo,
+)
+from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
+from opentelemetry.instrumentation.logging import LoggingInstrumentor
+from opentelemetry.instrumentation.redis import RedisInstrumentor
+from opentelemetry.instrumentation.requests import RequestsInstrumentor
+from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
+from opentelemetry.instrumentation.aiohttp_client import AioHttpClientInstrumentor
+from opentelemetry.trace import Span, StatusCode
+from redis import Redis
+from requests import PreparedRequest, Response
+
+from open_webui.utils.trace.constants import SPAN_REDIS_TYPE, SpanAttributes
+
+from open_webui.env import SRC_LOG_LEVELS
+
+logger = logging.getLogger(__name__)
+logger.setLevel(SRC_LOG_LEVELS["MAIN"])
+
+
+def requests_hook(span: Span, request: PreparedRequest):
+    """
+    Http Request Hook
+    """
+
+    span.update_name(f"{request.method} {request.url}")
+    span.set_attributes(
+        attributes={
+            SpanAttributes.HTTP_URL: request.url,
+            SpanAttributes.HTTP_METHOD: request.method,
+        }
+    )
+
+
+def response_hook(span: Span, request: PreparedRequest, response: Response):
+    """
+    HTTP Response Hook
+    """
+
+    span.set_attributes(
+        attributes={
+            SpanAttributes.HTTP_STATUS_CODE: response.status_code,
+        }
+    )
+    span.set_status(StatusCode.ERROR if response.status_code >= 400 else StatusCode.OK)
+
+
+def redis_request_hook(span: Span, instance: Redis, args, kwargs):
+    """
+    Redis Request Hook
+    """
+
+    try:
+        connection_kwargs: dict = instance.connection_pool.connection_kwargs
+        host = connection_kwargs.get("host")
+        port = connection_kwargs.get("port")
+        db = connection_kwargs.get("db")
+        span.set_attributes(
+            {
+                SpanAttributes.DB_INSTANCE: f"{host}/{db}",
+                SpanAttributes.DB_NAME: f"{host}/{db}",
+                SpanAttributes.DB_TYPE: SPAN_REDIS_TYPE,
+                SpanAttributes.DB_PORT: port,
+                SpanAttributes.DB_IP: host,
+                SpanAttributes.DB_STATEMENT: " ".join([str(i) for i in args]),
+                SpanAttributes.DB_OPERATION: str(args[0]),
+            }
+        )
+    except Exception:  # pylint: disable=W0718
+        logger.error(traceback.format_exc())
+
+
+def httpx_request_hook(span: Span, request: RequestInfo):
+    """
+    HTTPX Request Hook
+    """
+
+    span.update_name(f"{request.method.decode()} {str(request.url)}")
+    span.set_attributes(
+        attributes={
+            SpanAttributes.HTTP_URL: str(request.url),
+            SpanAttributes.HTTP_METHOD: request.method.decode(),
+        }
+    )
+
+
+def httpx_response_hook(span: Span, request: RequestInfo, response: ResponseInfo):
+    """
+    HTTPX Response Hook
+    """
+
+    span.set_attribute(SpanAttributes.HTTP_STATUS_CODE, response.status_code)
+    span.set_status(
+        StatusCode.ERROR
+        if response.status_code >= status.HTTP_400_BAD_REQUEST
+        else StatusCode.OK
+    )
+
+
+async def httpx_async_request_hook(span, request):
+    """
+    Async Request Hook
+    """
+
+    httpx_request_hook(span, request)
+
+
+async def httpx_async_response_hook(span, request, response):
+    """
+    Async Response Hook
+    """
+
+    httpx_response_hook(span, request, response)
+
+
+class Instrumentor(BaseInstrumentor):
+    """
+    Instrument OT
+    """
+
+    def __init__(self, app):
+        self.app = app
+
+    def instrumentation_dependencies(self) -> Collection[str]:
+        return []
+
+    def _instrument(self, **kwargs):
+        instrument_fastapi(app=self.app)
+        SQLAlchemyInstrumentor().instrument()
+        RedisInstrumentor().instrument(request_hook=redis_request_hook)
+        RequestsInstrumentor().instrument(
+            request_hook=requests_hook, response_hook=response_hook
+        )
+        LoggingInstrumentor().instrument()
+        HTTPXClientInstrumentor().instrument(
+            request_hook=httpx_request_hook,
+            response_hook=httpx_response_hook,
+            async_request_hook=httpx_async_request_hook,
+            async_response_hook=httpx_async_response_hook,
+        )
+        AioHttpClientInstrumentor().instrument()
+
+    def _uninstrument(self, **kwargs):
+        if getattr(self, "instrumentors", None) is None:
+            return
+        for instrumentor in self.instrumentors:
+            instrumentor.uninstrument()

+ 24 - 0
backend/open_webui/utils/trace/setup.py

@@ -0,0 +1,24 @@
+from opentelemetry import trace
+from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
+from opentelemetry.sdk.resources import SERVICE_NAME, Resource
+from opentelemetry.sdk.trace import TracerProvider
+from opentelemetry.sdk.trace.sampling import ALWAYS_ON
+
+from open_webui.utils.trace.exporters import LazyBatchSpanProcessor
+from open_webui.utils.trace.instrumentors import Instrumentor
+from open_webui.env import OT_SERVICE_NAME, OT_HOST, OT_TOKEN
+
+
+def setup(app):
+    trace.set_tracer_provider(
+        TracerProvider(
+            resource=Resource.create(
+                {SERVICE_NAME: OT_SERVICE_NAME, "token": OT_TOKEN}
+            ),
+            sampler=ALWAYS_ON,
+        )
+    )
+    # otlp
+    exporter = OTLPSpanExporter(endpoint=OT_HOST)
+    trace.get_tracer_provider().add_span_processor(LazyBatchSpanProcessor(exporter))
+    Instrumentor(app=app).instrument()

+ 15 - 1
backend/requirements.txt

@@ -37,7 +37,7 @@ asgiref==3.8.1
 # AI libraries
 openai
 anthropic
-google-generativeai==0.7.2
+google-generativeai==0.8.4
 tiktoken
 
 langchain==0.3.19
@@ -118,3 +118,17 @@ ldap3==2.9.1
 
 ## Firecrawl
 firecrawl-py==1.12.0
+
+## Trace
+opentelemetry-api==1.30.0
+opentelemetry-sdk==1.30.0
+opentelemetry-exporter-otlp==1.30.0
+opentelemetry-instrumentation==0.51b0
+opentelemetry-instrumentation-fastapi==0.51b0
+opentelemetry-instrumentation-sqlalchemy==0.51b0
+opentelemetry-instrumentation-redis==0.51b0
+opentelemetry-instrumentation-requests==0.51b0
+opentelemetry-instrumentation-logging==0.51b0
+opentelemetry-instrumentation-httpx==0.51b0
+opentelemetry-instrumentation-aiohttp-client==0.51b0
+opentelemetry-instrumentation-loguru==0.51b0