|
@@ -952,6 +952,7 @@ class RerankCompressor(BaseDocumentCompressor):
|
|
|
) -> Sequence[Document]:
|
|
|
reranking = self.reranking_function is not None
|
|
|
|
|
|
+ scores = None
|
|
|
if reranking:
|
|
|
scores = self.reranking_function(
|
|
|
[(query, doc.page_content) for doc in documents]
|
|
@@ -965,22 +966,31 @@ class RerankCompressor(BaseDocumentCompressor):
|
|
|
)
|
|
|
scores = util.cos_sim(query_embedding, document_embedding)[0]
|
|
|
|
|
|
- docs_with_scores = list(
|
|
|
- zip(documents, scores.tolist() if not isinstance(scores, list) else scores)
|
|
|
- )
|
|
|
- if self.r_score:
|
|
|
- docs_with_scores = [
|
|
|
- (d, s) for d, s in docs_with_scores if s >= self.r_score
|
|
|
- ]
|
|
|
-
|
|
|
- result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
|
|
|
- final_results = []
|
|
|
- for doc, doc_score in result[: self.top_n]:
|
|
|
- metadata = doc.metadata
|
|
|
- metadata["score"] = doc_score
|
|
|
- doc = Document(
|
|
|
- page_content=doc.page_content,
|
|
|
- metadata=metadata,
|
|
|
+ if scores:
|
|
|
+ docs_with_scores = list(
|
|
|
+ zip(
|
|
|
+ documents,
|
|
|
+ scores.tolist() if not isinstance(scores, list) else scores,
|
|
|
+ )
|
|
|
+ )
|
|
|
+ if self.r_score:
|
|
|
+ docs_with_scores = [
|
|
|
+ (d, s) for d, s in docs_with_scores if s >= self.r_score
|
|
|
+ ]
|
|
|
+
|
|
|
+ result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
|
|
|
+ final_results = []
|
|
|
+ for doc, doc_score in result[: self.top_n]:
|
|
|
+ metadata = doc.metadata
|
|
|
+ metadata["score"] = doc_score
|
|
|
+ doc = Document(
|
|
|
+ page_content=doc.page_content,
|
|
|
+ metadata=metadata,
|
|
|
+ )
|
|
|
+ final_results.append(doc)
|
|
|
+ return final_results
|
|
|
+ else:
|
|
|
+ log.warning(
|
|
|
+ "No valid scores found, check your reranking function. Returning original documents."
|
|
|
)
|
|
|
- final_results.append(doc)
|
|
|
- return final_results
|
|
|
+ return documents
|