1
0
Steven Kreitzer 1 жил өмнө
parent
commit
4e0b32b505

+ 8 - 3
Dockerfile

@@ -8,8 +8,9 @@ ARG USE_CUDA_VER=cu121
 # any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
 # any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
 # Leaderboard: https://huggingface.co/spaces/mteb/leaderboard 
 # Leaderboard: https://huggingface.co/spaces/mteb/leaderboard 
 # for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
 # for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
-# IMPORTANT: If you change the default model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
+# IMPORTANT: If you change the embedding model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
 ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
 ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
+ARG USE_RERANKING_MODEL=BAAI/bge-reranker-base
 
 
 ######## WebUI frontend ########
 ######## WebUI frontend ########
 FROM --platform=$BUILDPLATFORM node:21-alpine3.19 as build
 FROM --platform=$BUILDPLATFORM node:21-alpine3.19 as build
@@ -30,6 +31,7 @@ ARG USE_CUDA
 ARG USE_OLLAMA
 ARG USE_OLLAMA
 ARG USE_CUDA_VER
 ARG USE_CUDA_VER
 ARG USE_EMBEDDING_MODEL
 ARG USE_EMBEDDING_MODEL
+ARG USE_RERANKING_MODEL
 
 
 ## Basis ##
 ## Basis ##
 ENV ENV=prod \
 ENV ENV=prod \
@@ -38,7 +40,8 @@ ENV ENV=prod \
     USE_OLLAMA_DOCKER=${USE_OLLAMA} \
     USE_OLLAMA_DOCKER=${USE_OLLAMA} \
     USE_CUDA_DOCKER=${USE_CUDA} \
     USE_CUDA_DOCKER=${USE_CUDA} \
     USE_CUDA_DOCKER_VER=${USE_CUDA_VER} \
     USE_CUDA_DOCKER_VER=${USE_CUDA_VER} \
-    USE_EMBEDDING_MODEL_DOCKER=${USE_EMBEDDING_MODEL}
+    USE_EMBEDDING_MODEL_DOCKER=${USE_EMBEDDING_MODEL} \
+    USE_RERANKING_MODEL_DOCKER=${USE_RERANKING_MODEL}
 
 
 ## Basis URL Config ##
 ## Basis URL Config ##
 ENV OLLAMA_BASE_URL="/ollama" \
 ENV OLLAMA_BASE_URL="/ollama" \
@@ -62,7 +65,7 @@ ENV WHISPER_MODEL="base" \
 
 
 ## RAG Embedding model settings ##
 ## RAG Embedding model settings ##
 ENV RAG_EMBEDDING_MODEL="$USE_EMBEDDING_MODEL_DOCKER" \
 ENV RAG_EMBEDDING_MODEL="$USE_EMBEDDING_MODEL_DOCKER" \
-    RAG_EMBEDDING_MODEL_DIR="/app/backend/data/cache/embedding/models" \
+    RAG_RERANKING_MODEL="$USE_RERANKING_MODEL_DOCKER" \
     SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models"
     SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models"
 #### Other models ##########################################################
 #### Other models ##########################################################
 
 
@@ -99,11 +102,13 @@ RUN pip3 install uv && \
         pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$USE_CUDA_DOCKER_VER --no-cache-dir && \
         pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$USE_CUDA_DOCKER_VER --no-cache-dir && \
         uv pip install --system -r requirements.txt --no-cache-dir && \
         uv pip install --system -r requirements.txt --no-cache-dir && \
         python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
         python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
+        python -c "import os; from sentence_transformers import CrossEncoder; CrossEncoder(os.environ['RAG_RERANKING_MODEL'], device='cpu')" && \
         python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
         python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
     else \
     else \
         pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \
         pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \
         uv pip install --system -r requirements.txt --no-cache-dir && \
         uv pip install --system -r requirements.txt --no-cache-dir && \
         python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
         python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
+        python -c "import os; from sentence_transformers import CrossEncoder; CrossEncoder(os.environ['RAG_RERANKING_MODEL'], device='cpu')" && \
         python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
         python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
     fi
     fi
 
 

+ 77 - 65
backend/apps/rag/main.py

