Browse Source

[ML] fix random sampling background query consistency (#83676)

There was a consistency bug where the documents returned by the created scorer could change while looking at the same shard. This can occur if multiple weights are created from the same query.

For scenarios like Significant Terms/Text, we need a consistent view of each shard when using the same probability and seed.

This commit ensures this by creating a new random value supplier seeded by the shard hash & seed.
Benjamin Trent 3 years ago
parent
commit
7d1eb52253

+ 1 - 3
server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplingQuery.java

@@ -32,7 +32,6 @@ import java.util.function.IntSupplier;
 public final class RandomSamplingQuery extends Query {
 
     private final double p;
-    private final SplittableRandom splittableRandom;
     private final int seed;
     private final int hash;
 
@@ -49,7 +48,6 @@ public final class RandomSamplingQuery extends Query {
         this.p = p;
         this.seed = seed;
         this.hash = hash;
-        this.splittableRandom = new SplittableRandom(BitMixer.mix(hash, seed));
     }
 
     @Override
@@ -78,7 +76,7 @@ public final class RandomSamplingQuery extends Query {
 
             @Override
             public Scorer scorer(LeafReaderContext context) {
-                final SplittableRandom random = splittableRandom.split();
+                final SplittableRandom random = new SplittableRandom(BitMixer.mix(hash, seed));
                 int maxDoc = context.reader().maxDoc();
                 return new ConstantScoreScorer(
                     this,

+ 26 - 0
server/src/test/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomDocIDSetIteratorTests.java

@@ -11,8 +11,12 @@ package org.elasticsearch.search.aggregations.bucket.sampler.random;
 import org.apache.lucene.search.DocIdSetIterator;
 import org.elasticsearch.test.ESTestCase;
 
+import java.util.ArrayList;
+import java.util.List;
 import java.util.SplittableRandom;
 
+import static org.hamcrest.Matchers.equalTo;
+
 public class RandomDocIDSetIteratorTests extends ESTestCase {
 
     public void testRandomSampler() {
@@ -43,4 +47,26 @@ public class RandomDocIDSetIteratorTests extends ESTestCase {
         }
     }
 
+    public void testRandomSamplerConsistency() {
+        int maxDoc = 10000;
+        int seed = randomInt();
+
+        for (int i = 1; i < 100; i++) {
+            double p = i / 100.0;
+            SplittableRandom random = new SplittableRandom(seed);
+            List<Integer> iterationOne = new ArrayList<>();
+            RandomSamplingQuery.RandomSamplingIterator iter = new RandomSamplingQuery.RandomSamplingIterator(maxDoc, p, random::nextInt);
+            while (iter.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
+                iterationOne.add(iter.docID());
+            }
+            random = new SplittableRandom(seed);
+            List<Integer> iterationTwo = new ArrayList<>();
+            iter = new RandomSamplingQuery.RandomSamplingIterator(maxDoc, p, random::nextInt);
+            while (iter.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
+                iterationTwo.add(iter.docID());
+            }
+            assertThat(iterationOne, equalTo(iterationTwo));
+        }
+    }
+
 }