Browse Source

Add scripting support to AggregatorTestCase (#43494)

This refactors AggregatorTestCase to allow testing mock scripts.
The main change is to QueryShardContext.  This was previously mocked,
but to get the ScriptService you have to invoke a final method
which can't be mocked.

Instead, we just create a mostly-empty QueryShardContext and populate
the fields that are needed for testing.  It also introduces a few
new helper methods that can be overridden to change the default
behavior a bit.

Most tests should be able to override getMockScriptService() to supply
a ScriptService to the context, which is later used by the aggs.
More complicated tests can override queryShardContextMock() as before.

Adds a test to MaxAggregatorTests to test out the new functionality.
Zachary Tong 6 years ago
parent
commit
a814135177

+ 39 - 0
server/src/test/java/org/elasticsearch/search/aggregations/metrics/MaxAggregatorTests.java

@@ -46,8 +46,15 @@ import org.apache.lucene.util.Bits;
 import org.apache.lucene.util.FutureArrays;
 import org.elasticsearch.common.CheckedConsumer;
 import org.elasticsearch.common.collect.Tuple;
+import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.index.mapper.MappedFieldType;
 import org.elasticsearch.index.mapper.NumberFieldMapper;
+import org.elasticsearch.script.MockScriptEngine;
+import org.elasticsearch.script.Script;
+import org.elasticsearch.script.ScriptEngine;
+import org.elasticsearch.script.ScriptModule;
+import org.elasticsearch.script.ScriptService;
+import org.elasticsearch.script.ScriptType;
 import org.elasticsearch.search.aggregations.AggregatorTestCase;
 import org.elasticsearch.search.aggregations.support.AggregationInspectionHelper;
 
@@ -57,6 +64,7 @@ import java.util.Arrays;
 import java.util.Collections;
 import java.util.Comparator;
 import java.util.List;
+import java.util.Map;
 import java.util.function.Consumer;
 import java.util.function.Function;
 import java.util.function.Supplier;
@@ -65,6 +73,19 @@ import static java.util.Collections.singleton;
 import static org.hamcrest.Matchers.equalTo;
 
 public class MaxAggregatorTests extends AggregatorTestCase {
+    private final String SCRIPT_NAME = "script_name";
+    private final long SCRIPT_VALUE = 19L;
+
+    @Override
+    protected ScriptService getMockScriptService() {
+        MockScriptEngine scriptEngine = new MockScriptEngine(MockScriptEngine.NAME,
+            Collections.singletonMap(SCRIPT_NAME, script -> SCRIPT_VALUE), // return 19 from script
+            Collections.emptyMap());
+        Map<String, ScriptEngine> engines = Collections.singletonMap(scriptEngine.getType(), scriptEngine);
+
+        return new ScriptService(Settings.EMPTY, engines, ScriptModule.CORE_CONTEXTS);
+    }
+
     public void testNoDocs() throws IOException {
         testCase(new MatchAllDocsQuery(), iw -> {
             // Intentionally not writing any docs
@@ -147,6 +168,23 @@ public class MaxAggregatorTests extends AggregatorTestCase {
         }, null);
     }
 
+    public void testScript() throws IOException {
+        MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.INTEGER);
+        fieldType.setName("number");
+
+        MaxAggregationBuilder aggregationBuilder = new MaxAggregationBuilder("_name")
+            .field("number")
+            .script(new Script(ScriptType.INLINE, MockScriptEngine.NAME, SCRIPT_NAME, Collections.emptyMap()));
+
+        testCase(aggregationBuilder, new DocValuesFieldExistsQuery("number"), iw -> {
+            iw.addDocument(singleton(new NumericDocValuesField("number", 7)));
+            iw.addDocument(singleton(new NumericDocValuesField("number", 1)));
+        }, max -> {
+            assertEquals(max.getValue(), SCRIPT_VALUE, 0); // Note this is the script value (19L), not the doc values above
+            assertTrue(AggregationInspectionHelper.hasValue(max));
+        }, fieldType);
+    }
+
     private void testCase(Query query,
                           CheckedConsumer<RandomIndexWriter, IOException> buildIndex,
                           Consumer<InternalMax> verify) throws IOException {
@@ -282,4 +320,5 @@ public class MaxAggregatorTests extends AggregatorTestCase {
         });
         assertTrue(seen[0]);
     }
+
 }

+ 5 - 2
server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricAggregatorTests.java

@@ -26,8 +26,10 @@ import org.apache.lucene.index.RandomIndexWriter;
 import org.apache.lucene.search.MatchAllDocsQuery;
 import org.apache.lucene.store.Directory;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.index.IndexSettings;
 import org.elasticsearch.index.mapper.MapperService;
 import org.elasticsearch.index.query.QueryShardContext;
+import org.elasticsearch.indices.breaker.CircuitBreakerService;
 import org.elasticsearch.script.MockScriptEngine;
 import org.elasticsearch.script.Script;
 import org.elasticsearch.script.ScriptEngine;
@@ -403,11 +405,12 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
      * is final and cannot be mocked
      */
     @Override
