|
@@ -21,14 +21,22 @@ import org.elasticsearch.test.ESSingleNodeTestCase;
|
|
|
import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
|
|
|
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
|
|
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
|
|
|
+import org.hamcrest.Matcher;
|
|
|
import org.junit.Before;
|
|
|
|
|
|
import java.util.Collection;
|
|
|
import java.util.Collections;
|
|
|
import java.util.List;
|
|
|
import java.util.Map;
|
|
|
-import java.util.Objects;
|
|
|
|
|
|
+import static org.elasticsearch.index.query.QueryBuilders.boolQuery;
|
|
|
+import static org.elasticsearch.index.query.QueryBuilders.constantScoreQuery;
|
|
|
+import static org.elasticsearch.index.query.QueryBuilders.matchQuery;
|
|
|
+import static org.elasticsearch.test.LambdaMatchers.transformedMatch;
|
|
|
+import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.hasRank;
|
|
|
+import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.hasScore;
|
|
|
+import static org.hamcrest.Matchers.allOf;
|
|
|
+import static org.hamcrest.Matchers.arrayContaining;
|
|
|
import static org.hamcrest.Matchers.containsString;
|
|
|
import static org.hamcrest.Matchers.equalTo;
|
|
|
|
|
@@ -49,7 +57,7 @@ public class TextSimilarityRankTests extends ESSingleNodeTestCase {
|
|
|
Float minScore,
|
|
|
int topN
|
|
|
) {
|
|
|
- super(field, inferenceId + "-task-settings-top-" + topN, inferenceText, rankWindowSize, minScore);
|
|
|
+ super(field, inferenceId + "-task-settings-top-" + topN, inferenceText, rankWindowSize, minScore, false);
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -68,7 +76,7 @@ public class TextSimilarityRankTests extends ESSingleNodeTestCase {
|
|
|
Float minScore,
|
|
|
int inferenceResultCount
|
|
|
) {
|
|
|
- super(field, inferenceId, inferenceText, rankWindowSize, minScore);
|
|
|
+ super(field, inferenceId, inferenceText, rankWindowSize, minScore, false);
|
|
|
this.inferenceResultCount = inferenceResultCount;
|
|
|
}
|
|
|
|
|
@@ -81,7 +89,8 @@ public class TextSimilarityRankTests extends ESSingleNodeTestCase {
|
|
|
client,
|
|
|
inferenceId,
|
|
|
inferenceText,
|
|
|
- minScore
|
|
|
+ minScore,
|
|
|
+ failuresAllowed()
|
|
|
) {
|
|
|
@Override
|
|
|
protected InferenceAction.Request generateRequest(List<String> docFeatures) {
|
|
@@ -125,18 +134,21 @@ public class TextSimilarityRankTests extends ESSingleNodeTestCase {
|
|
|
ElasticsearchAssertions.assertNoFailuresAndResponse(
|
|
|
// Execute search with text similarity reranking
|
|
|
client.prepareSearch()
|
|
|
- .setRankBuilder(new TextSimilarityRankBuilder("text", "my-rerank-model", "my query", 100, 0.0f))
|
|
|
+ .setRankBuilder(new TextSimilarityRankBuilder("text", "my-rerank-model", "my query", 100, 0.0f, false))
|
|
|
.setQuery(QueryBuilders.matchAllQuery()),
|
|
|
response -> {
|
|
|
// Verify order, rank and score of results
|
|
|
- SearchHit[] hits = response.getHits().getHits();
|
|
|
- assertEquals(5, hits.length);
|
|
|
- // we add + 1 to all expected scores due to the default normalization being applied which shifts positive scores to by 1
|
|
|
- assertHitHasRankScoreAndText(hits[0], 1, 4.0f + 1f, "4");
|
|
|
- assertHitHasRankScoreAndText(hits[1], 2, 3.0f + 1f, "3");
|
|
|
- assertHitHasRankScoreAndText(hits[2], 3, 2.0f + 1f, "2");
|
|
|
- assertHitHasRankScoreAndText(hits[3], 4, 1.0f + 1f, "1");
|
|
|
- assertHitHasRankScoreAndText(hits[4], 5, 0.0f + 1f, "0");
|
|
|
+ assertThat(
|
|
|
+ response.getHits().getHits(),
|
|
|
+ arrayContaining(
|
|
|
+ // add 1 to all expected scores due to the default normalization being applied which shifts positive scores by 1
|
|
|
+ searchHitWith(1, 4.0f + 1f, "4"),
|
|
|
+ searchHitWith(2, 3.0f + 1f, "3"),
|
|
|
+ searchHitWith(3, 2.0f + 1f, "2"),
|
|
|
+ searchHitWith(4, 1.0f + 1f, "1"),
|
|
|
+ searchHitWith(5, 0.0f + 1f, "0")
|
|
|
+ )
|
|
|
+ );
|
|
|
}
|
|
|
);
|
|
|
}
|
|
@@ -145,15 +157,14 @@ public class TextSimilarityRankTests extends ESSingleNodeTestCase {
|
|
|
ElasticsearchAssertions.assertNoFailuresAndResponse(
|
|
|
// Execute search with text similarity reranking
|
|
|
client.prepareSearch()
|
|
|
- .setRankBuilder(new TextSimilarityRankBuilder("text", "my-rerank-model", "my query", 100, 1.5f))
|
|
|
+ .setRankBuilder(new TextSimilarityRankBuilder("text", "my-rerank-model", "my query", 100, 1.5f, false))
|
|
|
.setQuery(QueryBuilders.matchAllQuery()),
|
|
|
response -> {
|
|
|
// Verify order, rank and score of results
|
|
|
- SearchHit[] hits = response.getHits().getHits();
|
|
|
- assertEquals(3, hits.length);
|
|
|
- assertHitHasRankScoreAndText(hits[0], 1, 4.0f + 1f, "4");
|
|
|
- assertHitHasRankScoreAndText(hits[1], 2, 3.0f + 1f, "3");
|
|
|
- assertHitHasRankScoreAndText(hits[2], 3, 2.0f + 1f, "2");
|
|
|
+ assertThat(
|
|
|
+ response.getHits().getHits(),
|
|
|
+ arrayContaining(searchHitWith(1, 4.0f + 1f, "4"), searchHitWith(2, 3.0f + 1f, "3"), searchHitWith(3, 2.0f + 1f, "2"))
|
|
|
+ );
|
|
|
}
|
|
|
);
|
|
|
}
|
|
@@ -169,6 +180,7 @@ public class TextSimilarityRankTests extends ESSingleNodeTestCase {
|
|
|
"my-rerank-model",
|
|
|
"my query",
|
|
|
0.7f,
|
|
|
+ false,
|
|
|
AbstractRerankerIT.ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_COORDINATOR_CONTEXT.name()
|
|
|
)
|
|
|
)
|
|
@@ -178,6 +190,44 @@ public class TextSimilarityRankTests extends ESSingleNodeTestCase {
|
|
|
);
|
|
|
}
|
|
|
|
|
|
+ public void testRerankInferenceAllowedFailure() {
|
|
|
+ ElasticsearchAssertions.assertNoFailuresAndResponse(
|
|
|
+ // Execute search with text similarity reranking that fails, but it is allowed
|
|
|
+ client.prepareSearch()
|
|
|
+ .setRankBuilder(
|
|
|
+ new TextSimilarityTestPlugin.ThrowingMockRequestActionBasedRankBuilder(
|
|
|
+ 100,
|
|
|
+ "text",
|
|
|
+ "my-rerank-model",
|
|
|
+ "my query",
|
|
|
+ null,
|
|
|
+ true,
|
|
|
+ AbstractRerankerIT.ThrowingRankBuilderType.THROWING_RANK_FEATURE_PHASE_COORDINATOR_CONTEXT.name()
|
|
|
+ )
|
|
|
+ )
|
|
|
+ .setQuery(
|
|
|
+ boolQuery().should(constantScoreQuery(matchQuery("text", "0")).boost(50))
|
|
|
+ .should(constantScoreQuery(matchQuery("text", "1")).boost(40))
|
|
|
+ .should(constantScoreQuery(matchQuery("text", "2")).boost(30))
|
|
|
+ .should(constantScoreQuery(matchQuery("text", "3")).boost(20))
|
|
|
+ .should(constantScoreQuery(matchQuery("text", "4")).boost(10))
|
|
|
+ ),
|
|
|
+ response -> {
|
|
|
+ // these will all have the scores from the constant score clauses
|
|
|
+ assertThat(
|
|
|
+ response.getHits().getHits(),
|
|
|
+ arrayContaining(
|
|
|
+ searchHitWith(1, 50, "0"),
|
|
|
+ searchHitWith(2, 40, "1"),
|
|
|
+ searchHitWith(3, 30, "2"),
|
|
|
+ searchHitWith(4, 20, "3"),
|
|
|
+ searchHitWith(5, 10, "4")
|
|
|
+ )
|
|
|
+ );
|
|
|
+ }
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
public void testRerankTopNConfigurationAndRankWindowSizeMismatch() {
|
|
|
SearchPhaseExecutionException ex = expectThrows(
|
|
|
SearchPhaseExecutionException.class,
|
|
@@ -212,10 +262,11 @@ public class TextSimilarityRankTests extends ESSingleNodeTestCase {
|
|
|
assertThat(ex.getDetailedMessage(), containsString("Reranker input document count and returned score count mismatch"));
|
|
|
}
|
|
|
|
|
|
- private static void assertHitHasRankScoreAndText(SearchHit hit, int expectedRank, float expectedScore, String expectedText) {
|
|
|
- assertEquals(expectedRank, hit.getRank());
|
|
|
- assertEquals(expectedScore, hit.getScore(), 0.0f);
|
|
|
- assertEquals(expectedText, Objects.requireNonNull(hit.getSourceAsMap()).get("text"));
|
|
|
+ private static Matcher<SearchHit> searchHitWith(int expectedRank, float expectedScore, String expectedText) {
|
|
|
+ return allOf(
|
|
|
+ hasRank(expectedRank),
|
|
|
+ hasScore(expectedScore),
|
|
|
+ transformedMatch(hit -> hit.getSourceAsMap().get("text"), equalTo(expectedText))
|
|
|
+ );
|
|
|
}
|
|
|
-
|
|
|
}
|