@@ -49,8 +49,8 @@ from apps.web.models.documents import (
 
 
 from apps.rag.utils import (
 from apps.rag.utils import (
     query_embeddings_doc,
     query_embeddings_doc,
+    query_embeddings_function,
     query_embeddings_collection,
     query_embeddings_collection,
-    generate_openai_embeddings,
 )
 )
 
 
 from utils.misc import (
 from utils.misc import (
@@ -67,6 +67,8 @@ from config import (
     RAG_EMBEDDING_ENGINE,
     RAG_EMBEDDING_ENGINE,
     RAG_EMBEDDING_MODEL,
     RAG_EMBEDDING_MODEL,
     RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
     RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
+    RAG_RERANKING_MODEL,
+    RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
     RAG_OPENAI_API_BASE_URL,
     RAG_OPENAI_API_BASE_URL,
     RAG_OPENAI_API_KEY,
     RAG_OPENAI_API_KEY,
     DEVICE_TYPE,
     DEVICE_TYPE,
@@ -91,6 +93,7 @@ app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
 
 
 app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
 app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
 app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
 app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
+app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
 app.state.RAG_TEMPLATE = RAG_TEMPLATE
 app.state.RAG_TEMPLATE = RAG_TEMPLATE
 
 
 app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
 app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
@@ -105,6 +108,12 @@ if app.state.RAG_EMBEDDING_ENGINE == "":
         trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
         trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
     )
     )
 
 
+app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
+    app.state.RAG_RERANKING_MODEL,
+    device=DEVICE_TYPE,
+    trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
+)
+
 
 
 origins = ["*"]
 origins = ["*"]
 
 
@@ -134,6 +143,7 @@ async def get_status():
         "template": app.state.RAG_TEMPLATE,
         "template": app.state.RAG_TEMPLATE,
         "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
         "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
         "embedding_model": app.state.RAG_EMBEDDING_MODEL,
         "embedding_model": app.state.RAG_EMBEDDING_MODEL,
+        "reranking_model": app.state.RAG_RERANKING_MODEL,
     }
     }
 
 
 
 
@@ -150,6 +160,11 @@ async def get_embedding_config(user=Depends(get_admin_user)):
     }
     }
 
 
 
 
+@app.get("/reranking")
+async def get_reraanking_config(user=Depends(get_admin_user)):
+    return {"status": True, "reranking_model": app.state.RAG_RERANKING_MODEL}
+
+
 class OpenAIConfigForm(BaseModel):
 class OpenAIConfigForm(BaseModel):
     url: str
     url: str
     key: str
     key: str
@@ -205,6 +220,36 @@ async def update_embedding_config(
         )
         )
 
 
 
 
