Browse Source

refac: embedding prefix var naming

Timothy Jaeryang Baek 3 months ago
parent
commit
4b75966401
2 changed files with 63 additions and 41 deletions
  1. 4 8
      backend/open_webui/config.py
  2. 59 33
      backend/open_webui/retrieval/utils.py

+ 4 - 8
backend/open_webui/config.py

@@ -1783,16 +1783,12 @@ RAG_EMBEDDING_BATCH_SIZE = PersistentConfig(
     ),
     ),
 )
 )
 
 
-RAG_EMBEDDING_QUERY_PREFIX = (
-    os.environ.get("RAG_EMBEDDING_QUERY_PREFIX", None)
-)
+RAG_EMBEDDING_QUERY_PREFIX = os.environ.get("RAG_EMBEDDING_QUERY_PREFIX", None)
 
 
-RAG_EMBEDDING_PASSAGE_PREFIX = ( 
-    os.environ.get("RAG_EMBEDDING_PASSAGE_PREFIX", None)
-)
+RAG_EMBEDDING_CONTENT_PREFIX = os.environ.get("RAG_EMBEDDING_CONTENT_PREFIX", None)
 
 
-RAG_EMBEDDING_PREFIX_FIELD_NAME = (
-    os.environ.get("RAG_EMBEDDING_PREFIX_FIELD_NAME", None)
+RAG_EMBEDDING_PREFIX_FIELD_NAME = os.environ.get(
+    "RAG_EMBEDDING_PREFIX_FIELD_NAME", None
 )
 )
 
 
 RAG_RERANKING_MODEL = PersistentConfig(
 RAG_RERANKING_MODEL = PersistentConfig(

+ 59 - 33
backend/open_webui/retrieval/utils.py

@@ -25,9 +25,9 @@ from open_webui.env import (
     ENABLE_FORWARD_USER_INFO_HEADERS,
     ENABLE_FORWARD_USER_INFO_HEADERS,
 )
 )
 from open_webui.config import (
 from open_webui.config import (
-    RAG_EMBEDDING_QUERY_PREFIX, 
-    RAG_EMBEDDING_PASSAGE_PREFIX, 
-    RAG_EMBEDDING_PREFIX_FIELD_NAME
+    RAG_EMBEDDING_QUERY_PREFIX,
+    RAG_EMBEDDING_CONTENT_PREFIX,
+    RAG_EMBEDDING_PREFIX_FIELD_NAME,
 )
 )
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
@@ -53,7 +53,7 @@ class VectorSearchRetriever(BaseRetriever):
     ) -> list[Document]:
     ) -> list[Document]:
         result = VECTOR_DB_CLIENT.search(
         result = VECTOR_DB_CLIENT.search(
             collection_name=self.collection_name,
             collection_name=self.collection_name,
-            vectors=[self.embedding_function(query,RAG_EMBEDDING_QUERY_PREFIX)],
+            vectors=[self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)],
             limit=self.top_k,
             limit=self.top_k,
         )
         )
 
 
@@ -334,7 +334,9 @@ def get_embedding_function(
     embedding_batch_size,
     embedding_batch_size,
 ):
 ):
     if embedding_engine == "":
     if embedding_engine == "":
