|
@@ -22,7 +22,11 @@ package org.elasticsearch.client;
|
|
|
import org.elasticsearch.action.search.SearchRequest;
|
|
|
import org.elasticsearch.action.support.IndicesOptions;
|
|
|
import org.elasticsearch.index.query.MatchAllQueryBuilder;
|
|
|
+import org.elasticsearch.index.rankeval.DiscountedCumulativeGain;
|
|
|
import org.elasticsearch.index.rankeval.EvalQueryQuality;
|
|
|
+import org.elasticsearch.index.rankeval.EvaluationMetric;
|
|
|
+import org.elasticsearch.index.rankeval.ExpectedReciprocalRank;
|
|
|
+import org.elasticsearch.index.rankeval.MeanReciprocalRank;
|
|
|
import org.elasticsearch.index.rankeval.PrecisionAtK;
|
|
|
import org.elasticsearch.index.rankeval.RankEvalRequest;
|
|
|
import org.elasticsearch.index.rankeval.RankEvalResponse;
|
|
@@ -35,8 +39,10 @@ import org.junit.Before;
|
|
|
|
|
|
import java.io.IOException;
|
|
|
import java.util.ArrayList;
|
|
|
+import java.util.Arrays;
|
|
|
import java.util.List;
|
|
|
import java.util.Map;
|
|
|
+import java.util.function.Supplier;
|
|
|
import java.util.stream.Collectors;
|
|
|
import java.util.stream.Stream;
|
|
|
|
|
@@ -64,15 +70,7 @@ public class RankEvalIT extends ESRestHighLevelClientTestCase {
|
|
|
* calculation where all unlabeled documents are treated as not relevant.
|
|
|
*/
|
|
|
public void testRankEvalRequest() throws IOException {
|
|
|
- SearchSourceBuilder testQuery = new SearchSourceBuilder();
|
|
|
- testQuery.query(new MatchAllQueryBuilder());
|
|
|
- List<RatedDocument> amsterdamRatedDocs = createRelevant("index" , "amsterdam1", "amsterdam2", "amsterdam3", "amsterdam4");
|
|
|
- amsterdamRatedDocs.addAll(createRelevant("index2", "amsterdam0"));
|
|
|
- RatedRequest amsterdamRequest = new RatedRequest("amsterdam_query", amsterdamRatedDocs, testQuery);
|
|
|
- RatedRequest berlinRequest = new RatedRequest("berlin_query", createRelevant("index", "berlin"), testQuery);
|
|
|
- List<RatedRequest> specifications = new ArrayList<>();
|
|
|
- specifications.add(amsterdamRequest);
|
|
|
- specifications.add(berlinRequest);
|
|
|
+ List<RatedRequest> specifications = createTestEvaluationSpec();
|
|
|
PrecisionAtK metric = new PrecisionAtK(1, false, 10);
|
|
|
RankEvalSpec spec = new RankEvalSpec(specifications, metric);
|
|
|
|
|
@@ -114,6 +112,38 @@ public class RankEvalIT extends ESRestHighLevelClientTestCase {
|
|
|
response = execute(rankEvalRequest, highLevelClient()::rankEval, highLevelClient()::rankEvalAsync);
|
|
|
}
|
|
|
|
|
|
+ private static List<RatedRequest> createTestEvaluationSpec() {
|
|
|
+ SearchSourceBuilder testQuery = new SearchSourceBuilder();
|
|
|
+ testQuery.query(new MatchAllQueryBuilder());
|
|
|
+ List<RatedDocument> amsterdamRatedDocs = createRelevant("index" , "amsterdam1", "amsterdam2", "amsterdam3", "amsterdam4");
|
|
|
+ amsterdamRatedDocs.addAll(createRelevant("index2", "amsterdam0"));
|
|
|
+ RatedRequest amsterdamRequest = new RatedRequest("amsterdam_query", amsterdamRatedDocs, testQuery);
|
|
|
+ RatedRequest berlinRequest = new RatedRequest("berlin_query", createRelevant("index", "berlin"), testQuery);
|
|
|
+ List<RatedRequest> specifications = new ArrayList<>();
|
|
|
+ specifications.add(amsterdamRequest);
|
|
|
+ specifications.add(berlinRequest);
|
|
|
+ return specifications;
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Test case checks that the default metrics are registered and usable
|
|
|
+ */
|
|
|
+ public void testMetrics() throws IOException {
|
|
|
+ List<RatedRequest> specifications = createTestEvaluationSpec();
|
|
|
+ List<Supplier<EvaluationMetric>> metrics = Arrays.asList(PrecisionAtK::new, MeanReciprocalRank::new, DiscountedCumulativeGain::new,
|
|
|
+ () -> new ExpectedReciprocalRank(1));
|
|
|
+ double expectedScores[] = new double[] {0.4285714285714286, 0.75, 1.6408962261063627, 0.4407738095238095};
|
|
|
+ int i = 0;
|
|
|
+ for (Supplier<EvaluationMetric> metricSupplier : metrics) {
|
|
|
+ RankEvalSpec spec = new RankEvalSpec(specifications, metricSupplier.get());
|
|
|
+
|
|
|
+ RankEvalRequest rankEvalRequest = new RankEvalRequest(spec, new String[] { "index", "index2" });
|
|
|
+ RankEvalResponse response = execute(rankEvalRequest, highLevelClient()::rankEval, highLevelClient()::rankEvalAsync);
|
|
|
+ assertEquals(expectedScores[i], response.getMetricScore(), Double.MIN_VALUE);
|
|
|
+ i++;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
private static List<RatedDocument> createRelevant(String indexName, String... docs) {
|
|
|
return Stream.of(docs).map(s -> new RatedDocument(indexName, s, 1)).collect(Collectors.toList());
|
|
|
}
|