Browse Source

Added unit tests for MatrixStatsAggregator

Martijn van Groningen 8 years ago
parent
commit
34093735e3

+ 4 - 0
modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/InternalMatrixStats.java

@@ -139,6 +139,10 @@ public class InternalMatrixStats extends InternalAggregation implements MatrixSt
         return results.getCorrelation(fieldX, fieldY);
     }
 
+    RunningStats getStats() {
+        return stats;
+    }
+
     MatrixStatsResults getResults() {
         return results;
     }

+ 3 - 3
modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregator.java

@@ -41,14 +41,14 @@ import java.util.Map;
 /**
  * Metric Aggregation for computing the pearson product correlation coefficient between multiple fields
  **/
-public class MatrixStatsAggregator extends MetricsAggregator {
+final class MatrixStatsAggregator extends MetricsAggregator {
     /** Multiple ValuesSource with field names */
-    final NumericMultiValuesSource valuesSources;
+    private final NumericMultiValuesSource valuesSources;
 
     /** array of descriptive stats, per shard, needed to compute the correlation */
     ObjectArray<RunningStats> stats;
 
-    public MatrixStatsAggregator(String name, Map<String, ValuesSource.Numeric> valuesSources, SearchContext context,
+    MatrixStatsAggregator(String name, Map<String, ValuesSource.Numeric> valuesSources, SearchContext context,
                                  Aggregator parent, MultiValueMode multiValueMode, List<PipelineAggregator> pipelineAggregators,
                                  Map<String,Object> metaData) throws IOException {
         super(name, context, parent, pipelineAggregators, metaData);

+ 2 - 2
modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregatorFactory.java

@@ -32,12 +32,12 @@ import java.io.IOException;
 import java.util.List;
 import java.util.Map;
 
-public class MatrixStatsAggregatorFactory
+final class MatrixStatsAggregatorFactory
     extends MultiValuesSourceAggregatorFactory<ValuesSource.Numeric, MatrixStatsAggregatorFactory> {
 
     private final MultiValueMode multiValueMode;
 
-    public MatrixStatsAggregatorFactory(String name,
+    MatrixStatsAggregatorFactory(String name,
             Map<String, ValuesSourceConfig<ValuesSource.Numeric>> configs, MultiValueMode multiValueMode,
             SearchContext context, AggregatorFactory<?> parent, AggregatorFactories.Builder subFactoriesBuilder,
             Map<String, Object> metaData) throws IOException {

+ 96 - 0
modules/aggs-matrix-stats/src/test/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregatorTests.java

@@ -0,0 +1,96 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.elasticsearch.search.aggregations.matrix.stats;
+
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.Field;
+import org.apache.lucene.document.SortedNumericDocValuesField;
+import org.apache.lucene.document.StringField;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.RandomIndexWriter;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.MatchAllDocsQuery;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.util.NumericUtils;
+import org.elasticsearch.index.mapper.MappedFieldType;
+import org.elasticsearch.index.mapper.NumberFieldMapper;
+import org.elasticsearch.search.aggregations.AggregatorTestCase;
+
+import java.util.Arrays;
+import java.util.Collections;
+
+public class MatrixStatsAggregatorTests extends AggregatorTestCase {
+
+    public void testNoData() throws Exception {
+        MappedFieldType ft =
+            new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.DOUBLE);
+        ft.setName("field");
+
+        try (Directory directory = newDirectory();
+            RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
+            if (randomBoolean()) {
+                indexWriter.addDocument(Collections.singleton(new StringField("another_field", "value", Field.Store.NO)));
+            }
+            try (IndexReader reader = indexWriter.getReader()) {
+                IndexSearcher searcher = new IndexSearcher(reader);
+                MatrixStatsAggregationBuilder aggBuilder = new MatrixStatsAggregationBuilder("my_agg")
+                    .fields(Collections.singletonList("field"));
+                InternalMatrixStats stats = search(searcher, new MatchAllDocsQuery(), aggBuilder, ft);
+                assertNull(stats.getStats());
+            }
+        }
+    }
+
+    public void testTwoFields() throws Exception {
+        String fieldA = "a";
+        MappedFieldType ftA = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.DOUBLE);
+        ftA.setName(fieldA);
+        String fieldB = "b";
+        MappedFieldType ftB = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.DOUBLE);
+        ftB.setName(fieldB);
+
+        try (Directory directory = newDirectory();
+            RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
+
+            int numDocs = scaledRandomIntBetween(8192, 16384);
+            Double[] fieldAValues = new Double[numDocs];
+            Double[] fieldBValues = new Double[numDocs];
+            for (int docId = 0; docId < numDocs; docId++) {
+                Document document = new Document();
+                fieldAValues[docId] = randomDouble();
+                document.add(new SortedNumericDocValuesField(fieldA, NumericUtils.doubleToSortableLong(fieldAValues[docId])));
+
+                fieldBValues[docId] = randomDouble();
+                document.add(new SortedNumericDocValuesField(fieldB, NumericUtils.doubleToSortableLong(fieldBValues[docId])));
+                indexWriter.addDocument(document);
+            }
+
+            MultiPassStats multiPassStats = new MultiPassStats(fieldA, fieldB);
+            multiPassStats.computeStats(Arrays.asList(fieldAValues), Arrays.asList(fieldBValues));
+            try (IndexReader reader = indexWriter.getReader()) {
+                IndexSearcher searcher = new IndexSearcher(reader);
+                MatrixStatsAggregationBuilder aggBuilder = new MatrixStatsAggregationBuilder("my_agg")
+                    .fields(Arrays.asList(fieldA, fieldB));
+                InternalMatrixStats stats = search(searcher, new MatchAllDocsQuery(), aggBuilder, ftA, ftB);
+                multiPassStats.assertNearlyEqual(new MatrixStatsResults(stats.getStats()));
+            }
+        }
+    }
+
+}

+ 3 - 0
test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java

@@ -110,6 +110,9 @@ public abstract class AggregatorTestCase extends ESTestCase {
 
         QueryShardContext queryShardContext = queryShardContextMock(mapperService, fieldTypes, circuitBreakerService);
         when(searchContext.getQueryShardContext()).thenReturn(queryShardContext);
+        for (MappedFieldType fieldType : fieldTypes) {
+            when(searchContext.smartNameFieldType(fieldType.name())).thenReturn(fieldType);
+        }
 
         return aggregationBuilder.build(searchContext, null);
     }