Browse Source

Improve brute force vector search speed by using Lucene functions (#96617)

Lucene has integrated hardware accelerated vector calculations. Meaning,
calculations like `dot_product` can be much faster when using the Lucene
defined functions.

When a `dense_vector` is indexed, we already support this. However, when
`index: false` we store float vectors as binary fields in Lucene and
decode them ourselves. Meaning, we don't use the underlying Lucene
structures or functions.

To take advantage of the large performance boost, this PR refactors the
binary vector values in the following way:

 - Eagerly decode the binary blobs when iterated
 - Call the Lucene defined VectorUtil functions when possible

related to: https://github.com/elastic/elasticsearch/issues/96370
Benjamin Trent 2 years ago
parent
commit
763cd149fc

+ 16 - 12
benchmarks/src/main/java/org/elasticsearch/benchmark/vector/DistanceFunctionBenchmark.java

@@ -110,19 +110,20 @@ public class DistanceFunctionBenchmark {
     private abstract static class BinaryFloatBenchmarkFunction extends BenchmarkFunction {
 
         final BytesRef docVector;
+        final float[] docFloatVector;
         final float[] queryVector;
 
         private BinaryFloatBenchmarkFunction(int dims, boolean normalize) {
             super(dims);
 
-            float[] docVector = new float[dims];
+            docFloatVector = new float[dims];
             queryVector = new float[dims];
 
             float docMagnitude = 0f;
             float queryMagnitude = 0f;
 
             for (int i = 0; i < dims; ++i) {
-                docVector[i] = (float) (dims - i);
+                docFloatVector[i] = (float) (dims - i);
                 queryVector[i] = (float) i;
 
                 docMagnitude += (float) (dims - i);
@@ -136,11 +137,11 @@ public class DistanceFunctionBenchmark {
 
             for (int i = 0; i < dims; ++i) {
                 if (normalize) {
-                    docVector[i] /= docMagnitude;
+                    docFloatVector[i] /= docMagnitude;
                     queryVector[i] /= queryMagnitude;
                 }
 
-                byteBuffer.putFloat(docVector[i]);
+                byteBuffer.putFloat(docFloatVector[i]);
             }
 
             byteBuffer.putFloat(docMagnitude);
@@ -178,6 +179,7 @@ public class DistanceFunctionBenchmark {
     private abstract static class BinaryByteBenchmarkFunction extends BenchmarkFunction {
 
         final BytesRef docVector;
+        final byte[] vectorValue;
         final byte[] queryVector;
 
         final float queryMagnitude;
@@ -187,12 +189,14 @@ public class DistanceFunctionBenchmark {
 
             ByteBuffer docVector = ByteBuffer.allocate(dims + 4);
             queryVector = new byte[dims];
+            vectorValue = new byte[dims];
 
             float docMagnitude = 0f;
             float queryMagnitude = 0f;
 
             for (int i = 0; i < dims; ++i) {
                 docVector.put((byte) (dims - i));
+                vectorValue[i] = (byte) (dims - i);
                 queryVector[i] = (byte) i;
 
                 docMagnitude += (float) (dims - i);
@@ -238,7 +242,7 @@ public class DistanceFunctionBenchmark {
 
         @Override
         public void execute(Consumer<Object> consumer) {
-            new BinaryDenseVector(docVector, dims, Version.CURRENT).dotProduct(queryVector);
+            new BinaryDenseVector(docFloatVector, docVector, dims, Version.CURRENT).dotProduct(queryVector);
         }
     }
 
@@ -250,7 +254,7 @@ public class DistanceFunctionBenchmark {
 
         @Override
         public void execute(Consumer<Object> consumer) {
-            new ByteBinaryDenseVector(docVector, dims).dotProduct(queryVector);
+            new ByteBinaryDenseVector(vectorValue, docVector, dims).dotProduct(queryVector);
         }
     }
 
@@ -286,7 +290,7 @@ public class DistanceFunctionBenchmark {
 
         @Override
         public void execute(Consumer<Object> consumer) {
-            new BinaryDenseVector(docVector, dims, Version.CURRENT).cosineSimilarity(queryVector, false);
+            new BinaryDenseVector(docFloatVector, docVector, dims, Version.CURRENT).cosineSimilarity(queryVector, false);
         }
     }
 
@@ -298,7 +302,7 @@ public class DistanceFunctionBenchmark {
 
         @Override
         public void execute(Consumer<Object> consumer) {
-            new ByteBinaryDenseVector(docVector, dims).cosineSimilarity(queryVector, queryMagnitude);
+            new ByteBinaryDenseVector(vectorValue, docVector, dims).cosineSimilarity(queryVector, queryMagnitude);
         }
     }
 
@@ -334,7 +338,7 @@ public class DistanceFunctionBenchmark {
 
         @Override
         public void execute(Consumer<Object> consumer) {
-            new BinaryDenseVector(docVector, dims, Version.CURRENT).l1Norm(queryVector);
+            new BinaryDenseVector(docFloatVector, docVector, dims, Version.CURRENT).l1Norm(queryVector);
         }
     }
 
@@ -346,7 +350,7 @@ public class DistanceFunctionBenchmark {
 
         @Override
         public void execute(Consumer<Object> consumer) {
-            new ByteBinaryDenseVector(docVector, dims).l1Norm(queryVector);
+            new ByteBinaryDenseVector(vectorValue, docVector, dims).l1Norm(queryVector);
         }
     }
 
@@ -382,7 +386,7 @@ public class DistanceFunctionBenchmark {
 
         @Override
         public void execute(Consumer<Object> consumer) {
-            new BinaryDenseVector(docVector, dims, Version.CURRENT).l1Norm(queryVector);
+            new BinaryDenseVector(docFloatVector, docVector, dims, Version.CURRENT).l1Norm(queryVector);
         }
     }
 
@@ -394,7 +398,7 @@ public class DistanceFunctionBenchmark {
 
         @Override
         public void execute(Consumer<Object> consumer) {
-            consumer.accept(new ByteBinaryDenseVector(docVector, dims).l2Norm(queryVector));
+            consumer.accept(new ByteBinaryDenseVector(vectorValue, docVector, dims).l2Norm(queryVector));
         }
     }
 

+ 5 - 0
docs/changelog/96617.yaml

@@ -0,0 +1,5 @@
+pr: 96617
+summary: Improve brute force vector search speed by using Lucene functions
+area: Search
+type: enhancement
+issues: []

+ 6 - 9
server/src/main/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoder.java

@@ -36,26 +36,23 @@ public final class VectorEncoderDecoder {
     /**
      * Calculates vector magnitude
      */
-    private static float calculateMagnitude(Version indexVersion, BytesRef vectorBR) {
-        final int length = denseVectorLength(indexVersion, vectorBR);
-        ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
+    private static float calculateMagnitude(float[] decodedVector) {
         double magnitude = 0.0f;
-        for (int i = 0; i < length; i++) {
-            float value = byteBuffer.getFloat();
-            magnitude += value * value;
+        for (int i = 0; i < decodedVector.length; i++) {
+            magnitude += decodedVector[i] * decodedVector[i];
         }
         magnitude = Math.sqrt(magnitude);
         return (float) magnitude;
     }
 
-    public static float getMagnitude(Version indexVersion, BytesRef vectorBR) {
+    public static float getMagnitude(Version indexVersion, BytesRef vectorBR, float[] decodedVector) {
         if (vectorBR == null) {
             throw new IllegalArgumentException(DenseVectorScriptDocValues.MISSING_VECTOR_FIELD_MESSAGE);
         }
         if (indexVersion.onOrAfter(Version.V_7_5_0)) {
             return decodeMagnitude(indexVersion, vectorBR);
         } else {
-            return calculateMagnitude(indexVersion, vectorBR);
+            return calculateMagnitude(decodedVector);
         }
     }
 
@@ -70,7 +67,7 @@ public final class VectorEncoderDecoder {
         }
         ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
         for (int dim = 0; dim < vector.length; dim++) {
-            vector[dim] = byteBuffer.getFloat();
+            vector[dim] = byteBuffer.getFloat((dim * Float.BYTES) + vectorBR.offset);
         }
     }
 

+ 17 - 42
server/src/main/java/org/elasticsearch/script/field/vectors/BinaryDenseVector.java

@@ -9,21 +9,23 @@
 package org.elasticsearch.script.field.vectors;
 
 import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.VectorUtil;
 import org.elasticsearch.Version;
 import org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder;
 
-import java.nio.ByteBuffer;
 import java.util.List;
 
 public class BinaryDenseVector implements DenseVector {
 
-    protected final BytesRef docVector;
-    protected final int dims;
-    protected final Version indexVersion;
+    private final BytesRef docVector;
 
-    protected float[] decodedDocVector;
+    private final int dims;
+    private final Version indexVersion;
 
-    public BinaryDenseVector(BytesRef docVector, int dims, Version indexVersion) {
+    private final float[] decodedDocVector;
+
+    public BinaryDenseVector(float[] decodedDocVector, BytesRef docVector, int dims, Version indexVersion) {
+        this.decodedDocVector = decodedDocVector;
         this.docVector = docVector;
         this.indexVersion = indexVersion;
         this.dims = dims;
@@ -31,16 +33,12 @@ public class BinaryDenseVector implements DenseVector {
 
     @Override
     public float[] getVector() {
-        if (decodedDocVector == null) {
-            decodedDocVector = new float[dims];
-            VectorEncoderDecoder.decodeDenseVector(docVector, decodedDocVector);
-        }
         return decodedDocVector;
     }
 
     @Override
     public float getMagnitude() {
-        return VectorEncoderDecoder.getMagnitude(indexVersion, docVector);
+        return VectorEncoderDecoder.getMagnitude(indexVersion, docVector, decodedDocVector);
     }
 
     @Override
@@ -50,22 +48,14 @@ public class BinaryDenseVector implements DenseVector {
 
     @Override
     public double dotProduct(float[] queryVector) {
-        ByteBuffer byteBuffer = wrap(docVector);
-
-        double dotProduct = 0;
-        for (float v : queryVector) {
-            dotProduct += byteBuffer.getFloat() * v;
-        }
-        return dotProduct;
+        return VectorUtil.dotProduct(decodedDocVector, queryVector);
     }
 
     @Override
     public double dotProduct(List<Number> queryVector) {
-        ByteBuffer byteBuffer = wrap(docVector);
-
         double dotProduct = 0;
         for (int i = 0; i < queryVector.size(); i++) {
-            dotProduct += byteBuffer.getFloat() * queryVector.get(i).floatValue();
+            dotProduct += decodedDocVector[i] * queryVector.get(i).floatValue();
         }
         return dotProduct;
     }
@@ -77,22 +67,18 @@ public class BinaryDenseVector implements DenseVector {
 
     @Override
     public double l1Norm(float[] queryVector) {
-        ByteBuffer byteBuffer = wrap(docVector);
-
         double l1norm = 0;
-        for (float v : queryVector) {
-            l1norm += Math.abs(v - byteBuffer.getFloat());
+        for (int i = 0; i < queryVector.length; i++) {
+            l1norm += Math.abs(queryVector[i] - decodedDocVector[i]);
         }
         return l1norm;
     }
 
     @Override
     public double l1Norm(List<Number> queryVector) {
-        ByteBuffer byteBuffer = wrap(docVector);
-
         double l1norm = 0;
         for (int i = 0; i < queryVector.size(); i++) {
-            l1norm += Math.abs(queryVector.get(i).floatValue() - byteBuffer.getFloat());
+            l1norm += Math.abs(queryVector.get(i).floatValue() - decodedDocVector[i]);
         }
         return l1norm;
     }
@@ -104,21 +90,14 @@ public class BinaryDenseVector implements DenseVector {
 
     @Override
     public double l2Norm(float[] queryVector) {
-        ByteBuffer byteBuffer = wrap(docVector);
-        double l2norm = 0;
-        for (float queryValue : queryVector) {
-            double diff = byteBuffer.getFloat() - queryValue;
-            l2norm += diff * diff;
-        }
-        return Math.sqrt(l2norm);
+        return Math.sqrt(VectorUtil.squareDistance(queryVector, decodedDocVector));
     }
 
     @Override
     public double l2Norm(List<Number> queryVector) {
-        ByteBuffer byteBuffer = wrap(docVector);
         double l2norm = 0;
-        for (Number number : queryVector) {
-            double diff = byteBuffer.getFloat() - number.floatValue();
+        for (int i = 0; i < queryVector.size(); i++) {
+            double diff = decodedDocVector[i] - queryVector.get(i).floatValue();
             l2norm += diff * diff;
         }
         return Math.sqrt(l2norm);
@@ -156,8 +135,4 @@ public class BinaryDenseVector implements DenseVector {
     public int getDims() {
         return dims;
     }
-
-    private static ByteBuffer wrap(BytesRef dv) {
-        return ByteBuffer.wrap(dv.bytes, dv.offset, dv.length);
-    }
 }

+ 20 - 7
server/src/main/java/org/elasticsearch/script/field/vectors/BinaryDenseVectorDocValuesField.java

@@ -13,25 +13,30 @@ import org.apache.lucene.util.BytesRef;
 import org.elasticsearch.Version;
 import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
 import org.elasticsearch.index.mapper.vectors.DenseVectorScriptDocValues;
+import org.elasticsearch.index.mapper.vectors.VectorEncoderDecoder;
 
 import java.io.IOException;
 
 public class BinaryDenseVectorDocValuesField extends DenseVectorDocValuesField {
 
-    protected final BinaryDocValues input;
-    protected final Version indexVersion;
-    protected final int dims;
-    protected BytesRef value;
+    private final BinaryDocValues input;
+    private final float[] vectorValue;
+    private final Version indexVersion;
+    private boolean decoded;
+    private final int dims;
+    private BytesRef value;
 
     public BinaryDenseVectorDocValuesField(BinaryDocValues input, String name, ElementType elementType, int dims, Version indexVersion) {
         super(name, elementType);
         this.input = input;
         this.indexVersion = indexVersion;
         this.dims = dims;
+        this.vectorValue = new float[dims];
     }
 
     @Override
     public void setNextDocId(int docId) throws IOException {
+        decoded = false;
         if (input.advanceExact(docId)) {
             value = input.binaryValue();
         } else {
@@ -54,8 +59,8 @@ public class BinaryDenseVectorDocValuesField extends DenseVectorDocValuesField {
         if (isEmpty()) {
             return DenseVector.EMPTY;
         }
-
-        return new BinaryDenseVector(value, dims, indexVersion);
+        decodeVectorIfNecessary();
+        return new BinaryDenseVector(vectorValue, value, dims, indexVersion);
     }
 
     @Override
@@ -63,11 +68,19 @@ public class BinaryDenseVectorDocValuesField extends DenseVectorDocValuesField {
         if (isEmpty()) {
             return defaultValue;
         }
-        return new BinaryDenseVector(value, dims, indexVersion);
+        decodeVectorIfNecessary();
+        return new BinaryDenseVector(vectorValue, value, dims, indexVersion);
     }
 
     @Override
     public DenseVector getInternal() {
         return get(null);
     }
+
+    private void decodeVectorIfNecessary() {
+        if (decoded == false && value != null) {
+            VectorEncoderDecoder.decodeDenseVector(value, vectorValue);
+            decoded = true;
+        }
+    }
 }

+ 21 - 43
server/src/main/java/org/elasticsearch/script/field/vectors/ByteBinaryDenseVector.java

@@ -9,6 +9,7 @@
 package org.elasticsearch.script.field.vectors;
 
 import org.apache.lucene.util.BytesRef;
+import org.apache.lucene.util.VectorUtil;
 import org.elasticsearch.core.SuppressForbidden;
 
 import java.nio.ByteBuffer;
@@ -18,30 +19,28 @@ public class ByteBinaryDenseVector implements DenseVector {
 
     public static final int MAGNITUDE_BYTES = 4;
 
-    protected final BytesRef docVector;
-    protected final int dims;
+    private final BytesRef docVector;
+    private final byte[] vectorValue;
+    private final int dims;
 
-    protected float[] floatDocVector;
-    protected boolean magnitudeDecoded;
-    protected float magnitude;
+    private float[] floatDocVector;
+    private boolean magnitudeDecoded;
+    private float magnitude;
 
-    public ByteBinaryDenseVector(BytesRef docVector, int dims) {
+    public ByteBinaryDenseVector(byte[] vectorValue, BytesRef docVector, int dims) {
         this.docVector = docVector;
         this.dims = dims;
+        this.vectorValue = vectorValue;
     }
 
     @Override
     public float[] getVector() {
         if (floatDocVector == null) {
             floatDocVector = new float[dims];
-
-            int i = 0;
-            int j = docVector.offset;
-            while (i < dims) {
-                floatDocVector[i++] = docVector.bytes[j++];
+            for (int i = 0; i < dims; i++) {
+                floatDocVector[i] = vectorValue[i];
             }
         }
-
         return floatDocVector;
     }
 
@@ -56,13 +55,7 @@ public class ByteBinaryDenseVector implements DenseVector {
 
     @Override
     public int dotProduct(byte[] queryVector) {
-        int result = 0;
-        int i = 0;
-        int j = docVector.offset;
-        while (i < dims) {
-            result += docVector.bytes[j++] * queryVector[i++];
-        }
-        return result;
+        return VectorUtil.dotProduct(queryVector, vectorValue);
     }
 
     @Override
@@ -73,10 +66,8 @@ public class ByteBinaryDenseVector implements DenseVector {
     @Override
     public double dotProduct(List<Number> queryVector) {
         int result = 0;
-        int i = 0;
-        int j = docVector.offset;
-        while (i < dims) {
-            result += docVector.bytes[j++] * queryVector.get(i++).intValue();
+        for (int i = 0; i < queryVector.size(); i++) {
+            result += vectorValue[i] * queryVector.get(i).intValue();
         }
         return result;
     }
@@ -89,10 +80,8 @@ public class ByteBinaryDenseVector implements DenseVector {
     @Override
     public int l1Norm(byte[] queryVector) {
         int result = 0;
-        int i = 0;
-        int j = docVector.offset;
-        while (i < dims) {
-            result += abs(docVector.bytes[j++] - queryVector[i++]);
+        for (int i = 0; i < queryVector.length; i++) {
+            result += abs(vectorValue[i] - queryVector[i]);
         }
         return result;
     }
@@ -105,24 +94,15 @@ public class ByteBinaryDenseVector implements DenseVector {
     @Override
     public double l1Norm(List<Number> queryVector) {
         int result = 0;
-        int i = 0;
-        int j = docVector.offset;
-        while (i < dims) {
-            result += abs(docVector.bytes[j++] - queryVector.get(i++).intValue());
+        for (int i = 0; i < queryVector.size(); i++) {
+            result += abs(vectorValue[i] - queryVector.get(i).intValue());
         }
         return result;
     }
 
     @Override
     public double l2Norm(byte[] queryVector) {
-        int result = 0;
-        int i = 0;
-        int j = docVector.offset;
-        while (i < dims) {
-            int diff = docVector.bytes[j++] - queryVector[i++];
-            result += diff * diff;
-        }
-        return Math.sqrt(result);
+        return Math.sqrt(VectorUtil.squareDistance(queryVector, vectorValue));
     }
 
     @Override
@@ -133,10 +113,8 @@ public class ByteBinaryDenseVector implements DenseVector {
     @Override
     public double l2Norm(List<Number> queryVector) {
         int result = 0;
-        int i = 0;
-        int j = docVector.offset;
-        while (i < dims) {
-            int diff = docVector.bytes[j++] - queryVector.get(i++).intValue();
+        for (int i = 0; i < queryVector.size(); i++) {
+            int diff = vectorValue[i] - queryVector.get(i).intValue();
             result += diff * diff;
         }
         return Math.sqrt(result);

+ 18 - 6
server/src/main/java/org/elasticsearch/script/field/vectors/ByteBinaryDenseVectorDocValuesField.java

@@ -17,18 +17,22 @@ import java.io.IOException;
 
 public class ByteBinaryDenseVectorDocValuesField extends DenseVectorDocValuesField {
 
-    protected final BinaryDocValues input;
-    protected final int dims;
-    protected BytesRef value;
+    private final BinaryDocValues input;
+    private final int dims;
+    private final byte[] vectorValue;
+    private boolean decoded;
+    private BytesRef value;
 
     public ByteBinaryDenseVectorDocValuesField(BinaryDocValues input, String name, ElementType elementType, int dims) {
         super(name, elementType);
         this.input = input;
         this.dims = dims;
+        this.vectorValue = new byte[dims];
     }
 
     @Override
     public void setNextDocId(int docId) throws IOException {
+        decoded = false;
         if (input.advanceExact(docId)) {
             value = input.binaryValue();
         } else {
@@ -51,8 +55,8 @@ public class ByteBinaryDenseVectorDocValuesField extends DenseVectorDocValuesFie
         if (isEmpty()) {
             return DenseVector.EMPTY;
         }
-
-        return new ByteBinaryDenseVector(value, dims);
+        decodeVectorIfNecessary();
+        return new ByteBinaryDenseVector(vectorValue, value, dims);
     }
 
     @Override
@@ -60,11 +64,19 @@ public class ByteBinaryDenseVectorDocValuesField extends DenseVectorDocValuesFie
         if (isEmpty()) {
             return defaultValue;
         }
-        return new ByteBinaryDenseVector(value, dims);
+        decodeVectorIfNecessary();
+        return new ByteBinaryDenseVector(vectorValue, value, dims);
     }
 
     @Override
     public DenseVector getInternal() {
         return get(null);
     }
+
+    private void decodeVectorIfNecessary() {
+        if (decoded == false && value != null) {
+            System.arraycopy(value.bytes, value.offset, vectorValue, 0, dims);
+            decoded = true;
+        }
+    }
 }

+ 1 - 8
server/src/main/java/org/elasticsearch/script/field/vectors/ByteKnnDenseVector.java

@@ -103,14 +103,7 @@ public class ByteKnnDenseVector implements DenseVector {
 
     @Override
     public double l2Norm(byte[] queryVector) {
-        int result = 0;
-        int i = 0;
-        while (i < docVector.length) {
-            int diff = docVector[i] - queryVector[i];
-            result += diff * diff;
-            i++;
-        }
-        return Math.sqrt(result);
+        return Math.sqrt(VectorUtil.squareDistance(docVector, queryVector));
     }
 
     @Override

+ 39 - 0
server/src/test/java/org/elasticsearch/index/mapper/vectors/VectorEncoderDecoderTests.java

@@ -0,0 +1,39 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.index.mapper.vectors;
+
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.test.ESTestCase;
+
+import java.nio.ByteBuffer;
+
+public class VectorEncoderDecoderTests extends ESTestCase {
+
+    public void testVectorDecodingWithOffset() {
+        float[] inputFloats = new float[] { 1f, 2f, 3f, 4f };
+        ByteBuffer byteBuffer = ByteBuffer.allocate(20);
+        double magnitude = 0.0;
+        for (float f : inputFloats) {
+            byteBuffer.putFloat(f);
+            magnitude += f * f;
+        }
+        // Binary documents store magnitude in a float at the end of the buffer array
+        magnitude /= 4;
+        byteBuffer.putFloat((float) magnitude);
+        BytesRef floatBytes = new BytesRef(byteBuffer.array());
+        // adjust so that we have an offset ignoring the first float
+        floatBytes.length = 16;
+        floatBytes.offset = 4;
+        // since we are ignoring the first float to mock an offset, our dimensions can be assumed to be 3
+        float[] outputFloats = new float[3];
+        VectorEncoderDecoder.decodeDenseVector(floatBytes, outputFloats);
+        assertArrayEquals(outputFloats, new float[] { 2f, 3f, 4f }, 0f);
+    }
+
+}

+ 6 - 4
server/src/test/java/org/elasticsearch/script/field/vectors/DenseVectorTests.java

@@ -70,7 +70,7 @@ public class DenseVectorTests extends ESTestCase {
 
         for (Version indexVersion : Arrays.asList(Version.V_7_4_0, Version.CURRENT)) {
             BytesRef value = BinaryDenseVectorScriptDocValuesTests.mockEncodeDenseVector(docVector, ElementType.FLOAT, indexVersion);
-            BinaryDenseVector bdv = new BinaryDenseVector(value, dims, indexVersion);
+            BinaryDenseVector bdv = new BinaryDenseVector(docVector, value, dims, indexVersion);
 
             assertEquals(bdv.dotProduct(arrayQV), bdv.dotProduct(listQV), 0.001f);
             assertEquals(bdv.dotProduct((Object) listQV), bdv.dotProduct((Object) arrayQV), 0.001f);
@@ -115,7 +115,9 @@ public class DenseVectorTests extends ESTestCase {
         assertEquals(knn.cosineSimilarity((Object) listQV), knn.cosineSimilarity((Object) arrayQV), 0.001f);
 
         BytesRef value = BinaryDenseVectorScriptDocValuesTests.mockEncodeDenseVector(floatVector, ElementType.BYTE, Version.CURRENT);
-        ByteBinaryDenseVector bdv = new ByteBinaryDenseVector(value, dims);
+        byte[] byteVectorValue = new byte[dims];
+        System.arraycopy(value.bytes, value.offset, byteVectorValue, 0, dims);
+        ByteBinaryDenseVector bdv = new ByteBinaryDenseVector(byteVectorValue, value, dims);
 
         assertEquals(bdv.dotProduct(arrayQV), bdv.dotProduct(listQV), 0.001f);
         assertEquals(bdv.dotProduct((Object) listQV), bdv.dotProduct((Object) arrayQV), 0.001f);
@@ -162,7 +164,7 @@ public class DenseVectorTests extends ESTestCase {
         e = expectThrows(UnsupportedOperationException.class, () -> knn.cosineSimilarity((Object) queryVector));
         assertEquals(e.getMessage(), "use [double cosineSimilarity(byte[] queryVector, float qvMagnitude)] instead");
 
-        ByteBinaryDenseVector binary = new ByteBinaryDenseVector(new BytesRef(docVector), dims);
+        ByteBinaryDenseVector binary = new ByteBinaryDenseVector(docVector, new BytesRef(docVector), dims);
 
         e = expectThrows(UnsupportedOperationException.class, () -> binary.dotProduct(queryVector));
         assertEquals(e.getMessage(), "use [int dotProduct(byte[] queryVector)] instead");
@@ -219,7 +221,7 @@ public class DenseVectorTests extends ESTestCase {
         e = expectThrows(UnsupportedOperationException.class, () -> knn.cosineSimilarity((Object) queryVector));
         assertEquals(e.getMessage(), "use [double cosineSimilarity(float[] queryVector, boolean normalizeQueryVector)] instead");
 
-        BinaryDenseVector binary = new BinaryDenseVector(new BytesRef(docBuffer.array()), dims, Version.CURRENT);
+        BinaryDenseVector binary = new BinaryDenseVector(docVector, new BytesRef(docBuffer.array()), dims, Version.CURRENT);
 
         e = expectThrows(UnsupportedOperationException.class, () -> binary.dotProduct(queryVector));
         assertEquals(e.getMessage(), "use [double dotProduct(float[] queryVector)] instead");