Преглед изворни кода

[ML] Refactor doc value format into ExtractedField (#35053)

This commit moves the knowledge of which doc value format
to be used down to the `ExtractedField` instead of being
in the data extractor.
Dimitris Athanasiou пре 7 година
родитељ
комит
d85a654ebb

+ 12 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/ExtractedField.java

@@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.datafeed.extractor.scroll;
 
 import org.elasticsearch.common.document.DocumentField;
 import org.elasticsearch.search.SearchHit;
+import org.elasticsearch.search.fetch.subphase.DocValueFieldsContext;
 import org.joda.time.base.BaseDateTime;
 
 import java.util.List;
@@ -51,6 +52,10 @@ abstract class ExtractedField {
 
     public abstract Object[] value(SearchHit hit);
 
+    public String getDocValueFormat() {
+        return DocValueFieldsContext.USE_DEFAULT_FORMAT;
+    }
+
     public static ExtractedField newTimeField(String name, ExtractionMethod extractionMethod) {
         if (extractionMethod == ExtractionMethod.SOURCE) {
             throw new IllegalArgumentException("time field cannot be extracted from source");
@@ -93,6 +98,8 @@ abstract class ExtractedField {
 
     private static class TimeField extends FromFields {
 
+        private static final String EPOCH_MILLIS_FORMAT = "epoch_millis";
+
         TimeField(String name, ExtractionMethod extractionMethod) {
             super(name, name, extractionMethod);
         }
@@ -112,6 +119,11 @@ abstract class ExtractedField {
             }
             return value;
         }
+
+        @Override
+        public String getDocValueFormat() {
+            return EPOCH_MILLIS_FORMAT;
+        }
     }
 
     private static class FromSource extends ExtractedField {

+ 6 - 11
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/ExtractedFields.java

@@ -31,7 +31,7 @@ class ExtractedFields {
 
     private final ExtractedField timeField;
     private final List<ExtractedField> allFields;
-    private final String[] docValueFields;
+    private final List<ExtractedField> docValueFields;
     private final String[] sourceFields;
 
     ExtractedFields(ExtractedField timeField, List<ExtractedField> allFields) {
@@ -41,7 +41,8 @@ class ExtractedFields {
         this.timeField = Objects.requireNonNull(timeField);
         this.allFields = Collections.unmodifiableList(allFields);
         this.docValueFields = filterFields(ExtractedField.ExtractionMethod.DOC_VALUE, allFields);
-        this.sourceFields = filterFields(ExtractedField.ExtractionMethod.SOURCE, allFields);
+        this.sourceFields = filterFields(ExtractedField.ExtractionMethod.SOURCE, allFields).stream().map(ExtractedField::getName)
+            .toArray(String[]::new);
     }
 
     public List<ExtractedField> getAllFields() {
@@ -52,18 +53,12 @@ class ExtractedFields {
         return sourceFields;
     }
 
-    public String[] getDocValueFields() {
+    public List<ExtractedField> getDocValueFields() {
         return docValueFields;
     }
 
-    private static String[] filterFields(ExtractedField.ExtractionMethod method, List<ExtractedField> fields) {
-        List<String> result = new ArrayList<>();
-        for (ExtractedField field : fields) {
-            if (field.getExtractionMethod() == method) {
-                result.add(field.getName());
-            }
-        }
-        return result.toArray(new String[result.size()]);
+    private static List<ExtractedField> filterFields(ExtractedField.ExtractionMethod method, List<ExtractedField> fields) {
+        return fields.stream().filter(field -> field.getExtractionMethod() == method).collect(Collectors.toList());
     }
 
     public String timeField() {

+ 2 - 8
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/ScrollDataExtractor.java

@@ -19,7 +19,6 @@ import org.elasticsearch.client.Client;
 import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.search.SearchHit;
 import org.elasticsearch.search.fetch.StoredFieldsContext;
-import org.elasticsearch.search.fetch.subphase.DocValueFieldsContext;
 import org.elasticsearch.search.sort.SortOrder;
 import org.elasticsearch.xpack.core.ClientHelper;
 import org.elasticsearch.xpack.core.ml.datafeed.extractor.DataExtractor;
@@ -44,7 +43,6 @@ class ScrollDataExtractor implements DataExtractor {
 
     private static final Logger LOGGER = LogManager.getLogger(ScrollDataExtractor.class);
     private static final TimeValue SCROLL_TIMEOUT = new TimeValue(30, TimeUnit.MINUTES);
-    private static final String EPOCH_MILLIS_FORMAT = "epoch_millis";
 
     private final Client client;
     private final ScrollDataExtractorContext context;
@@ -112,12 +110,8 @@ class ScrollDataExtractor implements DataExtractor {
                 .setQuery(ExtractorUtils.wrapInTimeRangeQuery(
                         context.query, context.extractedFields.timeField(), start, context.end));
 
-        for (String docValueField : context.extractedFields.getDocValueFields()) {
-            if (docValueField.equals(context.extractedFields.timeField())) {
-                searchRequestBuilder.addDocValueField(docValueField, EPOCH_MILLIS_FORMAT);
-            } else {
-                searchRequestBuilder.addDocValueField(docValueField, DocValueFieldsContext.USE_DEFAULT_FORMAT);
-            }
+        for (ExtractedField docValueField : context.extractedFields.getDocValueFields()) {
+            searchRequestBuilder.addDocValueField(docValueField.getName(), docValueField.getDocValueFormat());
         }
         String[] sourceFields = context.extractedFields.getSourceFields();
         if (sourceFields.length == 0) {

+ 11 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/ExtractedFieldTests.java

@@ -6,6 +6,7 @@
 package org.elasticsearch.xpack.ml.datafeed.extractor.scroll;
 
 import org.elasticsearch.search.SearchHit;
+import org.elasticsearch.search.fetch.subphase.DocValueFieldsContext;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.ml.test.SearchHitBuilder;
 import org.joda.time.DateTime;
@@ -140,4 +141,14 @@ public class ExtractedFieldTests extends ESTestCase {
         assertThat(field.getName(), equalTo("b"));
         assertThat(field.value(hit), equalTo(new Integer[] { 2 }));
     }
+
+    public void testGetDocValueFormat() {
+        for (ExtractedField.ExtractionMethod method : ExtractedField.ExtractionMethod.values()) {
+            assertThat(ExtractedField.newField("f", method).getDocValueFormat(), equalTo(DocValueFieldsContext.USE_DEFAULT_FORMAT));
+        }
+        assertThat(ExtractedField.newTimeField("doc_value_time", ExtractedField.ExtractionMethod.DOC_VALUE).getDocValueFormat(),
+            equalTo("epoch_millis"));
+        assertThat(ExtractedField.newTimeField("source_time", ExtractedField.ExtractionMethod.SCRIPT_FIELD).getDocValueFormat(),
+            equalTo("epoch_millis"));
+    }
 }

+ 13 - 8
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/ExtractedFieldsTests.java

@@ -11,6 +11,7 @@ import org.elasticsearch.action.fieldcaps.FieldCapabilitiesResponse;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.search.SearchHit;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
+import org.elasticsearch.search.fetch.subphase.DocValueFieldsContext;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig;
 import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
@@ -43,7 +44,8 @@ public class ExtractedFieldsTests extends ESTestCase {
 
         assertThat(extractedFields.getAllFields(), equalTo(Arrays.asList(timeField)));
         assertThat(extractedFields.timeField(), equalTo("time"));
-        assertThat(extractedFields.getDocValueFields(), equalTo(new String[] { timeField.getName() }));
+        assertThat(extractedFields.getDocValueFields().stream().map(ExtractedField::getName).toArray(String[]::new),
+            equalTo(new String[] { timeField.getName() }));
         assertThat(extractedFields.getSourceFields().length, equalTo(0));
     }
 
@@ -59,7 +61,8 @@ public class ExtractedFieldsTests extends ESTestCase {
 
         assertThat(extractedFields.getAllFields().size(), equalTo(7));
         assertThat(extractedFields.timeField(), equalTo("time"));
-        assertThat(extractedFields.getDocValueFields(), equalTo(new String[] {"time", "doc1", "doc2"}));
+        assertThat(extractedFields.getDocValueFields().stream().map(ExtractedField::getName).toArray(String[]::new),
+            equalTo(new String[] {"time", "doc1", "doc2"}));
         assertThat(extractedFields.getSourceFields(), equalTo(new String[] {"src1", "src2"}));
     }
 
@@ -138,9 +141,11 @@ public class ExtractedFieldsTests extends ESTestCase {
                 fieldCapabilitiesResponse);
 
         assertThat(extractedFields.timeField(), equalTo("time"));
-        assertThat(extractedFields.getDocValueFields().length, equalTo(2));
-        assertThat(extractedFields.getDocValueFields()[0], equalTo("time"));
-        assertThat(extractedFields.getDocValueFields()[1], equalTo("value"));
+        assertThat(extractedFields.getDocValueFields().size(), equalTo(2));
+        assertThat(extractedFields.getDocValueFields().get(0).getName(), equalTo("time"));
+        assertThat(extractedFields.getDocValueFields().get(0).getDocValueFormat(), equalTo("epoch_millis"));
+        assertThat(extractedFields.getDocValueFields().get(1).getName(), equalTo("value"));
+        assertThat(extractedFields.getDocValueFields().get(1).getDocValueFormat(), equalTo(DocValueFieldsContext.USE_DEFAULT_FORMAT));
         assertThat(extractedFields.getSourceFields().length, equalTo(1));
         assertThat(extractedFields.getSourceFields()[0], equalTo("airline"));
         assertThat(extractedFields.getAllFields().size(), equalTo(4));
@@ -174,9 +179,9 @@ public class ExtractedFieldsTests extends ESTestCase {
                 fieldCapabilitiesResponse);
 
         assertThat(extractedFields.timeField(), equalTo("time"));
-        assertThat(extractedFields.getDocValueFields().length, equalTo(2));
-        assertThat(extractedFields.getDocValueFields()[0], equalTo("time"));
-        assertThat(extractedFields.getDocValueFields()[1], equalTo("airport.keyword"));
+        assertThat(extractedFields.getDocValueFields().size(), equalTo(2));
+        assertThat(extractedFields.getDocValueFields().get(0).getName(), equalTo("time"));
+        assertThat(extractedFields.getDocValueFields().get(1).getName(), equalTo("airport.keyword"));
         assertThat(extractedFields.getSourceFields().length, equalTo(1));
         assertThat(extractedFields.getSourceFields()[0], equalTo("airline"));
         assertThat(extractedFields.getAllFields().size(), equalTo(3));