Browse Source

Merge pull request #8594 from jayteaftw/main

feat: Support for instruct/prefixing embeddings
Timothy Jaeryang Baek 1 month ago
parent
commit
433b5bddc1

+ 12 - 0
backend/open_webui/config.py

@@ -1783,6 +1783,18 @@ RAG_EMBEDDING_BATCH_SIZE = PersistentConfig(
     ),
     ),
 )
 )
 
 
+RAG_EMBEDDING_QUERY_PREFIX = (
+    os.environ.get("RAG_EMBEDDING_QUERY_PREFIX", None)
+)
+
+RAG_EMBEDDING_PASSAGE_PREFIX = ( 
+    os.environ.get("RAG_EMBEDDING_PASSAGE_PREFIX", None)
+)
+
+RAG_EMBEDDING_PREFIX_FIELD_NAME = (
+    os.environ.get("RAG_EMBEDDING_PREFIX_FIELD_NAME", None)
+)
+
 RAG_RERANKING_MODEL = PersistentConfig(
 RAG_RERANKING_MODEL = PersistentConfig(
     "RAG_RERANKING_MODEL",
     "RAG_RERANKING_MODEL",
     "rag.reranking_model",
     "rag.reranking_model",

+ 54 - 28
backend/open_webui/retrieval/utils.py

@@ -18,11 +18,17 @@ from open_webui.models.files import Files
 
 
 from open_webui.retrieval.vector.main import GetResult
 from open_webui.retrieval.vector.main import GetResult
 
 
+
 from open_webui.env import (
 from open_webui.env import (
     SRC_LOG_LEVELS,
     SRC_LOG_LEVELS,
     OFFLINE_MODE,
     OFFLINE_MODE,
     ENABLE_FORWARD_USER_INFO_HEADERS,
     ENABLE_FORWARD_USER_INFO_HEADERS,
 )
 )
+from open_webui.config import (
+    RAG_EMBEDDING_QUERY_PREFIX, 
+    RAG_EMBEDDING_PASSAGE_PREFIX, 
+    RAG_EMBEDDING_PREFIX_FIELD_NAME
+)
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 log.setLevel(SRC_LOG_LEVELS["RAG"])
@@ -47,7 +53,7 @@ class VectorSearchRetriever(BaseRetriever):
     ) -> list[Document]:
     ) -> list[Document]:
         result = VECTOR_DB_CLIENT.search(
         result = VECTOR_DB_CLIENT.search(
             collection_name=self.collection_name,
             collection_name=self.collection_name,
-            vectors=[self.embedding_function(query)],
+            vectors=[self.embedding_function(query,RAG_EMBEDDING_QUERY_PREFIX)],
             limit=self.top_k,
             limit=self.top_k,
         )
         )
 
 
@@ -250,7 +256,7 @@ def query_collection(
 ) -> dict:
 ) -> dict:
     results = []
     results = []
     for query in queries:
     for query in queries:
-        query_embedding = embedding_function(query)
+        query_embedding = embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
         for collection_name in collection_names:
         for collection_name in collection_names:
             if collection_name:
             if collection_name:
                 try:
                 try:
@@ -328,33 +334,33 @@ def get_embedding_function(
     embedding_batch_size,
     embedding_batch_size,
 ):
 ):
     if embedding_engine == "":
     if embedding_engine == "":
-        return lambda query, user=None: embedding_function.encode(query).tolist()
+        return lambda query, prefix, user=None: embedding_function.encode(query, prompt = prefix if prefix else None).tolist()
     elif embedding_engine in ["ollama", "openai"]:
     elif embedding_engine in ["ollama", "openai"]:
-        func = lambda query, user=None: generate_embeddings(
+        func = lambda query, prefix, user=None: generate_embeddings(
             engine=embedding_engine,
             engine=embedding_engine,
             model=embedding_model,
             model=embedding_model,
             text=query,
             text=query,
+            prefix=prefix,
             url=url,
             url=url,
             key=key,
             key=key,
             user=user,
             user=user,
         )
         )
-
-        def generate_multiple(query, user, func):
+        def generate_multiple(query, prefix, user, func):
             if isinstance(query, list):
             if isinstance(query, list):
                 embeddings = []
                 embeddings = []
                 for i in range(0, len(query), embedding_batch_size):
                 for i in range(0, len(query), embedding_batch_size):
                     embeddings.extend(
                     embeddings.extend(
-                        func(query[i : i + embedding_batch_size], user=user)
+                        func(query[i : i + embedding_batch_size], prefix=prefix, user=user)
                     )
                     )
                 return embeddings
                 return embeddings
             else:
             else:
