|
|
@@ -9,6 +9,7 @@ package org.elasticsearch.xpack.ml.integration;
|
|
|
|
|
|
import org.elasticsearch.client.Request;
|
|
|
import org.elasticsearch.client.Response;
|
|
|
+import org.elasticsearch.client.ResponseException;
|
|
|
import org.elasticsearch.common.settings.Settings;
|
|
|
import org.elasticsearch.common.xcontent.support.XContentMapValues;
|
|
|
import org.junit.Before;
|
|
|
@@ -196,7 +197,6 @@ public class LearningToRankRescorerIT extends InferenceTestCase {
|
|
|
adminClient().performRequest(new Request("POST", INDEX_NAME + "/_refresh"));
|
|
|
}
|
|
|
|
|
|
- @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/103072")
|
|
|
public void testLearningToRankRescore() throws Exception {
|
|
|
Request request = new Request("GET", "store/_search?size=3&error_trace");
|
|
|
request.setJsonEntity("""
|
|
|
@@ -232,7 +232,6 @@ public class LearningToRankRescorerIT extends InferenceTestCase {
|
|
|
assertHitScores(client().performRequest(request), List.of(9.0, 9.0, 6.0));
|
|
|
}
|
|
|
|
|
|
- @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/103072")
|
|
|
public void testLearningToRankRescoreSmallWindow() throws Exception {
|
|
|
Request request = new Request("GET", "store/_search?size=5");
|
|
|
request.setJsonEntity("""
|
|
|
@@ -242,30 +241,33 @@ public class LearningToRankRescorerIT extends InferenceTestCase {
|
|
|
"learning_to_rank": { "model_id": "ltr-model" }
|
|
|
}
|
|
|
}""");
|
|
|
- assertHitScores(client().performRequest(request), List.of(20.0, 20.0, 1.0, 1.0, 1.0));
|
|
|
+ assertThrows(
|
|
|
+ "Rescore window is too small and should be at least the value of from + size but was [2]",
|
|
|
+ ResponseException.class,
|
|
|
+ () -> client().performRequest(request)
|
|
|
+ );
|
|
|
}
|
|
|
|
|
|
- @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/103072")
|
|
|
public void testLearningToRankRescorerWithChainedRescorers() throws IOException {
|
|
|
Request request = new Request("GET", "store/_search?size=5");
|
|
|
request.setJsonEntity("""
|
|
|
{
|
|
|
"rescore": [
|
|
|
{
|
|
|
- "window_size": 4,
|
|
|
+ "window_size": 15,
|
|
|
"query": { "rescore_query" : { "script_score": { "query": { "match_all": {} }, "script": { "source": "return 4" } } } }
|
|
|
},
|
|
|
{
|
|
|
- "window_size": 3,
|
|
|
+ "window_size": 25,
|
|
|
"learning_to_rank": { "model_id": "ltr-model" }
|
|
|
},
|
|
|
{
|
|
|
- "window_size": 2,
|
|
|
+ "window_size": 35,
|
|
|
"query": { "rescore_query": { "script_score": { "query": { "match_all": {} }, "script": { "source": "return 20"} } } }
|
|
|
}
|
|
|
]
|
|
|
}""");
|
|
|
- assertHitScores(client().performRequest(request), List.of(40.0, 40.0, 17.0, 5.0, 1.0));
|
|
|
+ assertHitScores(client().performRequest(request), List.of(40.0, 40.0, 37.0, 29.0, 29.0));
|
|
|
}
|
|
|
|
|
|
private void indexData(String data) throws IOException {
|