Переглянути джерело

Initialize support for prefixing embeddings

jvinolus 4 місяців тому
батько
коміт
47b8412695

+ 12 - 0
backend/open_webui/config.py

@@ -1330,6 +1330,18 @@ RAG_EMBEDDING_BATCH_SIZE = PersistentConfig(
     ),
 )
 
+RAG_EMBEDDING_PASSAGE_PREFIX = PersistentConfig(
+    "RAG_EMBEDDING_PASSAGE_PREFIX",
+    "rag.embedding_passage_prefix",
+    os.environ.get("RAG_EMBEDDING_PASSAGE_PREFIX", False),
+)
+
+RAG_EMBEDDING_QUERY_PREFIX = PersistentConfig(
+    "RAG_EMBEDDING_QUERY_PREFIX",
+    "rag.embedding_query_prefix",
+    os.environ.get("RAG_EMBEDDING_QUERY_PREFIX", False),
+)
+
 RAG_RERANKING_MODEL = PersistentConfig(
     "RAG_RERANKING_MODEL",
     "rag.reranking_model",

+ 21 - 19
backend/open_webui/retrieval/utils.py

@@ -15,7 +15,7 @@ from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
 from open_webui.utils.misc import get_last_user_message
 
 from open_webui.env import SRC_LOG_LEVELS, OFFLINE_MODE
-
+from open_webui.config import RAG_EMBEDDING_QUERY_PREFIX, RAG_EMBEDDING_PASSAGE_PREFIX
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 
@@ -39,7 +39,7 @@ class VectorSearchRetriever(BaseRetriever):
     ) -> list[Document]:
         result = VECTOR_DB_CLIENT.search(
             collection_name=self.collection_name,
-            vectors=[self.embedding_function(query)],
+            vectors=[self.embedding_function(query,RAG_EMBEDDING_QUERY_PREFIX)],
             limit=self.top_k,
         )
 
@@ -183,7 +183,7 @@ def query_collection(
 ) -> dict:
     results = []
     for query in queries:
-        query_embedding = embedding_function(query)
+        query_embedding = embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
         for collection_name in collection_names:
             if collection_name:
                 try:
@@ -247,26 +247,27 @@ def get_embedding_function(
     embedding_batch_size,
 ):
     if embedding_engine == "":
-        return lambda query: embedding_function.encode(query).tolist()
+        return lambda query, prefix: embedding_function.encode(query, prompt = prefix if prefix else None).tolist()
     elif embedding_engine in ["ollama", "openai"]:
-        func = lambda query: generate_embeddings(
+        func = lambda query, prefix: generate_embeddings(
             engine=embedding_engine,
             model=embedding_model,
             text=query,
+            prefix=prefix,
             url=url,
             key=key,
         )
 
-        def generate_multiple(query, func):
+        def generate_multiple(query, prefix, func):
             if isinstance(query, list):
                 embeddings = []
                 for i in range(0, len(query), embedding_batch_size):
-                    embeddings.extend(func(query[i : i + embedding_batch_size]))
+                    embeddings.extend(func(query[i : i + embedding_batch_size], prefix))
                 return embeddings
             else:
                 return func(query)
 
-        return lambda query: generate_multiple(query, func)
+        return lambda query, prefix: generate_multiple(query, prefix, func)
 
 
 def get_sources_from_files(
@@ -411,7 +412,7 @@ def get_model_path(model: str, update_model: bool = False):
 
 
 def generate_openai_batch_embeddings(
-    model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = ""
+    model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = "", prefix: str = None
 ) -> Optional[list[list[float]]]:
     try:
         r = requests.post(
@@ -420,7 +421,7 @@ def generate_openai_batch_embeddings(
                 "Content-Type": "application/json",
                 "Authorization": f"Bearer {key}",
             },
-            json={"input": texts, "model": model},
+            json={"input": texts, "model": model} if not prefix else {"input": texts, "model": model, "prefix": prefix},
         )
         r.raise_for_status()
         data = r.json()
@@ -434,7 +435,7 @@ def generate_openai_batch_embeddings(
 
 
 def generate_ollama_batch_embeddings(
-    model: str, texts: list[str], url: str, key: str = ""
+    model: str, texts: list[str], url: str, key: str = "", prefix: str = None 
 ) -> Optional[list[list[float]]]:
     try:
         r = requests.post(
@@ -443,7 +444,7 @@ def generate_ollama_batch_embeddings(
                 "Content-Type": "application/json",
                 "Authorization": f"Bearer {key}",
             },
-            json={"input": texts, "model": model},
+            json={"input": texts, "model": model} if not prefix else {"input": texts, "model": model, "prefix": prefix},
         )
         r.raise_for_status()
         data = r.json()
@@ -457,25 +458,25 @@ def generate_ollama_batch_embeddings(
         return None
 
 
-def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
+def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], prefix: Union[str , None] = None, **kwargs):
     url = kwargs.get("url", "")
     key = kwargs.get("key", "")
 
     if engine == "ollama":
         if isinstance(text, list):
             embeddings = generate_ollama_batch_embeddings(
-                **{"model": model, "texts": text, "url": url, "key": key}
+                **{"model": model, "texts": text, "url": url, "key": key, "prefix": prefix}
             )
         else:
             embeddings = generate_ollama_batch_embeddings(
-                **{"model": model, "texts": [text], "url": url, "key": key}
+                **{"model": model, "texts": [text], "url": url, "key": key, "prefix": prefix}
             )
         return embeddings[0] if isinstance(text, str) else embeddings
     elif engine == "openai":
         if isinstance(text, list):
-            embeddings = generate_openai_batch_embeddings(model, text, url, key)
+            embeddings = generate_openai_batch_embeddings(model, text, url, key, prefix)
         else:
-            embeddings = generate_openai_batch_embeddings(model, [text], url, key)
+            embeddings = generate_openai_batch_embeddings(model, [text], url, key, prefix)
 
         return embeddings[0] if isinstance(text, str) else embeddings
 
@@ -512,9 +513,10 @@ class RerankCompressor(BaseDocumentCompressor):
         else:
             from sentence_transformers import util
 
-            query_embedding = self.embedding_function(query)
+            query_embedding = self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
             document_embedding = self.embedding_function(
-                [doc.page_content for doc in documents]
+                [doc.page_content for doc in documents], 
+                RAG_EMBEDDING_PASSAGE_PREFIX
             )
             scores = util.cos_sim(query_embedding, document_embedding)[0]
 

+ 2 - 1
backend/open_webui/routers/retrieval.py

@@ -79,6 +79,7 @@ from open_webui.config import (
     RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
     UPLOAD_DIR,
     DEFAULT_LOCALE,
+    RAG_EMBEDDING_PASSAGE_PREFIX
 )
 from open_webui.env import (
     SRC_LOG_LEVELS,
@@ -775,7 +776,7 @@ def save_docs_to_vector_db(
         )
 
         embeddings = embedding_function(
-            list(map(lambda x: x.replace("\n", " "), texts))
+            list(map(lambda x: x.replace("\n", " "), texts)), RAG_EMBEDDING_PASSAGE_PREFIX
         )
 
         items = [