-                return func(query, user)
-
-        return lambda query, user=None: generate_multiple(query, user, func)
+                return func(query, prefix, user)
+        return lambda query, prefix, user=None: generate_multiple(query, prefix, user, func)
     else:
     else:
         raise ValueError(f"Unknown embedding engine: {embedding_engine}")
         raise ValueError(f"Unknown embedding engine: {embedding_engine}")
 
 
 
 
+
 def get_sources_from_files(
 def get_sources_from_files(
     request,
     request,
     files,
     files,
@@ -572,9 +578,17 @@ def generate_openai_batch_embeddings(
     texts: list[str],
     texts: list[str],
     url: str = "https://api.openai.com/v1",
     url: str = "https://api.openai.com/v1",
     key: str = "",
     key: str = "",
-    user: UserModel = None,
+    prefix: str = None,
+    user: UserModel = None
 ) -> Optional[list[list[float]]]:
 ) -> Optional[list[list[float]]]:
     try:
     try:
+        json_data = {
+            "input": texts, 
+            "model": model
+        }
+        if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME,str) and isinstance(prefix,str):
+            json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
+
         r = requests.post(
         r = requests.post(
             f"{url}/embeddings",
             f"{url}/embeddings",
             headers={
             headers={
@@ -591,7 +605,7 @@ def generate_openai_batch_embeddings(
                     else {}
                     else {}
                 ),
                 ),
             },
             },
-            json={"input": texts, "model": model},
+            json=json_data,
         )
         )
         r.raise_for_status()
         r.raise_for_status()
         data = r.json()
         data = r.json()
@@ -605,9 +619,21 @@ def generate_openai_batch_embeddings(
 
 
 
 
 def generate_ollama_batch_embeddings(
 def generate_ollama_batch_embeddings(
-    model: str, texts: list[str], url: str, key: str = "", user: UserModel = None
+    model: str, 
+    texts: list[str],
+    url: str,
+    key: str = "", 
+    prefix: str = None, 
+    user: UserModel = None
 ) -> Optional[list[list[float]]]:
 ) -> Optional[list[list[float]]]:
     try:
     try:
+        json_data = {
+            "input": texts, 
+            "model": model
+        }
+        if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME,str) and isinstance(prefix,str):
+            json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
+            
         r = requests.post(
         r = requests.post(
             f"{url}/api/embed",
             f"{url}/api/embed",
             headers={
             headers={
@@ -624,7 +650,7 @@ def generate_ollama_batch_embeddings(
                     else {}
                     else {}
                 ),
                 ),
             },
             },
-            json={"input": texts, "model": model},
+            json=json_data,
         )
         )
         r.raise_for_status()
         r.raise_for_status()
         data = r.json()
         data = r.json()
@@ -638,33 +664,32 @@ def generate_ollama_batch_embeddings(
         return None
         return None
 
 
 
 
-def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
+def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], prefix: Union[str , None] = None, **kwargs):
     url = kwargs.get("url", "")
     url = kwargs.get("url", "")
     key = kwargs.get("key", "")
     key = kwargs.get("key", "")
     user = kwargs.get("user")
     user = kwargs.get("user")
 
 
+    if prefix is not None and RAG_EMBEDDING_PREFIX_FIELD_NAME is None:
+        if isinstance(text, list):
+            text = [f'{prefix}{text_element}' for text_element in text]
+        else:
+            text = f'{prefix}{text}'
+
     if engine == "ollama":
     if engine == "ollama":
         if isinstance(text, list):
         if isinstance(text, list):
             embeddings = generate_ollama_batch_embeddings(
             embeddings = generate_ollama_batch_embeddings(
-                **{"model": model, "texts": text, "url": url, "key": key, "user": user}
+                **{"model": model, "texts": text, "url": url, "key": key, "prefix": prefix, "user": user}
             )
             )
         else:
         else:
             embeddings = generate_ollama_batch_embeddings(
             embeddings = generate_ollama_batch_embeddings(
-                **{
-                    "model": model,
-                    "texts": [text],
-                    "url": url,
-                    "key": key,
-                    "user": user,
-                }
+                **{"model": model, "texts": [text], "url": url, "key": key, "prefix": prefix, "user": user}
             )
             )
         return embeddings[0] if isinstance(text, str) else embeddings
         return embeddings[0] if isinstance(text, str) else embeddings
     elif engine == "openai":
     elif engine == "openai":
         if isinstance(text, list):
         if isinstance(text, list):
