Browse Source

[8.x] Add support for bitwise inner-product in painless (#116082) (#116285)

* Add support for bitwise inner-product in painless (#116082)

This adds bitwise inner product to painless. 

The idea here is:

 - For two bit arrays, which we determine to be a byte array whose dimensions match `dense_vector.dim/8`, we simply return bitwise `&`
 - For a stored bit array (remember, with `dense_vector.dim/8` bytes), sum up the provided byte or float array using the bit array as a mask.

This is effectively supporting asynchronous quantization. A prime
example of how this works is:
https://github.com/cohere-ai/BinaryVectorDB

Basically, you do your initial search against the binary space and then
rerank with a differently quantized vector allowing for more information
without additional storage space. 

closes:  https://github.com/elastic/elasticsearch/issues/111232

* removing unnecessary task adjustment

---------

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
Benjamin Trent 11 months ago
parent
commit
616b3908a0

+ 5 - 0
docs/changelog/116082.yaml

@@ -0,0 +1,5 @@
+pr: 116082
+summary: Add support for bitwise inner-product in painless
+area: Vector Search
+type: enhancement
+issues: []

+ 88 - 2
docs/reference/vectors/vector-functions.asciidoc

@@ -16,7 +16,7 @@ This is the list of available vector functions and vector access methods:
 6. <<vector-functions-accessing-vectors,`doc[<field>].vectorValue`>> – returns a vector's value as an array of floats
 7. <<vector-functions-accessing-vectors,`doc[<field>].magnitude`>> – returns a vector's magnitude
 
-NOTE: The `cosineSimilarity` and `dotProduct` functions are not supported for `bit` vectors.
+NOTE: The `cosineSimilarity` function is not supported for `bit` vectors.
 
 NOTE: The recommended way to access dense vectors is through the
 `cosineSimilarity`, `dotProduct`, `l1norm` or `l2norm` functions. Please note
@@ -332,6 +332,92 @@ When using `bit` vectors, not all the vector functions are available. The suppor
 * <<vector-functions-hamming,`hamming`>> – calculates Hamming distance, the sum of the bitwise XOR of the two vectors
 * <<vector-functions-l1,`l1norm`>> – calculates L^1^ distance, this is simply the `hamming` distance
 * <<vector-functions-l2,`l2norm`>> - calculates L^2^ distance, this is the square root of the `hamming` distance
+* <<vector-functions-dot-product,`dotProduct`>> – calculates dot product. When comparing two `bit` vectors,
+this is the sum of the bitwise AND of the two vectors. If providing `float[]` or `byte[]`, who has `dims` number of elements, as a query vector, the `dotProduct` is
+the sum of the floating point values using the stored `bit` vector as a mask.
 
-Currently, the `cosineSimilarity` and `dotProduct` functions are not supported for `bit` vectors.
+Here is an example of using dot-product with bit vectors.
+
+[source,console]
+--------------------------------------------------
+PUT my-index-bit-vectors
+{
+  "mappings": {
+    "properties": {
+      "my_dense_vector": {
+        "type": "dense_vector",
+        "index": false,
+        "element_type": "bit",
+        "dims": 40 <1>
+      }
+    }
+  }
+}
+
+PUT my-index-bit-vectors/_doc/1
+{
+  "my_dense_vector": [8, 5, -15, 1, -7] <2>
+}
+
+PUT my-index-bit-vectors/_doc/2
+{
+  "my_dense_vector": [-1, 115, -3, 4, -128]
+}
+
+PUT my-index-bit-vectors/_doc/3
+{
+  "my_dense_vector": [2, 18, -5, 0, -124]
+}
+
+POST my-index-bit-vectors/_refresh
+--------------------------------------------------
+// TEST[continued]
+<1> The number of dimensions or bits for the `bit` vector.
+<2> This vector represents 5 bytes, or `5 * 8 = 40` bits, which equals the configured dimensions
+
+[source,console]
+--------------------------------------------------
+GET my-index-bit-vectors/_search
+{
+  "query": {
+    "script_score": {
+      "query" : {
+        "match_all": {}
+      },
+      "script": {
+        "source": "dotProduct(params.query_vector, 'my_dense_vector')",
+        "params": {
+          "query_vector": [8, 5, -15, 1, -7] <1>
+        }
+      }
+    }
+  }
+}
+--------------------------------------------------
+// TEST[continued]
+<1> This vector is 40 bits, and thus will compute a bitwise `&` operation with the stored vectors.
+
+[source,console]
+--------------------------------------------------
+GET my-index-bit-vectors/_search
+{
+  "query": {
+    "script_score": {
+      "query" : {
+        "match_all": {}
+      },
+      "script": {
+        "source": "dotProduct(params.query_vector, 'my_dense_vector')",
+        "params": {
+          "query_vector": [0.23, 1.45, 3.67, 4.89, -0.56, 2.34, 3.21, 1.78, -2.45, 0.98, -0.12, 3.45, 4.56, 2.78, 1.23, 0.67, 3.89, 4.12, -2.34, 1.56, 0.78, 3.21, 4.12, 2.45, -1.67, 0.34, -3.45, 4.56, -2.78, 1.23, -0.67, 3.89, -4.34, 2.12, -1.56, 0.78, -3.21, 4.45, 2.12, 1.67] <1>
+        }
+      }
+    }
+  }
+}
+--------------------------------------------------
+// TEST[continued]
+<1> This vector is 40 individual dimensions, and thus will sum the floating point values using the stored `bit` vector as a mask.
+
+Currently, the `cosineSimilarity` function is not supported for `bit` vectors.
 

+ 122 - 0
libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java

@@ -9,13 +9,36 @@
 
 package org.elasticsearch.simdvec;
 
+import org.apache.lucene.util.BitUtil;
+import org.apache.lucene.util.Constants;
 import org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport;
 import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
 
+import java.lang.invoke.MethodHandle;
+import java.lang.invoke.MethodHandles;
+import java.lang.invoke.MethodType;
+
 import static org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport.B_QUERY;
 
 public class ESVectorUtil {
 
+    private static final MethodHandle BIT_COUNT_MH;
+    static {
+        try {
+            // For xorBitCount we stride over the values as either 64-bits (long) or 32-bits (int) at a time.
+            // On ARM Long::bitCount is not vectorized, and therefore produces less than optimal code, when
+            // compared to Integer::bitCount. While Long::bitCount is optimal on x64. See
+            // https://bugs.openjdk.org/browse/JDK-8336000
+            BIT_COUNT_MH = Constants.OS_ARCH.equals("aarch64")
+                ? MethodHandles.lookup()
+                    .findStatic(ESVectorUtil.class, "andBitCountInt", MethodType.methodType(int.class, byte[].class, byte[].class))
+                : MethodHandles.lookup()
+                    .findStatic(ESVectorUtil.class, "andBitCountLong", MethodType.methodType(int.class, byte[].class, byte[].class));
+        } catch (NoSuchMethodException | IllegalAccessException e) {
+            throw new AssertionError(e);
+        }
+    }
+
     private static final ESVectorUtilSupport IMPL = ESVectorizationProvider.getInstance().getVectorUtilSupport();
 
     public static long ipByteBinByte(byte[] q, byte[] d) {
@@ -24,4 +47,103 @@ public class ESVectorUtil {
         }
         return IMPL.ipByteBinByte(q, d);
     }
+
+    /**
+     * Compute the inner product of two vectors, where the query vector is a byte vector and the document vector is a bit vector.
+     * This will return the sum of the query vector values using the document vector as a mask.
+     * @param q the query vector
+     * @param d the document vector
+     * @return the inner product of the two vectors
+     */
+    public static int ipByteBit(byte[] q, byte[] d) {
+        if (q.length != d.length * Byte.SIZE) {
+            throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + Byte.SIZE + " x " + d.length);
+        }
+        int result = 0;
+        // now combine the two vectors, summing the byte dimensions where the bit in d is `1`
+        for (int i = 0; i < d.length; i++) {
+            byte mask = d[i];
+            for (int j = 0; j < Byte.SIZE; j++) {
+                if ((mask & (1 << j)) != 0) {
+                    result += q[i * Byte.SIZE + j];
+                }
+            }
+        }
+        return result;
+    }
+
+    /**
+     * Compute the inner product of two vectors, where the query vector is a float vector and the document vector is a bit vector.
+     * This will return the sum of the query vector values using the document vector as a mask.
+     * @param q the query vector
+     * @param d the document vector
+     * @return the inner product of the two vectors
+     */
+    public static float ipFloatBit(float[] q, byte[] d) {
+        if (q.length != d.length * Byte.SIZE) {
+            throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + Byte.SIZE + " x " + d.length);
+        }
+        float result = 0;
+        for (int i = 0; i < d.length; i++) {
+            byte mask = d[i];
+            for (int j = 0; j < Byte.SIZE; j++) {
+                if ((mask & (1 << j)) != 0) {
+                    result += q[i * Byte.SIZE + j];
+                }
+            }
+        }
+        return result;
+    }
+
+    /**
+     * AND bit count computed over signed bytes.
+     * Copied from Lucene's XOR implementation
+     * @param a bytes containing a vector
+     * @param b bytes containing another vector, of the same dimension
+     * @return the value of the AND bit count of the two vectors
+     */
+    public static int andBitCount(byte[] a, byte[] b) {
+        if (a.length != b.length) {
+            throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
+        }
+        try {
+            return (int) BIT_COUNT_MH.invokeExact(a, b);
+        } catch (Throwable e) {
+            if (e instanceof Error err) {
+                throw err;
+            } else if (e instanceof RuntimeException re) {
+                throw re;
+            } else {
+                throw new RuntimeException(e);
+            }
+        }
+    }
+
+    /** AND bit count striding over 4 bytes at a time. */
+    static int andBitCountInt(byte[] a, byte[] b) {
+        int distance = 0, i = 0;
+        // limit to number of int values in the array iterating by int byte views
+        for (final int upperBound = a.length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
+            distance += Integer.bitCount((int) BitUtil.VH_NATIVE_INT.get(a, i) & (int) BitUtil.VH_NATIVE_INT.get(b, i));
+        }
+        // tail:
+        for (; i < a.length; i++) {
+            distance += Integer.bitCount((a[i] & b[i]) & 0xFF);
+        }
+        return distance;
+    }
+
+    /** AND bit count striding over 8 bytes at a time**/
+    static int andBitCountLong(byte[] a, byte[] b) {
+        int distance = 0, i = 0;
+        // limit to number of long values in the array iterating by long byte views
+        for (final int upperBound = a.length & -Long.BYTES; i < upperBound; i += Long.BYTES) {
+            distance += Long.bitCount((long) BitUtil.VH_NATIVE_LONG.get(a, i) & (long) BitUtil.VH_NATIVE_LONG.get(b, i));
+        }
+        // tail:
+        for (; i < a.length; i++) {
+            distance += Integer.bitCount((a[i] & b[i]) & 0xFF);
+        }
+        return distance;
+    }
 }

