Browse Source

Add details section for dcg ranking metric (#31177)

While the other two ranking evaluation metrics (precicion and reciprocal rank)
already provide a more detailed output for how their score is calculated, the
discounted cumulative gain metric (dcg) and its normalized variant are lacking
this until now. Its not really clear which level of detail might be useful for
debugging and understanding the final metric calculation, but this change adds a
`metric_details` section to REST output that contains some information about the
evaluation details.
Christoph Büscher 7 years ago
parent
commit
a0d6c19e75

+ 4 - 2
client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java

@@ -20,6 +20,7 @@
 package org.elasticsearch.client;
 
 import com.fasterxml.jackson.core.JsonParseException;
+
 import org.apache.http.Header;
 import org.apache.http.HttpEntity;
 import org.apache.http.HttpHost;
@@ -607,7 +608,7 @@ public class RestHighLevelClientTests extends ESTestCase {
 
     public void testProvidedNamedXContents() {
         List<NamedXContentRegistry.Entry> namedXContents = RestHighLevelClient.getProvidedNamedXContents();
-        assertEquals(7, namedXContents.size());
+        assertEquals(8, namedXContents.size());
         Map<Class<?>, Integer> categories = new HashMap<>();
         List<String> names = new ArrayList<>();
         for (NamedXContentRegistry.Entry namedXContent : namedXContents) {
@@ -625,9 +626,10 @@ public class RestHighLevelClientTests extends ESTestCase {
         assertTrue(names.contains(PrecisionAtK.NAME));
         assertTrue(names.contains(DiscountedCumulativeGain.NAME));
         assertTrue(names.contains(MeanReciprocalRank.NAME));
-        assertEquals(Integer.valueOf(2), categories.get(MetricDetail.class));
+        assertEquals(Integer.valueOf(3), categories.get(MetricDetail.class));
         assertTrue(names.contains(PrecisionAtK.NAME));
         assertTrue(names.contains(MeanReciprocalRank.NAME));
+        assertTrue(names.contains(DiscountedCumulativeGain.NAME));
     }
 
     private static class TrackingActionListener implements ActionListener<Integer> {

+ 131 - 11
modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGain.java

@@ -36,6 +36,7 @@ import java.util.Objects;
 import java.util.Optional;
 import java.util.stream.Collectors;
 
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
 import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
 import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRatings;
 
@@ -129,26 +130,31 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
                 .collect(Collectors.toList());
         List<RatedSearchHit> ratedHits = joinHitsWithRatings(hits, ratedDocs);
         List<Integer> ratingsInSearchHits = new ArrayList<>(ratedHits.size());
+        int unratedResults = 0;
         for (RatedSearchHit hit : ratedHits) {
-            // unknownDocRating might be null, which means it will be unrated docs are
-            // ignored in the dcg calculation
-            // we still need to add them as a placeholder so the rank of the subsequent
-            // ratings is correct
+            // unknownDocRating might be null, in which case unrated docs will be ignored in the dcg calculation.
+            // we still need to add them as a placeholder so the rank of the subsequent ratings is correct
             ratingsInSearchHits.add(hit.getRating().orElse(unknownDocRating));
+            if (hit.getRating().isPresent() == false) {
+                unratedResults++;
+            }
         }
-        double dcg = computeDCG(ratingsInSearchHits);
+        final double dcg = computeDCG(ratingsInSearchHits);
+        double result = dcg;
+        double idcg = 0;
 
         if (normalize) {
             Collections.sort(allRatings, Comparator.nullsLast(Collections.reverseOrder()));
-            double idcg = computeDCG(allRatings.subList(0, Math.min(ratingsInSearchHits.size(), allRatings.size())));
-            if (idcg > 0) {
-                dcg = dcg / idcg;
+            idcg = computeDCG(allRatings.subList(0, Math.min(ratingsInSearchHits.size(), allRatings.size())));
+            if (idcg != 0) {
+                result = dcg / idcg;
             } else {
-                dcg = 0;
+                result = 0;
             }
         }
-        EvalQueryQuality evalQueryQuality = new EvalQueryQuality(taskId, dcg);
+        EvalQueryQuality evalQueryQuality = new EvalQueryQuality(taskId, result);
         evalQueryQuality.addHitsAndRatings(ratedHits);
+        evalQueryQuality.setMetricDetails(new Detail(dcg, idcg, unratedResults));
         return evalQueryQuality;
     }
 
@@ -167,7 +173,7 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
     private static final ParseField K_FIELD = new ParseField("k");
     private static final ParseField NORMALIZE_FIELD = new ParseField("normalize");
     private static final ParseField UNKNOWN_DOC_RATING_FIELD = new ParseField("unknown_doc_rating");
-    private static final ConstructingObjectParser<DiscountedCumulativeGain, Void> PARSER = new ConstructingObjectParser<>("dcg_at", false,
+    private static final ConstructingObjectParser<DiscountedCumulativeGain, Void> PARSER = new ConstructingObjectParser<>("dcg", false,
             args -> {
                 Boolean normalized = (Boolean) args[0];
                 Integer optK = (Integer) args[2];
@@ -217,4 +223,118 @@ public class DiscountedCumulativeGain implements EvaluationMetric {
     public final int hashCode() {
         return Objects.hash(normalize, unknownDocRating, k);
     }
+
+    public static final class Detail implements MetricDetail {
+
+        private static ParseField DCG_FIELD = new ParseField("dcg");
+        private static ParseField IDCG_FIELD = new ParseField("ideal_dcg");
+        private static ParseField NDCG_FIELD = new ParseField("normalized_dcg");
+        private static ParseField UNRATED_FIELD = new ParseField("unrated_docs");
+        private final double dcg;
+        private final double idcg;
+        private final int unratedDocs;
+
+        Detail(double dcg, double idcg, int unratedDocs) {
+            this.dcg = dcg;
+            this.idcg = idcg;
+            this.unratedDocs = unratedDocs;
+        }
+
+        Detail(StreamInput in) throws IOException {
+            this.dcg = in.readDouble();
+            this.idcg = in.readDouble();
+            this.unratedDocs = in.readVInt();
+        }
+
+        @Override
+        public
+        String getMetricName() {
+            return NAME;
+        }
+
+        @Override
+        public XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.field(DCG_FIELD.getPreferredName(), this.dcg);
+            if (this.idcg != 0) {
+                builder.field(IDCG_FIELD.getPreferredName(), this.idcg);
+                builder.field(NDCG_FIELD.getPreferredName(), this.dcg / this.idcg);
+            }
+            builder.field(UNRATED_FIELD.getPreferredName(), this.unratedDocs);
+            return builder;
+        }
+
+        private static final ConstructingObjectParser<Detail, Void> PARSER = new ConstructingObjectParser<>(NAME, true, args -> {
+            return new Detail((Double) args[0], (Double) args[1] != null ? (Double) args[1] : 0.0d, (Integer) args[2]);
+        });
+
+        static {
+            PARSER.declareDouble(constructorArg(), DCG_FIELD);
+            PARSER.declareDouble(optionalConstructorArg(), IDCG_FIELD);
+            PARSER.declareInt(constructorArg(), UNRATED_FIELD);
+        }
+
+        public static Detail fromXContent(XContentParser parser) {
+            return PARSER.apply(parser, null);
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeDouble(this.dcg);
+            out.writeDouble(this.idcg);
+            out.writeVInt(this.unratedDocs);
+        }
+
+        @Override
+        public String getWriteableName() {
+            return NAME;
+        }
+
+        /**
+         * @return the discounted cumulative gain
+         */
+        public double getDCG() {
+            return this.dcg;
+        }
+
+        /**
+         * @return the ideal discounted cumulative gain, can be 0 if nothing was computed, e.g. because no normalization was required
+         */
+        public double getIDCG() {
+            return this.idcg;
+        }
+
+        /**
+         * @return the normalized discounted cumulative gain, can be 0 if nothing was computed, e.g. because no normalization was required
+         */
+        public double getNDCG() {
+            return (this.idcg != 0) ? this.dcg / this.idcg : 0;
+        }
+
+        /**
+         * @return the number of unrated documents in the search results
+         */
+        public Object getUnratedDocs() {
+            return this.unratedDocs;
+        }
+
+        @Override
+        public boolean equals(Object obj) {
+            if (this == obj) {
+                return true;
+            }
+            if (obj == null || getClass() != obj.getClass()) {
+                return false;
+            }
+            DiscountedCumulativeGain.Detail other = (DiscountedCumulativeGain.Detail) obj;
+            return (this.dcg == other.dcg &&
+                    this.idcg == other.idcg &&
+                    this.unratedDocs == other.unratedDocs);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(this.dcg, this.idcg, this.unratedDocs);
+        }
+    }
 }
+

+ 2 - 0
modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalNamedXContentProvider.java

@@ -41,6 +41,8 @@ public class RankEvalNamedXContentProvider implements NamedXContentProvider {
                 PrecisionAtK.Detail::fromXContent));
         namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(MeanReciprocalRank.NAME),
                 MeanReciprocalRank.Detail::fromXContent));
+        namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(DiscountedCumulativeGain.NAME),
+                DiscountedCumulativeGain.Detail::fromXContent));
         return namedXContent;
     }
 }

