|
@@ -7,14 +7,20 @@
|
|
|
|
|
|
package org.elasticsearch.xpack.inference.rank.textsimilarity;
|
|
|
|
|
|
+import org.apache.lucene.search.ScoreDoc;
|
|
|
import org.elasticsearch.common.ParsingException;
|
|
|
import org.elasticsearch.features.NodeFeature;
|
|
|
+import org.elasticsearch.index.query.BoolQueryBuilder;
|
|
|
import org.elasticsearch.index.query.QueryBuilder;
|
|
|
-import org.elasticsearch.index.query.QueryRewriteContext;
|
|
|
import org.elasticsearch.license.LicenseUtils;
|
|
|
+import org.elasticsearch.search.builder.PointInTimeBuilder;
|
|
|
import org.elasticsearch.search.builder.SearchSourceBuilder;
|
|
|
+import org.elasticsearch.search.fetch.StoredFieldsContext;
|
|
|
+import org.elasticsearch.search.rank.RankDoc;
|
|
|
+import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
|
|
|
import org.elasticsearch.search.retriever.RetrieverBuilder;
|
|
|
import org.elasticsearch.search.retriever.RetrieverParserContext;
|
|
|
+import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
|
|
|
import org.elasticsearch.xcontent.ConstructingObjectParser;
|
|
|
import org.elasticsearch.xcontent.ParseField;
|
|
|
import org.elasticsearch.xcontent.XContentBuilder;
|
|
@@ -32,11 +38,14 @@ import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstr
|
|
|
/**
|
|
|
* A {@code RetrieverBuilder} for parsing and constructing a text similarity reranker retriever.
|
|
|
*/
|
|
|
-public class TextSimilarityRankRetrieverBuilder extends RetrieverBuilder {
|
|
|
+public class TextSimilarityRankRetrieverBuilder extends CompoundRetrieverBuilder<TextSimilarityRankRetrieverBuilder> {
|
|
|
|
|
|
public static final NodeFeature TEXT_SIMILARITY_RERANKER_RETRIEVER_SUPPORTED = new NodeFeature(
|
|
|
"text_similarity_reranker_retriever_supported"
|
|
|
);
|
|
|
+ public static final NodeFeature TEXT_SIMILARITY_RERANKER_COMPOSITION_SUPPORTED = new NodeFeature(
|
|
|
+ "text_similarity_reranker_retriever_composition_supported"
|
|
|
+ );
|
|
|
|
|
|
public static final ParseField RETRIEVER_FIELD = new ParseField("retriever");
|
|
|
public static final ParseField INFERENCE_ID_FIELD = new ParseField("inference_id");
|
|
@@ -51,7 +60,6 @@ public class TextSimilarityRankRetrieverBuilder extends RetrieverBuilder {
|
|
|
String inferenceText = (String) args[2];
|
|
|
String field = (String) args[3];
|
|
|
int rankWindowSize = args[4] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[4];
|
|
|
-
|
|
|
return new TextSimilarityRankRetrieverBuilder(retrieverBuilder, inferenceId, inferenceText, field, rankWindowSize);
|
|
|
});
|
|
|
|
|
@@ -70,17 +78,20 @@ public class TextSimilarityRankRetrieverBuilder extends RetrieverBuilder {
|
|
|
if (context.clusterSupportsFeature(TEXT_SIMILARITY_RERANKER_RETRIEVER_SUPPORTED) == false) {
|
|
|
throw new ParsingException(parser.getTokenLocation(), "unknown retriever [" + TextSimilarityRankBuilder.NAME + "]");
|
|
|
}
|
|
|
+ if (context.clusterSupportsFeature(TEXT_SIMILARITY_RERANKER_COMPOSITION_SUPPORTED) == false) {
|
|
|
+ throw new UnsupportedOperationException(
|
|
|
+ "[text_similarity_reranker] retriever composition feature is not supported by all nodes in the cluster"
|
|
|
+ );
|
|
|
+ }
|
|
|
if (TextSimilarityRankBuilder.TEXT_SIMILARITY_RERANKER_FEATURE.check(XPackPlugin.getSharedLicenseState()) == false) {
|
|
|
throw LicenseUtils.newComplianceException(TextSimilarityRankBuilder.NAME);
|
|
|
}
|
|
|
return PARSER.apply(parser, context);
|
|
|
}
|
|
|
|
|
|
- private final RetrieverBuilder retrieverBuilder;
|
|
|
private final String inferenceId;
|
|
|
private final String inferenceText;
|
|
|
private final String field;
|
|
|
- private final int rankWindowSize;
|
|
|
|
|
|
public TextSimilarityRankRetrieverBuilder(
|
|
|
RetrieverBuilder retrieverBuilder,
|
|
@@ -89,15 +100,14 @@ public class TextSimilarityRankRetrieverBuilder extends RetrieverBuilder {
|
|
|
String field,
|
|
|
int rankWindowSize
|
|
|
) {
|
|
|
- this.retrieverBuilder = retrieverBuilder;
|
|
|
+ super(List.of(new RetrieverSource(retrieverBuilder, null)), rankWindowSize);
|
|
|
this.inferenceId = inferenceId;
|
|
|
this.inferenceText = inferenceText;
|
|
|
this.field = field;
|
|
|
- this.rankWindowSize = rankWindowSize;
|
|
|
}
|
|
|
|
|
|
public TextSimilarityRankRetrieverBuilder(
|
|
|
- RetrieverBuilder retrieverBuilder,
|
|
|
+ List<RetrieverSource> retrieverSource,
|
|
|
String inferenceId,
|
|
|
String inferenceText,
|
|
|
String field,
|
|
@@ -106,66 +116,75 @@ public class TextSimilarityRankRetrieverBuilder extends RetrieverBuilder {
|
|
|
String retrieverName,
|
|
|
List<QueryBuilder> preFilterQueryBuilders
|
|
|
) {
|
|
|
- this.retrieverBuilder = retrieverBuilder;
|
|
|
+ super(retrieverSource, rankWindowSize);
|
|
|
+ if (retrieverSource.size() != 1) {
|
|
|
+ throw new IllegalArgumentException("[" + getName() + "] retriever should have exactly one inner retriever");
|
|
|
+ }
|
|
|
this.inferenceId = inferenceId;
|
|
|
this.inferenceText = inferenceText;
|
|
|
this.field = field;
|
|
|
- this.rankWindowSize = rankWindowSize;
|
|
|
this.minScore = minScore;
|
|
|
this.retrieverName = retrieverName;
|
|
|
this.preFilterQueryBuilders = preFilterQueryBuilders;
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
- public QueryBuilder topDocsQuery() {
|
|
|
- // the original matching set of the TextSimilarityRank retriever is specified by its nested retriever
|
|
|
- return retrieverBuilder.topDocsQuery();
|
|
|
+ protected TextSimilarityRankRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers) {
|
|
|
+ return new TextSimilarityRankRetrieverBuilder(
|
|
|
+ newChildRetrievers,
|
|
|
+ inferenceId,
|
|
|
+ inferenceText,
|
|
|
+ field,
|
|
|
+ rankWindowSize,
|
|
|
+ minScore,
|
|
|
+ retrieverName,
|
|
|
+ preFilterQueryBuilders
|
|
|
+ );
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
- public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException {
|
|
|
- // rewrite prefilters
|
|
|
- boolean hasChanged = false;
|
|
|
- var newPreFilters = rewritePreFilters(ctx);
|
|
|
- hasChanged |= newPreFilters != preFilterQueryBuilders;
|
|
|
-
|
|
|
- // rewrite nested retriever
|
|
|
- RetrieverBuilder newRetriever = retrieverBuilder.rewrite(ctx);
|
|
|
- hasChanged |= newRetriever != retrieverBuilder;
|
|
|
- if (hasChanged) {
|
|
|
- return new TextSimilarityRankRetrieverBuilder(
|
|
|
- newRetriever,
|
|
|
- field,
|
|
|
- inferenceText,
|
|
|
- inferenceId,
|
|
|
- rankWindowSize,
|
|
|
- minScore,
|
|
|
- this.retrieverName,
|
|
|
- newPreFilters
|
|
|
- );
|
|
|
+ protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults) {
|
|
|
+ assert rankResults.size() == 1;
|
|
|
+ ScoreDoc[] scoreDocs = rankResults.get(0);
|
|
|
+ TextSimilarityRankDoc[] textSimilarityRankDocs = new TextSimilarityRankDoc[scoreDocs.length];
|
|
|
+ for (int i = 0; i < scoreDocs.length; i++) {
|
|
|
+ ScoreDoc scoreDoc = scoreDocs[i];
|
|
|
+ textSimilarityRankDocs[i] = new TextSimilarityRankDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex, inferenceId, field);
|
|
|
}
|
|
|
- return this;
|
|
|
+ return textSimilarityRankDocs;
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
- public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
|
|
|
- retrieverBuilder.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders);
|
|
|
- retrieverBuilder.extractToSearchSourceBuilder(searchSourceBuilder, compoundUsed);
|
|
|
- // Combining with other rank builder (such as RRF) is not supported yet
|
|
|
- if (searchSourceBuilder.rankBuilder() != null) {
|
|
|
- throw new IllegalArgumentException("text similarity rank builder cannot be combined with other rank builders");
|
|
|
- }
|
|
|
+ public QueryBuilder explainQuery() {
|
|
|
+ // the original matching set of the TextSimilarityRank retriever is specified by its nested retriever
|
|
|
+ return new RankDocsQueryBuilder(rankDocs, new QueryBuilder[] { innerRetrievers.get(0).retriever().explainQuery() }, true);
|
|
|
+ }
|
|
|
|
|
|
- searchSourceBuilder.rankBuilder(
|
|
|
+ @Override
|
|
|
+ protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) {
|
|
|
+ var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit)
|
|
|
+ .trackTotalHits(false)
|
|
|
+ .storedFields(new StoredFieldsContext(false))
|
|
|
+ .size(rankWindowSize);
|
|
|
+ if (preFilterQueryBuilders.isEmpty() == false) {
|
|
|
+ retrieverBuilder.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders);
|
|
|
+ }
|
|
|
+ retrieverBuilder.extractToSearchSourceBuilder(sourceBuilder, true);
|
|
|
+
|
|
|
+ // apply the pre-filters
|
|
|
+ if (preFilterQueryBuilders.size() > 0) {
|
|
|
+ QueryBuilder query = sourceBuilder.query();
|
|
|
+ BoolQueryBuilder newQuery = new BoolQueryBuilder();
|
|
|
+ if (query != null) {
|
|
|
+ newQuery.must(query);
|
|
|
+ }
|
|
|
+ preFilterQueryBuilders.forEach(newQuery::filter);
|
|
|
+ sourceBuilder.query(newQuery);
|
|
|
+ }
|
|
|
+ sourceBuilder.rankBuilder(
|
|
|
new TextSimilarityRankBuilder(this.field, this.inferenceId, this.inferenceText, this.rankWindowSize, this.minScore)
|
|
|
);
|
|
|
- }
|
|
|
-
|
|
|
- /**
|
|
|
- * Determines if this retriever contains sub-retrievers that need to be executed prior to search.
|
|
|
- */
|
|
|
- public boolean isCompound() {
|
|
|
- return retrieverBuilder.isCompound();
|
|
|
+ return sourceBuilder;
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -179,7 +198,7 @@ public class TextSimilarityRankRetrieverBuilder extends RetrieverBuilder {
|
|
|
|
|
|
@Override
|
|
|
protected void doToXContent(XContentBuilder builder, Params params) throws IOException {
|
|
|
- builder.field(RETRIEVER_FIELD.getPreferredName(), retrieverBuilder);
|
|
|
+ builder.field(RETRIEVER_FIELD.getPreferredName(), innerRetrievers.get(0).retriever());
|
|
|
builder.field(INFERENCE_ID_FIELD.getPreferredName(), inferenceId);
|
|
|
builder.field(INFERENCE_TEXT_FIELD.getPreferredName(), inferenceText);
|
|
|
builder.field(FIELD_FIELD.getPreferredName(), field);
|
|
@@ -187,9 +206,9 @@ public class TextSimilarityRankRetrieverBuilder extends RetrieverBuilder {
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
- protected boolean doEquals(Object other) {
|
|
|
+ public boolean doEquals(Object other) {
|
|
|
TextSimilarityRankRetrieverBuilder that = (TextSimilarityRankRetrieverBuilder) other;
|
|
|
- return Objects.equals(retrieverBuilder, that.retrieverBuilder)
|
|
|
+ return super.doEquals(other)
|
|
|
&& Objects.equals(inferenceId, that.inferenceId)
|
|
|
&& Objects.equals(inferenceText, that.inferenceText)
|
|
|
&& Objects.equals(field, that.field)
|
|
@@ -198,7 +217,7 @@ public class TextSimilarityRankRetrieverBuilder extends RetrieverBuilder {
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
- protected int doHashCode() {
|
|
|
- return Objects.hash(retrieverBuilder, inferenceId, inferenceText, field, rankWindowSize, minScore);
|
|
|
+ public int doHashCode() {
|
|
|
+ return Objects.hash(inferenceId, inferenceText, field, rankWindowSize, minScore);
|
|
|
}
|
|
|
}
|