+ 29 - 0
libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java

@@ -21,6 +21,10 @@ public class ESVectorUtilTests extends BaseVectorizationTests {
     static final ESVectorizationProvider defaultedProvider = BaseVectorizationTests.defaultProvider();
     static final ESVectorizationProvider defOrPanamaProvider = BaseVectorizationTests.maybePanamaProvider();
 
+    public void testBitAndCount() {
+        testBasicBitAndImpl(ESVectorUtil::andBitCountLong);
+    }
+
     public void testIpByteBinInvariants() {
         int iterations = atLeast(10);
         for (int i = 0; i < iterations; i++) {
@@ -41,6 +45,23 @@ public class ESVectorUtilTests extends BaseVectorizationTests {
         long apply(byte[] q, byte[] d);
     }
 
+    interface BitOps {
+        long apply(byte[] q, byte[] d);
+    }
+
+    void testBasicBitAndImpl(BitOps bitAnd) {
+        assertEquals(0, bitAnd.apply(new byte[] { 0 }, new byte[] { 0 }));
+        assertEquals(0, bitAnd.apply(new byte[] { 1 }, new byte[] { 0 }));
+        assertEquals(0, bitAnd.apply(new byte[] { 0 }, new byte[] { 1 }));
+        assertEquals(1, bitAnd.apply(new byte[] { 1 }, new byte[] { 1 }));
+        byte[] a = new byte[31];
+        byte[] b = new byte[31];
+        random().nextBytes(a);
+        random().nextBytes(b);
+        int expected = scalarBitAnd(a, b);
+        assertEquals(expected, bitAnd.apply(a, b));
+    }
+
     void testBasicIpByteBinImpl(IpByteBin ipByteBinFunc) {
         assertEquals(15L, ipByteBinFunc.apply(new byte[] { 1, 1, 1, 1 }, new byte[] { 1 }));
         assertEquals(30L, ipByteBinFunc.apply(new byte[] { 1, 2, 1, 2, 1, 2, 1, 2 }, new byte[] { 1, 2 }));
@@ -115,6 +136,14 @@ public class ESVectorUtilTests extends BaseVectorizationTests {
         return res;
     }
 
+    static int scalarBitAnd(byte[] a, byte[] b) {
+        int res = 0;
+        for (int i = 0; i < a.length; i++) {
+            res += Integer.bitCount((a[i] & b[i]) & 0xFF);
+        }
+        return res;
+    }
+
     public static int popcount(byte[] a, int aOffset, byte[] b, int length) {
         int res = 0;
         for (int j = 0; j < length; j++) {

+ 1 - 1
modules/lang-painless/build.gradle

@@ -53,7 +53,7 @@ tasks.named("dependencyLicenses").configure {
 restResources {
     restApi {
         include '_common', 'cluster', 'nodes', 'indices', 'index', 'search', 'get', 'bulk', 'update',
-                'scripts_painless_execute', 'put_script', 'delete_script'
+                'scripts_painless_execute', 'put_script', 'delete_script', 'capabilities'
     }
 }
 

+ 123 - 2
modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/146_dense_vector_bit_basic.yml

@@ -101,9 +101,15 @@ setup:
 
   - match: {hits.hits.2._id: "3"}
   - close_to: {hits.hits.2._score: {value: 3.4641016, error: 0.01}}
-
 ---
 "Dot Product is not supported":
+  - skip:
+      features: [capabilities]
+      capabilities:
+        - method: POST
+          path: /_search
+          capabilities: [ byte_float_bit_dot_product ]
+      reason: Capability required to run test
   - do:
       catch: bad_request
       headers:
@@ -131,7 +137,6 @@ setup:
                 source: "dotProduct(params.query_vector, 'vector')"
                 params:
                   query_vector: "006ff30e84"
-
 ---
 "Cosine Similarity is not supported":
   - do:
@@ -388,3 +393,119 @@ setup:
 
   - match: {hits.hits.2._id: "3"}
   - match: {hits.hits.2._score: 11.0}
+---
+"Dot product with float":
+  - requires:
+      capabilities:
+        - method: POST
+          path: /_search
+          capabilities: [ byte_float_bit_dot_product ]
+      test_runner_features: [capabilities, close_to]
+      reason: Capability required to run test
+  - do:
+      headers:
+        Content-Type: application/json
+      search:
+        rest_total_hits_as_int: true
+        body:
+          query:
+            script_score:
+              query: { match_all: { } }
+              script:
+                source: "dotProduct(params.query_vector, 'vector')"
+                params:
+                  query_vector: [0.23, 1.45, 3.67, 4.89, -0.56, 2.34, 3.21, 1.78, -2.45, 0.98, -0.12, 3.45, 4.56, 2.78, 1.23, 0.67, 3.89, 4.12, -2.34, 1.56, 0.78, 3.21, 4.12, 2.45, -1.67, 0.34, -3.45, 4.56, -2.78, 1.23, -0.67, 3.89, -4.34, 2.12, -1.56, 0.78, -3.21, 4.45, 2.12, 1.67]
+
+  - match: { hits.total: 3 }
+
+  - match: {hits.hits.0._id: "2"}
+  - close_to: {hits.hits.0._score: {value: 35.999, error: 0.01}}
+
+  - match: {hits.hits.1._id: "3"}
+  - close_to: {hits.hits.1._score:{value: 27.23, error: 0.01}}
+
+  - match: {hits.hits.2._id: "1"}
+  - close_to: {hits.hits.2._score: {value: 16.57, error: 0.01}}
+
+  - do:
+      headers:
+        Content-Type: application/json
+      search:
+        rest_total_hits_as_int: true
+        body:
+          query:
+            script_score:
+              query: { match_all: { } }
+              script:
+                source: "dotProduct(params.query_vector, 'indexed_vector')"
+                params:
+                  query_vector: [0.23, 1.45, 3.67, 4.89, -0.56, 2.34, 3.21, 1.78, -2.45, 0.98, -0.12, 3.45, 4.56, 2.78, 1.23, 0.67, 3.89, 4.12, -2.34, 1.56, 0.78, 3.21, 4.12, 2.45, -1.67, 0.34, -3.45, 4.56, -2.78, 1.23, -0.67, 3.89, -4.34, 2.12, -1.56, 0.78, -3.21, 4.45, 2.12, 1.67]
+
+  - match: { hits.total: 3 }
+
+  - match: {hits.hits.0._id: "2"}
+  - close_to: {hits.hits.0._score: {value: 35.999, error: 0.01}}
+
+  - match: {hits.hits.1._id: "3"}
+  - close_to: {hits.hits.1._score:{value: 27.23, error: 0.01}}
+
+  - match: {hits.hits.2._id: "1"}
+  - close_to: {hits.hits.2._score: {value: 16.57, error: 0.01}}
+---
+"Dot product with byte":
+  - requires:
+      capabilities:
+        - method: POST
+          path: /_search
+          capabilities: [ byte_float_bit_dot_product ]
+      test_runner_features: capabilities
+      reason: Capability required to run test
+  - do:
+      headers:
+        Content-Type: application/json
+      search:
+        rest_total_hits_as_int: true
+        body:
+          query:
+            script_score:
+              query: { match_all: { } }
+              script:
+                source: "dotProduct(params.query_vector, 'vector')"
+                params:
+                  query_vector: [12, -34, 56, -78, 90, 12, 34, -56, 78, -90, 23, -45, 67, -89, 12, 34, 56, 78, 90, -12, 34, -56, 78, -90, 23, -45, 67, -89, 12, -34, 56, -78, 90, -12, 34, -56, 78, 90, 23, -45]
+
+  - match: { hits.total: 3 }
+
+  - match: {hits.hits.0._id: "1"}
+  - match: {hits.hits.0._score: 248}
+
+  - match: {hits.hits.1._id: "2"}
+  - match: {hits.hits.1._score: 136}
+
+  - match: {hits.hits.2._id: "3"}
+  - match: {hits.hits.2._score: 20}
+
+  - do:
+      headers:
+        Content-Type: application/json
+      search:
+        rest_total_hits_as_int: true
+        body:
+          query:
+            script_score:
+              query: { match_all: { } }
+              script:
+                source: "dotProduct(params.query_vector, 'indexed_vector')"
+                params:
+                  query_vector: [12, -34, 56, -78, 90, 12, 34, -56, 78, -90, 23, -45, 67, -89, 12, 34, 56, 78, 90, -12, 34, -56, 78, -90, 23, -45, 67, -89, 12, -34, 56, -78, 90, -12, 34, -56, 78, 90, 23, -45]
+
+  - match: { hits.total: 3 }
+
+  - match: {hits.hits.0._id: "1"}
+  - match: {hits.hits.0._score: 248}
+
+  - match: {hits.hits.1._id: "2"}
+  - match: {hits.hits.1._score: 136}
+
+  - match: {hits.hits.2._id: "3"}
+  - match: {hits.hits.2._score: 20}

+ 4 - 1
server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java

@@ -22,9 +22,12 @@ public final class SearchCapabilities {
     private static final String RANGE_REGEX_INTERVAL_QUERY_CAPABILITY = "range_regexp_interval_queries";
     /** Support synthetic source with `bit` type in `dense_vector` field when `index` is set to `false`. */
     private static final String BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY = "bit_dense_vector_synthetic_source";
+    /** Support Byte and Float with Bit dot product. */
+    private static final String BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY = "byte_float_bit_dot_product";
 
     public static final Set<String> CAPABILITIES = Set.of(
         RANGE_REGEX_INTERVAL_QUERY_CAPABILITY,
-        BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY
+        BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY,
+        BYTE_FLOAT_BIT_DOT_PRODUCT_CAPABILITY
     );
 }

+ 91 - 1
server/src/main/java/org/elasticsearch/script/VectorScoreScriptUtils.java

@@ -307,6 +307,87 @@ public class VectorScoreScriptUtils {
         double dotProduct();
     }
 
+    public static class BitDotProduct extends DenseVectorFunction implements DotProductInterface {
+        private final byte[] byteQueryVector;
+        private final float[] floatQueryVector;
+
+        public BitDotProduct(ScoreScript scoreScript, DenseVectorDocValuesField field, byte[] queryVector) {
+            super(scoreScript, field);
+            if (field.getElementType() != DenseVectorFieldMapper.ElementType.BIT) {
+                throw new IllegalArgumentException("cannot calculate bit dot product for non-bit vectors");
+            }
+            int fieldDims = field.get().getDims();
+            if (fieldDims != queryVector.length * Byte.SIZE && fieldDims != queryVector.length) {
+                throw new IllegalArgumentException(
+                    "The query vector has an incorrect number of dimensions. Must be ["
+                        + fieldDims / 8
+                        + "] for bitwise operations, or ["
+                        + fieldDims
+                        + "] for byte wise operations: provided ["
+                        + queryVector.length
+                        + "]."
+                );
+            }
+            this.byteQueryVector = queryVector;
+            this.floatQueryVector = null;
+        }
+
+        public BitDotProduct(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
+            super(scoreScript, field);
+            if (field.getElementType() != DenseVectorFieldMapper.ElementType.BIT) {
+                throw new IllegalArgumentException("cannot calculate bit dot product for non-bit vectors");
+            }
+            float[] floatQueryVector = new float[queryVector.size()];
+            byte[] byteQueryVector = new byte[queryVector.size()];
+            boolean isFloat = false;
+            for (int i = 0; i < queryVector.size(); i++) {
+                Number number = queryVector.get(i);
+                floatQueryVector[i] = number.floatValue();
+                byteQueryVector[i] = number.byteValue();
+                if (isFloat
+                    || floatQueryVector[i] % 1.0f != 0.0f
+                    || floatQueryVector[i] < Byte.MIN_VALUE
+                    || floatQueryVector[i] > Byte.MAX_VALUE) {
+                    isFloat = true;
+                }
+            }
+            int fieldDims = field.get().getDims();
+            if (isFloat) {
+                this.floatQueryVector = floatQueryVector;
+                this.byteQueryVector = null;
+                if (fieldDims != floatQueryVector.length) {
+                    throw new IllegalArgumentException(
+                        "The query vector has an incorrect number of dimensions. Must be ["
+                            + fieldDims
+                            + "] for float wise operations: provided ["
+                            + floatQueryVector.length
+                            + "]."
+                    );
+                }
+            } else {
+                this.floatQueryVector = null;
+                this.byteQueryVector = byteQueryVector;
+                if (fieldDims != byteQueryVector.length * Byte.SIZE && fieldDims != byteQueryVector.length) {
+                    throw new IllegalArgumentException(
+                        "The query vector has an incorrect number of dimensions. Must be ["
+                            + fieldDims / 8
+                            + "] for bitwise operations, or ["
+                            + fieldDims
+                            + "] for byte wise operations: provided ["
+                            + byteQueryVector.length
+                            + "]."
+                    );
+                }
+            }
+        }
+
+        @Override
+        public double dotProduct() {
+            setNextVector();
+            return byteQueryVector != null ? field.get().dotProduct(byteQueryVector) : field.get().dotProduct(floatQueryVector);
+        }
+    }
+
     public static class ByteDotProduct extends ByteDenseVectorFunction implements DotProductInterface {
 
         public ByteDotProduct(ScoreScript scoreScript, DenseVectorDocValuesField field, List<Number> queryVector) {
@@ -343,7 +424,16 @@ public class VectorScoreScriptUtils {
         public DotProduct(ScoreScript scoreScript, Object queryVector, String fieldName) {
             DenseVectorDocValuesField field = (DenseVectorDocValuesField) scoreScript.field(fieldName);
             function = switch (field.getElementType()) {
-                case BYTE, BIT -> {
+                case BIT -> {
+                    if (queryVector instanceof List) {
+                        yield new BitDotProduct(scoreScript, field, (List<Number>) queryVector);
+                    } else if (queryVector instanceof String s) {
+                        byte[] parsedQueryVector = HexFormat.of().parseHex(s);
+                        yield new BitDotProduct(scoreScript, field, parsedQueryVector);
+                    }
+                    throw new IllegalArgumentException("Unsupported input object for bit vectors: " + queryVector.getClass().getName());
+                }
+                case BYTE -> {
                     if (queryVector instanceof List) {
                         yield new ByteDotProduct(scoreScript, field, (List<Number>) queryVector);
                     } else if (queryVector instanceof String s) {

+ 10 - 2
server/src/main/java/org/elasticsearch/script/field/vectors/BitBinaryDenseVector.java

@@ -13,6 +13,10 @@ import org.apache.lucene.util.BytesRef;
 
 import java.util.List;
 
+import static org.elasticsearch.simdvec.ESVectorUtil.andBitCount;
+import static org.elasticsearch.simdvec.ESVectorUtil.ipByteBit;
+import static org.elasticsearch.simdvec.ESVectorUtil.ipFloatBit;
+
 public class BitBinaryDenseVector extends ByteBinaryDenseVector {
 
     public BitBinaryDenseVector(byte[] vectorValue, BytesRef docVector, int dims) {
@@ -54,7 +58,11 @@ public class BitBinaryDenseVector extends ByteBinaryDenseVector {
 
     @Override
     public int dotProduct(byte[] queryVector) {
-        throw new UnsupportedOperationException("dotProduct is not supported for bit vectors.");
+        if (queryVector.length == vectorValue.length) {
+            // assume that the query vector is a bit vector and do a bitwise AND
+            return andBitCount(vectorValue, queryVector);
+        }
+        return ipByteBit(queryVector, vectorValue);
     }
 
     @Override
@@ -79,7 +87,7 @@ public class BitBinaryDenseVector extends ByteBinaryDenseVector {
 
     @Override
     public double dotProduct(float[] queryVector) {
-        throw new UnsupportedOperationException("dotProduct is not supported for bit vectors.");
+        return ipFloatBit(queryVector, vectorValue);
     }
 
     @Override

+ 10 - 2
server/src/main/java/org/elasticsearch/script/field/vectors/BitKnnDenseVector.java

@@ -11,6 +11,10 @@ package org.elasticsearch.script.field.vectors;
 
 import java.util.List;
 
+import static org.elasticsearch.simdvec.ESVectorUtil.andBitCount;
+import static org.elasticsearch.simdvec.ESVectorUtil.ipByteBit;
+import static org.elasticsearch.simdvec.ESVectorUtil.ipFloatBit;
+
 public class BitKnnDenseVector extends ByteKnnDenseVector {
 
     public BitKnnDenseVector(byte[] vector) {
@@ -61,7 +65,11 @@ public class BitKnnDenseVector extends ByteKnnDenseVector {
 
     @Override
     public int dotProduct(byte[] queryVector) {
-        throw new UnsupportedOperationException("dotProduct is not supported for bit vectors.");
+        if (queryVector.length == docVector.length) {
+            // assume that the query vector is a bit vector and do a bitwise AND
+            return andBitCount(docVector, queryVector);
+        }
+        return ipByteBit(queryVector, docVector);
     }
 
     @Override
@@ -86,7 +94,7 @@ public class BitKnnDenseVector extends ByteKnnDenseVector {
 
     @Override
     public double dotProduct(float[] queryVector) {
-        throw new UnsupportedOperationException("dotProduct is not supported for bit vectors.");
+        return ipFloatBit(queryVector, docVector);
     }
 
     @Override

+ 1 - 1
server/src/main/java/org/elasticsearch/script/field/vectors/ByteBinaryDenseVector.java

@@ -21,7 +21,7 @@ public class ByteBinaryDenseVector implements DenseVector {
     public static final int MAGNITUDE_BYTES = 4;
 
     private final BytesRef docVector;
-    private final byte[] vectorValue;
+    protected final byte[] vectorValue;
     protected final int dims;
 
     private float[] floatDocVector;

+ 7 - 3
server/src/test/java/org/elasticsearch/index/mapper/vectors/BinaryDenseVectorScriptDocValuesTests.java

@@ -236,15 +236,19 @@ public class BinaryDenseVectorScriptDocValuesTests extends ESTestCase {
     }
 
     public static BytesRef mockEncodeDenseVector(float[] values, ElementType elementType, IndexVersion indexVersion) {
+        int dims = values.length;
+        if (elementType == ElementType.BIT) {
+            dims *= Byte.SIZE;
+        }
         int numBytes = indexVersion.onOrAfter(DenseVectorFieldMapper.MAGNITUDE_STORED_INDEX_VERSION)
-            ? elementType.getNumBytes(values.length) + DenseVectorFieldMapper.MAGNITUDE_BYTES
-            : elementType.getNumBytes(values.length);
+            ? elementType.getNumBytes(dims) + DenseVectorFieldMapper.MAGNITUDE_BYTES
+            : elementType.getNumBytes(dims);
         double dotProduct = 0f;
         ByteBuffer byteBuffer = elementType.createByteBuffer(indexVersion, numBytes);
         for (float value : values) {
             if (elementType == ElementType.FLOAT) {
                 byteBuffer.putFloat(value);
-            } else if (elementType == ElementType.BYTE) {
+            } else if (elementType == ElementType.BYTE || elementType == ElementType.BIT) {
                 byteBuffer.put((byte) value);
             } else {
                 throw new IllegalStateException("unknown element_type [" + elementType + "]");

+ 57 - 0
server/src/test/java/org/elasticsearch/script/VectorScoreScriptUtilsTests.java

@@ -20,6 +20,8 @@ import org.elasticsearch.script.VectorScoreScriptUtils.Hamming;
 import org.elasticsearch.script.VectorScoreScriptUtils.L1Norm;
 import org.elasticsearch.script.VectorScoreScriptUtils.L2Norm;
 import org.elasticsearch.script.field.vectors.BinaryDenseVectorDocValuesField;
+import org.elasticsearch.script.field.vectors.BitBinaryDenseVectorDocValuesField;
+import org.elasticsearch.script.field.vectors.BitKnnDenseVectorDocValuesField;
 import org.elasticsearch.script.field.vectors.ByteBinaryDenseVectorDocValuesField;
 import org.elasticsearch.script.field.vectors.ByteKnnDenseVectorDocValuesField;
 import org.elasticsearch.script.field.vectors.DenseVectorDocValuesField;
@@ -229,6 +231,61 @@ public class VectorScoreScriptUtilsTests extends ESTestCase {
         }
     }
 
+    public void testBitVectorClassBindingsDotProduct() throws IOException {
+        String fieldName = "vector";
+        int dims = 8;
+        float[] docVector = new float[] { 124 };
+        // 124 in binary is b01111100
+        List<Number> queryVector = Arrays.asList((byte) 1, (byte) 125, (byte) -12, (byte) 2, (byte) 4, (byte) 1, (byte) 125, (byte) -12);
+        List<Number> floatQueryVector = Arrays.asList(1.4f, -1.4f, 0.42f, 0.0f, 1f, -1f, -0.42f, 1.2f);
+        List<Number> invalidQueryVector = Arrays.asList((byte) 1, (byte) 1);
+        String hexidecimalString = HexFormat.of().formatHex(new byte[] { 124 });
+
+        List<DenseVectorDocValuesField> fields = List.of(
+            new BitBinaryDenseVectorDocValuesField(
+                BinaryDenseVectorScriptDocValuesTests.wrap(new float[][] { docVector }, ElementType.BIT, IndexVersion.current()),
+                "test",
+                ElementType.BIT,
+                dims
+            ),
+            new BitKnnDenseVectorDocValuesField(KnnDenseVectorScriptDocValuesTests.wrapBytes(new float[][] { docVector }), "test", dims)
+        );
+        for (DenseVectorDocValuesField field : fields) {
+            field.setNextDocId(0);
+
+            ScoreScript scoreScript = mock(ScoreScript.class);
+            when(scoreScript.field(fieldName)).thenAnswer(mock -> field);
+
+            // Test cosine similarity explicitly, as it must perform special logic on top of the doc values
+            DotProduct function = new DotProduct(scoreScript, queryVector, fieldName);
+            assertEquals("dotProduct result is not equal to the expected value!", -12 + 2 + 4 + 1 + 125, function.dotProduct(), 0.001);
+
+            function = new DotProduct(scoreScript, floatQueryVector, fieldName);
+            assertEquals(
+                "dotProduct result is not equal to the expected value!",
+                0.42f + 0f + 1f - 1f - 0.42f,
+                function.dotProduct(),
+                0.001
+            );
+
+            function = new DotProduct(scoreScript, hexidecimalString, fieldName);
+            assertEquals("dotProduct result is not equal to the expected value!", Integer.bitCount(124), function.dotProduct(), 0.0);
+
+            // Check each function rejects query vectors with the wrong dimension
+            IllegalArgumentException e = expectThrows(
+                IllegalArgumentException.class,
+                () -> new DotProduct(scoreScript, invalidQueryVector, fieldName)
+            );
+            assertThat(
+                e.getMessage(),
+                containsString(
+                    "query vector has an incorrect number of dimensions. "
+                        + "Must be [1] for bitwise operations, or [8] for byte wise operations: provided [2]."
+                )
+            );
+        }
+    }
+
     public void testByteVsFloatSimilarity() throws IOException {
         int dims = 5;
         float[] docVector = new float[] { 1f, 127f, -128f, 5f, -10f };