Browse Source

Add basic implementations of float-byte script comparisons (#122381)

Add implementations of `cosineSimilarity` and `dotProduct` to query byte vector fields using float vectors
Simon Cooper 7 tháng trước cách đây
mục cha
commit
82668b40f4

+ 6 - 0
docs/changelog/122381.yaml

@@ -0,0 +1,6 @@
+pr: 122381
+summary: Adds implementations of dotProduct and cosineSimilarity painless methods to operate on float vectors for byte fields
+area: Vector Search
+type: enhancement
+issues:
+ - 117274

+ 13 - 0
libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java

@@ -80,6 +80,19 @@ public class ESVectorUtil {
         return IMPL.ipFloatBit(q, d);
     }
 
+    /**
+     * Compute the inner product of two vectors, where the query vector is a float vector and the document vector is a byte vector.
+     * @param q the query vector
+     * @param d the document vector
+     * @return the inner product of the two vectors
+     */
+    public static float ipFloatByte(float[] q, byte[] d) {
+        if (q.length != d.length) {
+            throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + d.length);
+        }
+        return IMPL.ipFloatByte(q, d);
+    }
+
     /**
      * AND bit count computed over signed bytes.
      * Copied from Lucene's XOR implementation

+ 13 - 0
libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java

@@ -39,6 +39,11 @@ final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {
         return ipFloatBitImpl(q, d);
     }
 
+    @Override
+    public float ipFloatByte(float[] q, byte[] d) {
+        return ipFloatByteImpl(q, d);
+    }
+
     public static int ipByteBitImpl(byte[] q, byte[] d) {
         assert q.length == d.length * Byte.SIZE;
         int acc0 = 0;
@@ -101,4 +106,12 @@ final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {
         }
         return ret;
     }
+
+    public static float ipFloatByteImpl(float[] q, byte[] d) {
+        float ret = 0;
+        for (int i = 0; i < q.length; i++) {
+            ret += q[i] * d[i];
+        }
+        return ret;
+    }
 }

+ 2 - 0
libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java

@@ -18,4 +18,6 @@ public interface ESVectorUtilSupport {
     int ipByteBit(byte[] q, byte[] d);
 
     float ipFloatBit(float[] q, byte[] d);
+
+    float ipFloatByte(float[] q, byte[] d);
 }

+ 5 - 0
libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java

@@ -58,6 +58,11 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
         return DefaultESVectorUtilSupport.ipFloatBitImpl(q, d);
     }
 
+    @Override
+    public float ipFloatByte(float[] q, byte[] d) {
+        return DefaultESVectorUtilSupport.ipFloatByteImpl(q, d);
+    }
+
     private static final VectorSpecies<Byte> BYTE_SPECIES_128 = ByteVector.SPECIES_128;
     private static final VectorSpecies<Byte> BYTE_SPECIES_256 = ByteVector.SPECIES_256;
 

+ 18 - 1
libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java

@@ -32,11 +32,28 @@ public class ESVectorUtilTests extends BaseVectorizationTests {
     public void testIpFloatBit() {
         float[] q = new float[16];
         byte[] d = new byte[] { (byte) Integer.parseInt("01100010", 2), (byte) Integer.parseInt("10100111", 2) };
-        random().nextFloat();
+        for (int i = 0; i < q.length; i++) {
+            q[i] = random().nextFloat();
+        }
         float expected = q[1] + q[2] + q[6] + q[8] + q[10] + q[13] + q[14] + q[15];
         assertEquals(expected, ESVectorUtil.ipFloatBit(q, d), 1e-6);
     }
 
+    public void testIpFloatByte() {
+        float[] q = new float[16];
+        byte[] d = new byte[16];
+        for (int i = 0; i < q.length; i++) {
+            q[i] = random().nextFloat();
+        }
+        random().nextBytes(d);
+
+        float expected = 0;
+        for (int i = 0; i < q.length; i++) {
+            expected += q[i] * d[i];
+        }
+        assertEquals(expected, ESVectorUtil.ipFloatByte(q, d), 1e-6);
+    }
+
     public void testBitAndCount() {
         testBasicBitAndImpl(ESVectorUtil::andBitCountLong);
     }

+ 68 - 0
modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/145_dense_vector_byte_basic.yml

@@ -107,6 +107,38 @@ setup:
   - match: {hits.hits.2._id: "1"}
   - match: {hits.hits.2._score: 1632.0}
 ---
+"Dot Product float":
+  - requires:
+      capabilities:
+        - path: /_search
+          capabilities: [byte_float_dot_product_capability]
+      test_runner_features: [capabilities]
+      reason: "float vector queries capability added"
+  - do:
+      headers:
+        Content-Type: application/json
+      search:
+        rest_total_hits_as_int: true
+        body:
+          query:
+            script_score:
+              query: {match_all: {} }
+              script:
+                source: "dotProduct(params.query_vector, 'vector')"
+                params:
+                  query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]
+
+  - match: {hits.total: 3}
+
+  - match: {hits.hits.0._id: "2"}
+  - match: {hits.hits.0._score: 32865.2}
+
+  - match: {hits.hits.1._id: "3"}
+  - match: {hits.hits.1._score: 21413.4}
+
+  - match: {hits.hits.2._id: "1"}
+  - match: {hits.hits.2._score: 1862.3}
+---
 "Cosine Similarity":
   - do:
       headers:
@@ -198,3 +230,39 @@ setup:
   - match: {hits.hits.2._id: "1"}
   - gte: {hits.hits.2._score: 0.509}
   - lte: {hits.hits.2._score: 0.512}
+
+---
+"Cosine Similarity float":
+  - requires:
+      capabilities:
+        - path: /_search
+          capabilities: [byte_float_dot_product_capability]
+      test_runner_features: [capabilities]
+      reason: "float vector queries capability added"
+  - do:
+      headers:
+        Content-Type: application/json
+      search:
+        rest_total_hits_as_int: true
+        body:
+          query:
+            script_score:
+              query: {match_all: {} }
+              script:
+                source: "cosineSimilarity(params.query_vector, 'vector')"
+                params:
+                  query_vector: [0.5, 111.3, -13.0, 14.8, -156.0]
+
+  - match: {hits.total: 3}
+
+  - match: {hits.hits.0._id: "2"}
+  - gte: {hits.hits.0._score: 0.989}
+  - lte: {hits.hits.0._score: 0.992}
+
+  - match: {hits.hits.1._id: "3"}
+  - gte: {hits.hits.1._score: 0.885}
+  - lte: {hits.hits.1._score: 0.888}
+
+  - match: {hits.hits.2._id: "1"}
+  - gte: {hits.hits.2._score: 0.505}
+  - lte: {hits.hits.2._score: 0.508}

+ 58 - 25
server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

@@ -346,16 +346,17 @@ public class DenseVectorFieldMapper extends FieldMapper {
             }
 
             @Override
-            public void checkVectorBounds(float[] vector) {
-                checkNanAndInfinite(vector);
-
-                StringBuilder errorBuilder = null;
+            StringBuilder checkVectorErrors(float[] vector) {
+                StringBuilder errors = checkNanAndInfinite(vector);
+                if (errors != null) {
+                    return errors;
+                }
 
                 for (int index = 0; index < vector.length; ++index) {
                     float value = vector[index];
 
                     if (value % 1.0f != 0.0f) {
-                        errorBuilder = new StringBuilder(
+                        errors = new StringBuilder(
                             "element_type ["
                                 + this
                                 + "] vectors only support non-decimal values but found decimal value ["
@@ -368,7 +369,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
                     }
 
                     if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) {
-                        errorBuilder = new StringBuilder(
+                        errors = new StringBuilder(
                             "element_type ["
                                 + this
                                 + "] vectors only support integers between ["
@@ -385,9 +386,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
                     }
                 }
 
-                if (errorBuilder != null) {
-                    throw new IllegalArgumentException(appendErrorElements(errorBuilder, vector).toString());
-                }
+                return errors;
             }
 
             @Override
@@ -614,8 +613,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
             }
 
             @Override
-            public void checkVectorBounds(float[] vector) {
-                checkNanAndInfinite(vector);
+            StringBuilder checkVectorErrors(float[] vector) {
+                return checkNanAndInfinite(vector);
             }
 
             @Override
@@ -768,16 +767,17 @@ public class DenseVectorFieldMapper extends FieldMapper {
             }
 
             @Override
-            public void checkVectorBounds(float[] vector) {
-                checkNanAndInfinite(vector);
-
-                StringBuilder errorBuilder = null;
+            StringBuilder checkVectorErrors(float[] vector) {
+                StringBuilder errors = checkNanAndInfinite(vector);
+                if (errors != null) {
+                    return errors;
+                }
 
                 for (int index = 0; index < vector.length; ++index) {
                     float value = vector[index];
 
                     if (value % 1.0f != 0.0f) {
-                        errorBuilder = new StringBuilder(
+                        errors = new StringBuilder(
                             "element_type ["
                                 + this
                                 + "] vectors only support non-decimal values but found decimal value ["
@@ -790,7 +790,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
                     }
 
                     if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) {
-                        errorBuilder = new StringBuilder(
+                        errors = new StringBuilder(
                             "element_type ["
                                 + this
                                 + "] vectors only support integers between ["
@@ -807,9 +807,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
                     }
                 }
 
-                if (errorBuilder != null) {
-                    throw new IllegalArgumentException(appendErrorElements(errorBuilder, vector).toString());
-                }
+                return errors;
             }
 
             @Override
@@ -993,7 +991,44 @@ public class DenseVectorFieldMapper extends FieldMapper {
 
         public abstract ByteBuffer createByteBuffer(IndexVersion indexVersion, int numBytes);
 
-        public abstract void checkVectorBounds(float[] vector);
+        /**
+         * Checks the input {@code vector} is one of the {@code possibleTypes},
+         * and returns the first type that it matches
+         */
+        public static ElementType checkValidVector(float[] vector, ElementType... possibleTypes) {
+            assert possibleTypes.length != 0;
+            // we're looking for one valid allowed type
+            // assume the types are in order of specificity
+            StringBuilder[] errors = new StringBuilder[possibleTypes.length];
+            for (int i = 0; i < possibleTypes.length; i++) {
+                StringBuilder error = possibleTypes[i].checkVectorErrors(vector);
+                if (error == null) {
+                    // this one works - use it
+                    return possibleTypes[i];
+                } else {
+                    errors[i] = error;
+                }
+            }
+
+            // oh dear, none of the possible types work with this vector. Generate the error message and throw.
+            StringBuilder message = new StringBuilder();
+            for (int i = 0; i < possibleTypes.length; i++) {
+                if (i > 0) {
+                    message.append(" ");
+                }
+                message.append("Vector is not a ").append(possibleTypes[i]).append(" vector: ").append(errors[i]);
+            }
+            throw new IllegalArgumentException(appendErrorElements(message, vector).toString());
+        }
+
+        public void checkVectorBounds(float[] vector) {
+            StringBuilder errors = checkVectorErrors(vector);
+            if (errors != null) {
+                throw new IllegalArgumentException(appendErrorElements(errors, vector).toString());
+            }
+        }
+
+        abstract StringBuilder checkVectorErrors(float[] vector);
 
         abstract void checkVectorMagnitude(
             VectorSimilarity similarity,
@@ -1017,7 +1052,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
             return index;
         }
 