-    protected QueryShardContext queryShardContextMock(MapperService mapperService) {
+    protected QueryShardContext queryShardContextMock(MapperService mapperService, IndexSettings indexSettings,
+                                                      CircuitBreakerService circuitBreakerService) {
         MockScriptEngine scriptEngine = new MockScriptEngine(MockScriptEngine.NAME, SCRIPTS, Collections.emptyMap());
         Map<String, ScriptEngine> engines = Collections.singletonMap(scriptEngine.getType(), scriptEngine);
         ScriptService scriptService =  new ScriptService(Settings.EMPTY, engines, ScriptModule.CORE_CONTEXTS);
-        return new QueryShardContext(0, mapperService.getIndexSettings(), null, null, null, mapperService, null, scriptService,
+        return new QueryShardContext(0, indexSettings, null, null, null, mapperService, null, scriptService,
                 xContentRegistry(), writableRegistry(), null, null, System::currentTimeMillis, null);
     }
 }

+ 31 - 17
test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java

@@ -48,6 +48,7 @@ import org.elasticsearch.index.cache.bitset.BitsetFilterCache;
 import org.elasticsearch.index.cache.bitset.BitsetFilterCache.Listener;
 import org.elasticsearch.index.cache.query.DisabledQueryCache;
 import org.elasticsearch.index.engine.Engine;
+import org.elasticsearch.index.fielddata.IndexFieldData;
 import org.elasticsearch.index.fielddata.IndexFieldDataCache;
 import org.elasticsearch.index.fielddata.IndexFieldDataService;
 import org.elasticsearch.index.mapper.ContentPath;
@@ -58,7 +59,6 @@ import org.elasticsearch.index.mapper.MapperService;
 import org.elasticsearch.index.mapper.ObjectMapper;
 import org.elasticsearch.index.mapper.ObjectMapper.Nested;
 import org.elasticsearch.index.query.QueryShardContext;
-import org.elasticsearch.index.query.support.NestedScope;
 import org.elasticsearch.index.shard.IndexShard;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.indices.breaker.CircuitBreakerService;
@@ -83,10 +83,10 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
-import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
+import java.util.function.BiFunction;
 import java.util.function.Function;
 import java.util.stream.Collectors;
 
@@ -157,8 +157,8 @@ public abstract class AggregatorTestCase extends ESTestCase {
         SearchLookup searchLookup = new SearchLookup(mapperService, ifds::getForField);
         when(searchContext.lookup()).thenReturn(searchLookup);
 
-        QueryShardContext queryShardContext = queryShardContextMock(mapperService);
-        when(queryShardContext.getIndexSettings()).thenReturn(indexSettings);
+        QueryShardContext queryShardContext = queryShardContextMock(mapperService, indexSettings, circuitBreakerService);
+
         when(searchContext.getQueryShardContext()).thenReturn(queryShardContext);
         Map<String, MappedFieldType> fieldNameToType = new HashMap<>();
         fieldNameToType.putAll(Arrays.stream(fieldTypes)
@@ -189,16 +189,11 @@ public abstract class AggregatorTestCase extends ESTestCase {
             String fieldName = entry.getKey();
             MappedFieldType fieldType = entry.getValue();
 
-            when(queryShardContext.fieldMapper(fieldName)).thenReturn(fieldType);
+            when(mapperService.fullName(fieldName)).thenReturn(fieldType);
             when(searchContext.smartNameFieldType(fieldName)).thenReturn(fieldType);
         }
 
-        for (MappedFieldType fieldType : new HashSet<>(fieldNameToType.values())) {
-            when(queryShardContext.getForField(fieldType)).then(invocation ->
-                fieldType.fielddataBuilder(mapperService.getIndexSettings().getIndex().getName())
-                    .build(mapperService.getIndexSettings(), fieldType,
-                        new IndexFieldDataCache.None(), circuitBreakerService, mapperService));
-        }
+
     }
 
     protected <A extends Aggregator> A createAggregator(AggregationBuilder aggregationBuilder,
@@ -304,12 +299,31 @@ public abstract class AggregatorTestCase extends ESTestCase {
     /**
      * sub-tests that need a more complex mock can overwrite this
      */
-    protected QueryShardContext queryShardContextMock(MapperService mapperService) {
-        QueryShardContext queryShardContext = mock(QueryShardContext.class);
-        when(queryShardContext.getMapperService()).thenReturn(mapperService);
-        NestedScope nestedScope = new NestedScope();
-        when(queryShardContext.nestedScope()).thenReturn(nestedScope);
-        return queryShardContext;
+    protected QueryShardContext queryShardContextMock(MapperService mapperService, IndexSettings indexSettings,
+                                                      CircuitBreakerService circuitBreakerService) {
+
+        return new QueryShardContext(0, indexSettings, null, null,
+            getIndexFieldDataLookup(mapperService, circuitBreakerService),
+            mapperService, null, getMockScriptService(), xContentRegistry(),
+            writableRegistry(), null, null, System::currentTimeMillis, null);
+    }
+
+    /**
+     * Sub-tests that need a more complex index field data provider can override this
+     */
+    protected BiFunction<MappedFieldType, String, IndexFieldData<?>> getIndexFieldDataLookup(MapperService mapperService,
+                                                                                             CircuitBreakerService circuitBreakerService) {
+        return (fieldType, s) -> fieldType.fielddataBuilder(mapperService.getIndexSettings().getIndex().getName())
+            .build(mapperService.getIndexSettings(), fieldType,
+                new IndexFieldDataCache.None(), circuitBreakerService, mapperService);
+
+    }
+
+    /**
+     * Sub-tests that need scripting can override this method to provide a script service and pre-baked scripts
+     */
+    protected ScriptService getMockScriptService() {
+        return null;
     }
 
     protected <A extends InternalAggregation, C extends Aggregator> A search(IndexSearcher searcher,