Timothy Jaeryang Baek 1 month ago
parent
commit
cafc5413f5

+ 4 - 4
backend/open_webui/retrieval/utils.py

@@ -256,7 +256,7 @@ def query_collection(
 ) -> dict:
 ) -> dict:
     results = []
     results = []
     for query in queries:
     for query in queries:
-        query_embedding = embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
+        query_embedding = embedding_function(query, prefix=RAG_EMBEDDING_QUERY_PREFIX)
         for collection_name in collection_names:
         for collection_name in collection_names:
             if collection_name:
             if collection_name:
                 try:
                 try:
@@ -334,11 +334,11 @@ def get_embedding_function(
     embedding_batch_size,
     embedding_batch_size,
 ):
 ):
     if embedding_engine == "":
     if embedding_engine == "":
-        return lambda query, prefix, user=None: embedding_function.encode(
+        return lambda query, prefix=None, user=None: embedding_function.encode(
             query, prompt=prefix if prefix else None
             query, prompt=prefix if prefix else None
         ).tolist()
         ).tolist()
     elif embedding_engine in ["ollama", "openai"]:
     elif embedding_engine in ["ollama", "openai"]:
-        func = lambda query, prefix, user=None: generate_embeddings(
+        func = lambda query, prefix=None, user=None: generate_embeddings(
             engine=embedding_engine,
             engine=embedding_engine,
             model=embedding_model,
             model=embedding_model,
             text=query,
             text=query,
@@ -363,7 +363,7 @@ def get_embedding_function(
             else:
             else:
                 return func(query, prefix, user)
                 return func(query, prefix, user)
 
 
-        return lambda query, prefix, user=None: generate_multiple(
+        return lambda query, prefix=None, user=None: generate_multiple(
             query, prefix, user, func
             query, prefix, user, func
         )
         )
     else:
     else:

+ 8 - 4
backend/open_webui/routers/memories.py

@@ -57,7 +57,9 @@ async def add_memory(
             {
             {
                 "id": memory.id,
                 "id": memory.id,
                 "text": memory.content,
                 "text": memory.content,
-                "vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user),
+                "vector": request.app.state.EMBEDDING_FUNCTION(
+                    memory.content, user=user
+                ),
                 "metadata": {"created_at": memory.created_at},
                 "metadata": {"created_at": memory.created_at},
             }
             }
         ],
         ],
@@ -82,7 +84,7 @@ async def query_memory(
 ):
 ):
     results = VECTOR_DB_CLIENT.search(
     results = VECTOR_DB_CLIENT.search(
         collection_name=f"user-memory-{user.id}",
         collection_name=f"user-memory-{user.id}",
-        vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user)],
+        vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user=user)],
         limit=form_data.k,
         limit=form_data.k,
     )
     )
 
 
@@ -105,7 +107,9 @@ async def reset_memory_from_vector_db(
             {
             {
                 "id": memory.id,
                 "id": memory.id,
                 "text": memory.content,
                 "text": memory.content,
-                "vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user),
+                "vector": request.app.state.EMBEDDING_FUNCTION(
+                    memory.content, user=user
+                ),
                 "metadata": {
                 "metadata": {
                     "created_at": memory.created_at,
                     "created_at": memory.created_at,
                     "updated_at": memory.updated_at,
                     "updated_at": memory.updated_at,
@@ -161,7 +165,7 @@ async def update_memory_by_id(
                     "id": memory.id,
                     "id": memory.id,
                     "text": memory.content,
                     "text": memory.content,
                     "vector": request.app.state.EMBEDDING_FUNCTION(
                     "vector": request.app.state.EMBEDDING_FUNCTION(
-                        memory.content, user
+                        memory.content, user=user
                     ),
                     ),
                     "metadata": {
                     "metadata": {
                         "created_at": memory.created_at,
                         "created_at": memory.created_at,

+ 7 - 7
backend/open_webui/routers/retrieval.py

@@ -1518,8 +1518,8 @@ def query_doc_handler(
             return query_doc_with_hybrid_search(
             return query_doc_with_hybrid_search(
                 collection_name=form_data.collection_name,
                 collection_name=form_data.collection_name,
                 query=form_data.query,
                 query=form_data.query,
-                embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
-                    query, user=user
+                embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
+                    query, prefix=prefix, user=user
                 ),
                 ),
                 k=form_data.k if form_data.k else request.app.state.config.TOP_K,
                 k=form_data.k if form_data.k else request.app.state.config.TOP_K,
                 reranking_function=request.app.state.rf,
                 reranking_function=request.app.state.rf,
@@ -1569,8 +1569,8 @@ def query_collection_handler(
             return query_collection_with_hybrid_search(
             return query_collection_with_hybrid_search(
                 collection_names=form_data.collection_names,
                 collection_names=form_data.collection_names,
                 queries=[form_data.query],
                 queries=[form_data.query],
-                embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
-                    query, user=user
+                embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
+                    query, prefix=prefix, user=user
                 ),
                 ),
                 k=form_data.k if form_data.k else request.app.state.config.TOP_K,
                 k=form_data.k if form_data.k else request.app.state.config.TOP_K,
                 reranking_function=request.app.state.rf,
                 reranking_function=request.app.state.rf,
@@ -1586,8 +1586,8 @@ def query_collection_handler(
             return query_collection(
             return query_collection(
                 collection_names=form_data.collection_names,
                 collection_names=form_data.collection_names,
                 queries=[form_data.query],
                 queries=[form_data.query],
-                embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(
-                    query, user=user
+                embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION(
+                    query, prefix=prefix, user=user
                 ),
                 ),
                 k=form_data.k if form_data.k else request.app.state.config.TOP_K,
                 k=form_data.k if form_data.k else request.app.state.config.TOP_K,
             )
             )
@@ -1666,7 +1666,7 @@ if ENV == "dev":
     async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"):
     async def get_embeddings(request: Request, text: Optional[str] = "Hello World!"):
         return {
         return {
             "result": request.app.state.EMBEDDING_FUNCTION(
             "result": request.app.state.EMBEDDING_FUNCTION(
-                text, RAG_EMBEDDING_QUERY_PREFIX
+                text, prefix=RAG_EMBEDDING_QUERY_PREFIX
             )
             )
         }
         }