-        void checkNanAndInfinite(float[] vector) {
+        StringBuilder checkNanAndInfinite(float[] vector) {
             StringBuilder errorBuilder = null;
 
             for (int index = 0; index < vector.length; ++index) {
@@ -1044,9 +1079,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
                 }
             }
 
-            if (errorBuilder != null) {
-                throw new IllegalArgumentException(appendErrorElements(errorBuilder, vector).toString());
-            }
+            return errorBuilder;
         }
 
         static StringBuilder appendErrorElements(StringBuilder errorBuilder, float[] vector) {

+ 3 - 0
server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java

@@ -25,6 +25,8 @@ public final class SearchCapabilities {
     private static final String BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY = "bit_dense_vector_synthetic_source";
     /** Support Byte and Float with Bit dot product. */
     private static final String BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY = "byte_float_bit_dot_product_with_bugfix";
+    /** Support float query vectors on byte vectors */
+    private static final String BYTE_FLOAT_DOT_PRODUCT_CAPABILITY = "byte_float_dot_product_capability";
     /** Support docvalue_fields parameter for `dense_vector` field. */
     private static final String DENSE_VECTOR_DOCVALUE_FIELDS = "dense_vector_docvalue_fields";
     /** Support transforming rank rrf queries to the corresponding rrf retriever. */
@@ -50,6 +52,7 @@ public final class SearchCapabilities {
         capabilities.add(RANGE_REGEX_INTERVAL_QUERY_CAPABILITY);
         capabilities.add(BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY);
         capabilities.add(BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY);
+        capabilities.add(BYTE_FLOAT_DOT_PRODUCT_CAPABILITY);
         capabilities.add(DENSE_VECTOR_DOCVALUE_FIELDS);
         capabilities.add(TRANSFORM_RANK_RRF_TO_RETRIEVER);
         capabilities.add(NESTED_RETRIEVER_INNER_HITS_SUPPORT);

+ 66 - 23
server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java

@@ -11,6 +11,7 @@ package org.elasticsearch.script;
 
 import org.elasticsearch.ExceptionsHelper;
 import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
+import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
 import org.elasticsearch.script.field.vectors.DenseVector;
 import org.elasticsearch.script.field.vectors.DenseVectorDocValuesField;
 
@@ -42,7 +43,10 @@ public class VectorScoreScriptUtils {
     }
 
     public static class ByteDenseVectorFunction extends DenseVectorFunction {
-        protected final byte[] queryVector;
+        // either byteQueryVector or floatQueryVector will be non-null
+        protected final byte[] byteQueryVector;
+        protected final float[] floatQueryVector;
+        // only valid if byteQueryVector is used
         protected final float qvMagnitude;
 
         /**
@@ -51,22 +55,51 @@ public class VectorScoreScriptUtils {
          * @param scoreScript The script in which this function was referenced.
          * @param field The vector field.
          * @param queryVector The query vector.
+         * @param normalizeFloatQuery {@code true} if the query vector is a float vector, then normalize it.
+         * @param allowedTypes The types the vector is allowed to be.
          */
-        public ByteDenseVectorFunction(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
+        public ByteDenseVectorFunction(
+            ScoreScript scoreScript,
+            DenseVectorDocValuesField field,
+            List<Number> queryVector,
+            boolean normalizeFloatQuery,
+            ElementType... allowedTypes
+        ) {
             super(scoreScript, field);
             field.getElementType().checkDimensions(field.get().getDims(), queryVector.size());
-            this.queryVector = new byte[queryVector.size()];
-            float[] validateValues = new float[queryVector.size()];
-            int queryMagnitude = 0;
+            float[] floatValues = new float[queryVector.size()];
+            double queryMagnitude = 0;
             for (int i = 0; i < queryVector.size(); i++) {
-                final Number number = queryVector.get(i);
-                byte value = number.byteValue();
-                this.queryVector[i] = value;
+                float value = queryVector.get(i).floatValue();
+                floatValues[i] = value;
                 queryMagnitude += value * value;
-                validateValues[i] = number.floatValue();
             }
-            this.qvMagnitude = (float) Math.sqrt(queryMagnitude);
-            field.getElementType().checkVectorBounds(validateValues);
+            queryMagnitude = Math.sqrt(queryMagnitude);
+
+            switch (ElementType.checkValidVector(floatValues, allowedTypes)) {
+                case FLOAT:
+                    byteQueryVector = null;
+                    floatQueryVector = floatValues;
+                    qvMagnitude = -1;   // invalid valid, not used for float vectors
+
+                    if (normalizeFloatQuery) {
+                        for (int i = 0; i < floatQueryVector.length; i++) {
+                            floatQueryVector[i] /= (float) queryMagnitude;
+                        }
+                    }
+                    break;
+                case BYTE:
+                    floatQueryVector = null;
+                    byteQueryVector = new byte[floatValues.length];
+                    for (int i = 0; i < floatValues.length; i++) {
+                        byteQueryVector[i] = (byte) floatValues[i];
+                    }
+                    this.qvMagnitude = (float) queryMagnitude;
+                    break;
+                default:
+                    throw new AssertionError("Unexpected element type");
+            }
+
         }
 
         /**
@@ -78,8 +111,9 @@ public class VectorScoreScriptUtils {
          */
         public ByteDenseVectorFunction(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
             super(scoreScript, field);
-            this.queryVector = queryVector;
-            float queryMagnitude = 0.0f;
+            byteQueryVector = queryVector;
+            floatQueryVector = null;
+            double queryMagnitude = 0.0f;
             for (byte value : queryVector) {
                 queryMagnitude += value * value;
             }
@@ -133,7 +167,7 @@ public class VectorScoreScriptUtils {
     public static class ByteL1Norm extends ByteDenseVectorFunction implements L1NormInterface {
 
         public ByteL1Norm(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
-            super(scoreScript, field, queryVector);
+            super(scoreScript, field, queryVector, false, ElementType.BYTE);
         }
 
         public ByteL1Norm(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
@@ -142,7 +176,7 @@ public class VectorScoreScriptUtils {
 
         public double l1norm() {
             setNextVector();
-            return field.get().l1Norm(queryVector);
+            return field.get().l1Norm(byteQueryVector);
         }
     }
 
@@ -197,7 +231,7 @@ public class VectorScoreScriptUtils {
     public static class ByteHammingDistance extends ByteDenseVectorFunction implements HammingDistanceInterface {
 
         public ByteHammingDistance(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
-            super(scoreScript, field, queryVector);
+            super(scoreScript, field, queryVector, false, ElementType.BYTE);
         }
 
         public ByteHammingDistance(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
@@ -206,7 +240,7 @@ public class VectorScoreScriptUtils {
 
         public int hamming() {
             setNextVector();
-            return field.get().hamming(queryVector);
+            return field.get().hamming(byteQueryVector);
         }
     }
 
@@ -243,7 +277,7 @@ public class VectorScoreScriptUtils {
     public static class ByteL2Norm extends ByteDenseVectorFunction implements L2NormInterface {
 
         public ByteL2Norm(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
-            super(scoreScript, field, queryVector);
+            super(scoreScript, field, queryVector, false, ElementType.BYTE);
         }
 
         public ByteL2Norm(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
@@ -252,7 +286,7 @@ public class VectorScoreScriptUtils {
 
         public double l2norm() {
             setNextVector();
-            return field.get().l2Norm(queryVector);
+            return field.get().l2Norm(byteQueryVector);
         }
     }
 
@@ -388,7 +422,7 @@ public class VectorScoreScriptUtils {
     public static class ByteDotProduct extends ByteDenseVectorFunction implements DotProductInterface {
 
         public ByteDotProduct(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
-            super(scoreScript, field, queryVector);
+            super(scoreScript, field, queryVector, false, ElementType.BYTE, ElementType.FLOAT);
         }
 
         public ByteDotProduct(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
@@ -397,7 +431,11 @@ public class VectorScoreScriptUtils {
 
         public double dotProduct() {
             setNextVector();
-            return field.get().dotProduct(queryVector);
+            if (floatQueryVector != null) {
+                return field.get().dotProduct(floatQueryVector);
+            } else {
+                return field.get().dotProduct(byteQueryVector);
+            }
         }
     }
 
@@ -461,7 +499,7 @@ public class VectorScoreScriptUtils {
     public static class ByteCosineSimilarity extends ByteDenseVectorFunction implements CosineSimilarityInterface {
 
         public ByteCosineSimilarity(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
-            super(scoreScript, field, queryVector);
+            super(scoreScript, field, queryVector, true, ElementType.BYTE, ElementType.FLOAT);
         }
 
         public ByteCosineSimilarity(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
@@ -470,7 +508,12 @@ public class VectorScoreScriptUtils {
 
         public double cosineSimilarity() {
             setNextVector();
-            return field.get().cosineSimilarity(queryVector, qvMagnitude);
+            if (floatQueryVector != null) {
+                // float vector is already normalized by the superclass constructor
+                return field.get().cosineSimilarity(floatQueryVector, false);
+            } else {
+                return field.get().cosineSimilarity(byteQueryVector, qvMagnitude);
+            }
         }
     }
 

+ 7 - 2
server/src/main/java/org/elasticsearch/script/field/vectors/ByteBinaryDenseVector.java

@@ -12,6 +12,7 @@ package org.elasticsearch.script.field.vectors;
 import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.VectorUtil;
 import org.elasticsearch.core.SuppressForbidden;
+import org.elasticsearch.simdvec.ESVectorUtil;
 
 import java.nio.ByteBuffer;
 import java.util.List;
@@ -61,7 +62,7 @@ public class ByteBinaryDenseVector implements DenseVector {
 
     @Override
     public double dotProduct(float[] queryVector) {
-        throw new UnsupportedOperationException("use [int dotProduct(byte[] queryVector)] instead");
+        return ESVectorUtil.ipFloatByte(queryVector, vectorValue);
     }
 
     @Override
@@ -142,7 +143,11 @@ public class ByteBinaryDenseVector implements DenseVector {
 
     @Override
     public double cosineSimilarity(float[] queryVector, boolean normalizeQueryVector) {
-        throw new UnsupportedOperationException("use [double cosineSimilarity(byte[] queryVector, float qvMagnitude)] instead");
+        if (normalizeQueryVector) {
+            return dotProduct(queryVector) / (DenseVector.getMagnitude(queryVector) * getMagnitude());
+        }
+
+        return dotProduct(queryVector) / getMagnitude();
     }
 
     @Override

+ 8 - 3
server/src/main/java/org/elasticsearch/script/field/vectors/ByteKnnDenseVector.java

@@ -11,6 +11,7 @@ package org.elasticsearch.script.field.vectors;
 
 import org.apache.lucene.util.VectorUtil;
 import org.elasticsearch.core.SuppressForbidden;
+import org.elasticsearch.simdvec.ESVectorUtil;
 
 import java.util.List;
 
@@ -51,12 +52,12 @@ public class ByteKnnDenseVector implements DenseVector {
 
     @Override
     public int dotProduct(byte[] queryVector) {
-        return VectorUtil.dotProduct(docVector, queryVector);
+        return VectorUtil.dotProduct(queryVector, docVector);
     }
 
     @Override
     public double dotProduct(float[] queryVector) {
-        throw new UnsupportedOperationException("use [int dotProduct(byte[] queryVector)] instead");
+        return ESVectorUtil.ipFloatByte(queryVector, docVector);
     }
 
     @Override
@@ -145,7 +146,11 @@ public class ByteKnnDenseVector implements DenseVector {
 
     @Override
     public double cosineSimilarity(float[] queryVector, boolean normalizeQueryVector) {
-        throw new UnsupportedOperationException("use [double cosineSimilarity(byte[] queryVector, float qvMagnitude)] instead");
+        if (normalizeQueryVector) {
+            return dotProduct(queryVector) / (DenseVector.getMagnitude(queryVector) * getMagnitude());
+        }
+
+        return dotProduct(queryVector) / getMagnitude();
     }
 
     @Override

+ 58 - 95
server/src/test/java/org/elasticsearch/script/VectorScoreScriptUtilsTests.java

@@ -29,7 +29,6 @@ import org.elasticsearch.script.field.vectors.KnnDenseVectorDocValuesField;
 import org.elasticsearch.test.ESTestCase;
 
 import java.io.IOException;
-import java.util.Arrays;
 import java.util.HexFormat;
 import java.util.List;
 
@@ -43,8 +42,8 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
         String fieldName = "vector";
         int dims = 5;
         float[] docVector = new float[] { 230.0f, 300.33f, -34.8988f, 15.555f, -200.0f };
-        List<Number> queryVector = Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f);
-        List<Number> invalidQueryVector = Arrays.asList(0.5, 111.3);
+        List<Number> queryVector = List.of(0.5f, 111.3f, -13.0f, 14.8f, -156.0f);
+        List<Number> invalidQueryVector = List.of(0.5, 111.3);
 
         List<DenseVectorDocValuesField> fields = List.of(
             new BinaryDenseVectorDocValuesField(
@@ -141,8 +140,8 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
         String fieldName = "vector";
         int dims = 5;
         float[] docVector = new float[] { 1, 127, -128, 5, -10 };
-        List<Number> queryVector = Arrays.asList((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4);
-        List<Number> invalidQueryVector = Arrays.asList((byte) 1, (byte) 1);
+        List<Number> queryVector = List.of((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4);
+        List<Number> invalidQueryVector = List.of((byte) 1, (byte) 1);
         String hexidecimalString = HexFormat.of().formatHex(new byte[] { 1, 125, -12, 2, 4 });
 
         List<DenseVectorDocValuesField> fields = List.of(
@@ -183,11 +182,12 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
             for (int i = 0; i < queryVectorArray.length; i++) {
                 queryVectorArray[i] = queryVector.get(i).floatValue();
             }
-            UnsupportedOperationException uoe = expectThrows(
-                UnsupportedOperationException.class,
-                () -> field.getInternal().cosineSimilarity(queryVectorArray, true)
+            assertEquals(
+                "cosineSimilarity result is not equal to the expected value!",
+                cosineSimilarityExpected,
+                field.getInternal().cosineSimilarity(queryVectorArray, true),
+                0.001
             );
-            assertThat(uoe.getMessage(), containsString("use [double cosineSimilarity(byte[] queryVector, float qvMagnitude)] instead"));
 
             // Check each function rejects query vectors with the wrong dimension
             IllegalArgumentException e = expectThrows(
@@ -240,9 +240,9 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
         int dims = 8;
         float[] docVector = new float[] { 124 };
         // 124 in binary is b01111100
-        List<Number> queryVector = Arrays.asList((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4, (byte) 1, (byte) 125, (byte) -12);
-        List<Number> floatQueryVector = Arrays.asList(1.4f, -1.4f, 0.42f, 0.0f, 1f, -1f, -0.42f, 1.2f);
-        List<Number> invalidQueryVector = Arrays.asList((byte) 1, (byte) 1);
+        List<Number> queryVector = List.of((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4, (byte) 1, (byte) 125, (byte) -12);
+        List<Number> floatQueryVector = List.of(1.4f, -1.4f, 0.42f, 0.0f, 1f, -1f, -0.42f, 1.2f);
+        List<Number> invalidQueryVector = List.of((byte) 1, (byte) 1);
         String hexidecimalString = HexFormat.of().formatHex(new byte[] { 124 });
 
         List<DenseVectorDocValuesField> fields = List.of(
@@ -293,8 +293,8 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
     public void testByteVsFloatSimilarity() throws IOException {
         int dims = 5;
         float[] docVector = new float[] { 1f, 127f, -128f, 5f, -10f };
-        List<Number> listFloatVector = Arrays.asList(1f, 125f, -12f, 2f, 4f);
-        List<Number> listByteVector = Arrays.asList((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4);
+        List<Number> listFloatVector = List.of(1f, 125f, -12f, 2f, 4f);
+        List<Number> listByteVector = List.of((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4);
         float[] floatVector = new float[] { 1f, 125f, -12f, 2f, 4f };
         byte[] byteVector = new byte[] { (byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4 };
 
@@ -342,11 +342,7 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
             switch (field.getElementType()) {
                 case BYTE -> {
                     assertEquals(field.getName(), dotProductExpected, field.get().dotProduct(byteVector));
-                    UnsupportedOperationException e = expectThrows(
-                        UnsupportedOperationException.class,
-                        () -> field.get().dotProduct(floatVector)
-                    );
-                    assertThat(e.getMessage(), containsString("use [int dotProduct(byte[] queryVector)] instead"));
+                    assertEquals(field.getName(), dotProductExpected, field.get().dotProduct(floatVector), 0.001);
                 }
                 case FLOAT -> {
                     assertEquals(field.getName(), dotProductExpected, field.get().dotProduct(floatVector), 0.001);
@@ -423,14 +419,7 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
             switch (field.getElementType()) {
                 case BYTE -> {
                     assertEquals(field.getName(), cosineSimilarityExpected, field.get().cosineSimilarity(byteVector), 0.001);
-                    UnsupportedOperationException e = expectThrows(
-                        UnsupportedOperationException.class,
-                        () -> field.get().cosineSimilarity(floatVector)
-                    );
-                    assertThat(
-                        e.getMessage(),
-                        containsString("use [double cosineSimilarity(byte[] queryVector, float qvMagnitude)] instead")
-                    );
+                    assertEquals(field.getName(), cosineSimilarityExpected, field.get().cosineSimilarity(floatVector), 0.001);
                 }
                 case FLOAT -> {
                     assertEquals(field.getName(), cosineSimilarityExpected, field.get().cosineSimilarity(floatVector), 0.001);
@@ -471,81 +460,55 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
             ScoreScript scoreScript = mock(ScoreScript.class);
             when(scoreScript.field(fieldName)).thenAnswer(mock -> field);
 
-            IllegalArgumentException e;
-
-            e = expectThrows(IllegalArgumentException.class, () -> new DotProduct(scoreScript, greaterThanVector, fieldName));
-            assertEquals(
-                e.getMessage(),
-                "element_type [byte] vectors only support integers between [-128, 127] but found [128.0] at dim [0]; "
-                    + "Preview of invalid vector: [128.0]"
-            );
-            e = expectThrows(IllegalArgumentException.class, () -> new L1Norm(scoreScript, greaterThanVector, fieldName));
-            assertEquals(
-                e.getMessage(),
-                "element_type [byte] vectors only support integers between [-128, 127] but found [128.0] at dim [0]; "
-                    + "Preview of invalid vector: [128.0]"
-            );
-            e = expectThrows(IllegalArgumentException.class, () -> new L2Norm(scoreScript, greaterThanVector, fieldName));
-            assertEquals(
-                e.getMessage(),
-                "element_type [byte] vectors only support integers between [-128, 127] but found [128.0] at dim [0]; "
-                    + "Preview of invalid vector: [128.0]"
+            expectThrows(
+                IllegalArgumentException.class,
+                containsString(
+                    "element_type [byte] vectors only support integers between [-128, 127] but found [128.0] at dim [0]; "
+                        + "Preview of invalid vector: [128.0]"
+                ),
+                () -> new L1Norm(scoreScript, greaterThanVector, fieldName)
             );
-            e = expectThrows(IllegalArgumentException.class, () -> new CosineSimilarity(scoreScript, greaterThanVector, fieldName));
-            assertEquals(
-                e.getMessage(),
-                "element_type [byte] vectors only support integers between [-128, 127] but found [128.0] at dim [0]; "
-                    + "Preview of invalid vector: [128.0]"
+            expectThrows(
+                IllegalArgumentException.class,
+                containsString(
+                    "element_type [byte] vectors only support integers between [-128, 127] but found [128.0] at dim [0]; "
+                        + "Preview of invalid vector: [128.0]"
+                ),
+                () -> new L2Norm(scoreScript, greaterThanVector, fieldName)
             );
 
-            e = expectThrows(IllegalArgumentException.class, () -> new DotProduct(scoreScript, lessThanVector, fieldName));
-            assertEquals(
-                e.getMessage(),
-                "element_type [byte] vectors only support integers between [-128, 127] but found [-129.0] at dim [0]; "
-                    + "Preview of invalid vector: [-129.0]"
-            );
-            e = expectThrows(IllegalArgumentException.class, () -> new L1Norm(scoreScript, lessThanVector, fieldName));
-            assertEquals(
-                e.getMessage(),
-                "element_type [byte] vectors only support integers between [-128, 127] but found [-129.0] at dim [0]; "
-                    + "Preview of invalid vector: [-129.0]"
-            );
-            e = expectThrows(IllegalArgumentException.class, () -> new L2Norm(scoreScript, lessThanVector, fieldName));
-            assertEquals(
-                e.getMessage(),
-                "element_type [byte] vectors only support integers between [-128, 127] but found [-129.0] at dim [0]; "
-                    + "Preview of invalid vector: [-129.0]"
+            expectThrows(
+                IllegalArgumentException.class,
+                containsString(
+                    "element_type [byte] vectors only support integers between [-128, 127] but found [-129.0] at dim [0]; "
+                        + "Preview of invalid vector: [-129.0]"
+                ),
+                () -> new L1Norm(scoreScript, lessThanVector, fieldName)
             );
-            e = expectThrows(IllegalArgumentException.class, () -> new CosineSimilarity(scoreScript, lessThanVector, fieldName));
-            assertEquals(
-                e.getMessage(),
-                "element_type [byte] vectors only support integers between [-128, 127] but found [-129.0] at dim [0]; "
-                    + "Preview of invalid vector: [-129.0]"
+            expectThrows(
+                IllegalArgumentException.class,
+                containsString(
+                    "element_type [byte] vectors only support integers between [-128, 127] but found [-129.0] at dim [0]; "
+                        + "Preview of invalid vector: [-129.0]"
+                ),
+                () -> new L2Norm(scoreScript, lessThanVector, fieldName)
             );
 
-            e = expectThrows(IllegalArgumentException.class, () -> new DotProduct(scoreScript, decimalVector, fieldName));
-            assertEquals(
-                e.getMessage(),
-                "element_type [byte] vectors only support non-decimal values but found decimal value [0.5] at dim [0]; "
-                    + "Preview of invalid vector: [0.5]"
-            );
-            e = expectThrows(IllegalArgumentException.class, () -> new L1Norm(scoreScript, decimalVector, fieldName));
-            assertEquals(
-                e.getMessage(),
-                "element_type [byte] vectors only support non-decimal values but found decimal value [0.5] at dim [0]; "
-                    + "Preview of invalid vector: [0.5]"
-            );
-            e = expectThrows(IllegalArgumentException.class, () -> new L2Norm(scoreScript, decimalVector, fieldName));
-            assertEquals(
-                e.getMessage(),
-                "element_type [byte] vectors only support non-decimal values but found decimal value [0.5] at dim [0]; "
-                    + "Preview of invalid vector: [0.5]"
+            expectThrows(
+                IllegalArgumentException.class,
+                containsString(
+                    "element_type [byte] vectors only support non-decimal values but found decimal value [0.5] at dim [0]; "
+                        + "Preview of invalid vector: [0.5]"
+                ),
+                () -> new L1Norm(scoreScript, decimalVector, fieldName)
             );
-            e = expectThrows(IllegalArgumentException.class, () -> new CosineSimilarity(scoreScript, decimalVector, fieldName));
-            assertEquals(
-                e.getMessage(),
-                "element_type [byte] vectors only support non-decimal values but found decimal value [0.5] at dim [0]; "
-                    + "Preview of invalid vector: [0.5]"
+            expectThrows(
+                IllegalArgumentException.class,
+                containsString(
+                    "element_type [byte] vectors only support non-decimal values but found decimal value [0.5] at dim [0]; "
+                        + "Preview of invalid vector: [0.5]"
+                ),
+                () -> new L2Norm(scoreScript, decimalVector, fieldName)
             );
         }
     }

+ 0 - 20
server/src/test/java/org/elasticsearch/script/field/vectors/DenseVectorTests.java

@@ -149,11 +149,6 @@ public class DenseVectorTests extends ESTestCase {
         ByteKnnDenseVector knn = new ByteKnnDenseVector(docVector);
         UnsupportedOperationException e;
 
-        e = expectThrows(UnsupportedOperationException.class, () -> knn.dotProduct(queryVector));
-        assertEquals(e.getMessage(), "use [int dotProduct(byte[] queryVector)] instead");
-        e = expectThrows(UnsupportedOperationException.class, () -> knn.dotProduct((Object) queryVector));
-        assertEquals(e.getMessage(), "use [int dotProduct(byte[] queryVector)] instead");
-
         e = expectThrows(UnsupportedOperationException.class, () -> knn.l1Norm(queryVector));
         assertEquals(e.getMessage(), "use [int l1Norm(byte[] queryVector)] instead");
         e = expectThrows(UnsupportedOperationException.class, () -> knn.l1Norm((Object) queryVector));
@@ -164,18 +159,8 @@ public class DenseVectorTests extends ESTestCase {
         e = expectThrows(UnsupportedOperationException.class, () -> knn.l2Norm((Object) queryVector));
         assertEquals(e.getMessage(), "use [double l2Norm(byte[] queryVector)] instead");
 
-        e = expectThrows(UnsupportedOperationException.class, () -> knn.cosineSimilarity(queryVector));
-        assertEquals(e.getMessage(), "use [double cosineSimilarity(byte[] queryVector, float qvMagnitude)] instead");
-        e = expectThrows(UnsupportedOperationException.class, () -> knn.cosineSimilarity((Object) queryVector));
-        assertEquals(e.getMessage(), "use [double cosineSimilarity(byte[] queryVector, float qvMagnitude)] instead");
-
         ByteBinaryDenseVector binary = new ByteBinaryDenseVector(docVector, new BytesRef(docVector), dims);
 
-        e = expectThrows(UnsupportedOperationException.class, () -> binary.dotProduct(queryVector));
-        assertEquals(e.getMessage(), "use [int dotProduct(byte[] queryVector)] instead");
-        e = expectThrows(UnsupportedOperationException.class, () -> binary.dotProduct((Object) queryVector));
-        assertEquals(e.getMessage(), "use [int dotProduct(byte[] queryVector)] instead");
-
         e = expectThrows(UnsupportedOperationException.class, () -> binary.l1Norm(queryVector));
         assertEquals(e.getMessage(), "use [int l1Norm(byte[] queryVector)] instead");
         e = expectThrows(UnsupportedOperationException.class, () -> binary.l1Norm((Object) queryVector));
@@ -185,11 +170,6 @@ public class DenseVectorTests extends ESTestCase {
         assertEquals(e.getMessage(), "use [double l2Norm(byte[] queryVector)] instead");
         e = expectThrows(UnsupportedOperationException.class, () -> binary.l2Norm((Object) queryVector));
         assertEquals(e.getMessage(), "use [double l2Norm(byte[] queryVector)] instead");
-
-        e = expectThrows(UnsupportedOperationException.class, () -> binary.cosineSimilarity(queryVector));
-        assertEquals(e.getMessage(), "use [double cosineSimilarity(byte[] queryVector, float qvMagnitude)] instead");
-        e = expectThrows(UnsupportedOperationException.class, () -> binary.cosineSimilarity((Object) queryVector));
-        assertEquals(e.getMessage(), "use [double cosineSimilarity(byte[] queryVector, float qvMagnitude)] instead");
     }
 
     public void testFloatUnsupported() {