Browse Source

Use float instead of double for query vectors. (#46004)

Currently, when using script_score functions like cosineSimilarity, the query
vector is treated as an array of doubles. Since the stored document vectors use
floats, it seems like the least surprising behavior for the query vectors to
also be float arrays.

In addition to improving consistency, this change may help with some
optimizations we have been considering around vector dot product.
Julie Tibshirani 6 years ago
parent
commit
8d16c9bee6

+ 1 - 1
x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/15_dense_vector_l1l2.yml

@@ -65,7 +65,7 @@ setup:
 
   - match: {hits.hits.1._id: "2"}
   - gte: {hits.hits.1._score: 12.29}
-  - lte: {hits.hits.1._score: 12.30}
+  - lte: {hits.hits.1._score: 12.31}
 
   - match: {hits.hits.2._id: "3"}
   - gte: {hits.hits.2._score: 0.00}

+ 1 - 1
x-pack/plugin/src/test/resources/rest-api-spec/test/vectors/35_sparse_vector_l1l2.yml

@@ -63,7 +63,7 @@ setup:
 
   - match: {hits.hits.1._id: "2"}
   - gte: {hits.hits.1._score: 12.29}
-  - lte: {hits.hits.1._score: 12.30}
+  - lte: {hits.hits.1._score: 12.31}
 
   - match: {hits.hits.2._id: "3"}
   - gte: {hits.hits.2._score: 0.00}

+ 2 - 2
x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoder.java

@@ -130,7 +130,7 @@ public final class VectorEncoderDecoder {
      * @param values - values for the sparse query vector
      * @param n - number of dimensions
      */
-    public static void sortSparseDimsDoubleValues(int[] dims, double[] values, int n) {
+    public static void sortSparseDimsFloatValues(int[] dims, float[] values, int n) {
         new InPlaceMergeSorter() {
             @Override
             public int compare(int i, int j) {
@@ -143,7 +143,7 @@ public final class VectorEncoderDecoder {
                 dims[i] = dims[j];
                 dims[j] = tempDim;
 
-                double tempValue = values[j];
+                float tempValue = values[j];
                 values[j] = values[i];
                 values[i] = tempValue;
             }

+ 12 - 12
x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java

@@ -14,7 +14,7 @@ import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 
-import static org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder.sortSparseDimsDoubleValues;
+import static org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder.sortSparseDimsFloatValues;
 
 public class ScoreScriptUtils {
 
@@ -37,7 +37,7 @@ public class ScoreScriptUtils {
         Iterator<Number> queryVectorIter = queryVector.iterator();
         double l1norm = 0;
         for (int dim = 0; dim < docVector.length; dim++){
-            l1norm += Math.abs(queryVectorIter.next().doubleValue() - docVector[dim]);
+            l1norm += Math.abs(queryVectorIter.next().floatValue() - docVector[dim]);
         }
         return l1norm;
     }
@@ -59,7 +59,7 @@ public class ScoreScriptUtils {
         Iterator<Number> queryVectorIter = queryVector.iterator();
         double l2norm = 0;
         for (int dim = 0; dim < docVector.length; dim++){
-            double diff = queryVectorIter.next().doubleValue() - docVector[dim];
+            double diff = queryVectorIter.next().floatValue() - docVector[dim];
             l2norm += diff * diff;
         }
         return Math.sqrt(l2norm);
@@ -97,11 +97,11 @@ public class ScoreScriptUtils {
         // calculate queryVectorMagnitude once per query execution
         public CosineSimilarity(List<Number> queryVector) {
             this.queryVector = queryVector;
-            double doubleValue;
+
             double dotProduct = 0;
             for (Number value : queryVector) {
-                doubleValue = value.doubleValue();
-                dotProduct += doubleValue * doubleValue;
+                float floatValue = value.floatValue();
+                dotProduct += floatValue * floatValue;
             }
             this.queryVectorMagnitude = Math.sqrt(dotProduct);
         }
@@ -130,7 +130,7 @@ public class ScoreScriptUtils {
         double v1v2DotProduct = 0;
         Iterator<Number> v1Iter = v1.iterator();
         for (int dim = 0; dim < v2.length; dim++) {
-            v1v2DotProduct += v1Iter.next().doubleValue() * v2[dim];
+            v1v2DotProduct += v1Iter.next().floatValue() * v2[dim];
         }
         return v1v2DotProduct;
     }
@@ -139,7 +139,7 @@ public class ScoreScriptUtils {
     //**************FUNCTIONS FOR SPARSE VECTORS
 
     public static class VectorSparseFunctions {
-        final double[] queryValues;
+        final float[] queryValues;
         final int[] queryDims;
 
         // prepare queryVector once per script execution
@@ -147,7 +147,7 @@ public class ScoreScriptUtils {
         public VectorSparseFunctions(Map<String, Number> queryVector) {
             //break vector into two arrays dims and values
             int n = queryVector.size();
-            queryValues = new double[n];
+            queryValues = new float[n];
             queryDims = new int[n];
             int i = 0;
             for (Map.Entry<String, Number> dimValue : queryVector.entrySet()) {
@@ -156,11 +156,11 @@ public class ScoreScriptUtils {
                 } catch (final NumberFormatException e) {
                     throw new IllegalArgumentException("Failed to parse a query vector dimension, it must be an integer!", e);
                 }
-                queryValues[i] = dimValue.getValue().doubleValue();
+                queryValues[i] = dimValue.getValue().floatValue();
                 i++;
             }
             // Sort dimensions in the ascending order and sort values in the same order as their corresponding dimensions
-            sortSparseDimsDoubleValues(queryDims, queryValues, n);
+            sortSparseDimsFloatValues(queryDims, queryValues, n);
         }
     }
 
@@ -317,7 +317,7 @@ public class ScoreScriptUtils {
         }
     }
 
-    private static double intDotProductSparse(double[] v1Values, int[] v1Dims, float[] v2Values, int[] v2Dims) {
+    private static double intDotProductSparse(float[] v1Values, int[] v1Dims, float[] v2Values, int[] v2Dims) {
         double v1v2DotProduct = 0;
         int v1Index = 0;
         int v2Index = 0;

+ 5 - 5
x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtilsTests.java

@@ -36,11 +36,11 @@ public class ScoreScriptUtilsTests extends ESTestCase {
         BytesRef encodedDocVector =  mockEncodeDenseVector(docVector);
         VectorScriptDocValues.DenseVectorScriptDocValues dvs = mock(VectorScriptDocValues.DenseVectorScriptDocValues.class);
         when(dvs.getEncodedValue()).thenReturn(encodedDocVector);
-        List<Number> queryVector = Arrays.asList(0.5, 111.3, -13.0, 14.8, -156.0);
+        List<Number> queryVector = Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f);
 
         // test dotProduct
         double result = dotProduct(queryVector, dvs);
-        assertEquals("dotProduct result is not equal to the expected value!", 65425.626, result, 0.001);
+        assertEquals("dotProduct result is not equal to the expected value!", 65425.624, result, 0.001);
 
         // test cosineSimilarity
         CosineSimilarity cosineSimilarity = new CosineSimilarity(queryVector);
@@ -91,7 +91,7 @@ public class ScoreScriptUtilsTests extends ESTestCase {
         // test dotProduct
         DotProductSparse docProductSparse = new DotProductSparse(queryVector);
         double result = docProductSparse.dotProductSparse(dvs);
-        assertEquals("dotProductSparse result is not equal to the expected value!", 65425.626, result, 0.001);
+        assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);
 
         // test cosineSimilarity
         CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(queryVector);
@@ -128,7 +128,7 @@ public class ScoreScriptUtilsTests extends ESTestCase {
         // test dotProduct
         DotProductSparse docProductSparse = new DotProductSparse(queryVector);
         double result = docProductSparse.dotProductSparse(dvs);
-        assertEquals("dotProductSparse result is not equal to the expected value!", 65425.626, result, 0.001);
+        assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);
 
         // test cosineSimilarity
         CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(queryVector);
@@ -165,7 +165,7 @@ public class ScoreScriptUtilsTests extends ESTestCase {
         // test dotProduct
         DotProductSparse docProductSparse = new DotProductSparse(queryVector);
         double result = docProductSparse.dotProductSparse(dvs);
-        assertEquals("dotProductSparse result is not equal to the expected value!", 65425.626, result, 0.001);
+        assertEquals("dotProductSparse result is not equal to the expected value!", 65425.624, result, 0.001);
 
         // test cosineSimilarity
         CosineSimilaritySparse cosineSimilaritySparse = new CosineSimilaritySparse(queryVector);