Browse Source

Add half float mapping to the scripting fields API (#82294)

This adds the mapped type half float the scripting fields API. This also adds additional methods for 
asDouble to reach parity with old-style doc access.
Jack Conradson 3 years ago
parent
commit
4f37df15f4

+ 5 - 0
docs/changelog/82294.yaml

@@ -0,0 +1,5 @@
+pr: 82294
+summary: Add half float mapping to the scripting fields API
+area: Infra/Scripting
+type: enhancement
+issues: []

+ 9 - 0
modules/lang-painless/src/main/resources/org/elasticsearch/painless/org.elasticsearch.script.fields.txt

@@ -61,6 +61,15 @@ class org.elasticsearch.script.field.ScaledFloatDocValuesField @dynamic_type {
   double get(int, double)
 }
 
+# defaults are cast to float, taking an double facilitates resolution with constants without casting
+class org.elasticsearch.script.field.HalfFloatDocValuesField @dynamic_type {
+  float get(double)
+  float get(int, double)
+  List asDoubles()
+  double asDouble(double)
+  double asDouble(int, double)
+}
+
 # defaults are cast to byte, taking an int facilitates resolution with constants without casting
 class org.elasticsearch.script.field.ByteDocValuesField @dynamic_type {
   byte get(int)

+ 71 - 2
modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/50_script_doc_values.yml

@@ -90,6 +90,7 @@ setup:
               byte: [16, 32, 64, 8, 4]
               double: [3.141592653588, 2.141592653587]
               float: [1.123, 2.234]
+              half_float: [1.123, 2.234]
               scaled_float: [-3.5, 2.5]
 
 
@@ -1306,6 +1307,9 @@ setup:
 
 ---
 "half_float":
+    - skip:
+        features: close_to
+
     - do:
         search:
             rest_total_hits_as_int: true
@@ -1315,7 +1319,7 @@ setup:
                     field:
                         script:
                             source: "doc['half_float'].get(0)"
-    - match: { hits.hits.0.fields.field.0: 3.140625 }
+    - close_to: { hits.hits.0.fields.field.0: { value: 3.140625, error: 0.001 } }
 
     - do:
         search:
@@ -1326,7 +1330,72 @@ setup:
                     field:
                         script:
                             source: "doc['half_float'].value"
-    - match: { hits.hits.0.fields.field.0: 3.140625 }
+    - close_to: { hits.hits.0.fields.field.0: { value: 3.140625, error: 0.001 } }
+
+    - do:
+        search:
+          rest_total_hits_as_int: true
+          body:
+            sort: [ { rank: asc } ]
+            script_fields:
+              field:
+                script:
+                  source: "field('half_float').get(0.0)"
+    - close_to: { hits.hits.0.fields.field.0: { value: 3.140625, error: 0.001 } }
+    - match: { hits.hits.1.fields.field.0: 0.0 }
+    - close_to: { hits.hits.2.fields.field.0: { value: 1.123, error: 0.001 } }
+
+    - do:
+        search:
+          rest_total_hits_as_int: true
+          body:
+            sort: [ { rank: asc } ]
+            script_fields:
+              field:
+                script:
+                  source: "/* avoid stash */ $('half_float', 0.0)"
+    - close_to: { hits.hits.0.fields.field.0: { value: 3.140625, error: 0.001 } }
+    - match: { hits.hits.1.fields.field.0: 0.0 }
+    - close_to: { hits.hits.2.fields.field.0: { value: 1.123, error: 0.001 } }
+
+    - do:
+        search:
+          rest_total_hits_as_int: true
+          body:
+            sort: [ { rank: asc } ]
+            script_fields:
+              field:
+                script:
+                  source: "field('half_float').get(1, 0.0)"
+    - match: { hits.hits.0.fields.field.0: 0.0 }
+    - match: { hits.hits.1.fields.field.0: 0.0 }
+    - close_to: { hits.hits.2.fields.field.0: { value: 2.234, error: 0.001 } }
+
+    - do:
+        search:
+          rest_total_hits_as_int: true
+          body:
+            sort: [ { rank: asc } ]
+            script_fields:
+              field:
+                script:
+                  source: "field('half_float').asDouble(0.0)"
+    - close_to: { hits.hits.0.fields.field.0: { value: 3.140625, error: 0.001 } }
+    - match: { hits.hits.1.fields.field.0: 0.0 }
+    - close_to: { hits.hits.2.fields.field.0: { value: 1.123, error: 0.0001 } }
+
+    - do:
+        search:
+          rest_total_hits_as_int: true
+          body:
+            sort: [ { rank: asc } ]
+            script_fields:
+              field:
+                script:
+                  source: "field('half_float').asDouble(1, 0.0)"
+    - match: { hits.hits.0.fields.field.0: 0.0 }
+    - match: { hits.hits.1.fields.field.0: 0.0 }
+    - close_to: { hits.hits.2.fields.field.0: { value: 2.234, error: 0.001 } }
 
 ---
 "scaled_float":

+ 2 - 8
server/src/main/java/org/elasticsearch/index/mapper/NumberFieldMapper.java

@@ -35,8 +35,6 @@ import org.elasticsearch.common.settings.Setting.Property;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.index.fielddata.IndexFieldData;
 import org.elasticsearch.index.fielddata.IndexNumericFieldData.NumericType;
-import org.elasticsearch.index.fielddata.ScriptDocValues.Doubles;
-import org.elasticsearch.index.fielddata.ScriptDocValues.DoublesSupplier;
 import org.elasticsearch.index.fielddata.plain.SortedDoublesIndexFieldData;
 import org.elasticsearch.index.fielddata.plain.SortedNumericIndexFieldData;
 import org.elasticsearch.index.mapper.TimeSeriesParams.MetricType;
@@ -46,9 +44,9 @@ import org.elasticsearch.script.LongFieldScript;
 import org.elasticsearch.script.Script;
 import org.elasticsearch.script.ScriptCompiler;
 import org.elasticsearch.script.field.ByteDocValuesField;
-import org.elasticsearch.script.field.DelegateDocValuesField;
 import org.elasticsearch.script.field.DoubleDocValuesField;
 import org.elasticsearch.script.field.FloatDocValuesField;
+import org.elasticsearch.script.field.HalfFloatDocValuesField;
 import org.elasticsearch.script.field.IntegerDocValuesField;
 import org.elasticsearch.script.field.LongDocValuesField;
 import org.elasticsearch.script.field.ShortDocValuesField;
@@ -338,11 +336,7 @@ public class NumberFieldMapper extends FieldMapper {
 
             @Override
             public IndexFieldData.Builder getFieldDataBuilder(String name) {
-                return new SortedDoublesIndexFieldData.Builder(
-                    name,
-                    numericType(),
-                    (dv, n) -> new DelegateDocValuesField(new Doubles(new DoublesSupplier(dv)), n)
-                );
+                return new SortedDoublesIndexFieldData.Builder(name, numericType(), HalfFloatDocValuesField::new);
             }
 
             private void validateParsed(float value) {

+ 153 - 0
server/src/main/java/org/elasticsearch/script/field/HalfFloatDocValuesField.java

@@ -0,0 +1,153 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.script.field;
+
+import org.apache.lucene.util.ArrayUtil;
+import org.elasticsearch.index.fielddata.ScriptDocValues;
+import org.elasticsearch.index.fielddata.SortedNumericDoubleValues;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+import java.util.NoSuchElementException;
+
+public class HalfFloatDocValuesField implements DocValuesField<Float>, ScriptDocValues.Supplier<Double> {
+
+    protected final SortedNumericDoubleValues input;
+    protected final String name;
+
+    protected double[] values = new double[0];
+    protected int count;
+
+    private ScriptDocValues.Doubles doubles = null;
+
+    public HalfFloatDocValuesField(SortedNumericDoubleValues input, String name) {
+        this.input = input;
+        this.name = name;
+    }
+
+    @Override
+    public void setNextDocId(int docId) throws IOException {
+        if (input.advanceExact(docId)) {
+            resize(input.docValueCount());
+            for (int i = 0; i < count; i++) {
+                values[i] = input.nextValue();
+            }
+        } else {
+            resize(0);
+        }
+    }
+
+    protected void resize(int newSize) {
+        count = newSize;
+
+        assert count >= 0 : "size must be positive (got " + count + "): likely integer overflow?";
+        values = ArrayUtil.grow(values, count);
+    }
+
+    @Override
+    public ScriptDocValues<Double> getScriptDocValues() {
+        if (doubles == null) {
+            doubles = new ScriptDocValues.Doubles(this);
+        }
+
+        return doubles;
+    }
+
+    @Override
+    public Double getInternal(int index) {
+        return values[index];
+    }
+
+    @Override
+    public String getName() {
+        return name;
+    }
+
+    @Override
+    public boolean isEmpty() {
+        return count == 0;
+    }
+
+    @Override
+    public int size() {
+        return count;
+    }
+
+    /**
+     * Does a downcast for defaultValue from a double to a float
+     * to allow users to avoid explicit casting.
+     */
+    public float get(double defaultValue) {
+        return get(0, defaultValue);
+    }
+
+    /**
+     * Does a downcast for defaultValue from a double to a float
+     * to allow users to avoid explicit casting.
+     */
+    public float get(int index, double defaultValue) {
+        if (isEmpty() || index < 0 || index >= count) {
+            return (float) defaultValue;
+        }
+
+        return (float) values[index];
+    }
+
+    @Override
+    public Iterator<Float> iterator() {
+        return new Iterator<Float>() {
+            private int index = 0;
+
+            @Override
+            public boolean hasNext() {
+                return index < count;
+            }
+
+            @Override
+            public Float next() {
+                if (hasNext() == false) {
+                    throw new NoSuchElementException();
+                }
+                return (float) values[index++];
+            }
+        };
+    }
+
+    /** Converts all the values to {@code Double} and returns them as a {@code List}. */
+    public List<Double> asDoubles() {
+        if (isEmpty()) {
+            return Collections.emptyList();
+        }
+
+        List<Double> doubleValues = new ArrayList<>(count);
+
+        for (int index = 0; index < count; ++index) {
+            doubleValues.add(values[index]);
+        }
+
+        return doubleValues;
+    }
+
+    /** Returns the 0th index value as a {@code double} if it exists, otherwise {@code defaultValue}. */
+    public double asDouble(double defaultValue) {
+        return asDouble(0, defaultValue);
+    }
+
+    /** Returns the value at {@code index} as a {@code double} if it exists, otherwise {@code defaultValue}. */
+    public double asDouble(int index, double defaultValue) {
+        if (isEmpty() || index < 0 || index >= count) {
+            return defaultValue;
+        }
+
+        return values[index];
+    }
+}

+ 29 - 0
server/src/test/java/org/elasticsearch/index/fielddata/FloatDocValuesFieldTests.java

@@ -10,6 +10,7 @@ package org.elasticsearch.index.fielddata;
 
 import org.elasticsearch.script.field.DoubleDocValuesField;
 import org.elasticsearch.script.field.FloatDocValuesField;
+import org.elasticsearch.script.field.HalfFloatDocValuesField;
 import org.elasticsearch.script.field.ScaledFloatDocValuesField;
 import org.elasticsearch.test.ESTestCase;
 
@@ -80,6 +81,34 @@ public class FloatDocValuesFieldTests extends ESTestCase {
         }
     }
 
+    public void testHalfFloatField() throws IOException {
+        double[][] values = generate(ESTestCase::randomDouble);
+        HalfFloatDocValuesField halfFloatField = new HalfFloatDocValuesField(wrap(values), "test");
+        for (int round = 0; round < 10; round++) {
+            int d = between(0, values.length - 1);
+            halfFloatField.setNextDocId(d);
+            if (values[d].length > 0) {
+                assertEquals((float) values[d][0], halfFloatField.get(Float.MIN_VALUE), 0.0f);
+                assertEquals((float) values[d][0], halfFloatField.get(0, Float.MIN_VALUE), 0.0f);
+                assertEquals(values[d][0], halfFloatField.asDouble(Double.MIN_VALUE), 0.0);
+                assertEquals(values[d][0], halfFloatField.asDouble(0, Double.MIN_VALUE), 0.0);
+            }
+            assertEquals(values[d].length, halfFloatField.size());
+            for (int i = 0; i < values[d].length; i++) {
+                assertEquals((float) values[d][i], halfFloatField.get(i, Float.MIN_VALUE), 0.0f);
+                assertEquals(values[d][i], halfFloatField.asDouble(i, Double.MIN_VALUE), 0.0);
+            }
+            int i = 0;
+            for (float flt : halfFloatField) {
+                assertEquals((float) values[d][i++], flt, 0.0f);
+            }
+            i = 0;
+            for (double dbl : halfFloatField.asDoubles()) {
+                assertEquals(values[d][i++], dbl, 0.0);
+            }
+        }
+    }
+
     protected double[][] generate(DoubleSupplier supplier) {
         double[][] values = new double[between(3, 10)][];
         for (int d = 0; d < values.length; d++) {