+class RerankingModelUpdateForm(BaseModel):
+    reranking_model: str
+    
+
+@app.post("/reranking/update")
+async def update_reranking_config(
+    form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
+):
+    log.info(
+        f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
+    )
+    try:
+        app.state.RAG_RERANKING_MODEL = form_data.reranking_model
+        app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
+            app.state.RAG_RERANKING_MODEL,
+            device=DEVICE_TYPE,
+        )
+
+        return {
+            "status": True,
+            "reranking_model": app.state.RAG_RERANKING_MODEL,
+        }
+    except Exception as e:
+        log.exception(f"Problem updating reranking model: {e}")
+        raise HTTPException(
+            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+            detail=ERROR_MESSAGES.DEFAULT(e),
+        )
+
+
 @app.get("/config")
 @app.get("/config")
 async def get_rag_config(user=Depends(get_admin_user)):
 async def get_rag_config(user=Depends(get_admin_user)):
     return {
     return {
@@ -286,34 +331,21 @@ def query_doc_handler(
     user=Depends(get_current_user),
     user=Depends(get_current_user),
 ):
 ):
     try:
     try:
-        if app.state.RAG_EMBEDDING_ENGINE == "":
-            query_embeddings = app.state.sentence_transformer_ef.encode(
-                form_data.query
-            ).tolist()
-        elif app.state.RAG_EMBEDDING_ENGINE == "ollama":
-            query_embeddings = generate_ollama_embeddings(
-                GenerateEmbeddingsForm(
-                    **{
-                        "model": app.state.RAG_EMBEDDING_MODEL,
-                        "prompt": form_data.query,
-                    }
-                )
-            )
-        elif app.state.RAG_EMBEDDING_ENGINE == "openai":
-            query_embeddings = generate_openai_embeddings(
-                model=app.state.RAG_EMBEDDING_MODEL,
-                text=form_data.query,
-                key=app.state.OPENAI_API_KEY,
-                url=app.state.OPENAI_API_BASE_URL,
-            )
+        embeddings_function = query_embeddings_function(
+            app.state.RAG_EMBEDDING_ENGINE,
+            app.state.RAG_EMBEDDING_MODEL,
+            app.state.sentence_transformer_ef,
+            app.state.OPENAI_API_KEY,
+            app.state.OPENAI_API_BASE_URL,
+        )
 
 
         return query_embeddings_doc(
         return query_embeddings_doc(
             collection_name=form_data.collection_name,
             collection_name=form_data.collection_name,
             query=form_data.query,
             query=form_data.query,
-            query_embeddings=query_embeddings,
             k=form_data.k if form_data.k else app.state.TOP_K,
             k=form_data.k if form_data.k else app.state.TOP_K,
+            embeddings_function=embeddings_function,
+            reranking_function=app.state.sentence_transformer_rf,
         )
         )
-
     except Exception as e:
     except Exception as e:
         log.exception(e)
         log.exception(e)
         raise HTTPException(
         raise HTTPException(
@@ -334,33 +366,21 @@ def query_collection_handler(
     user=Depends(get_current_user),
     user=Depends(get_current_user),
 ):
 ):
     try:
     try:
-        if app.state.RAG_EMBEDDING_ENGINE == "":
-            query_embeddings = app.state.sentence_transformer_ef.encode(
-                form_data.query
-            ).tolist()
-        elif app.state.RAG_EMBEDDING_ENGINE == "ollama":
-            query_embeddings = generate_ollama_embeddings(
-                GenerateEmbeddingsForm(
-                    **{
-                        "model": app.state.RAG_EMBEDDING_MODEL,
-                        "prompt": form_data.query,
-                    }
-                )
-            )
-        elif app.state.RAG_EMBEDDING_ENGINE == "openai":
-            query_embeddings = generate_openai_embeddings(
-                model=app.state.RAG_EMBEDDING_MODEL,
-                text=form_data.query,
-                key=app.state.OPENAI_API_KEY,
-                url=app.state.OPENAI_API_BASE_URL,
-            )
+        embeddings_function = embeddings_function(
+            app.state.RAG_EMBEDDING_ENGINE,
+            app.state.RAG_EMBEDDING_MODEL,
+            app.state.sentence_transformer_ef,
+            app.state.OPENAI_API_KEY,
+            app.state.OPENAI_API_BASE_URL,
+        )
 
 
         return query_embeddings_collection(
         return query_embeddings_collection(
             collection_names=form_data.collection_names,
             collection_names=form_data.collection_names,
-            query_embeddings=query_embeddings,
+            query=form_data.query,
             k=form_data.k if form_data.k else app.state.TOP_K,
             k=form_data.k if form_data.k else app.state.TOP_K,
+            embeddings_function=embeddings_function,
+            reranking_function=app.state.sentence_transformer_rf,
         )
         )
-
     except Exception as e:
     except Exception as e:
         log.exception(e)
         log.exception(e)
         raise HTTPException(
         raise HTTPException(
@@ -427,8 +447,6 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
     log.info(f"store_docs_in_vector_db {docs} {collection_name}")
     log.info(f"store_docs_in_vector_db {docs} {collection_name}")
 
 
     texts = [doc.page_content for doc in docs]
     texts = [doc.page_content for doc in docs]
-    texts = list(map(lambda x: x.replace("\n", " "), texts))
-
     metadatas = [doc.metadata for doc in docs]
     metadatas = [doc.metadata for doc in docs]
 
 
     try:
     try:
@@ -440,26 +458,20 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
 
 
         collection = CHROMA_CLIENT.create_collection(name=collection_name)
         collection = CHROMA_CLIENT.create_collection(name=collection_name)
 
 
+        embedding_func = query_embeddings_function(
+            app.state.RAG_EMBEDDING_ENGINE,
+            app.state.RAG_EMBEDDING_MODEL,
+            app.state.sentence_transformer_ef,
+            app.state.OPENAI_API_KEY,
+            app.state.OPENAI_API_BASE_URL,
+        )
+
+        embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
         if app.state.RAG_EMBEDDING_ENGINE == "":
         if app.state.RAG_EMBEDDING_ENGINE == "":
-            embeddings = app.state.sentence_transformer_ef.encode(texts).tolist()
-        elif app.state.RAG_EMBEDDING_ENGINE == "ollama":
-            embeddings = [
-                generate_ollama_embeddings(
-                    GenerateEmbeddingsForm(
-                        **{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text}
-                    )
-                )
-                for text in texts
-            ]
-        elif app.state.RAG_EMBEDDING_ENGINE == "openai":
+            embeddings = embedding_func(embedding_texts)
+        else:
             embeddings = [
             embeddings = [
-                generate_openai_embeddings(
-                    model=app.state.RAG_EMBEDDING_MODEL,
-                    text=text,
-                    key=app.state.OPENAI_API_KEY,
-                    url=app.state.OPENAI_API_BASE_URL,
-                )
-                for text in texts
+                embedding_func(embedding_texts) for text in texts
             ]
             ]
 
 
         for batch in create_batches(
         for batch in create_batches(

+ 146 - 41
backend/apps/rag/utils.py

@@ -1,5 +1,8 @@
 import logging
 import logging
 import requests
 import requests
+import operator
+
+import sentence_transformers
 
 
 from typing import List
 from typing import List
 
 
@@ -8,6 +11,11 @@ from apps.ollama.main import (
     GenerateEmbeddingsForm,
     GenerateEmbeddingsForm,
 )
 )
 
 
+from langchain.retrievers import (
+    BM25Retriever,
+    EnsembleRetriever,
+)
+
 from config import SRC_LOG_LEVELS, CHROMA_CLIENT
 from config import SRC_LOG_LEVELS, CHROMA_CLIENT
 
 
 
 
@@ -15,60 +23,96 @@ log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 
 
 
 
-def query_embeddings_doc(collection_name: str, query: str, query_embeddings, k: int):
+def query_embeddings_doc(
+    collection_name: str,
+    query: str,
+    k: int,
+    embeddings_function,
+    reranking_function,
+):
     try:
     try:
         # if you use docker use the model from the environment variable
         # if you use docker use the model from the environment variable
-        log.info(f"query_embeddings_doc {query_embeddings}")
         collection = CHROMA_CLIENT.get_collection(name=collection_name)
         collection = CHROMA_CLIENT.get_collection(name=collection_name)
 
 
-        result = collection.query(
-            query_embeddings=[query_embeddings],
-            n_results=k,
+        # keyword search
+        documents = collection.get() # get all documents
+        bm25_retriever = BM25Retriever.from_texts(
+            texts=documents.get("documents"),
+            metadatas=documents.get("metadatas"),
+        )
+        bm25_retriever.k = k
+
+        # semantic search (vector)
+        chroma_retriever = ChromaRetriever(
+            collection=collection,
+            k=k,
+            embeddings_function=embeddings_function,
+        )
+
+        # hybrid search (ensemble)
+        ensemble_retriever = EnsembleRetriever(
+            retrievers=[bm25_retriever, chroma_retriever],
+            weights=[0.6, 0.4]
         )
         )
 
 
-        log.info(f"query_embeddings_doc:result {result}")
+        documents = ensemble_retriever.invoke(query)
+        result = query_results_rank(
+            query=query,
+            documents=documents,
+            k=k,
+            reranking_function=reranking_function,
+        )
+        result = {
+            "distances": [[d[1].item() for d in result]],
+            "documents": [[d[0].page_content for d in result]],
+            "metadatas": [[d[0].metadata for d in result]],
+        }
+
         return result
         return result
     except Exception as e:
     except Exception as e:
         raise e
         raise e
 
 
 
 
+def query_results_rank(query: str, documents, k: int, reranking_function):
+    scores = reranking_function.predict([(query, doc.page_content) for doc in documents])
+    docs_with_scores = list(zip(documents, scores))
+    result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
+    return result[: k]
+
+
 def merge_and_sort_query_results(query_results, k):
 def merge_and_sort_query_results(query_results, k):
     # Initialize lists to store combined data
     # Initialize lists to store combined data
-    combined_ids = []
     combined_distances = []
     combined_distances = []
-    combined_metadatas = []
     combined_documents = []
     combined_documents = []
+    combined_metadatas = []
 
 
     # Combine data from each dictionary
     # Combine data from each dictionary
     for data in query_results:
     for data in query_results:
-        combined_ids.extend(data["ids"][0])
         combined_distances.extend(data["distances"][0])
         combined_distances.extend(data["distances"][0])
-        combined_metadatas.extend(data["metadatas"][0])
         combined_documents.extend(data["documents"][0])
         combined_documents.extend(data["documents"][0])
+        combined_metadatas.extend(data["metadatas"][0])
 
 
-    # Create a list of tuples (distance, id, metadata, document)
+    # Create a list of tuples (distance, document, metadata)
     combined = list(
     combined = list(
-        zip(combined_distances, combined_ids, combined_metadatas, combined_documents)
+        zip(combined_distances, combined_documents, combined_metadatas)
     )
     )
 
 
     # Sort the list based on distances
     # Sort the list based on distances
     combined.sort(key=lambda x: x[0])
     combined.sort(key=lambda x: x[0])
 
 
     # Unzip the sorted list
     # Unzip the sorted list
-    sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined)
+    sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
 
 
     # Slicing the lists to include only k elements
     # Slicing the lists to include only k elements
     sorted_distances = list(sorted_distances)[:k]
     sorted_distances = list(sorted_distances)[:k]
-    sorted_ids = list(sorted_ids)[:k]
-    sorted_metadatas = list(sorted_metadatas)[:k]
     sorted_documents = list(sorted_documents)[:k]
     sorted_documents = list(sorted_documents)[:k]
+    sorted_metadatas = list(sorted_metadatas)[:k]
 
 
     # Create the output dictionary
     # Create the output dictionary
     merged_query_results = {
     merged_query_results = {
-        "ids": [sorted_ids],
         "distances": [sorted_distances],
         "distances": [sorted_distances],
-        "metadatas": [sorted_metadatas],
         "documents": [sorted_documents],
         "documents": [sorted_documents],
+        "metadatas": [sorted_metadatas],
         "embeddings": None,
         "embeddings": None,
         "uris": None,
         "uris": None,
         "data": None,
         "data": None,
@@ -78,19 +122,23 @@ def merge_and_sort_query_results(query_results, k):
 
 
 
 
 def query_embeddings_collection(
 def query_embeddings_collection(
-    collection_names: List[str], query: str, query_embeddings, k: int
+    collection_names: List[str],
+    query: str,
+    k: int,
+    embeddings_function,
+    reranking_function,
 ):
 ):
 
 
     results = []
     results = []
-    log.info(f"query_embeddings_collection {query_embeddings}")
 
 
     for collection_name in collection_names:
     for collection_name in collection_names:
         try:
         try:
             result = query_embeddings_doc(
             result = query_embeddings_doc(
                 collection_name=collection_name,
                 collection_name=collection_name,
                 query=query,
                 query=query,
-                query_embeddings=query_embeddings,
                 k=k,
                 k=k,
+                embeddings_function=embeddings_function,
+                reranking_function=reranking_function,
             )
             )
             results.append(result)
             results.append(result)
         except:
         except:
@@ -105,6 +153,33 @@ def rag_template(template: str, context: str, query: str):
     return template
     return template
 
 
 
 
+def query_embeddings_function(
+    embedding_engine,
+    embedding_model,
+    embedding_function,
+    openai_key,
+    openai_url,
+):
+    if embedding_engine == "":
+        return lambda query: embedding_function.encode(query).tolist()
+    elif embedding_engine == "ollama":
+        return lambda query: generate_ollama_embeddings(
+            GenerateEmbeddingsForm(
+                **{
+                    "model": embedding_model,
+                    "prompt": query,
+                }
+            )
+        )
+    elif embedding_engine == "openai":
+        return lambda query: generate_openai_embeddings(
+            model=embedding_model,
+            text=query,
+            key=openai_key,
+            url=openai_url,
+        )
+
+
 def rag_messages(
 def rag_messages(
     docs,
     docs,
     messages,
     messages,
@@ -113,11 +188,12 @@ def rag_messages(
     embedding_engine,
     embedding_engine,
     embedding_model,
     embedding_model,
     embedding_function,
     embedding_function,
+    reranking_function,
     openai_key,
     openai_key,
     openai_url,
     openai_url,
 ):
 ):
     log.debug(
     log.debug(
-        f"docs: {docs} {messages} {embedding_engine} {embedding_model} {embedding_function} {openai_key} {openai_url}"
+        f"docs: {docs} {messages} {embedding_engine} {embedding_model} {embedding_function} {reranking_function} {openai_key} {openai_url}"
     )
     )
 
 
     last_user_message_idx = None
     last_user_message_idx = None
@@ -155,38 +231,29 @@ def rag_messages(
             if doc["type"] == "text":
             if doc["type"] == "text":
                 context = doc["content"]
                 context = doc["content"]
             else:
             else:
-                if embedding_engine == "":
-                    query_embeddings = embedding_function.encode(query).tolist()
-                elif embedding_engine == "ollama":
-                    query_embeddings = generate_ollama_embeddings(
-                        GenerateEmbeddingsForm(
-                            **{
-                                "model": embedding_model,
-                                "prompt": query,
-                            }
-                        )
-                    )
-                elif embedding_engine == "openai":
-                    query_embeddings = generate_openai_embeddings(
-                        model=embedding_model,
-                        text=query,
-                        key=openai_key,
-                        url=openai_url,
-                    )
+                embeddings_function = query_embeddings_function(
+                    embedding_engine,
+                    embedding_model,
+                    embedding_function,
+                    openai_key,
+                    openai_url,
+                )
 
 
                 if doc["type"] == "collection":
                 if doc["type"] == "collection":
                     context = query_embeddings_collection(
                     context = query_embeddings_collection(
                         collection_names=doc["collection_names"],
                         collection_names=doc["collection_names"],
                         query=query,
                         query=query,
-                        query_embeddings=query_embeddings,
                         k=k,
                         k=k,
+                        embeddings_function=embeddings_function,
+                        reranking_function=reranking_function,
                     )
                     )
                 else:
                 else:
                     context = query_embeddings_doc(
                     context = query_embeddings_doc(
                         collection_name=doc["collection_name"],
                         collection_name=doc["collection_name"],
                         query=query,
                         query=query,
-                        query_embeddings=query_embeddings,
                         k=k,
                         k=k,
+                        embeddings_function=embeddings_function,
+                        reranking_function=reranking_function,
                     )
                     )
 
 
         except Exception as e:
         except Exception as e:
@@ -250,3 +317,41 @@ def generate_openai_embeddings(
     except Exception as e:
     except Exception as e:
         print(e)
         print(e)
         return None
         return None
+
+
+from typing import Any
+
+from langchain_core.callbacks import CallbackManagerForRetrieverRun
+from langchain_core.documents import Document
+from langchain_core.retrievers import BaseRetriever
+
+
+class ChromaRetriever(BaseRetriever):
+    collection: Any
+    k: int
+    embeddings_function: Any
+
+    def _get_relevant_documents(
+        self,
+        query: str,
+        *,
+        run_manager: CallbackManagerForRetrieverRun,
+    ) -> List[Document]:
+        query_embeddings = self.embeddings_function(query)
+
+        results = self.collection.query(
+            query_embeddings=[query_embeddings],
+            n_results=self.k,
+        )
+
+        ids = results["ids"][0]
+        metadatas = results["metadatas"][0]
+        documents = results["documents"][0]
+
+        return [
+            Document(
+                metadata=metadatas[idx],
+                page_content=documents[idx],
+            )
+            for idx in range(len(ids))
+        ]

+ 9 - 0
backend/config.py

@@ -424,6 +424,15 @@ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
     os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
     os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
 )
 )
 
 
+RAG_RERANKING_MODEL = os.environ.get(
+    "RAG_RERANKING_MODEL", "BAAI/bge-reranker-v2-m3"
+)
+log.info(f"Reranking model set: {RAG_RERANKING_MODEL}"),
+
+RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
+    os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
+)
+
 # device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
 # device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
 USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false")
 USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false")
 
 

+ 1 - 0
backend/main.py

@@ -117,6 +117,7 @@ class RAGMiddleware(BaseHTTPMiddleware):
                     rag_app.state.RAG_EMBEDDING_ENGINE,
                     rag_app.state.RAG_EMBEDDING_ENGINE,
                     rag_app.state.RAG_EMBEDDING_MODEL,
                     rag_app.state.RAG_EMBEDDING_MODEL,
                     rag_app.state.sentence_transformer_ef,
                     rag_app.state.sentence_transformer_ef,
+                    rag_app.state.sentence_transformer_rf,
                     rag_app.state.RAG_OPENAI_API_KEY,
                     rag_app.state.RAG_OPENAI_API_KEY,
                     rag_app.state.RAG_OPENAI_API_BASE_URL,
                     rag_app.state.RAG_OPENAI_API_BASE_URL,
                 )
                 )

+ 61 - 0
src/lib/apis/rag/index.ts

@@ -413,3 +413,64 @@ export const updateEmbeddingConfig = async (token: string, payload: EmbeddingMod
 
 
 	return res;
 	return res;
 };
 };
+
+export const getRerankingConfig = async (token: string) => {
+	let error = null;
+
+	const res = await fetch(`${RAG_API_BASE_URL}/reranking`, {
+		method: 'GET',
+		headers: {
+			'Content-Type': 'application/json',
+			Authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			error = err.detail;
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+type RerankingModelUpdateForm = {
+	reranking_model: string;
+};
+
+export const updateRerankingConfig = async (token: string, payload: RerankingModelUpdateForm) => {
+	let error = null;
+
+	const res = await fetch(`${RAG_API_BASE_URL}/reranking/update`, {
+		method: 'POST',
+		headers: {
+			'Content-Type': 'application/json',
+			Authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			...payload
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			error = err.detail;
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};

+ 104 - 1
src/lib/components/documents/Settings/General.svelte

@@ -8,7 +8,9 @@
 		updateQuerySettings,
 		updateQuerySettings,
 		resetVectorDB,
 		resetVectorDB,
 		getEmbeddingConfig,
 		getEmbeddingConfig,
-		updateEmbeddingConfig
+		updateEmbeddingConfig,
+		getRerankingConfig,
+		updateRerankingConfig
 	} from '$lib/apis/rag';
 	} from '$lib/apis/rag';
 
 
 	import { documents, models } from '$lib/stores';
 	import { documents, models } from '$lib/stores';
@@ -23,11 +25,13 @@
 
 
 	let scanDirLoading = false;
 	let scanDirLoading = false;
 	let updateEmbeddingModelLoading = false;
 	let updateEmbeddingModelLoading = false;
+	let updateRerankingModelLoading = false;
 
 
 	let showResetConfirm = false;
 	let showResetConfirm = false;
 
 
 	let embeddingEngine = '';
 	let embeddingEngine = '';
 	let embeddingModel = '';
 	let embeddingModel = '';
+	let rerankingModel = '';
 
 
 	let OpenAIKey = '';
 	let OpenAIKey = '';
 	let OpenAIUrl = '';
 	let OpenAIUrl = '';
@@ -115,6 +119,29 @@
 		}
 		}
 	};
 	};
 
 
+	const rerankingModelUpdateHandler = async () => {
+		console.log('Update reranking model attempt:', rerankingModel);
+
+		updateRerankingModelLoading = true;
+		const res = await updateRerankingConfig(localStorage.token, {
+			reranking_model: rerankingModel,
+		}).catch(async (error) => {
+			toast.error(error);
+			await setRerankingConfig();
+			return null;
+		});
+		updateRerankingModelLoading = false;
+
+		if (res) {
+			console.log('rerankingModelUpdateHandler:', res);
+			if (res.status === true) {
+				toast.success($i18n.t('Reranking model set to "{{reranking_model}}"', res), {
+					duration: 1000 * 10
+				});
+			}
+		}
+	};
+
 	const submitHandler = async () => {
 	const submitHandler = async () => {
 		const res = await updateRAGConfig(localStorage.token, {
 		const res = await updateRAGConfig(localStorage.token, {
 			pdf_extract_images: pdfExtractImages,
 			pdf_extract_images: pdfExtractImages,
@@ -138,6 +165,14 @@
 		}
 		}
 	};
 	};
 
 
+	const setRerankingConfig = async () => {
+		const rerankingConfig = await getRerankingConfig(localStorage.token);
+
+		if (rerankingConfig) {
+			rerankingModel = rerankingConfig.reranking_model;
+		}
+	};
+
 	onMount(async () => {
 	onMount(async () => {
 		const res = await getRAGConfig(localStorage.token);
 		const res = await getRAGConfig(localStorage.token);
 
 
@@ -149,6 +184,7 @@
 		}
 		}
 
 
 		await setEmbeddingConfig();
 		await setEmbeddingConfig();
+		await setRerankingConfig();
 
 
 		querySettings = await getQuerySettings(localStorage.token);
 		querySettings = await getQuerySettings(localStorage.token);
 	});
 	});
@@ -349,6 +385,73 @@
 
 
 				<hr class=" dark:border-gray-700 my-3" />
 				<hr class=" dark:border-gray-700 my-3" />
 
 
+				<div class=" ">
+					<div class=" mb-2 text-sm font-medium">{$i18n.t('Update Reranking Model')}</div>
+
+					<div class="flex w-full">
+						<div class="flex-1 mr-2">
+							<input
+								class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+								placeholder={$i18n.t('Update reranking model (e.g. {{model}})', {
+									model: rerankingModel.slice(-40)
+								})}
+								bind:value={rerankingModel}
+							/>
+						</div>
+						<button
+							class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
+							on:click={() => {
+								rerankingModelUpdateHandler();
+							}}
+							disabled={updateRerankingModelLoading}
+						>
+							{#if updateRerankingModelLoading}
+								<div class="self-center">
+									<svg
+										class=" w-4 h-4"
+										viewBox="0 0 24 24"
+										fill="currentColor"
+										xmlns="http://www.w3.org/2000/svg"
+										><style>
+											.spinner_ajPY {
+												transform-origin: center;
+												animation: spinner_AtaB 0.75s infinite linear;
+											}
+											@keyframes spinner_AtaB {
+												100% {
+													transform: rotate(360deg);
+												}
+											}
+										</style><path
+											d="M12,1A11,11,0,1,0,23,12,11,11,0,0,0,12,1Zm0,19a8,8,0,1,1,8-8A8,8,0,0,1,12,20Z"
+											opacity=".25"
+										/><path
+											d="M10.14,1.16a11,11,0,0,0-9,8.92A1.59,1.59,0,0,0,2.46,12,1.52,1.52,0,0,0,4.11,10.7a8,8,0,0,1,6.66-6.61A1.42,1.42,0,0,0,12,2.69h0A1.57,1.57,0,0,0,10.14,1.16Z"
+											class="spinner_ajPY"
+										/></svg
+									>
+								</div>
+							{:else}
+								<svg
+									xmlns="http://www.w3.org/2000/svg"
+									viewBox="0 0 16 16"
+									fill="currentColor"
+									class="w-4 h-4"
+								>
+									<path
+										d="M8.75 2.75a.75.75 0 0 0-1.5 0v5.69L5.03 6.22a.75.75 0 0 0-1.06 1.06l3.5 3.5a.75.75 0 0 0 1.06 0l3.5-3.5a.75.75 0 0 0-1.06-1.06L8.75 8.44V2.75Z"
+									/>
+									<path
+										d="M3.5 9.75a.75.75 0 0 0-1.5 0v1.5A2.75 2.75 0 0 0 4.75 14h6.5A2.75 2.75 0 0 0 14 11.25v-1.5a.75.75 0 0 0-1.5 0v1.5c0 .69-.56 1.25-1.25 1.25h-6.5c-.69 0-1.25-.56-1.25-1.25v-1.5Z"
+									/>
+								</svg>
+							{/if}
+						</button>
+					</div>
+				</div>
+
+				<hr class=" dark:border-gray-700 my-3" />
+
 				<div class="  flex w-full justify-between">
 				<div class="  flex w-full justify-between">
 					<div class=" self-center text-xs font-medium">
 					<div class=" self-center text-xs font-medium">
 						{$i18n.t('Scan for documents from {{path}}', { path: '/data/docs' })}
 						{$i18n.t('Scan for documents from {{path}}', { path: '/data/docs' })}