瀏覽代碼

[ML] add new normalize_above parameter to p_value significant terms heuristic (#78833)

This commit adds the new normalize_above parameter to the p_value significant
terms heuristic.

This parameter allows for consistent significance results at various scales. When a total count (in or out of the set background set) is above the normalize_above parameter, both the total set and the set including the term are scaled by normalize_above/count where count is term in the set or total set size.
Benjamin Trent 4 年之前
父節點
當前提交
843fa42c1e

+ 5 - 1
docs/reference/aggregations/bucket/significantterms-aggregation.asciidoc

@@ -404,6 +404,10 @@ the foreground set of "ended in failure" versus "NOT ended in failure".
 `"background_is_superset": false` indicates that the background set does 
 not contain the counts of the foreground set as they are filtered out.
 
+`"normalize_above": 1000` facilitates returning consistent significance results
+at various scales. `1000` indicates that term counts greater than `1000` are
+scaled down by a factor of `1000/term_count`.
+
 [source,console]
 --------------------------------------------------
 GET /_search
@@ -466,7 +470,7 @@ GET /_search
             ]
           }
         },
-        "p_value": {"background_is_superset": false}
+        "p_value": {"background_is_superset": false, "normalize_above": 1000}
       }
     }
   }

+ 6 - 2
test/framework/src/main/java/org/elasticsearch/search/aggregations/bucket/AbstractSignificanceHeuristicTestCase.java

@@ -37,6 +37,7 @@ import org.elasticsearch.search.aggregations.bucket.terms.SignificantTermsAggreg
 import org.elasticsearch.search.aggregations.bucket.terms.heuristic.SignificanceHeuristic;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.InternalAggregationTestCase;
+import org.elasticsearch.test.VersionUtils;
 
 import java.io.ByteArrayInputStream;
 import java.io.ByteArrayOutputStream;
@@ -52,7 +53,6 @@ import static java.util.Collections.emptyList;
 import static java.util.Collections.emptyMap;
 import static java.util.Collections.singletonList;
 import static org.elasticsearch.search.aggregations.AggregationBuilders.significantTerms;
