Browse Source

Correct bit * byte and bit * float script comparisons (#117404) (#117507)

I goofed on the bit * byte and bit * float comparisons. Naturally, these
should be bigendian and compare the dimensions with the binary ones
appropriately.

Additionally, I added a test to ensure that this is handled correctly.

(cherry picked from commit 374c88a832edb53525a1c52873db185c2acdfb23)
Benjamin Trent 10 months ago
parent
commit
8d16529169

+ 5 - 0
docs/changelog/117404.yaml

@@ -0,0 +1,5 @@
+pr: 117404
+summary: Correct bit * byte and bit * float script comparisons
+area: Vector Search
+type: bug
+issues: []

+ 4 - 0
docs/reference/vectors/vector-functions.asciidoc

@@ -336,6 +336,10 @@ When using `bit` vectors, not all the vector functions are available. The suppor
 this is the sum of the bitwise AND of the two vectors. If providing `float[]` or `byte[]`, who has `dims` number of elements, as a query vector, the `dotProduct` is
 the sum of the floating point values using the stored `bit` vector as a mask.
 
+NOTE: When comparing `floats` and `bytes` with `bit` vectors, the `bit` vector is treated as a mask in big-endian order.
+For example, if the `bit` vector is `10100001` (e.g. the single byte value `161`) and its compared
+with array of values `[1, 2, 3, 4, 5, 6, 7, 8]` the `dotProduct` will be `1 + 3 + 8 = 16`.
+
 Here is an example of using dot-product with bit vectors.
 
 [source,console]

+ 8 - 4
libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java

@@ -51,6 +51,8 @@ public class ESVectorUtil {
     /**
      * Compute the inner product of two vectors, where the query vector is a byte vector and the document vector is a bit vector.
      * This will return the sum of the query vector values using the document vector as a mask.
+     * When comparing the bits with the bytes, they are done in "big endian" order. For example, if the byte vector
+     * is [1, 2, 3, 4, 5, 6, 7, 8] and the bit vector is [0b10000000], the inner product will be 1.0.
      * @param q the query vector
      * @param d the document vector
      * @return the inner product of the two vectors
@@ -63,9 +65,9 @@ public class ESVectorUtil {
         // now combine the two vectors, summing the byte dimensions where the bit in d is `1`
         for (int i = 0; i < d.length; i++) {
             byte mask = d[i];
-            for (int j = 0; j < Byte.SIZE; j++) {
+            for (int j = Byte.SIZE - 1; j >= 0; j--) {
                 if ((mask & (1 << j)) != 0) {
-                    result += q[i * Byte.SIZE + j];
+                    result += q[i * Byte.SIZE + Byte.SIZE - 1 - j];
                 }
             }
         }
@@ -75,6 +77,8 @@ public class ESVectorUtil {
     /**
      * Compute the inner product of two vectors, where the query vector is a float vector and the document vector is a bit vector.
      * This will return the sum of the query vector values using the document vector as a mask.
+     * When comparing the bits with the floats, they are done in "big endian" order. For example, if the float vector
+     * is [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0] and the bit vector is [0b10000000], the inner product will be 1.0.
      * @param q the query vector
      * @param d the document vector
      * @return the inner product of the two vectors
@@ -86,9 +90,9 @@ public class ESVectorUtil {
         float result = 0;
         for (int i = 0; i < d.length; i++) {
             byte mask = d[i];
-            for (int j = 0; j < Byte.SIZE; j++) {
+            for (int j = Byte.SIZE - 1; j >= 0; j--) {
                 if ((mask & (1 << j)) != 0) {
-                    result += q[i * Byte.SIZE + j];
+                    result += q[i * Byte.SIZE + Byte.SIZE - 1 - j];
                 }
             }
         }

+ 16 - 0
libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java

@@ -21,6 +21,22 @@ public class ESVectorUtilTests extends BaseVectorizationTests {
     static final ESVectorizationProvider defaultedProvider = BaseVectorizationTests.defaultProvider();
     static final ESVectorizationProvider defOrPanamaProvider = BaseVectorizationTests.maybePanamaProvider();
 
+    public void testIpByteBit() {
+        byte[] q = new byte[16];
+        byte[] d = new byte[] { (byte) Integer.parseInt("01100010", 2), (byte) Integer.parseInt("10100111", 2) };
+        random().nextBytes(q);
+        int expected = q[1] + q[2] + q[6] + q[8] + q[10] + q[13] + q[14] + q[15];
+        assertEquals(expected, ESVectorUtil.ipByteBit(q, d));
+    }
+
+    public void testIpFloatBit() {
+        float[] q = new float[16];
+        byte[] d = new byte[] { (byte) Integer.parseInt("01100010", 2), (byte) Integer.parseInt("10100111", 2) };
+        random().nextFloat();
+        float expected = q[1] + q[2] + q[6] + q[8] + q[10] + q[13] + q[14] + q[15];
+        assertEquals(expected, ESVectorUtil.ipFloatBit(q, d), 1e-6);
+    }
+
     public void testBitAndCount() {
         testBasicBitAndImpl(ESVectorUtil::andBitCountLong);
     }

+ 3 - 3
modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/141_multi_dense_vector_max_sim.yml

@@ -3,7 +3,7 @@ setup:
       capabilities:
         - method: POST
           path: /_search
-          capabilities: [ multi_dense_vector_script_max_sim ]
+          capabilities: [ multi_dense_vector_script_max_sim_with_bugfix ]
       test_runner_features: capabilities
       reason: "Support for multi dense vector max-sim functions capability required"
   - skip:
@@ -136,10 +136,10 @@ setup:
   - match: {hits.total: 2}
 
   - match: {hits.hits.0._id: "1"}
-  - close_to: {hits.hits.0._score: {value: 190, error: 0.01}}
+  - close_to: {hits.hits.0._score: {value: 220, error: 0.01}}
 
   - match: {hits.hits.1._id: "3"}
-  - close_to: {hits.hits.1._score: {value: 125, error: 0.01}}
+  - close_to: {hits.hits.1._score: {value: 147, error: 0.01}}
 ---
 "Test max-sim inv hamming scoring":
   - skip:

+ 21 - 21
modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/146_dense_vector_bit_basic.yml

@@ -108,7 +108,7 @@ setup:
       capabilities:
         - method: POST
           path: /_search
-          capabilities: [ byte_float_bit_dot_product ]
+          capabilities: [ byte_float_bit_dot_product_with_bugfix ]
       reason: Capability required to run test
   - do:
       catch: bad_request
@@ -399,7 +399,7 @@ setup:
       capabilities:
         - method: POST
           path: /_search
-          capabilities: [ byte_float_bit_dot_product ]
+          capabilities: [ byte_float_bit_dot_product_with_bugfix ]
       test_runner_features: [capabilities, close_to]
       reason: Capability required to run test
   - do:
@@ -419,13 +419,13 @@ setup:
   - match: { hits.total: 3 }
 
   - match: {hits.hits.0._id: "2"}
-  - close_to: {hits.hits.0._score: {value: 35.999, error: 0.01}}
+  - close_to: {hits.hits.0._score: {value: 33.78, error: 0.01}}
 
   - match: {hits.hits.1._id: "3"}
-  - close_to: {hits.hits.1._score:{value: 27.23, error: 0.01}}
+  - close_to: {hits.hits.1._score:{value: 22.579, error: 0.01}}
 
   - match: {hits.hits.2._id: "1"}
-  - close_to: {hits.hits.2._score: {value: 16.57, error: 0.01}}
+  - close_to: {hits.hits.2._score: {value: 11.919, error: 0.01}}
 
   - do:
       headers:
@@ -444,20 +444,20 @@ setup:
   - match: { hits.total: 3 }
 
   - match: {hits.hits.0._id: "2"}
-  - close_to: {hits.hits.0._score: {value: 35.999, error: 0.01}}
+  - close_to: {hits.hits.0._score: {value: 33.78, error: 0.01}}
 
   - match: {hits.hits.1._id: "3"}
-  - close_to: {hits.hits.1._score:{value: 27.23, error: 0.01}}
+  - close_to: {hits.hits.1._score:{value: 22.579, error: 0.01}}
 
   - match: {hits.hits.2._id: "1"}
-  - close_to: {hits.hits.2._score: {value: 16.57, error: 0.01}}
+  - close_to: {hits.hits.2._score: {value: 11.919, error: 0.01}}
 ---
 "Dot product with byte":
   - requires:
       capabilities:
         - method: POST
           path: /_search
-          capabilities: [ byte_float_bit_dot_product ]
+          capabilities: [ byte_float_bit_dot_product_with_bugfix ]
       test_runner_features: capabilities
       reason: Capability required to run test
   - do:
@@ -476,14 +476,14 @@ setup:
 
   - match: { hits.total: 3 }
 
-  - match: {hits.hits.0._id: "1"}
-  - match: {hits.hits.0._score: 248}
+  - match: {hits.hits.0._id: "3"}
+  - match: {hits.hits.0._score: 415}
 
-  - match: {hits.hits.1._id: "2"}
-  - match: {hits.hits.1._score: 136}
+  - match: {hits.hits.1._id: "1"}
+  - match: {hits.hits.1._score: 168}
 
-  - match: {hits.hits.2._id: "3"}
-  - match: {hits.hits.2._score: 20}
+  - match: {hits.hits.2._id: "2"}
+  - match: {hits.hits.2._score: 126}
 
   - do:
       headers:
@@ -501,11 +501,11 @@ setup:
 
   - match: { hits.total: 3 }
 
-  - match: {hits.hits.0._id: "1"}
-  - match: {hits.hits.0._score: 248}
+  - match: {hits.hits.0._id: "3"}
+  - match: {hits.hits.0._score: 415}
 
-  - match: {hits.hits.1._id: "2"}
-  - match: {hits.hits.1._score: 136}
+  - match: {hits.hits.1._id: "1"}
+  - match: {hits.hits.1._score: 168}
 
-  - match: {hits.hits.2._id: "3"}
-  - match: {hits.hits.2._score: 20}
+  - match: {hits.hits.2._id: "2"}
+  - match: {hits.hits.2._score: 126}

+ 2 - 2
server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java

@@ -27,7 +27,7 @@ public final class SearchCapabilities {
     /** Support synthetic source with `bit` type in `dense_vector` field when `index` is set to `false`. */
     private static final String BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY = "bit_dense_vector_synthetic_source";
     /** Support Byte and Float with Bit dot product. */
-    private static final String BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY = "byte_float_bit_dot_product";
+    private static final String BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY = "byte_float_bit_dot_product_with_bugfix";
     /** Support docvalue_fields parameter for `dense_vector` field. */
     private static final String DENSE_VECTOR_DOCVALUE_FIELDS = "dense_vector_docvalue_fields";
     /** Support kql query. */
@@ -39,7 +39,7 @@ public final class SearchCapabilities {
     /** Support multi-dense-vector script field access. */
     private static final String MULTI_DENSE_VECTOR_SCRIPT_ACCESS = "multi_dense_vector_script_access";
     /** Initial support for multi-dense-vector maxSim functions access. */
-    private static final String MULTI_DENSE_VECTOR_SCRIPT_MAX_SIM = "multi_dense_vector_script_max_sim";
+    private static final String MULTI_DENSE_VECTOR_SCRIPT_MAX_SIM = "multi_dense_vector_script_max_sim_with_bugfix";
 
     private static final String RANDOM_SAMPLER_WITH_SCORED_SUBAGGS = "random_sampler_with_scored_subaggs";
 

+ 1 - 1
server/src/test/java/org/elasticsearch/script/MultiVectorScoreScriptUtilsTests.java

@@ -200,7 +200,7 @@ public class MultiVectorScoreScriptUtilsTests extends ESTestCase {
             function = new MaxSimDotProduct(scoreScript, floatQueryVector, fieldName);
             assertEquals(
                 "maxSimDotProduct result is not equal to the expected value!",
-                0.42f + 0f + 1f - 1f - 0.42f,
+                -1.4f + 0.42f + 0f + 1f - 1f,
                 function.maxSimDotProduct(),
                 0.001
             );

+ 1 - 1
server/src/test/java/org/elasticsearch/script/VectorScoreScriptUtilsTests.java

@@ -263,7 +263,7 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
             function = new DotProduct(scoreScript, floatQueryVector, fieldName);
             assertEquals(
                 "dotProduct result is not equal to the expected value!",
-                0.42f + 0f + 1f - 1f - 0.42f,
+                -1.4f + 0.42f + 0f + 1f - 1f,
                 function.dotProduct(),
                 0.001
             );