+ 3 - 2
modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java

@@ -61,8 +61,9 @@ public class RankEvalPlugin extends Plugin implements ActionPlugin {
         namedWriteables.add(
                 new NamedWriteableRegistry.Entry(EvaluationMetric.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain::new));
         namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, PrecisionAtK.NAME, PrecisionAtK.Detail::new));
-        namedWriteables
-                .add(new NamedWriteableRegistry.Entry(MetricDetail.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Detail::new));
+        namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Detail::new));
+        namedWriteables.add(
+                new NamedWriteableRegistry.Entry(MetricDetail.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain.Detail::new));
         return namedWriteables;
     }
 

+ 22 - 4
modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainTests.java

@@ -19,6 +19,7 @@
 
 package org.elasticsearch.index.rankeval;
 
+import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.text.Text;
@@ -254,9 +255,8 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
 
     public static DiscountedCumulativeGain createTestItem() {
         boolean normalize = randomBoolean();
-        Integer unknownDocRating = Integer.valueOf(randomIntBetween(0, 1000));
-
-        return new DiscountedCumulativeGain(normalize, unknownDocRating, 10);
+        Integer unknownDocRating = frequently() ? Integer.valueOf(randomIntBetween(0, 1000)) : null;
+        return new DiscountedCumulativeGain(normalize, unknownDocRating, randomIntBetween(1, 10));
     }
 
     public void testXContentRoundtrip() throws IOException {
@@ -283,7 +283,25 @@ public class DiscountedCumulativeGainTests extends ESTestCase {
             parser.nextToken();
             XContentParseException exception = expectThrows(XContentParseException.class,
                     () -> DiscountedCumulativeGain.fromXContent(parser));
-            assertThat(exception.getMessage(), containsString("[dcg_at] unknown field"));
+            assertThat(exception.getMessage(), containsString("[dcg] unknown field"));
+        }
+    }
+
+    public void testMetricDetails() {
+        double dcg = randomDoubleBetween(0, 1, true);
+        double idcg = randomBoolean() ? 0.0 : randomDoubleBetween(0, 1, true);
+        double expectedNdcg = idcg != 0 ? dcg / idcg : 0.0;
+        int unratedDocs = randomIntBetween(0, 100);
+        DiscountedCumulativeGain.Detail detail = new DiscountedCumulativeGain.Detail(dcg, idcg, unratedDocs);
+        assertEquals(dcg, detail.getDCG(), 0.0);
+        assertEquals(idcg, detail.getIDCG(), 0.0);
+        assertEquals(expectedNdcg, detail.getNDCG(), 0.0);
+        assertEquals(unratedDocs, detail.getUnratedDocs());
+        if (idcg != 0) {
+            assertEquals("{\"dcg\":{\"dcg\":" + dcg + ",\"ideal_dcg\":" + idcg + ",\"normalized_dcg\":" + expectedNdcg
+                    + ",\"unrated_docs\":" + unratedDocs + "}}", Strings.toString(detail));
+        } else {
+            assertEquals("{\"dcg\":{\"dcg\":" + dcg + ",\"unrated_docs\":" + unratedDocs + "}}", Strings.toString(detail));
         }
     }
 

+ 12 - 2
modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/EvalQueryQualityTests.java

@@ -68,10 +68,20 @@ public class EvalQueryQualityTests extends ESTestCase {
         EvalQueryQuality evalQueryQuality = new EvalQueryQuality(randomAlphaOfLength(10),
                 randomDoubleBetween(0.0, 1.0, true));
         if (randomBoolean()) {
-            if (randomBoolean()) {
+            int metricDetail = randomIntBetween(0, 2);
+            switch (metricDetail) {
+            case 0:
                 evalQueryQuality.setMetricDetails(new PrecisionAtK.Detail(randomIntBetween(0, 1000), randomIntBetween(0, 1000)));
-            } else {
+                break;
+            case 1:
                 evalQueryQuality.setMetricDetails(new MeanReciprocalRank.Detail(randomIntBetween(0, 1000)));
+                break;
+            case 2:
+                evalQueryQuality.setMetricDetails(new DiscountedCumulativeGain.Detail(randomDoubleBetween(0, 1, true),
+                        randomBoolean() ? randomDoubleBetween(0, 1, true) : 0, randomInt()));
+                break;
+            default:
+                throw new IllegalArgumentException("illegal randomized value in test");
             }
         }
         evalQueryQuality.addHitsAndRatings(ratedHits);