|
@@ -260,23 +260,47 @@ def query_collection(
|
|
|
k: int,
|
|
|
) -> dict:
|
|
|
results = []
|
|
|
- for query in queries:
|
|
|
- log.debug(f"query_collection:query {query}")
|
|
|
- query_embedding = embedding_function(query, prefix=RAG_EMBEDDING_QUERY_PREFIX)
|
|
|
- for collection_name in collection_names:
|
|
|
+ error = False
|
|
|
+
|
|
|
+ def process_query_collection(collection_name, query_embedding):
|
|
|
+ try:
|
|
|
if collection_name:
|
|
|
- try:
|
|
|
- result = query_doc(
|
|
|
- collection_name=collection_name,
|
|
|
- k=k,
|
|
|
- query_embedding=query_embedding,
|
|
|
- )
|
|
|
- if result is not None:
|
|
|
- results.append(result.model_dump())
|
|
|
- except Exception as e:
|
|
|
- log.exception(f"Error when querying the collection: {e}")
|
|
|
- else:
|
|
|
- pass
|
|
|
+ result = query_doc(
|
|
|
+ collection_name=collection_name,
|
|
|
+ k=k,
|
|
|
+ query_embedding=query_embedding,
|
|
|
+ )
|
|
|
+ if result is not None:
|
|
|
+ return result.model_dump(), None
|
|
|
+ return None, None
|
|
|
+ except Exception as e:
|
|
|
+ log.exception(f"Error when querying the collection: {e}")
|
|
|
+ return None, e
|
|
|
+
|
|
|
+ # Generate all query embeddings (in one call)
|
|
|
+ query_embeddings = embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX)
|
|
|
+ log.debug(
|
|
|
+ f"query_collection: processing {len(queries)} queries across {len(collection_names)} collections"
|
|
|
+ )
|
|
|
+
|
|
|
+ with ThreadPoolExecutor() as executor:
|
|
|
+ future_results = []
|
|
|
+ for query_embedding in query_embeddings:
|
|
|
+ for collection_name in collection_names:
|
|
|
+ result = executor.submit(
|
|
|
+ process_query_collection, collection_name, query_embedding
|
|
|
+ )
|
|
|
+ future_results.append(result)
|
|
|
+ task_results = [future.result() for future in future_results]
|
|
|
+
|
|
|
+ for result, err in task_results:
|
|
|
+ if err is not None:
|
|
|
+ error = True
|
|
|
+ elif result is not None:
|
|
|
+ results.append(result)
|
|
|
+
|
|
|
+ if error and not results:
|
|
|
+ log.warning("All collection queries failed. No results returned.")
|
|
|
|
|
|
return merge_and_sort_query_results(results, k=k)
|
|
|
|