-            embeddings = generate_openai_batch_embeddings(model, text, url, key, user)
+            embeddings = generate_openai_batch_embeddings(model, text, url, key, prefix, user)
         else:
         else:
-            embeddings = generate_openai_batch_embeddings(model, [text], url, key, user)
-
+            embeddings = generate_openai_batch_embeddings(model, [text], url, key, prefix, user)
         return embeddings[0] if isinstance(text, str) else embeddings
         return embeddings[0] if isinstance(text, str) else embeddings
 
 
 
 
@@ -700,9 +725,10 @@ class RerankCompressor(BaseDocumentCompressor):
         else:
         else:
             from sentence_transformers import util
             from sentence_transformers import util
 
 
-            query_embedding = self.embedding_function(query)
+            query_embedding = self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
             document_embedding = self.embedding_function(
             document_embedding = self.embedding_function(
-                [doc.page_content for doc in documents]
+                [doc.page_content for doc in documents], 
+                RAG_EMBEDDING_PASSAGE_PREFIX
             )
             )
             scores = util.cos_sim(query_embedding, document_embedding)[0]
             scores = util.cos_sim(query_embedding, document_embedding)[0]
 
 

+ 6 - 4
backend/open_webui/routers/retrieval.py

@@ -74,7 +74,6 @@ from open_webui.utils.misc import (
 )
 )
 from open_webui.utils.auth import get_admin_user, get_verified_user
 from open_webui.utils.auth import get_admin_user, get_verified_user
 
 
-
 from open_webui.config import (
 from open_webui.config import (
     ENV,
     ENV,
     RAG_EMBEDDING_MODEL_AUTO_UPDATE,
     RAG_EMBEDDING_MODEL_AUTO_UPDATE,
@@ -83,6 +82,8 @@ from open_webui.config import (
     RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
     RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
     UPLOAD_DIR,
     UPLOAD_DIR,
     DEFAULT_LOCALE,
     DEFAULT_LOCALE,
+    RAG_EMBEDDING_PASSAGE_PREFIX,
+    RAG_EMBEDDING_QUERY_PREFIX
 )
 )
 from open_webui.env import (
 from open_webui.env import (
     SRC_LOG_LEVELS,
     SRC_LOG_LEVELS,
@@ -891,7 +892,7 @@ def save_docs_to_vector_db(
         )
         )
 
 
         embeddings = embedding_function(
         embeddings = embedding_function(
-            list(map(lambda x: x.replace("\n", " "), texts)), user=user
+            list(map(lambda x: x.replace("\n", " "), texts)), prefix=RAG_EMBEDDING_PASSAGE_PREFIX, user=user
         )
         )
 
 
         items = [
         items = [
@@ -1533,8 +1534,9 @@ def query_doc_handler(
             return query_doc(
             return query_doc(
                 collection_name=form_data.collection_name,
                 collection_name=form_data.collection_name,
                 query_embedding=request.app.state.EMBEDDING_FUNCTION(
                 query_embedding=request.app.state.EMBEDDING_FUNCTION(
-                    form_data.query, user=user
+                    form_data.query, prefix=RAG_EMBEDDING_QUERY_PREFIX, user=user
                 ),
                 ),
+
                 k=form_data.k if form_data.k else request.app.state.config.TOP_K,
                 k=form_data.k if form_data.k else request.app.state.config.TOP_K,
                 user=user,
                 user=user,
             )
             )
@@ -1661,7 +1663,7 @@ if ENV == "dev":
 
 
     @router.get("/ef/{text}")
     @router.get("/ef/{text}")
     async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"):
     async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"):
-        return {"result": request.app.state.EMBEDDING_FUNCTION(text)}
+        return {"result": request.app.state.EMBEDDING_FUNCTION(text, RAG_EMBEDDING_QUERY_PREFIX)}
 
 
 
 
 class BatchProcessFilesForm(BaseModel):
 class BatchProcessFilesForm(BaseModel):