1
0
Эх сурвалжийг харах

Add new shard_seed parameter for random_sampler agg (#104830)

While it is important to ensure IID via shard hashes, it can become a
barrier and a complexity when testing out random_sampler.

So, this commit adds a new optional parameter called `shardSeed`, which,
when combined with `seed` ensures 100% consistent sampling over shards
where data is exactly the same.
Benjamin Trent 1 жил өмнө
parent
commit
1dd2712bad

+ 5 - 0
docs/changelog/104830.yaml

@@ -0,0 +1,5 @@
+pr: 104830
+summary: All new `shard_seed` parameter for `random_sampler` agg
+area: Aggregations
+type: enhancement
+issues: []

+ 43 - 0
server/src/internalClusterTest/java/org/elasticsearch/search/aggregations/bucket/RandomSamplerIT.java

@@ -24,6 +24,7 @@ import static org.elasticsearch.search.aggregations.AggregationBuilders.avg;
 import static org.elasticsearch.search.aggregations.AggregationBuilders.histogram;
 import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse;
 import static org.elasticsearch.xcontent.XContentFactory.jsonBuilder;
+import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.lessThan;
 
 @ESIntegTestCase.SuiteScopeTestCase
@@ -84,6 +85,48 @@ public class RandomSamplerIT extends ESIntegTestCase {
         ensureSearchable();
     }
 
+    public void testRandomSamplerConsistentSeed() {
+        double[] sampleMonotonicValue = new double[1];
+        double[] sampleNumericValue = new double[1];
+        long[] sampledDocCount = new long[1];
+        // initialize the values
+        assertResponse(
+            prepareSearch("idx").setPreference("shard:0")
+                .addAggregation(
+                    new RandomSamplerAggregationBuilder("sampler").setProbability(PROBABILITY)
+                        .setSeed(0)
+                        .subAggregation(avg("mean_monotonic").field(MONOTONIC_VALUE))
+                        .subAggregation(avg("mean_numeric").field(NUMERIC_VALUE))
+                        .setShardSeed(42)
+                ),
+            response -> {
+                InternalRandomSampler sampler = response.getAggregations().get("sampler");
+                sampleMonotonicValue[0] = ((Avg) sampler.getAggregations().get("mean_monotonic")).getValue();
+                sampleNumericValue[0] = ((Avg) sampler.getAggregations().get("mean_numeric")).getValue();
+                sampledDocCount[0] = sampler.getDocCount();
+            }
+        );
+
+        for (int i = 0; i < NUM_SAMPLE_RUNS; i++) {
+            assertResponse(
+                prepareSearch("idx").setPreference("shard:0")
+                    .addAggregation(
+                        new RandomSamplerAggregationBuilder("sampler").setProbability(PROBABILITY)
+                            .setSeed(0)
+                            .subAggregation(avg("mean_monotonic").field(MONOTONIC_VALUE))
+                            .subAggregation(avg("mean_numeric").field(NUMERIC_VALUE))
+                            .setShardSeed(42)
+                    ),
+                response -> {
+                    InternalRandomSampler sampler = response.getAggregations().get("sampler");
+                    assertThat(((Avg) sampler.getAggregations().get("mean_monotonic")).getValue(), equalTo(sampleMonotonicValue[0]));
+                    assertThat(((Avg) sampler.getAggregations().get("mean_numeric")).getValue(), equalTo(sampleNumericValue[0]));
+                    assertThat(sampler.getDocCount(), equalTo(sampledDocCount[0]));
+                }
+            );
+        }
+    }
+
     public void testRandomSampler() {
         double[] sampleMonotonicValue = new double[1];
         double[] sampleNumericValue = new double[1];

+ 1 - 0
server/src/main/java/org/elasticsearch/TransportVersions.java

@@ -133,6 +133,7 @@ public class TransportVersions {
     public static final TransportVersion INDEX_REQUEST_NORMALIZED_BYTES_PARSED = def(8_593_00_0);
     public static final TransportVersion INGEST_GRAPH_STRUCTURE_EXCEPTION = def(8_594_00_0);
     public static final TransportVersion ML_MODEL_IN_SERVICE_SETTINGS = def(8_595_00_0);
+    public static final TransportVersion RANDOM_AGG_SHARD_SEED = def(8_596_00_0);
 
     /*
      * STOP! READ THIS FIRST! No, really,

+ 17 - 2
server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/InternalRandomSampler.java

@@ -8,6 +8,7 @@
 
 package org.elasticsearch.search.aggregations.bucket.sampler.random;
 
+import org.elasticsearch.TransportVersions;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.core.Releasables;
@@ -29,18 +30,21 @@ public class InternalRandomSampler extends InternalSingleBucketAggregation imple
     public static final String PARSER_NAME = "random_sampler";
 
     private final int seed;
+    private final Integer shardSeed;
     private final double probability;
 
     InternalRandomSampler(
         String name,
         long docCount,
         int seed,
+        Integer shardSeed,
         double probability,
         InternalAggregations subAggregations,
         Map<String, Object> metadata
     ) {
         super(name, docCount, subAggregations, metadata);
         this.seed = seed;
+        this.shardSeed = shardSeed;
         this.probability = probability;
     }
 
@@ -51,6 +55,11 @@ public class InternalRandomSampler extends InternalSingleBucketAggregation imple
         super(in);
         this.seed = in.readInt();
         this.probability = in.readDouble();
+        if (in.getTransportVersion().onOrAfter(TransportVersions.RANDOM_AGG_SHARD_SEED)) {
+            this.shardSeed = in.readOptionalInt();
+        } else {
+            this.shardSeed = null;
+        }
     }
 
     @Override
@@ -58,6 +67,9 @@ public class InternalRandomSampler extends InternalSingleBucketAggregation imple
         super.doWriteTo(out);
         out.writeInt(seed);
         out.writeDouble(probability);
+        if (out.getTransportVersion().onOrAfter(TransportVersions.RANDOM_AGG_SHARD_SEED)) {
+            out.writeOptionalInt(shardSeed);
+        }
     }
 
     @Override
@@ -72,7 +84,7 @@ public class InternalRandomSampler extends InternalSingleBucketAggregation imple
 
     @Override
     protected InternalSingleBucketAggregation newAggregation(String name, long docCount, InternalAggregations subAggregations) {
-        return new InternalRandomSampler(name, docCount, seed, probability, subAggregations, metadata);
+        return new InternalRandomSampler(name, docCount, seed, shardSeed, probability, subAggregations, metadata);
     }
 
     @Override
@@ -105,12 +117,15 @@ public class InternalRandomSampler extends InternalSingleBucketAggregation imple
     }
 
     public SamplingContext buildContext() {
-        return new SamplingContext(probability, seed);
+        return new SamplingContext(probability, seed, shardSeed);
     }
 
     @Override
     public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
         builder.field(RandomSamplerAggregationBuilder.SEED.getPreferredName(), seed);
+        if (shardSeed != null) {
+            builder.field(RandomSamplerAggregationBuilder.SHARD_SEED.getPreferredName(), shardSeed);
+        }
         builder.field(RandomSamplerAggregationBuilder.PROBABILITY.getPreferredName(), probability);
         builder.field(CommonFields.DOC_COUNT.getPreferredName(), getDocCount());
         getAggregations().toXContentInternal(builder, params);

+ 21 - 3
server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplerAggregationBuilder.java

@@ -34,6 +34,7 @@ public class RandomSamplerAggregationBuilder extends AbstractAggregationBuilder<
 
     static final ParseField PROBABILITY = new ParseField("probability");
     static final ParseField SEED = new ParseField("seed");
+    static final ParseField SHARD_SEED = new ParseField("shard_seed");
 
     public static final ObjectParser<RandomSamplerAggregationBuilder, String> PARSER = ObjectParser.fromBuilder(
         RandomSamplerAggregationBuilder.NAME,
@@ -41,10 +42,12 @@ public class RandomSamplerAggregationBuilder extends AbstractAggregationBuilder<
     );
     static {
         PARSER.declareInt(RandomSamplerAggregationBuilder::setSeed, SEED);
+        PARSER.declareInt(RandomSamplerAggregationBuilder::setShardSeed, SHARD_SEED);
         PARSER.declareDouble(RandomSamplerAggregationBuilder::setProbability, PROBABILITY);
     }
 
     private int seed = Randomness.get().nextInt();
+    private Integer shardSeed;
     private double p;
 
     public RandomSamplerAggregationBuilder(String name) {
@@ -67,10 +70,18 @@ public class RandomSamplerAggregationBuilder extends AbstractAggregationBuilder<
         return this;
     }
 
+    public RandomSamplerAggregationBuilder setShardSeed(int shardSeed) {
+        this.shardSeed = shardSeed;
+        return this;
+    }
+
     public RandomSamplerAggregationBuilder(StreamInput in) throws IOException {
         super(in);
         this.p = in.readDouble();
         this.seed = in.readInt();
+        if (in.getTransportVersion().onOrAfter(TransportVersions.RANDOM_AGG_SHARD_SEED)) {
+            this.shardSeed = in.readOptionalInt();
+        }
     }
 
     protected RandomSamplerAggregationBuilder(
@@ -81,12 +92,16 @@ public class RandomSamplerAggregationBuilder extends AbstractAggregationBuilder<
         super(clone, factoriesBuilder, metadata);
         this.p = clone.p;
         this.seed = clone.seed;
+        this.shardSeed = clone.shardSeed;
     }
 
     @Override
     protected void doWriteTo(StreamOutput out) throws IOException {
         out.writeDouble(p);
         out.writeInt(seed);
+        if (out.getTransportVersion().onOrAfter(TransportVersions.RANDOM_AGG_SHARD_SEED)) {
+            out.writeOptionalInt(shardSeed);
+        }
     }
 
     static void recursivelyCheckSubAggs(Collection<AggregationBuilder> builders, Consumer<AggregationBuilder> aggregationCheck) {
@@ -128,7 +143,7 @@ public class RandomSamplerAggregationBuilder extends AbstractAggregationBuilder<
                 );
             }
         });
-        return new RandomSamplerAggregatorFactory(name, seed, p, context, parent, subfactoriesBuilder, metadata);
+        return new RandomSamplerAggregatorFactory(name, seed, shardSeed, p, context, parent, subfactoriesBuilder, metadata);
     }
 
     @Override
@@ -136,6 +151,9 @@ public class RandomSamplerAggregationBuilder extends AbstractAggregationBuilder<
         builder.startObject();
         builder.field(PROBABILITY.getPreferredName(), p);
         builder.field(SEED.getPreferredName(), seed);
+        if (shardSeed != null) {
+            builder.field(SHARD_SEED.getPreferredName(), shardSeed);
+        }
         builder.endObject();
         return null;
     }
@@ -162,7 +180,7 @@ public class RandomSamplerAggregationBuilder extends AbstractAggregationBuilder<
 
     @Override
     public int hashCode() {
-        return Objects.hash(super.hashCode(), p, seed);
+        return Objects.hash(super.hashCode(), p, seed, shardSeed);
     }
 
     @Override
@@ -171,6 +189,6 @@ public class RandomSamplerAggregationBuilder extends AbstractAggregationBuilder<
         if (obj == null || getClass() != obj.getClass()) return false;
         if (super.equals(obj) == false) return false;
         RandomSamplerAggregationBuilder other = (RandomSamplerAggregationBuilder) obj;
-        return Objects.equals(p, other.p) && Objects.equals(seed, other.seed);
+        return Objects.equals(p, other.p) && Objects.equals(seed, other.seed) && Objects.equals(shardSeed, other.shardSeed);
     }
 }

+ 5 - 1
server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplerAggregator.java

@@ -30,12 +30,14 @@ import java.util.Map;
 public class RandomSamplerAggregator extends BucketsAggregator implements SingleBucketAggregator {
 
     private final int seed;
+    private final Integer shardSeed;
     private final double probability;
     private final CheckedSupplier<Weight, IOException> weightSupplier;
 
     RandomSamplerAggregator(
         String name,
         int seed,
+        Integer shardSeed,
         double probability,
         CheckedSupplier<Weight, IOException> weightSupplier,
         AggregatorFactories factories,
@@ -53,6 +55,7 @@ public class RandomSamplerAggregator extends BucketsAggregator implements Single
             );
         }
         this.weightSupplier = weightSupplier;
+        this.shardSeed = shardSeed;
     }
 
     @Override
@@ -63,6 +66,7 @@ public class RandomSamplerAggregator extends BucketsAggregator implements Single
                 name,
                 bucketDocCount(owningBucketOrd),
                 seed,
+                shardSeed,
                 probability,
                 subAggregationResults,
                 metadata()
@@ -72,7 +76,7 @@ public class RandomSamplerAggregator extends BucketsAggregator implements Single
 
     @Override
     public InternalAggregation buildEmptyAggregation() {
-        return new InternalRandomSampler(name, 0, seed, probability, buildEmptySubAggregations(), metadata());
+        return new InternalRandomSampler(name, 0, seed, shardSeed, probability, buildEmptySubAggregations(), metadata());
     }
 
     /**

+ 21 - 3
server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplerAggregatorFactory.java

@@ -26,6 +26,7 @@ import java.util.Optional;
 public class RandomSamplerAggregatorFactory extends AggregatorFactory {
 
     private final int seed;
+    private final Integer shardSeed;
     private final double probability;
     private final SamplingContext samplingContext;
     private Weight weight;
@@ -33,6 +34,7 @@ public class RandomSamplerAggregatorFactory extends AggregatorFactory {
     RandomSamplerAggregatorFactory(
         String name,
         int seed,
+        Integer shardSeed,
         double probability,
         AggregationContext context,
         AggregatorFactory parent,
@@ -42,7 +44,8 @@ public class RandomSamplerAggregatorFactory extends AggregatorFactory {
         super(name, context, parent, subFactories, metadata);
         this.probability = probability;
         this.seed = seed;
-        this.samplingContext = new SamplingContext(probability, seed);
+        this.samplingContext = new SamplingContext(probability, seed, shardSeed);
+        this.shardSeed = shardSeed;
     }
 
     @Override
@@ -53,7 +56,18 @@ public class RandomSamplerAggregatorFactory extends AggregatorFactory {
     @Override
     public Aggregator createInternal(Aggregator parent, CardinalityUpperBound cardinality, Map<String, Object> metadata)
         throws IOException {
-        return new RandomSamplerAggregator(name, seed, probability, this::getWeight, factories, context, parent, cardinality, metadata);
+        return new RandomSamplerAggregator(
+            name,
+            seed,
+            shardSeed,
+            probability,
+            this::getWeight,
+            factories,
+            context,
+            parent,
+            cardinality,
+            metadata
+        );
     }
 
     /**
@@ -66,7 +80,11 @@ public class RandomSamplerAggregatorFactory extends AggregatorFactory {
      */
     private Weight getWeight() throws IOException {
         if (weight == null) {
-            RandomSamplingQuery query = new RandomSamplingQuery(probability, seed, context.shardRandomSeed());
+            RandomSamplingQuery query = new RandomSamplingQuery(
+                probability,
+                seed,
+                shardSeed == null ? context.shardRandomSeed() : shardSeed
+            );
             BooleanQuery booleanQuery = new BooleanQuery.Builder().add(query, BooleanClause.Occur.FILTER)
                 .add(context.query(), BooleanClause.Occur.FILTER)
                 .build();

+ 9 - 6
server/src/main/java/org/elasticsearch/search/aggregations/support/SamplingContext.java

@@ -20,8 +20,9 @@ import java.util.Optional;
 /**
  * This provides information around the current sampling context for aggregations
  */
-public record SamplingContext(double probability, int seed) {
-    public static final SamplingContext NONE = new SamplingContext(1.0, 0);
+public record SamplingContext(double probability, int seed, Integer shardSeed) {
+
+    public static final SamplingContext NONE = new SamplingContext(1.0, 0, null);
 
     public boolean isSampled() {
         return probability < 1.0;
@@ -97,20 +98,22 @@ public record SamplingContext(double probability, int seed) {
         }
         BooleanQuery.Builder queryBuilder = new BooleanQuery.Builder();
         queryBuilder.add(rewritten, BooleanClause.Occur.FILTER);
-        queryBuilder.add(new RandomSamplingQuery(probability(), seed(), context.shardRandomSeed()), BooleanClause.Occur.FILTER);
+        queryBuilder.add(
+            new RandomSamplingQuery(probability(), seed(), shardSeed == null ? context.shardRandomSeed() : shardSeed),
+            BooleanClause.Occur.FILTER
+        );
         return queryBuilder.build();
     }
 
     /**
      * @param context The current aggregation context
      * @return the sampling query if the sampling context indicates that sampling is required
-     * @throws IOException thrown on query build failure
      */
-    public Optional<Query> buildSamplingQueryIfNecessary(AggregationContext context) throws IOException {
+    public Optional<Query> buildSamplingQueryIfNecessary(AggregationContext context) {
         if (isSampled() == false) {
             return Optional.empty();
         }
-        return Optional.of(new RandomSamplingQuery(probability(), seed(), context.shardRandomSeed()));
+        return Optional.of(new RandomSamplingQuery(probability(), seed(), shardSeed == null ? context.shardRandomSeed() : shardSeed));
     }
 
 }

+ 3 - 0
server/src/test/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplerAggregationBuilderTests.java

@@ -19,6 +19,9 @@ public class RandomSamplerAggregationBuilderTests extends BaseAggregationTestCas
         if (randomBoolean()) {
             builder.setSeed(randomInt());
         }
+        if (randomBoolean()) {
+            builder.setShardSeed(randomInt());
+        }
         builder.setProbability(randomFrom(1.0, randomDoubleBetween(0.0, 0.5, false)));
         builder.subAggregation(AggregationBuilders.max(randomAlphaOfLength(10)).field(randomAlphaOfLength(10)));
         return builder;

+ 2 - 3
server/src/test/java/org/elasticsearch/search/aggregations/support/SamplingContextTests.java

@@ -14,10 +14,9 @@ import static org.hamcrest.Matchers.closeTo;
 import static org.hamcrest.Matchers.equalTo;
 
 public class SamplingContextTests extends ESTestCase {
-    protected static final int NUMBER_OF_TEST_RUNS = 20;
 
     private static SamplingContext randomContext() {
-        return new SamplingContext(randomDoubleBetween(1e-6, 0.1, false), randomInt());
+        return new SamplingContext(randomDoubleBetween(1e-6, 0.1, false), randomInt(), randomBoolean() ? null : randomInt());
     }
 
     public void testScaling() {
@@ -41,7 +40,7 @@ public class SamplingContextTests extends ESTestCase {
     }
 
     public void testNoScaling() {
-        SamplingContext samplingContext = new SamplingContext(1.0, randomInt());
+        SamplingContext samplingContext = new SamplingContext(1.0, randomInt(), randomBoolean() ? null : randomInt());
         long randomLong = randomLong();
         double randomDouble = randomDouble();
         assertThat(randomLong, equalTo(samplingContext.scaleDown(randomLong)));

+ 5 - 1
test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java

@@ -1114,7 +1114,11 @@ public abstract class AggregatorTestCase extends ESTestCase {
                         // We should make sure if the builder says it supports sampling, that the internal aggregations returned override
                         // finalizeSampling
                         if (aggregationBuilder.supportsSampling()) {
-                            SamplingContext randomSamplingContext = new SamplingContext(randomDoubleBetween(1e-8, 0.1, false), randomInt());
+                            SamplingContext randomSamplingContext = new SamplingContext(
+                                randomDoubleBetween(1e-8, 0.1, false),
+                                randomInt(),
+                                randomBoolean() ? null : randomInt()
+                            );
                             InternalAggregation sampledResult = internalAggregation.finalizeSampling(randomSamplingContext);
                             assertThat(sampledResult.getClass(), equalTo(internalAggregation.getClass()));
                         }

+ 5 - 1
test/framework/src/main/java/org/elasticsearch/test/InternalAggregationTestCase.java

@@ -283,7 +283,11 @@ public abstract class InternalAggregationTestCase<T extends InternalAggregation>
         doAssertReducedMultiBucketConsumer(reduced, bucketConsumer);
         assertReduced(reduced, inputs.toReduce());
         if (supportsSampling()) {
-            SamplingContext randomContext = new SamplingContext(randomDoubleBetween(1e-8, 0.1, false), randomInt());
+            SamplingContext randomContext = new SamplingContext(
+                randomDoubleBetween(1e-8, 0.1, false),
+                randomInt(),
+                randomBoolean() ? null : randomInt()
+            );
             @SuppressWarnings("unchecked")
             T sampled = (T) reduced.finalizeSampling(randomContext);
             assertSampled(sampled, reduced, randomContext);