-import static org.elasticsearch.test.VersionUtils.randomVersion;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.greaterThan;
@@ -69,9 +69,13 @@ public abstract class AbstractSignificanceHeuristicTestCase extends ESTestCase {
      */
     protected abstract SignificanceHeuristic getHeuristic();
 
+    protected Version randomVersion() {
+        return VersionUtils.randomVersion(random());
+    }
+
     // test that stream output can actually be read - does not replace bwc test
     public void testStreamResponse() throws Exception {
-        Version version = randomVersion(random());
+        Version version = randomVersion();
         InternalMappedSignificantTerms<?, ?> sigTerms = getRandomSignificantTerms(getHeuristic());
 
         // write

+ 72 - 11
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/aggs/heuristic/PValueScore.java

@@ -10,9 +10,11 @@ package org.elasticsearch.xpack.ml.aggs.heuristic;
 
 
 import org.apache.commons.math3.util.FastMath;
+import org.elasticsearch.Version;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.xcontent.ConstructingObjectParser;
+import org.elasticsearch.xcontent.ParseField;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentParser;
 import org.elasticsearch.search.aggregations.AggregationExecutionException;
@@ -20,47 +22,80 @@ import org.elasticsearch.search.aggregations.bucket.terms.heuristic.NXYSignifica
 import org.elasticsearch.search.aggregations.bucket.terms.heuristic.SignificanceHeuristic;
 
 import java.io.IOException;
+import java.util.Objects;
 
 import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
 
+/**
+ * Significant terms heuristic that calculates the p-value between the term existing in foreground and background sets.
+ *
+ * The p-value is the probability of obtaining test results at least as extreme as
+ * the results actually observed, under the assumption that the null hypothesis is
+ * correct. The p-value is calculated assuming that the foreground set and the
+ * background set are independent https://en.wikipedia.org/wiki/Bernoulli_trial, with the null
+ * hypothesis that the probabilities are the same.
+ */
 public class PValueScore extends NXYSignificanceHeuristic {
     public static final String NAME = "p_value";
+    public static final ParseField NORMALIZE_ABOVE = new ParseField("normalize_above");
     public static final ConstructingObjectParser<PValueScore, Void> PARSER = new ConstructingObjectParser<>(NAME, args -> {
         boolean backgroundIsSuperset = args[0] == null || (boolean) args[0];
-        return new PValueScore(backgroundIsSuperset);
+        return new PValueScore(backgroundIsSuperset, (Long)args[1]);
     });
     static {
         PARSER.declareBoolean(optionalConstructorArg(), BACKGROUND_IS_SUPERSET);
+        PARSER.declareLong(optionalConstructorArg(), NORMALIZE_ABOVE);
     }
 
     private static final MlChiSquaredDistribution CHI_SQUARED_DISTRIBUTION = new MlChiSquaredDistribution(1);
 
-    public PValueScore(boolean backgroundIsSuperset) {
+    // NOTE: `0` is a magic value indicating no normalization occurs
+    private final long normalizeAbove;
+
+    /**
+     * @param backgroundIsSuperset Does the background contain the foreground docs?
+     * @param normalizeAbove Should the results be normalized when above the given value.
+     *                       Note: `0` is a special value which means no normalization (set as such when `null` is provided)
+     */
+    public PValueScore(boolean backgroundIsSuperset, Long normalizeAbove) {
         super(true, backgroundIsSuperset);
+        if (normalizeAbove != null && normalizeAbove <= 0) {
+            throw new IllegalArgumentException(
+                "[" + NORMALIZE_ABOVE.getPreferredName() + "] must be a positive value, provided [" + normalizeAbove + "]"
+            );
+        }
+        this.normalizeAbove = normalizeAbove == null ? 0L : normalizeAbove;
     }
 
     public PValueScore(StreamInput in) throws IOException {
         super(true, in.readBoolean());
+        if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
+            normalizeAbove = in.readVLong();
+        } else {
+            normalizeAbove = 0L;
+        }
     }
 
     @Override
     public void writeTo(StreamOutput out) throws IOException {
         out.writeBoolean(backgroundIsSuperset);
+        if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
+            out.writeVLong(normalizeAbove);
+        }
     }
 
     @Override
-    public boolean equals(Object obj) {
-        if ((obj instanceof PValueScore) == false) {
-            return false;
-        }
-        return super.equals(obj);
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        if (super.equals(o) == false) return false;
+        PValueScore that = (PValueScore) o;
+        return normalizeAbove == that.normalizeAbove;
     }
 
     @Override
     public int hashCode() {
-        int result = NAME.hashCode();
-        result = 31 * result + super.hashCode();
-        return result;
+        return Objects.hash(super.hashCode(), normalizeAbove);
     }
 
     @Override
@@ -72,6 +107,9 @@ public class PValueScore extends NXYSignificanceHeuristic {
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject(NAME);
         builder.field(BACKGROUND_IS_SUPERSET.getPreferredName(), backgroundIsSuperset);
+        if (normalizeAbove > 0) {
+            builder.field(NORMALIZE_ABOVE.getPreferredName(), normalizeAbove);
+        }
         builder.endObject();
         return builder;
     }
@@ -113,6 +151,19 @@ public class PValueScore extends NXYSignificanceHeuristic {
             return 0.0;
         }
 
+        if (normalizeAbove > 0L) {
+            if (allDocsInClass > normalizeAbove) {
+                double factor = (double) normalizeAbove / allDocsInClass;
+                allDocsInClass = (long)(allDocsInClass * factor);
+                docsContainTermInClass = (long)(docsContainTermInClass * factor);
+            }
+            if (allDocsNotInClass > normalizeAbove) {
+                double factor = (double) normalizeAbove / allDocsNotInClass;
+                allDocsNotInClass = (long)(allDocsNotInClass * factor);
+                docsContainTermNotInClass = (long)(docsContainTermNotInClass * factor);
+            }
+        }
+
         // casting to `long` to round down to nearest whole number
         double epsAllDocsInClass = (long)eps(allDocsInClass);
         double epsAllDocsNotInClass = (long)eps(allDocsNotInClass);
@@ -164,15 +215,25 @@ public class PValueScore extends NXYSignificanceHeuristic {
     }
 
     public static class PValueScoreBuilder extends NXYBuilder {
+        private final long normalizeAbove;
 
-        public PValueScoreBuilder(boolean backgroundIsSuperset) {
+        public PValueScoreBuilder(boolean backgroundIsSuperset, Long normalizeAbove) {
             super(true, backgroundIsSuperset);
+            this.normalizeAbove = normalizeAbove == null ? 0L : normalizeAbove;
+            if (normalizeAbove != null && normalizeAbove <= 0) {
+                throw new IllegalArgumentException(
+                    "[" + NORMALIZE_ABOVE.getPreferredName() + "] must be a positive value, provided [" + normalizeAbove + "]"
+                );
+            }
         }
 
         @Override
         public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
             builder.startObject(NAME);
             builder.field(BACKGROUND_IS_SUPERSET.getPreferredName(), backgroundIsSuperset);
+            if (normalizeAbove > 0) {
+                builder.field(NORMALIZE_ABOVE.getPreferredName(), normalizeAbove);
+            }
             builder.endObject();
             return builder;
         }

+ 68 - 21
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/aggs/heuristic/PValueScoreTests.java

@@ -8,6 +8,7 @@
 package org.elasticsearch.xpack.ml.aggs.heuristic;
 
 import org.apache.commons.math3.util.FastMath;
+import org.elasticsearch.Version;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.xcontent.NamedXContentRegistry;
@@ -29,19 +30,27 @@ public class PValueScoreTests extends AbstractNXYSignificanceHeuristicTestCase {
 
     private static final double eps = 1e-9;
 
+    @Override
+    protected Version randomVersion() {
+        return Version.V_8_0_0;
+    }
+
     @Override
     protected SignificanceHeuristic getHeuristic() {
-        return new PValueScore(randomBoolean());
+        return new PValueScore(randomBoolean(), randomBoolean() ? null : randomLongBetween(1, 10000000L));
     }
 
     @Override
     protected SignificanceHeuristic getHeuristic(boolean includeNegatives, boolean backgroundIsSuperset) {
-        return new PValueScore(backgroundIsSuperset);
+        return new PValueScore(backgroundIsSuperset, randomBoolean() ? null : randomLongBetween(1, 10000000L));
     }
 
     @Override
     public void testAssertions() {
-        testBackgroundAssertions(new PValueScore(true), new PValueScore(false));
+        testBackgroundAssertions(
+            new PValueScore(true, randomBoolean() ? null : randomLongBetween(1, 10000000L)),
+            new PValueScore(false, randomBoolean() ? null : randomLongBetween(1, 10000000L))
+        );
     }
 
     @Override
@@ -59,7 +68,7 @@ public class PValueScoreTests extends AbstractNXYSignificanceHeuristicTestCase {
     }
 
     public void testPValueScore_WhenAllDocsContainTerm() {
-        PValueScore pValueScore = new PValueScore(randomBoolean());
+        PValueScore pValueScore = new PValueScore(randomBoolean(), null);
         long supersetCount = randomNonNegativeLong();
         long subsetCount = randomLongBetween(0L, supersetCount);
         assertThat(pValueScore.getScore(subsetCount, subsetCount, supersetCount, supersetCount), equalTo(0.0));
@@ -78,7 +87,7 @@ public class PValueScoreTests extends AbstractNXYSignificanceHeuristicTestCase {
             supersetFreqCount += subsetFreqCount;
         }
 
-        PValueScore pValueScore = new PValueScore(backgroundIsSuperset);
+        PValueScore pValueScore = new PValueScore(backgroundIsSuperset, null);
         assertThat(pValueScore.getScore(subsetFreqCount, subsetCount, supersetFreqCount, supersetCount), greaterThanOrEqualTo(700.0));
     }
 
@@ -95,7 +104,7 @@ public class PValueScoreTests extends AbstractNXYSignificanceHeuristicTestCase {
             supersetFreqCount += subsetFreqCount;
         }
 
-        PValueScore pValueScore = new PValueScore(backgroundIsSuperset);
+        PValueScore pValueScore = new PValueScore(backgroundIsSuperset, null);
         assertThat(
             pValueScore.getScore(subsetFreqCount, subsetCount, supersetFreqCount, supersetCount),
             allOf(lessThanOrEqualTo(5.0), greaterThanOrEqualTo(0.0))
@@ -104,66 +113,104 @@ public class PValueScoreTests extends AbstractNXYSignificanceHeuristicTestCase {
 
     public void testPValueScore() {
         assertThat(
-            FastMath.exp(-new PValueScore(false).getScore(10, 100, 100, 1000)),
+            FastMath.exp(-new PValueScore(false, null).getScore(10, 100, 100, 1000)),
             closeTo(1.0, eps)
         );
         assertThat(
-            FastMath.exp(-new PValueScore(false).getScore(10, 100, 10, 1000)),
+            FastMath.exp(-new PValueScore(false, 200L).getScore(10, 100, 100, 1000)),
+            closeTo(1.0, eps)
+        );
+        assertThat(
+            FastMath.exp(-new PValueScore(false, null).getScore(10, 100, 10, 1000)),
             closeTo(0.003972388976814195, eps)
         );
         assertThat(
-            FastMath.exp(-new PValueScore(false).getScore(10, 100, 200, 1000)),
+            FastMath.exp(-new PValueScore(false, 200L).getScore(10, 100, 10, 1000)),
+            closeTo(0.020890782016496683, eps)
+        );
+        assertThat(
+            FastMath.exp(-new PValueScore(false, null).getScore(10, 100, 200, 1000)),
+            closeTo(1.0, eps)
+        );
+        assertThat(
+            FastMath.exp(-new PValueScore(false, 200L).getScore(10, 100, 200, 1000)),
+            closeTo(1.0, eps)
+        );
+        assertThat(
+            FastMath.exp(-new PValueScore(false, null).getScore(20, 10000, 5, 10000)),
             closeTo(1.0, eps)
         );
         assertThat(
-            FastMath.exp(-new PValueScore(false).getScore(20, 10000, 5, 10000)),
+            FastMath.exp(-new PValueScore(false, 200L).getScore(20, 10000, 5, 10000)),
             closeTo(1.0, eps)
         );
     }
 
     public void testSmallChanges() {
         assertThat(
-            FastMath.exp(-new PValueScore(false).getScore(1, 4205, 0, 821496)),
+            FastMath.exp(-new PValueScore(false, null).getScore(1, 4205, 0, 821496)),
             closeTo(0.9999037287868853, eps)
         );
+
         // Same(ish) ratios
         assertThat(
-            FastMath.exp(-new PValueScore(false).getScore(10, 4205, 195, 82149)),
+            FastMath.exp(-new PValueScore(false, null).getScore(10, 4205, 195, 82149)),
             closeTo(0.9995943820612134, eps)
         );
         assertThat(
-            FastMath.exp(-new PValueScore(false).getScore(10, 4205, 1950, 821496)),
+            FastMath.exp(-new PValueScore(false, 100L).getScore(10, 4205, 195, 82149)),
+            closeTo(0.9876284079864467, eps)
+        );
+
+        assertThat(
+            FastMath.exp(-new PValueScore(false, null).getScore(10, 4205, 1950, 821496)),
             closeTo(0.9999942565428899, eps)
         );
+        assertThat(
+            FastMath.exp(-new PValueScore(false,  100L).getScore(10, 4205, 1950, 821496)),
+            closeTo(1.0, eps)
+        );
 
         // 4% vs 0%
         assertThat(
-            FastMath.exp(-new PValueScore(false).getScore(168, 4205, 0, 821496)),
+            FastMath.exp(-new PValueScore(false, null).getScore(168, 4205, 0, 821496)),
             closeTo(1.2680918648731284e-26, eps)
         );
+        assertThat(
+            FastMath.exp(-new PValueScore(false, 100L).getScore(168, 4205, 0, 821496)),
+            closeTo(0.3882951183744724, eps)
+        );
         // 4% vs 2%
         assertThat(
-            FastMath.exp(-new PValueScore(false).getScore(168, 4205, 16429, 821496)),
+            FastMath.exp(-new PValueScore(false, null).getScore(168, 4205, 16429, 821496)),
             closeTo(8.542608559219833e-5, eps)
         );
+        assertThat(
+            FastMath.exp(-new PValueScore(false, 100L).getScore(168, 4205, 16429, 821496)),
+            closeTo(0.579463586350363, eps)
+        );
         // 4% vs 3.5%
         assertThat(
-            FastMath.exp(-new PValueScore(false).getScore(168, 4205, 28752, 821496)),
+            FastMath.exp(-new PValueScore(false, null).getScore(168, 4205, 28752, 821496)),
             closeTo(0.8833950526957098, eps)
         );
+        assertThat(
+            FastMath.exp(-new PValueScore(false, 100L).getScore(168, 4205, 28752, 821496)),
+            closeTo(1.0, eps)
+        );
     }
 
     public void testLargerValues() {
         assertThat(
-            FastMath.exp(-new PValueScore(false).getScore(101000, 1000000, 500000, 5000000)),
+            FastMath.exp(-new PValueScore(false, null).getScore(101000, 1000000, 500000, 5000000)),
             closeTo(1.0, eps)
         );
         assertThat(
-            FastMath.exp(-new PValueScore(false).getScore(102000, 1000000, 500000, 5000000)),
+            FastMath.exp(-new PValueScore(false, null).getScore(102000, 1000000, 500000, 5000000)),
             closeTo(1.0, eps)
         );
         assertThat(
-            FastMath.exp(-new PValueScore(false).getScore(103000, 1000000, 500000, 5000000)),
+            FastMath.exp(-new PValueScore(false, null).getScore(103000, 1000000, 500000, 5000000)),
             closeTo(1.0, eps)
         );
     }
@@ -171,7 +218,7 @@ public class PValueScoreTests extends AbstractNXYSignificanceHeuristicTestCase {
     public void testScoreIsZero() {
         for (int j = 0; j < 10; j++) {
             assertThat(
-                new PValueScore(false).getScore((j + 1)*5, (j + 10)*100, (j + 1)*10, (j + 10)*100),
+                new PValueScore(false, null).getScore((j + 1)*5, (j + 10)*100, (j + 1)*10, (j + 10)*100),
                 equalTo(0.0)
             );
         }
@@ -179,7 +226,7 @@ public class PValueScoreTests extends AbstractNXYSignificanceHeuristicTestCase {
 
     public void testIncreasedSubsetIncreasedScore() {
         final Function<Long, Double> getScore = (subsetFreq) ->
-            new PValueScore(false).getScore(subsetFreq, 5000, 5, 5000);
+            new PValueScore(false, null).getScore(subsetFreq, 5000, 5, 5000);
         double priorScore = getScore.apply(5L);
         assertThat(priorScore, greaterThanOrEqualTo(0.0));
         for (int j = 1; j < 11; j++) {