Explorar o código

Bug fix to allow access to top level params in reduce script (#42096)

Jack Conradson %!s(int64=6) %!d(string=hai) anos
pai
achega
c87ea81573

+ 3 - 6
server/src/main/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregatorFactory.java

@@ -89,7 +89,7 @@ class ScriptedMetricAggregatorFactory extends AggregatorFactory<ScriptedMetricAg
         final ScriptedMetricAggContexts.CombineScript combineScript = this.combineScript.newInstance(
             mergeParams(aggParams, combineScriptParams), aggState);
 
-        final Script reduceScript = deepCopyScript(this.reduceScript, context);
+        final Script reduceScript = deepCopyScript(this.reduceScript, context, aggParams);
         if (initScript != null) {
             initScript.execute();
             CollectionUtils.ensureNoSelfReferences(aggState, "Scripted metric aggs init script");
@@ -99,12 +99,9 @@ class ScriptedMetricAggregatorFactory extends AggregatorFactory<ScriptedMetricAg
                 pipelineAggregators, metaData);
     }
 
-    private static Script deepCopyScript(Script script, SearchContext context) {
+    private static Script deepCopyScript(Script script, SearchContext context, Map<String, Object> aggParams) {
         if (script != null) {
-            Map<String, Object> params = script.getParams();
-            if (params != null) {
-                params = deepCopyParams(params, context);
-            }
+            Map<String, Object> params = mergeParams(aggParams, deepCopyParams(script.getParams(), context));
             return new Script(script.getType(), script.getLang(), script.getIdOrCode(), params);
         } else {
             return null;

+ 37 - 4
server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregatorTests.java

@@ -71,7 +71,9 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
     private static final Script MAP_SCRIPT_PARAMS = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "mapScriptParams",
             Collections.singletonMap("itemValue", 12));
     private static final Script COMBINE_SCRIPT_PARAMS = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "combineScriptParams",
-            Collections.singletonMap("divisor", 4));
+            Collections.singletonMap("multiplier", 4));
+    private static final Script REDUCE_SCRIPT_PARAMS = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "reduceScriptParams",
+            Collections.singletonMap("additional", 2));
     private static final String CONFLICTING_PARAM_NAME = "initialValue";
 
     private static final Script INIT_SCRIPT_SELF_REF = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "initScriptSelfRef",
@@ -140,9 +142,14 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
         });
         SCRIPTS.put("combineScriptParams", params -> {
             Map<String, Object> state = (Map<String, Object>) params.get("state");
-            int divisor = ((Integer) params.get("divisor"));
-            return ((List<Integer>) state.get("collector")).stream().mapToInt(Integer::intValue).map(i -> i / divisor).sum();
+            int multiplier = ((Integer) params.get("multiplier"));
+            return ((List<Integer>) state.get("collector")).stream().mapToInt(Integer::intValue).map(i -> i * multiplier).sum();
         });
+        SCRIPTS.put("reduceScriptParams", params ->
+            ((List)params.get("states")).stream().mapToInt(i -> (int)i).sum() +
+                    (int)params.get("aggs_param") + (int)params.get("additional") -
+                    ((List)params.get("states")).size()*24*4
+        );
 
         SCRIPTS.put("initScriptSelfRef", params -> {
             Map<String, Object> state = (Map<String, Object>) params.get("state");
@@ -279,7 +286,33 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
                 ScriptedMetric scriptedMetric = search(newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder);
 
                 // The result value depends on the script params.
-                assertEquals(306, scriptedMetric.aggregation());
+                assertEquals(4896, scriptedMetric.aggregation());
+            }
+        }
+    }
+
+    public void testAggParamsPassedToReduceScript() throws IOException {
+        MockScriptEngine scriptEngine = new MockScriptEngine(MockScriptEngine.NAME, SCRIPTS, Collections.emptyMap());
+        Map<String, ScriptEngine> engines = Collections.singletonMap(scriptEngine.getType(), scriptEngine);
+        ScriptService scriptService =  new ScriptService(Settings.EMPTY, engines, ScriptModule.CORE_CONTEXTS);
+
+        try (Directory directory = newDirectory()) {
+            try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
+                for (int i = 0; i < 100; i++) {
+                    indexWriter.addDocument(singleton(new SortedNumericDocValuesField("number", i)));
+                }
+            }
+
+            try (IndexReader indexReader = DirectoryReader.open(directory)) {
+                ScriptedMetricAggregationBuilder aggregationBuilder = new ScriptedMetricAggregationBuilder(AGG_NAME);
+                aggregationBuilder.params(Collections.singletonMap("aggs_param", 1))
+                        .initScript(INIT_SCRIPT_PARAMS).mapScript(MAP_SCRIPT_PARAMS)
+                        .combineScript(COMBINE_SCRIPT_PARAMS).reduceScript(REDUCE_SCRIPT_PARAMS);
+                ScriptedMetric scriptedMetric = searchAndReduce(
+                        newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder, 0, scriptService);
+
+                // The result value depends on the script params.
+                assertEquals(4803, scriptedMetric.aggregation());
             }
         }
     }