-        return lambda query, prefix, user=None: embedding_function.encode(query, prompt = prefix if prefix else None).tolist()
+        return lambda query, prefix, user=None: embedding_function.encode(
+            query, prompt=prefix if prefix else None
+        ).tolist()
     elif embedding_engine in ["ollama", "openai"]:
     elif embedding_engine in ["ollama", "openai"]:
         func = lambda query, prefix, user=None: generate_embeddings(
         func = lambda query, prefix, user=None: generate_embeddings(
             engine=embedding_engine,
             engine=embedding_engine,
@@ -345,22 +347,29 @@ def get_embedding_function(
             key=key,
             key=key,
             user=user,
             user=user,
         )
         )
+
         def generate_multiple(query, prefix, user, func):
         def generate_multiple(query, prefix, user, func):
             if isinstance(query, list):
             if isinstance(query, list):
                 embeddings = []
                 embeddings = []
                 for i in range(0, len(query), embedding_batch_size):
                 for i in range(0, len(query), embedding_batch_size):
                     embeddings.extend(
                     embeddings.extend(
-                        func(query[i : i + embedding_batch_size], prefix=prefix, user=user)
+                        func(
+                            query[i : i + embedding_batch_size],
+                            prefix=prefix,
+                            user=user,
+                        )
                     )
                     )
                 return embeddings
                 return embeddings
             else:
             else:
                 return func(query, prefix, user)
                 return func(query, prefix, user)
-        return lambda query, prefix, user=None: generate_multiple(query, prefix, user, func)
+
+        return lambda query, prefix, user=None: generate_multiple(
+            query, prefix, user, func
+        )
     else:
     else:
         raise ValueError(f"Unknown embedding engine: {embedding_engine}")
         raise ValueError(f"Unknown embedding engine: {embedding_engine}")
 
 
 
 
-
 def get_sources_from_files(
 def get_sources_from_files(
     request,
     request,
     files,
     files,
@@ -579,14 +588,11 @@ def generate_openai_batch_embeddings(
     url: str = "https://api.openai.com/v1",
     url: str = "https://api.openai.com/v1",
     key: str = "",
     key: str = "",
     prefix: str = None,
     prefix: str = None,
-    user: UserModel = None
+    user: UserModel = None,
 ) -> Optional[list[list[float]]]:
 ) -> Optional[list[list[float]]]:
     try:
     try:
-        json_data = {
-            "input": texts, 
-            "model": model
-        }
-        if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME,str) and isinstance(prefix,str):
+        json_data = {"input": texts, "model": model}
+        if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
             json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
             json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
 
 
         r = requests.post(
         r = requests.post(
@@ -619,21 +625,18 @@ def generate_openai_batch_embeddings(
 
 
 
 
 def generate_ollama_batch_embeddings(
 def generate_ollama_batch_embeddings(
-    model: str, 
+    model: str,
     texts: list[str],
     texts: list[str],
     url: str,
     url: str,
-    key: str = "", 
-    prefix: str = None, 
-    user: UserModel = None
+    key: str = "",
+    prefix: str = None,
+    user: UserModel = None,
 ) -> Optional[list[list[float]]]:
 ) -> Optional[list[list[float]]]:
     try:
     try:
-        json_data = {
-            "input": texts, 
-            "model": model
-        }
-        if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME,str) and isinstance(prefix,str):
+        json_data = {"input": texts, "model": model}
+        if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
             json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
             json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
-            
+
         r = requests.post(
         r = requests.post(
             f"{url}/api/embed",
             f"{url}/api/embed",
             headers={
             headers={
@@ -664,32 +667,56 @@ def generate_ollama_batch_embeddings(
         return None
         return None
 
 
 
 
-def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], prefix: Union[str , None] = None, **kwargs):
+def generate_embeddings(
+    engine: str,
+    model: str,
+    text: Union[str, list[str]],
+    prefix: Union[str, None] = None,
+    **kwargs,
+):
     url = kwargs.get("url", "")
     url = kwargs.get("url", "")
     key = kwargs.get("key", "")
     key = kwargs.get("key", "")
     user = kwargs.get("user")
     user = kwargs.get("user")
 
 
     if prefix is not None and RAG_EMBEDDING_PREFIX_FIELD_NAME is None:
     if prefix is not None and RAG_EMBEDDING_PREFIX_FIELD_NAME is None:
         if isinstance(text, list):
         if isinstance(text, list):
-            text = [f'{prefix}{text_element}' for text_element in text]
+            text = [f"{prefix}{text_element}" for text_element in text]
         else:
         else:
-            text = f'{prefix}{text}'
+            text = f"{prefix}{text}"
 
 
     if engine == "ollama":
     if engine == "ollama":
         if isinstance(text, list):
         if isinstance(text, list):
             embeddings = generate_ollama_batch_embeddings(
             embeddings = generate_ollama_batch_embeddings(
-                **{"model": model, "texts": text, "url": url, "key": key, "prefix": prefix, "user": user}
+                **{
+                    "model": model,
+                    "texts": text,
+                    "url": url,
+                    "key": key,
+                    "prefix": prefix,
+                    "user": user,
+                }
             )
             )
         else:
         else:
             embeddings = generate_ollama_batch_embeddings(
             embeddings = generate_ollama_batch_embeddings(
-                **{"model": model, "texts": [text], "url": url, "key": key, "prefix": prefix, "user": user}
+                **{
+                    "model": model,
+                    "texts": [text],
+                    "url": url,
+                    "key": key,
+                    "prefix": prefix,
+                    "user": user,
+                }
             )
             )
         return embeddings[0] if isinstance(text, str) else embeddings
         return embeddings[0] if isinstance(text, str) else embeddings
     elif engine == "openai":
     elif engine == "openai":
         if isinstance(text, list):
         if isinstance(text, list):
-            embeddings = generate_openai_batch_embeddings(model, text, url, key, prefix, user)
+            embeddings = generate_openai_batch_embeddings(
+                model, text, url, key, prefix, user
+            )
         else:
         else:
-            embeddings = generate_openai_batch_embeddings(model, [text], url, key, prefix, user)
+            embeddings = generate_openai_batch_embeddings(
+                model, [text], url, key, prefix, user
+            )
         return embeddings[0] if isinstance(text, str) else embeddings
         return embeddings[0] if isinstance(text, str) else embeddings
 
 
 
 
@@ -727,8 +754,7 @@ class RerankCompressor(BaseDocumentCompressor):
 
 
             query_embedding = self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
             query_embedding = self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
             document_embedding = self.embedding_function(
             document_embedding = self.embedding_function(
-                [doc.page_content for doc in documents], 
-                RAG_EMBEDDING_PASSAGE_PREFIX
+                [doc.page_content for doc in documents], RAG_EMBEDDING_CONTENT_PREFIX
             )
             )
             scores = util.cos_sim(query_embedding, document_embedding)[0]
             scores = util.cos_sim(query_embedding, document_embedding)[0]