Ver Fonte

fix: integration

Timothy J. Baek há 1 ano atrás
pai
commit
36ce157907
3 ficheiros alterados com 28 adições e 7 exclusões
  1. 5 0
      backend/apps/ollama/main.py
  2. 20 7
      backend/apps/rag/main.py
  3. 3 0
      backend/apps/rag/utils.py

+ 5 - 0
backend/apps/ollama/main.py

@@ -658,6 +658,9 @@ def generate_ollama_embeddings(
     form_data: GenerateEmbeddingsForm,
     form_data: GenerateEmbeddingsForm,
     url_idx: Optional[int] = None,
     url_idx: Optional[int] = None,
 ):
 ):
+
+    log.info("generate_ollama_embeddings", form_data)
+
     if url_idx == None:
     if url_idx == None:
         model = form_data.model
         model = form_data.model
 
 
@@ -685,6 +688,8 @@ def generate_ollama_embeddings(
 
 
         data = r.json()
         data = r.json()
 
 
+        log.info("generate_ollama_embeddings", data)
+
         if "embedding" in data:
         if "embedding" in data:
             return data["embedding"]
             return data["embedding"]
         else:
         else:

+ 20 - 7
backend/apps/rag/main.py

@@ -39,7 +39,7 @@ import uuid
 import json
 import json
 
 
 
 
-from apps.ollama.main import generate_ollama_embeddings
+from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
 
 
 from apps.web.models.documents import (
 from apps.web.models.documents import (
     Documents,
     Documents,
@@ -277,7 +277,12 @@ def query_doc_handler(
     try:
     try:
         if app.state.RAG_EMBEDDING_ENGINE == "ollama":
         if app.state.RAG_EMBEDDING_ENGINE == "ollama":
             query_embeddings = generate_ollama_embeddings(
             query_embeddings = generate_ollama_embeddings(
-                {"model": app.state.RAG_EMBEDDING_MODEL, "prompt": form_data.query}
+                GenerateEmbeddingsForm(
+                    **{
+                        "model": app.state.RAG_EMBEDDING_MODEL,
+                        "prompt": form_data.query,
+                    }
+                )
             )
             )
 
 
             return query_embeddings_doc(
             return query_embeddings_doc(
@@ -314,7 +319,12 @@ def query_collection_handler(
     try:
     try:
         if app.state.RAG_EMBEDDING_ENGINE == "ollama":
         if app.state.RAG_EMBEDDING_ENGINE == "ollama":
             query_embeddings = generate_ollama_embeddings(
             query_embeddings = generate_ollama_embeddings(
-                {"model": app.state.RAG_EMBEDDING_MODEL, "prompt": form_data.query}
+                GenerateEmbeddingsForm(
+                    **{
+                        "model": app.state.RAG_EMBEDDING_MODEL,
+                        "prompt": form_data.query,
+                    }
+                )
             )
             )
 
 
             return query_embeddings_collection(
             return query_embeddings_collection(
@@ -373,6 +383,7 @@ def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> b
     docs = text_splitter.split_documents(data)
     docs = text_splitter.split_documents(data)
 
 
     if len(docs) > 0:
     if len(docs) > 0:
+        log.info("store_data_in_vector_db", "store_docs_in_vector_db")
         return store_docs_in_vector_db(docs, collection_name, overwrite), None
         return store_docs_in_vector_db(docs, collection_name, overwrite), None
     else:
     else:
         raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
         raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
@@ -390,9 +401,8 @@ def store_text_in_vector_db(
     return store_docs_in_vector_db(docs, collection_name, overwrite)
     return store_docs_in_vector_db(docs, collection_name, overwrite)
 
 
 
 
-async def store_docs_in_vector_db(
-    docs, collection_name, overwrite: bool = False
-) -> bool:
+def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
+    log.info("store_docs_in_vector_db", docs, collection_name)
 
 
     texts = [doc.page_content for doc in docs]
     texts = [doc.page_content for doc in docs]
     metadatas = [doc.metadata for doc in docs]
     metadatas = [doc.metadata for doc in docs]
@@ -413,13 +423,16 @@ async def store_docs_in_vector_db(
                 metadatas=metadatas,
                 metadatas=metadatas,
                 embeddings=[
                 embeddings=[
                     generate_ollama_embeddings(
                     generate_ollama_embeddings(
-                        {"model": RAG_EMBEDDING_MODEL, "prompt": text}
+                        GenerateEmbeddingsForm(
+                            **{"model": RAG_EMBEDDING_MODEL, "prompt": text}
+                        )
                     )
                     )
                     for text in texts
                     for text in texts
                 ],
                 ],
             ):
             ):
                 collection.add(*batch)
                 collection.add(*batch)
         else:
         else:
+
             collection = CHROMA_CLIENT.create_collection(
             collection = CHROMA_CLIENT.create_collection(
                 name=collection_name,
                 name=collection_name,
                 embedding_function=app.state.sentence_transformer_ef,
                 embedding_function=app.state.sentence_transformer_ef,

+ 3 - 0
backend/apps/rag/utils.py

@@ -32,6 +32,7 @@ def query_doc(collection_name: str, query: str, k: int, embedding_function):
 def query_embeddings_doc(collection_name: str, query_embeddings, k: int):
 def query_embeddings_doc(collection_name: str, query_embeddings, k: int):
     try:
     try:
         # if you use docker use the model from the environment variable
         # if you use docker use the model from the environment variable
+        log.info("query_embeddings_doc", query_embeddings)
         collection = CHROMA_CLIENT.get_collection(
         collection = CHROMA_CLIENT.get_collection(
             name=collection_name,
             name=collection_name,
         )
         )
@@ -117,6 +118,8 @@ def query_collection(
 def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int):
 def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int):
 
 
     results = []
     results = []
+    log.info("query_embeddings_collection", query_embeddings)
+
     for collection_name in collection_names:
     for collection_name in collection_names:
         try:
         try:
             collection = CHROMA_CLIENT.get_collection(name=collection_name)
             collection = CHROMA_CLIENT.get_collection(name=collection_name)