1
0
Эх сурвалжийг харах

Adding support for hex-encoded byte vectors on knn-search (#105393)

Panagiotis Bailis 1 жил өмнө
parent
commit
d471ccb5bb
24 өөрчлөгдсөн 1136 нэмэгдсэн , 209 устгасан
  1. 5 0
      docs/changelog/105393.yaml
  2. 2 2
      docs/reference/query-dsl/knn-query.asciidoc
  3. 1 1
      docs/reference/rest-api/common-parms.asciidoc
  4. 1 1
      docs/reference/search/knn-search.asciidoc
  5. 21 0
      docs/reference/search/search-your-data/knn-search.asciidoc
  6. 163 0
      rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/170_knn_search_hex_encoded_byte_vectors.yml
  7. 162 0
      rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/175_knn_query_hex_encoded_byte_vectors.yml
  8. 1 0
      server/src/main/java/org/elasticsearch/TransportVersions.java
  9. 14 0
      server/src/main/java/org/elasticsearch/common/io/stream/StreamInput.java
  10. 13 0
      server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java
  11. 189 97
      server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java
  12. 9 1
      server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java
  13. 24 7
      server/src/main/java/org/elasticsearch/search/vectors/ExactKnnQueryBuilder.java
  14. 24 6
      server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java
  15. 49 27
      server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java
  16. 42 30
      server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java
  17. 168 0
      server/src/main/java/org/elasticsearch/search/vectors/VectorData.java
  18. 3 4
      server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java
  19. 24 29
      server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java
  20. 9 0
      server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java
  21. 10 1
      server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java
  22. 2 2
      server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java
  23. 199 0
      server/src/test/java/org/elasticsearch/search/vectors/VectorDataTests.java
  24. 1 1
      test/framework/src/main/java/org/elasticsearch/test/AbstractQueryVectorBuilderTestCase.java

+ 5 - 0
docs/changelog/105393.yaml

@@ -0,0 +1,5 @@
+pr: 105393
+summary: Adding support for hex-encoded byte vectors on knn-search
+area: Vector Search
+type: feature
+issues: []

+ 2 - 2
docs/reference/query-dsl/knn-query.asciidoc

@@ -87,8 +87,8 @@ the top `size` results.
 `query_vector`::
 +
 --
-(Required, array of floats) Query vector. Must have the same number of dimensions
-as the vector field you are searching against.
+(Required, array of floats or string) Query vector. Must have the same number of dimensions
+as the vector field you are searching against. Must be either an array of floats or a hex-encoded byte vector.
 --
 
 `num_candidates`::

+ 1 - 1
docs/reference/rest-api/common-parms.asciidoc

@@ -597,7 +597,7 @@ end::knn-num-candidates[]
 
 tag::knn-query-vector[]
 Query vector. Must have the same number of dimensions as the vector field you
-are searching against.
+are searching against. Must be either an array of floats or a hex-encoded byte vector.
 end::knn-query-vector[]
 
 tag::knn-similarity[]

+ 1 - 1
docs/reference/search/knn-search.asciidoc

@@ -121,7 +121,7 @@ include::{es-repo-dir}/rest-api/common-parms.asciidoc[tag=knn-k]
 include::{es-repo-dir}/rest-api/common-parms.asciidoc[tag=knn-num-candidates]
 
 `query_vector`::
-(Required, array of floats)
+(Required, array of floats or string)
 include::{es-repo-dir}/rest-api/common-parms.asciidoc[tag=knn-query-vector]
 ====
 

+ 21 - 0
docs/reference/search/search-your-data/knn-search.asciidoc

@@ -242,6 +242,27 @@ POST byte-image-index/_search
 // TEST[s/"k": 10/"k": 3/]
 // TEST[s/"num_candidates": 100/"num_candidates": 3/]
 
+
+_Note_: In addition to the standard byte array, one can also provide a hex-encoded string value
+for the `query_vector` param. As an example, the search request above can also be expressed as follows,
+which would yield the same results
+[source,console]
+----
+POST byte-image-index/_search
+{
+  "knn": {
+    "field": "byte-image-vector",
+    "query_vector": "fb09",
+    "k": 10,
+    "num_candidates": 100
+  },
+  "fields": [ "title" ]
+}
+----
+// TEST[continued]
+// TEST[s/"k": 10/"k": 3/]
+// TEST[s/"num_candidates": 100/"num_candidates": 3/]
+
 [discrete]
 [[knn-search-quantized-example]]
 ==== Byte quantized kNN search

+ 163 - 0
rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/170_knn_search_hex_encoded_byte_vectors.yml

@@ -0,0 +1,163 @@
+setup:
+  - skip:
+      version: ' - 8.13.99'
+      reason: 'hex encoding for byte vectors was added in 8.14'
+
+  - do:
+      indices.create:
+        index: knn_hex_vector_index
+        body:
+          settings:
+            number_of_shards: 1
+          mappings:
+            dynamic: false
+            properties:
+              my_vector_byte:
+                type: dense_vector
+                dims: 3
+                index : true
+                similarity : l2_norm
+                element_type: byte
+              my_vector_float:
+                type: dense_vector
+                dims: 3
+                index: true
+                element_type: float
+                similarity : l2_norm
+
+  # [-128, 127, 10] - is encoded as '807f0a'
+  - do:
+      index:
+        index: knn_hex_vector_index
+        id: "1"
+        body:
+          my_vector_byte: "807f0a"
+
+
+  # [0, 1, 0] - is encoded as '000100'
+  - do:
+      index:
+        index: knn_hex_vector_index
+        id: "2"
+        body:
+          my_vector_byte: "000100"
+
+  # [64, -10, -30] - is encoded as '40f6e2'
+  - do:
+      index:
+        index: knn_hex_vector_index
+        id: "3"
+        body:
+          my_vector_byte: "40f6e2"
+
+  - do:
+      index:
+        index: knn_hex_vector_index
+        id: "4"
+        body:
+          my_vector_float: [10.5, -10, 1024]
+
+  - do:
+      indices.refresh: {}
+
+---
+"Fail to index hex-encoded vector on float field":
+
+  # [-128, 127, 10] - is encoded as '807f0a'
+  - do:
+      catch: /Failed to parse object./
+      index:
+        index: knn_hex_vector_index
+        id: "5"
+        body:
+          my_vector_float: "807f0a"
+
+---
+"Knn search with hex string for float field" :
+  # [64, 10, -30] - is encoded as '400ae2'
+  # this will be properly decoded but only because:
+  # (i) the provided input is compatible as the values are within [Byte.MIN_VALUE, Byte.MAX_VALUE] range
+  # (ii) we do not differentiate between byte and float fields when initially parsing a query even for hex
+  # (iii) we support expansion from byte to float
+
+  - do:
+      search:
+        index: knn_hex_vector_index
+        body:
+          size: 3
+          knn:
+            field: my_vector_float
+            query_vector: "400ae2"
+            num_candidates: 100
+            k: 10
+
+  - match: { hits.total.value: 1 }
+  - match: { hits.hits.0._id: "4" }
+
+---
+"Knn search with hex string for byte field" :
+  # [64, 10, -30] - is encoded as '400ae2'
+  - do:
+      search:
+        index: knn_hex_vector_index
+        body:
+          size: 3
+          knn:
+            field: my_vector_byte
+            query_vector: "400ae2"
+            num_candidates: 100
+            k: 10
+
+  - match: { hits.total.value: 3 }
+  - match: { hits.hits.0._id: "3" }
+  - match: { hits.hits.1._id: "2" }
+  - match: { hits.hits.2._id: "1" }
+
+---
+"Knn search with hex string for byte field - dimensions mismatch" :
+  # [64, 10, -30, 10] - is encoded as '400ae20a'
+  - do:
+      catch: /the query vector has a different dimension \[4\] than the index vectors \[3\]/
+      search:
+        index: knn_hex_vector_index
+        body:
+          size: 3
+          knn:
+            field: my_vector_byte
+            query_vector: "400ae20a"
+            num_candidates: 100
+            k: 10
+
+
+---
+"Knn search with hex string for byte field - cannot decode string" :
+  # '40af20a' is garbage :)
+  - do:
+      catch: /failed to parse field \[query_vector\]/
+      search:
+        index: knn_hex_vector_index
+        body:
+          size: 3
+          knn:
+            field: my_vector_byte
+            query_vector: "40af20a"
+            num_candidates: 100
+            k: 10
+
+---
+"Knn search with standard byte vector matching against hex-encoded indexed docs" :
+  - do:
+      search:
+        index: knn_hex_vector_index
+        body:
+          size: 3
+          knn:
+            field: my_vector_byte
+            query_vector: [64, 10, -30]
+            num_candidates: 100
+            k: 10
+
+  - match: { hits.total.value: 3 }
+  - match: { hits.hits.0._id: "3" }
+  - match: { hits.hits.1._id: "2" }
+  - match: { hits.hits.2._id: "1" }

+ 162 - 0
rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/175_knn_query_hex_encoded_byte_vectors.yml

@@ -0,0 +1,162 @@
+setup:
+  - skip:
+      version: ' - 8.13.99'
+      reason: 'hex encoding for byte vectors was added in 8.14'
+
+  - do:
+      indices.create:
+        index: knn_hex_vector_index
+        body:
+          settings:
+            number_of_shards: 1
+          mappings:
+            dynamic: false
+            properties:
+              my_vector_byte:
+                type: dense_vector
+                dims: 3
+                index : true
+                similarity : l2_norm
+                element_type: byte
+              my_vector_float:
+                type: dense_vector
+                dims: 3
+                index: true
+                element_type: float
+                similarity : l2_norm
+
+  # [-128, 127, 10] - is encoded as '807f0a'
+  - do:
+      index:
+        index: knn_hex_vector_index
+        id: "1"
+        body:
+          my_vector_byte: "807f0a"
+
+
+  # [0, 1, 0] - is encoded as '000100'
+  - do:
+      index:
+        index: knn_hex_vector_index
+        id: "2"
+        body:
+          my_vector_byte: "000100"
+
+  # [64, -10, -30] - is encoded as '40f6e2'
+  - do:
+      index:
+        index: knn_hex_vector_index
+        id: "3"
+        body:
+          my_vector_byte: "40f6e2"
+
+  - do:
+      index:
+        index: knn_hex_vector_index
+        id: "4"
+        body:
+          my_vector_float: [10.5, -10, 1024]
+
+  - do:
+      indices.refresh: {}
+
+---
+"Fail to index hex-encoded vector on float field":
+
+  # [-128, 127, 10] - is encoded as '807f0a'
+  - do:
+      catch: /Failed to parse object./
+      index:
+        index: knn_hex_vector_index
+        id: "5"
+        body:
+          my_vector_float: "807f0a"
+
+---
+"Knn query with hex string for float field" :
+  # [64, 10, -30] - is encoded as '400ae2'
+  # this will be properly decoded but only because:
+  # (i) the provided input is compatible as the values are within [Byte.MIN_VALUE, Byte.MAX_VALUE] range
+  # (ii) we do not differentiate between byte and float fields when initially parsing a query even for hex
+  # (iii) we support expansion from byte to float
+
+  - do:
+      search:
+        index: knn_hex_vector_index
+        body:
+          size: 3
+          query:
+            knn:
+              field: my_vector_float
+              query_vector: "400ae2"
+              num_candidates: 100
+
+  - match: { hits.total.value: 1 }
+  - match: { hits.hits.0._id: "4" }
+
+---
+"Knn query with hex string for byte field" :
+  # [64, 10, -30] - is encoded as '400ae2'
+  - do:
+      search:
+        index: knn_hex_vector_index
+        body:
+          size: 3
+          query:
+            knn:
+              field: my_vector_byte
+              query_vector: "400ae2"
+              num_candidates: 100
+
+  - match: { hits.total.value: 3 }
+  - match: { hits.hits.0._id: "3" }
+  - match: { hits.hits.1._id: "2" }
+  - match: { hits.hits.2._id: "1" }
+
+---
+"Knn query with hex string for byte field - dimensions mismatch" :
+  # [64, 10, -30, 10] - is encoded as '400ae20a'
+  - do:
+      catch: /the query vector has a different dimension \[4\] than the index vectors \[3\]/
+      search:
+        index: knn_hex_vector_index
+        body:
+          size: 3
+          query:
+            knn:
+              field: my_vector_byte
+              query_vector: "400ae20a"
+              num_candidates: 100
+
+---
+"Knn query with hex string for byte field - cannot decode string" :
+  # '40af20a' is garbage :)
+  - do:
+      catch: /failed to parse field \[query_vector\]/
+      search:
+        index: knn_hex_vector_index
+        body:
+          size: 3
+          query:
+            knn:
+                field: my_vector_byte
+                query_vector: "40af20a"
+                num_candidates: 100
+
+---
+"Knn query with standard byte vector matching against hex-encoded indexed docs" :
+  - do:
+      search:
+        index: knn_hex_vector_index
+        body:
+          size: 3
+          query:
+            knn:
+              field: my_vector_byte
+              query_vector: [64, 10, -30]
+              num_candidates: 100
+
+  - match: { hits.total.value: 3 }
+  - match: { hits.hits.0._id: "3" }
+  - match: { hits.hits.1._id: "2" }
+  - match: { hits.hits.2._id: "1" }

+ 1 - 0
server/src/main/java/org/elasticsearch/TransportVersions.java

@@ -143,6 +143,7 @@ public class TransportVersions {
     public static final TransportVersion ADD_DATA_STREAM_GLOBAL_RETENTION = def(8_603_00_0);
     public static final TransportVersion ALLOCATION_STATS = def(8_604_00_0);
     public static final TransportVersion ESQL_EXTENDED_ENRICH_TYPES = def(8_605_00_0);
+    public static final TransportVersion KNN_EXPLICIT_BYTE_QUERY_VECTOR_PARSING = def(8_606_00_0);
 
     /*
      * STOP! READ THIS FIRST! No, really,

+ 14 - 0
server/src/main/java/org/elasticsearch/common/io/stream/StreamInput.java

@@ -693,6 +693,20 @@ public abstract class StreamInput extends InputStream {
         return null;
     }
 
+    /**
+     * Reads an optional float array. It's effectively the same as readFloatArray, except
+     * it supports null.
+     * @return a float array or null
+     * @throws IOException
+     */
+    @Nullable
+    public float[] readOptionalFloatArray() throws IOException {
+        if (readBoolean()) {
+            return readFloatArray();
+        }
+        return null;
+    }
+
     /**
      * Same as {@link #readMap(Writeable.Reader, Writeable.Reader)} but always reading string keys.
      */

+ 13 - 0
server/src/main/java/org/elasticsearch/common/io/stream/StreamOutput.java

@@ -534,6 +534,19 @@ public abstract class StreamOutput extends OutputStream {
         }
     }
 
+    /**
+     * Writes a float array, for null arrays it writes false.
+     * @param array an array or null
+     */
+    public void writeOptionalFloatArray(@Nullable float[] array) throws IOException {
+        if (array == null) {
+            writeBoolean(false);
+        } else {
+            writeBoolean(true);
+            writeFloatArray(array);
+        }
+    }
+
     public void writeGenericMap(@Nullable Map<String, Object> map) throws IOException {
         writeGenericValue(map);
     }

+ 189 - 97
server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

@@ -43,6 +43,7 @@ import org.apache.lucene.search.Query;
 import org.apache.lucene.search.join.BitSetProducer;
 import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.VectorUtil;
+import org.elasticsearch.common.ParsingException;
 import org.elasticsearch.common.xcontent.support.XContentMapValues;
 import org.elasticsearch.index.IndexVersion;
 import org.elasticsearch.index.IndexVersions;
@@ -69,6 +70,7 @@ import org.elasticsearch.search.vectors.ESDiversifyingChildrenByteKnnVectorQuery
 import org.elasticsearch.search.vectors.ESDiversifyingChildrenFloatKnnVectorQuery;
 import org.elasticsearch.search.vectors.ESKnnByteVectorQuery;
 import org.elasticsearch.search.vectors.ESKnnFloatVectorQuery;
+import org.elasticsearch.search.vectors.VectorData;
 import org.elasticsearch.search.vectors.VectorSimilarityQuery;
 import org.elasticsearch.xcontent.ToXContent;
 import org.elasticsearch.xcontent.XContentBuilder;
@@ -80,6 +82,7 @@ import java.nio.ByteBuffer;
 import java.nio.ByteOrder;
 import java.time.ZoneId;
 import java.util.Arrays;
+import java.util.HexFormat;
 import java.util.Locale;
 import java.util.Map;
 import java.util.Objects;
@@ -88,6 +91,7 @@ import java.util.function.Function;
 import java.util.function.Supplier;
 import java.util.stream.Stream;
 
+import static org.elasticsearch.common.Strings.format;
 import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
 
 /**
@@ -338,11 +342,16 @@ public class DenseVectorFieldMapper extends FieldMapper {
             }
 
             @Override
-            public void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
+            public double computeDotProduct(VectorData vectorData) {
+                return VectorUtil.dotProduct(vectorData.asByteVector(), vectorData.asByteVector());
+            }
+
+            private VectorData parseVectorArray(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
                 int index = 0;
                 byte[] vector = new byte[fieldMapper.fieldType().dims];
                 float squaredMagnitude = 0;
-                for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) {
+                for (XContentParser.Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser()
+                    .nextToken()) {
                     fieldMapper.checkDimensionExceeded(index, context);
                     ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser());
                     final int value;
@@ -383,44 +392,49 @@ public class DenseVectorFieldMapper extends FieldMapper {
                 }
                 fieldMapper.checkDimensionMatches(index, context);
                 checkVectorMagnitude(fieldMapper.fieldType().similarity, errorByteElementsAppender(vector), squaredMagnitude);
+                return VectorData.fromBytes(vector);
+            }
+
+            private VectorData parseHexEncodedVector(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
+                byte[] decodedVector = HexFormat.of().parseHex(context.parser().text());
+                fieldMapper.checkDimensionMatches(decodedVector.length, context);
+                VectorData vectorData = VectorData.fromBytes(decodedVector);
+                double squaredMagnitude = computeDotProduct(vectorData);
+                checkVectorMagnitude(
+                    fieldMapper.fieldType().similarity,
+                    errorByteElementsAppender(decodedVector),
+                    (float) squaredMagnitude
+                );
+                return vectorData;
+            }
+
+            @Override
+            VectorData parseKnnVector(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
+                XContentParser.Token token = context.parser().currentToken();
+                return switch (token) {
+                    case START_ARRAY -> parseVectorArray(context, fieldMapper);
+                    case VALUE_STRING -> parseHexEncodedVector(context, fieldMapper);
+                    default -> throw new ParsingException(
+                        context.parser().getTokenLocation(),
+                        format("Unsupported type [%s] for provided value [%s]", token, context.parser().text())
+                    );
+                };
+            }
+
+            @Override
+            public void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
+                VectorData vectorData = parseKnnVector(context, fieldMapper);
                 Field field = createKnnVectorField(
                     fieldMapper.fieldType().name(),
-                    vector,
+                    vectorData.asByteVector(),
                     fieldMapper.fieldType().similarity.vectorSimilarityFunction(fieldMapper.indexCreatedVersion, this)
                 );
                 context.doc().addWithKey(fieldMapper.fieldType().name(), field);
             }
 
             @Override
-            double parseKnnVectorToByteBuffer(DocumentParserContext context, DenseVectorFieldMapper fieldMapper, ByteBuffer byteBuffer)
-                throws IOException {
-                double dotProduct = 0f;
-                int index = 0;
-                for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) {
-                    fieldMapper.checkDimensionExceeded(index, context);
-                    ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser());
-                    int value = context.parser().intValue(true);
-                    if (value < Byte.MIN_VALUE || value > Byte.MAX_VALUE) {
-                        throw new IllegalArgumentException(
-                            "element_type ["
-                                + this
-                                + "] vectors only support integers between ["
-                                + Byte.MIN_VALUE
-                                + ", "
-                                + Byte.MAX_VALUE
-                                + "] but found ["
-                                + value
-                                + "] at dim ["
-                                + index
-                                + "];"
-                        );
-                    }
-                    byteBuffer.put((byte) value);
-                    dotProduct += value * value;
-                    index++;
-                }
-                fieldMapper.checkDimensionMatches(index, context);
-                return dotProduct;
+            int getNumBytes(int dimensions) {
+                return dimensions * elementBytes;
             }
 
             @Override
@@ -530,6 +544,11 @@ public class DenseVectorFieldMapper extends FieldMapper {
                 }
             }
 
+            @Override
+            public double computeDotProduct(VectorData vectorData) {
+                return VectorUtil.dotProduct(vectorData.asFloatVector(), vectorData.asFloatVector());
+            }
+
             @Override
             public void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
                 int index = 0;
@@ -566,23 +585,27 @@ public class DenseVectorFieldMapper extends FieldMapper {
             }
 
             @Override
-            double parseKnnVectorToByteBuffer(DocumentParserContext context, DenseVectorFieldMapper fieldMapper, ByteBuffer byteBuffer)
-                throws IOException {
-                double dotProduct = 0f;
+            VectorData parseKnnVector(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException {
                 int index = 0;
+                float squaredMagnitude = 0;
                 float[] vector = new float[fieldMapper.fieldType().dims];
                 for (Token token = context.parser().nextToken(); token != Token.END_ARRAY; token = context.parser().nextToken()) {
                     fieldMapper.checkDimensionExceeded(index, context);
                     ensureExpectedToken(Token.VALUE_NUMBER, token, context.parser());
                     float value = context.parser().floatValue(true);
                     vector[index] = value;
-                    byteBuffer.putFloat(value);
-                    dotProduct += value * value;
+                    squaredMagnitude += value * value;
                     index++;
                 }
                 fieldMapper.checkDimensionMatches(index, context);
                 checkVectorBounds(vector);
-                return dotProduct;
+                checkVectorMagnitude(fieldMapper.fieldType().similarity, errorFloatElementsAppender(vector), squaredMagnitude);
+                return VectorData.fromFloats(vector);
+            }
+
+            @Override
+            int getNumBytes(int dimensions) {
+                return dimensions * elementBytes;
             }
 
             @Override
@@ -607,8 +630,9 @@ public class DenseVectorFieldMapper extends FieldMapper {
 
         abstract void parseKnnVectorAndIndex(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException;
 
-        abstract double parseKnnVectorToByteBuffer(DocumentParserContext context, DenseVectorFieldMapper fieldMapper, ByteBuffer byteBuffer)
-            throws IOException;
+        abstract VectorData parseKnnVector(DocumentParserContext context, DenseVectorFieldMapper fieldMapper) throws IOException;
+
+        abstract int getNumBytes(int dimensions);
 
         abstract ByteBuffer createByteBuffer(IndexVersion indexVersion, int numBytes);
 
@@ -699,6 +723,8 @@ public class DenseVectorFieldMapper extends FieldMapper {
         static Function<StringBuilder, StringBuilder> errorByteElementsAppender(byte[] vector) {
             return sb -> appendErrorElements(sb, vector);
         }
+
+        public abstract double computeDotProduct(VectorData vectorData);
     }
 
     static final Map<String, ElementType> namesToElementType = Map.of(
@@ -1158,66 +1184,120 @@ public class DenseVectorFieldMapper extends FieldMapper {
             return knnQuery;
         }
 
-        public Query createExactKnnQuery(float[] queryVector) {
-            queryVector = validateAndNormalize(queryVector);
-            VectorSimilarityFunction vectorSimilarityFunction = similarity.vectorSimilarityFunction(indexVersionCreated, elementType);
+        public Query createExactKnnQuery(VectorData queryVector) {
+            if (isIndexed() == false) {
+                throw new IllegalArgumentException(
+                    "to perform knn search on field [" + name() + "], its mapping must have [index] set to [true]"
+                );
+            }
             return switch (elementType) {
-                case BYTE -> {
-                    byte[] bytes = new byte[queryVector.length];
+                case BYTE -> createExactKnnByteQuery(queryVector.asByteVector());
+                case FLOAT -> createExactKnnFloatQuery(queryVector.asFloatVector());
+            };
+        }
+
+        private Query createExactKnnByteQuery(byte[] queryVector) {
+            if (queryVector.length != dims) {
+                throw new IllegalArgumentException(
+                    "the query vector has a different dimension [" + queryVector.length + "] than the index vectors [" + dims + "]"
+                );
+            }
+            if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
+                float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
+                elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude);
+            }
+            VectorSimilarityFunction vectorSimilarityFunction = similarity.vectorSimilarityFunction(indexVersionCreated, elementType);
+            return new BooleanQuery.Builder().add(new FieldExistsQuery(name()), BooleanClause.Occur.FILTER)
+                .add(
+                    new FunctionQuery(
+                        new ByteVectorSimilarityFunction(
+                            vectorSimilarityFunction,
+                            new ByteKnnVectorFieldSource(name()),
+                            new ConstKnnByteVectorValueSource(queryVector)
+                        )
+                    ),
+                    BooleanClause.Occur.SHOULD
+                )
+                .build();
+        }
+
+        private Query createExactKnnFloatQuery(float[] queryVector) {
+            if (queryVector.length != dims) {
+                throw new IllegalArgumentException(
+                    "the query vector has a different dimension [" + queryVector.length + "] than the index vectors [" + dims + "]"
+                );
+            }
+            elementType.checkVectorBounds(queryVector);
+            if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
+                float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
+                elementType.checkVectorMagnitude(similarity, ElementType.errorFloatElementsAppender(queryVector), squaredMagnitude);
+                if (similarity == VectorSimilarity.COSINE
+                    && indexVersionCreated.onOrAfter(NORMALIZE_COSINE)
+                    && isNotUnitVector(squaredMagnitude)) {
+                    float length = (float) Math.sqrt(squaredMagnitude);
+                    queryVector = Arrays.copyOf(queryVector, queryVector.length);
                     for (int i = 0; i < queryVector.length; i++) {
-                        bytes[i] = (byte) queryVector[i];
+                        queryVector[i] /= length;
                     }
-                    yield new BooleanQuery.Builder().add(new FieldExistsQuery(name()), BooleanClause.Occur.FILTER)
-                        .add(
-                            new FunctionQuery(
-                                new ByteVectorSimilarityFunction(
-                                    vectorSimilarityFunction,
-                                    new ByteKnnVectorFieldSource(name()),
-                                    new ConstKnnByteVectorValueSource(bytes)
-                                )
-                            ),
-                            BooleanClause.Occur.SHOULD
-                        )
-                        .build();
                 }
-                case FLOAT -> new BooleanQuery.Builder().add(new FieldExistsQuery(name()), BooleanClause.Occur.FILTER)
-                    .add(
-                        new FunctionQuery(
-                            new FloatVectorSimilarityFunction(
-                                vectorSimilarityFunction,
-                                new FloatKnnVectorFieldSource(name()),
-                                new ConstKnnFloatValueSource(queryVector)
-                            )
-                        ),
-                        BooleanClause.Occur.SHOULD
-                    )
-                    .build();
-            };
+            }
+            VectorSimilarityFunction vectorSimilarityFunction = similarity.vectorSimilarityFunction(indexVersionCreated, elementType);
+            return new BooleanQuery.Builder().add(new FieldExistsQuery(name()), BooleanClause.Occur.FILTER)
+                .add(
+                    new FunctionQuery(
+                        new FloatVectorSimilarityFunction(
+                            vectorSimilarityFunction,
+                            new FloatKnnVectorFieldSource(name()),
+                            new ConstKnnFloatValueSource(queryVector)
+                        )
+                    ),
+                    BooleanClause.Occur.SHOULD
+                )
+                .build();
+        }
+
+        Query createKnnQuery(float[] queryVector, int numCands, Query filter, Float similarityThreshold, BitSetProducer parentFilter) {
+            return createKnnQuery(VectorData.fromFloats(queryVector), numCands, filter, similarityThreshold, parentFilter);
         }
 
         public Query createKnnQuery(
-            float[] queryVector,
+            VectorData queryVector,
             int numCands,
             Query filter,
             Float similarityThreshold,
             BitSetProducer parentFilter
         ) {
-            queryVector = validateAndNormalize(queryVector);
-            Query knnQuery = switch (elementType) {
-                case BYTE -> {
-                    byte[] bytes = new byte[queryVector.length];
-                    for (int i = 0; i < queryVector.length; i++) {
-                        bytes[i] = (byte) queryVector[i];
-                    }
-                    yield parentFilter != null
-                        ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), bytes, filter, numCands, parentFilter)
-                        : new ESKnnByteVectorQuery(name(), bytes, numCands, filter);
-                }
-                case FLOAT -> parentFilter != null
-                    ? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, numCands, parentFilter)
-                    : new ESKnnFloatVectorQuery(name(), queryVector, numCands, filter);
+            if (isIndexed() == false) {
+                throw new IllegalArgumentException(
+                    "to perform knn search on field [" + name() + "], its mapping must have [index] set to [true]"
+                );
+            }
+            return switch (getElementType()) {
+                case BYTE -> createKnnByteQuery(queryVector.asByteVector(), numCands, filter, similarityThreshold, parentFilter);
+                case FLOAT -> createKnnFloatQuery(queryVector.asFloatVector(), numCands, filter, similarityThreshold, parentFilter);
             };
+        }
+
+        private Query createKnnByteQuery(
+            byte[] queryVector,
+            int numCands,
+            Query filter,
+            Float similarityThreshold,
+            BitSetProducer parentFilter
+        ) {
+            if (queryVector.length != dims) {
+                throw new IllegalArgumentException(
+                    "the query vector has a different dimension [" + queryVector.length + "] than the index vectors [" + dims + "]"
+                );
+            }
 
+            if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
+                float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
+                elementType.checkVectorMagnitude(similarity, ElementType.errorByteElementsAppender(queryVector), squaredMagnitude);
+            }
+            Query knnQuery = parentFilter != null
+                ? new ESDiversifyingChildrenByteKnnVectorQuery(name(), queryVector, filter, numCands, parentFilter)
+                : new ESKnnByteVectorQuery(name(), queryVector, numCands, filter);
             if (similarityThreshold != null) {
                 knnQuery = new VectorSimilarityQuery(
                     knnQuery,
@@ -1228,12 +1308,13 @@ public class DenseVectorFieldMapper extends FieldMapper {
             return knnQuery;
         }
 
-        private float[] validateAndNormalize(float[] queryVector) {
-            if (isIndexed() == false) {
-                throw new IllegalArgumentException(
-                    "to perform knn search on field [" + name() + "], its mapping must have [index] set to [true]"
-                );
-            }
+        private Query createKnnFloatQuery(
+            float[] queryVector,
+            int numCands,
+            Query filter,
+            Float similarityThreshold,
+            BitSetProducer parentFilter
+        ) {
             if (queryVector.length != dims) {
                 throw new IllegalArgumentException(
                     "the query vector has a different dimension [" + queryVector.length + "] than the index vectors [" + dims + "]"
@@ -1244,7 +1325,6 @@ public class DenseVectorFieldMapper extends FieldMapper {
                 float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
                 elementType.checkVectorMagnitude(similarity, ElementType.errorFloatElementsAppender(queryVector), squaredMagnitude);
                 if (similarity == VectorSimilarity.COSINE
-                    && ElementType.FLOAT.equals(elementType)
                     && indexVersionCreated.onOrAfter(NORMALIZE_COSINE)
                     && isNotUnitVector(squaredMagnitude)) {
                     float length = (float) Math.sqrt(squaredMagnitude);
@@ -1254,7 +1334,17 @@ public class DenseVectorFieldMapper extends FieldMapper {
                     }
                 }
             }
-            return queryVector;
+            Query knnQuery = parentFilter != null
+                ? new ESDiversifyingChildrenFloatKnnVectorQuery(name(), queryVector, filter, numCands, parentFilter)
+                : new ESKnnFloatVectorQuery(name(), queryVector, numCands, filter);
+            if (similarityThreshold != null) {
+                knnQuery = new VectorSimilarityQuery(
+                    knnQuery,
+                    similarityThreshold,
+                    similarity.score(similarityThreshold, elementType, dims)
+                );
+            }
+            return knnQuery;
         }
 
         VectorSimilarity getSimilarity() {
@@ -1349,13 +1439,15 @@ public class DenseVectorFieldMapper extends FieldMapper {
         int dims = fieldType().dims;
         ElementType elementType = fieldType().elementType;
         int numBytes = indexCreatedVersion.onOrAfter(MAGNITUDE_STORED_INDEX_VERSION)
-            ? dims * elementType.elementBytes + MAGNITUDE_BYTES
-            : dims * elementType.elementBytes;
+            ? elementType.getNumBytes(dims) + MAGNITUDE_BYTES
+            : elementType.getNumBytes(dims);
 
         ByteBuffer byteBuffer = elementType.createByteBuffer(indexCreatedVersion, numBytes);
-        double dotProduct = elementType.parseKnnVectorToByteBuffer(context, this, byteBuffer);
+        VectorData vectorData = elementType.parseKnnVector(context, this);
+        vectorData.addToBuffer(byteBuffer);
         if (indexCreatedVersion.onOrAfter(MAGNITUDE_STORED_INDEX_VERSION)) {
             // encode vector magnitude at the end
+            double dotProduct = elementType.computeDotProduct(vectorData);
             float vectorMagnitude = (float) Math.sqrt(dotProduct);
             byteBuffer.putFloat(vectorMagnitude);
         }

+ 9 - 1
server/src/main/java/org/elasticsearch/search/retriever/KnnRetrieverBuilder.java

@@ -13,6 +13,7 @@ import org.elasticsearch.features.NodeFeature;
 import org.elasticsearch.search.builder.SearchSourceBuilder;
 import org.elasticsearch.search.vectors.KnnSearchBuilder;
 import org.elasticsearch.search.vectors.QueryVectorBuilder;
+import org.elasticsearch.search.vectors.VectorData;
 import org.elasticsearch.xcontent.ConstructingObjectParser;
 import org.elasticsearch.xcontent.ParseField;
 import org.elasticsearch.xcontent.XContentBuilder;
@@ -121,7 +122,14 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder {
 
     @Override
     public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
-        KnnSearchBuilder knnSearchBuilder = new KnnSearchBuilder(field, queryVector, queryVectorBuilder, k, numCands, similarity);
+        KnnSearchBuilder knnSearchBuilder = new KnnSearchBuilder(
+            field,
+            VectorData.fromFloats(queryVector),
+            queryVectorBuilder,
+            k,
+            numCands,
+            similarity
+        );
         if (preFilterQueryBuilders != null) {
             knnSearchBuilder.addFilterQueries(preFilterQueryBuilders);
         }

+ 24 - 7
server/src/main/java/org/elasticsearch/search/vectors/ExactKnnQueryBuilder.java

@@ -22,7 +22,6 @@ import org.elasticsearch.index.query.SearchExecutionContext;
 import org.elasticsearch.xcontent.XContentBuilder;
 
 import java.io.IOException;
-import java.util.Arrays;
 import java.util.Objects;
 
 /**
@@ -32,7 +31,7 @@ import java.util.Objects;
 public class ExactKnnQueryBuilder extends AbstractQueryBuilder<ExactKnnQueryBuilder> {
     public static final String NAME = "exact_knn";
     private final String field;
-    private final float[] query;
+    private final VectorData query;
 
     /**
      * Creates a query builder.
@@ -41,13 +40,27 @@ public class ExactKnnQueryBuilder extends AbstractQueryBuilder<ExactKnnQueryBuil
      * @param field    the field that was used for the kNN query
      */
     public ExactKnnQueryBuilder(float[] query, String field) {
+        this(VectorData.fromFloats(query), field);
+    }
+
+    /**
+     * Creates a query builder.
+     *
+     * @param query    the query vector
+     * @param field    the field that was used for the kNN query
+     */
+    public ExactKnnQueryBuilder(VectorData query, String field) {
         this.query = query;
         this.field = field;
     }
 
     public ExactKnnQueryBuilder(StreamInput in) throws IOException {
         super(in);
-        this.query = in.readFloatArray();
+        if (in.getTransportVersion().onOrAfter(TransportVersions.KNN_EXPLICIT_BYTE_QUERY_VECTOR_PARSING)) {
+            this.query = in.readOptionalWriteable(VectorData::new);
+        } else {
+            this.query = VectorData.fromFloats(in.readFloatArray());
+        }
         this.field = in.readString();
     }
 
@@ -55,7 +68,7 @@ public class ExactKnnQueryBuilder extends AbstractQueryBuilder<ExactKnnQueryBuil
         return field;
     }
 
-    float[] getQuery() {
+    VectorData getQuery() {
         return query;
     }
 
@@ -66,7 +79,11 @@ public class ExactKnnQueryBuilder extends AbstractQueryBuilder<ExactKnnQueryBuil
 
     @Override
     protected void doWriteTo(StreamOutput out) throws IOException {
-        out.writeFloatArray(query);
+        if (out.getTransportVersion().onOrAfter(TransportVersions.KNN_EXPLICIT_BYTE_QUERY_VECTOR_PARSING)) {
+            out.writeOptionalWriteable(query);
+        } else {
+            out.writeFloatArray(query.asFloatVector());
+        }
         out.writeString(field);
     }
 
@@ -96,12 +113,12 @@ public class ExactKnnQueryBuilder extends AbstractQueryBuilder<ExactKnnQueryBuil
 
     @Override
     protected boolean doEquals(ExactKnnQueryBuilder other) {
-        return field.equals(other.field) && Arrays.equals(query, other.query);
+        return field.equals(other.field) && Objects.equals(query, other.query);
     }
 
     @Override
     protected int doHashCode() {
-        return Objects.hash(field, Arrays.hashCode(query));
+        return Objects.hash(field, Objects.hashCode(query));
     }
 
     @Override

+ 24 - 6
server/src/main/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilder.java

@@ -36,7 +36,7 @@ public class KnnScoreDocQueryBuilder extends AbstractQueryBuilder<KnnScoreDocQue
     public static final String NAME = "knn_score_doc";
     private final ScoreDoc[] scoreDocs;
     private final String fieldName;
-    private final float[] queryVector;
+    private final VectorData queryVector;
 
     /**
      * Creates a query builder.
@@ -45,6 +45,16 @@ public class KnnScoreDocQueryBuilder extends AbstractQueryBuilder<KnnScoreDocQue
      *                  sorted in order of ascending doc IDs.
      */
     public KnnScoreDocQueryBuilder(ScoreDoc[] scoreDocs, String fieldName, float[] queryVector) {
+        this(scoreDocs, fieldName, VectorData.fromFloats(queryVector));
+    }
+
+    /**
+     * Creates a query builder.
+     *
+     * @param scoreDocs the docs and scores this query should match. The array must be
+     *                  sorted in order of ascending doc IDs.
+     */
+    public KnnScoreDocQueryBuilder(ScoreDoc[] scoreDocs, String fieldName, VectorData queryVector) {
         this.scoreDocs = scoreDocs;
         this.fieldName = fieldName;
         this.queryVector = queryVector;
@@ -56,7 +66,11 @@ public class KnnScoreDocQueryBuilder extends AbstractQueryBuilder<KnnScoreDocQue
         if (in.getTransportVersion().onOrAfter(TransportVersions.NESTED_KNN_MORE_INNER_HITS)) {
             this.fieldName = in.readOptionalString();
             if (in.readBoolean()) {
-                this.queryVector = in.readFloatArray();
+                if (in.getTransportVersion().onOrAfter(TransportVersions.KNN_EXPLICIT_BYTE_QUERY_VECTOR_PARSING)) {
+                    this.queryVector = in.readOptionalWriteable(VectorData::new);
+                } else {
+                    this.queryVector = VectorData.fromFloats(in.readFloatArray());
+                }
             } else {
                 this.queryVector = null;
             }
@@ -79,7 +93,7 @@ public class KnnScoreDocQueryBuilder extends AbstractQueryBuilder<KnnScoreDocQue
         return fieldName;
     }
 
-    float[] queryVector() {
+    VectorData queryVector() {
         return queryVector;
     }
 
@@ -90,7 +104,11 @@ public class KnnScoreDocQueryBuilder extends AbstractQueryBuilder<KnnScoreDocQue
             out.writeOptionalString(fieldName);
             if (queryVector != null) {
                 out.writeBoolean(true);
-                out.writeFloatArray(queryVector);
+                if (out.getTransportVersion().onOrAfter(TransportVersions.KNN_EXPLICIT_BYTE_QUERY_VECTOR_PARSING)) {
+                    out.writeOptionalWriteable(queryVector);
+                } else {
+                    out.writeFloatArray(queryVector.asFloatVector());
+                }
             } else {
                 out.writeBoolean(false);
             }
@@ -175,7 +193,7 @@ public class KnnScoreDocQueryBuilder extends AbstractQueryBuilder<KnnScoreDocQue
                 return false;
             }
         }
-        return Objects.equals(fieldName, other.fieldName) && Arrays.equals(queryVector, other.queryVector);
+        return Objects.equals(fieldName, other.fieldName) && Objects.equals(queryVector, other.queryVector);
     }
 
     @Override
@@ -185,7 +203,7 @@ public class KnnScoreDocQueryBuilder extends AbstractQueryBuilder<KnnScoreDocQue
             int hashCode = Objects.hash(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex);
             result = 31 * result + hashCode;
         }
-        return Objects.hash(result, fieldName, Arrays.hashCode(queryVector));
+        return Objects.hash(result, fieldName, Objects.hashCode(queryVector));
     }
 
     @Override

+ 49 - 27
server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java

@@ -27,7 +27,6 @@ import org.elasticsearch.xcontent.XContentParser;
 
 import java.io.IOException;
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.List;
 import java.util.Objects;
 import java.util.function.Supplier;
@@ -59,18 +58,8 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
     @SuppressWarnings("unchecked")
     private static final ConstructingObjectParser<KnnSearchBuilder.Builder, Void> PARSER = new ConstructingObjectParser<>("knn", args -> {
         // TODO optimize parsing for when BYTE values are provided
-        List<Float> vector = (List<Float>) args[1];
-        final float[] vectorArray;
-        if (vector != null) {
-            vectorArray = new float[vector.size()];
-            for (int i = 0; i < vector.size(); i++) {
-                vectorArray[i] = vector.get(i);
-            }
-        } else {
-            vectorArray = null;
-        }
         return new Builder().field((String) args[0])
-            .queryVector(vectorArray)
+            .queryVector((VectorData) args[1])
             .queryVectorBuilder((QueryVectorBuilder) args[4])
             .k((Integer) args[2])
             .numCandidates((Integer) args[3])
@@ -79,9 +68,15 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
 
     static {
         PARSER.declareString(constructorArg(), FIELD_FIELD);
-        PARSER.declareFloatArray(optionalConstructorArg(), QUERY_VECTOR_FIELD);
+        PARSER.declareField(
+            optionalConstructorArg(),
+            (p, c) -> VectorData.parseXContent(p),
+            QUERY_VECTOR_FIELD,
+            ObjectParser.ValueType.OBJECT_ARRAY_STRING_OR_NUMBER
+        );
         PARSER.declareInt(optionalConstructorArg(), K_FIELD);
         PARSER.declareInt(optionalConstructorArg(), NUM_CANDS_FIELD);
+
         PARSER.declareNamedObject(
             optionalConstructorArg(),
             (p, c, n) -> p.namedObject(QueryVectorBuilder.class, n, c),
@@ -108,7 +103,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
     }
 
     final String field;
-    final float[] queryVector;
+    final VectorData queryVector;
     final QueryVectorBuilder queryVectorBuilder;
     private final Supplier<float[]> querySupplier;
     final int k;
@@ -127,7 +122,26 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
      * @param numCands    the number of nearest neighbor candidates to consider per shard
      */
     public KnnSearchBuilder(String field, float[] queryVector, int k, int numCands, Float similarity) {
-        this(field, Objects.requireNonNull(queryVector, format("[%s] cannot be null", QUERY_VECTOR_FIELD)), null, k, numCands, similarity);
+        this(
+            field,
+            Objects.requireNonNull(VectorData.fromFloats(queryVector), format("[%s] cannot be null", QUERY_VECTOR_FIELD)),
+            null,
+            k,
+            numCands,
+            similarity
+        );
+    }
+
+    /**
+     * Defines a kNN search.
+     *
+     * @param field       the name of the vector field to search against
+     * @param queryVector the query vector
+     * @param k           the final number of nearest neighbors to return as top hits
+     * @param numCands    the number of nearest neighbor candidates to consider per shard
+     */
+    public KnnSearchBuilder(String field, VectorData queryVector, int k, int numCands, Float similarity) {
+        this(field, queryVector, null, k, numCands, similarity);
     }
 
     /**
@@ -151,7 +165,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
 
     public KnnSearchBuilder(
         String field,
-        float[] queryVector,
+        VectorData queryVector,
         QueryVectorBuilder queryVectorBuilder,
         int k,
         int numCands,
@@ -169,7 +183,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
         Float similarity
     ) {
         this.field = field;
-        this.queryVector = new float[0];
+        this.queryVector = VectorData.fromFloats(new float[0]);
         this.queryVectorBuilder = null;
         this.k = k;
         this.numCands = numCands;
@@ -181,7 +195,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
     private KnnSearchBuilder(
         String field,
         QueryVectorBuilder queryVectorBuilder,
-        float[] queryVector,
+        VectorData queryVector,
         List<QueryBuilder> filterQueries,
         int k,
         int numCandidates,
@@ -219,7 +233,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
             );
         }
         this.field = field;
-        this.queryVector = queryVector == null ? new float[0] : queryVector;
+        this.queryVector = queryVector == null ? VectorData.fromFloats(new float[0]) : queryVector;
         this.queryVectorBuilder = queryVectorBuilder;
         this.k = k;
         this.numCands = numCandidates;
@@ -234,7 +248,11 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
         this.field = in.readString();
         this.k = in.readVInt();
         this.numCands = in.readVInt();
-        this.queryVector = in.readFloatArray();
+        if (in.getTransportVersion().onOrAfter(TransportVersions.KNN_EXPLICIT_BYTE_QUERY_VECTOR_PARSING)) {
+            this.queryVector = in.readOptionalWriteable(VectorData::new);
+        } else {
+            this.queryVector = VectorData.fromFloats(in.readFloatArray());
+        }
         this.filterQueries = in.readNamedWriteableCollectionAsList(QueryBuilder.class);
         this.boost = in.readFloat();
         if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_7_0)) {
@@ -262,7 +280,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
     }
 
     // for testing only
-    public float[] getQueryVector() {
+    public VectorData getQueryVector() {
         return queryVector;
     }
 
@@ -365,7 +383,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
         return k == that.k
             && numCands == that.numCands
             && Objects.equals(field, that.field)
-            && Arrays.equals(queryVector, that.queryVector)
+            && Objects.equals(queryVector, that.queryVector)
             && Objects.equals(queryVectorBuilder, that.queryVectorBuilder)
             && Objects.equals(querySupplier, that.querySupplier)
             && Objects.equals(filterQueries, that.filterQueries)
@@ -383,7 +401,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
             querySupplier,
             queryVectorBuilder,
             similarity,
-            Arrays.hashCode(queryVector),
+            Objects.hashCode(queryVector),
             Objects.hashCode(filterQueries),
             innerHitBuilder,
             boost
@@ -401,7 +419,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
             builder.field(queryVectorBuilder.getWriteableName(), queryVectorBuilder);
             builder.endObject();
         } else {
-            builder.array(QUERY_VECTOR_FIELD.getPreferredName(), queryVector);
+            builder.field(QUERY_VECTOR_FIELD.getPreferredName(), queryVector);
         }
         if (similarity != null) {
             builder.field(VECTOR_SIMILARITY.getPreferredName(), similarity);
@@ -434,7 +452,11 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
         out.writeString(field);
         out.writeVInt(k);
         out.writeVInt(numCands);
-        out.writeFloatArray(queryVector);
+        if (out.getTransportVersion().onOrAfter(TransportVersions.KNN_EXPLICIT_BYTE_QUERY_VECTOR_PARSING)) {
+            out.writeOptionalWriteable(queryVector);
+        } else {
+            out.writeFloatArray(queryVector.asFloatVector());
+        }
         out.writeNamedWriteableCollection(filterQueries);
         out.writeFloat(boost);
         if (out.getTransportVersion().before(TransportVersions.V_8_7_0) && queryVectorBuilder != null) {
@@ -460,7 +482,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
     public static class Builder {
 
         private String field;
-        private float[] queryVector;
+        private VectorData queryVector;
         private QueryVectorBuilder queryVectorBuilder;
         private Integer k;
         private Integer numCandidates;
@@ -490,7 +512,7 @@ public class KnnSearchBuilder implements Writeable, ToXContentFragment, Rewritea
             return this;
         }
 
-        public Builder queryVector(float[] queryVector) {
+        public Builder queryVector(VectorData queryVector) {
             this.queryVector = queryVector;
             return this;
         }

+ 42 - 30
server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java

@@ -37,7 +37,6 @@ import org.elasticsearch.xcontent.XContentParser;
 
 import java.io.IOException;
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.List;
 import java.util.Objects;
 
@@ -62,23 +61,19 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
     public static final ParseField FILTER_FIELD = new ParseField("filter");
 
     @SuppressWarnings("unchecked")
-    public static final ConstructingObjectParser<KnnVectorQueryBuilder, Void> PARSER = new ConstructingObjectParser<>("knn", args -> {
-        List<Float> vector = (List<Float>) args[1];
-        final float[] vectorArray;
-        if (vector != null) {
-            vectorArray = new float[vector.size()];
-            for (int i = 0; i < vector.size(); i++) {
-                vectorArray[i] = vector.get(i);
-            }
-        } else {
-            vectorArray = null;
-        }
-        return new KnnVectorQueryBuilder((String) args[0], vectorArray, (Integer) args[2], (Float) args[3]);
-    });
+    public static final ConstructingObjectParser<KnnVectorQueryBuilder, Void> PARSER = new ConstructingObjectParser<>(
+        "knn",
+        args -> new KnnVectorQueryBuilder((String) args[0], (VectorData) args[1], (Integer) args[2], (Float) args[3])
+    );
 
     static {
         PARSER.declareString(constructorArg(), FIELD_FIELD);
-        PARSER.declareFloatArray(constructorArg(), QUERY_VECTOR_FIELD);
+        PARSER.declareField(
+            optionalConstructorArg(),
+            (p, c) -> VectorData.parseXContent(p),
+            QUERY_VECTOR_FIELD,
+            ObjectParser.ValueType.OBJECT_ARRAY_STRING_OR_NUMBER
+        );
         PARSER.declareInt(optionalConstructorArg(), NUM_CANDS_FIELD);
         PARSER.declareFloat(optionalConstructorArg(), VECTOR_SIMILARITY_FIELD);
         PARSER.declareFieldArray(
@@ -95,12 +90,20 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
     }
 
     private final String fieldName;
-    private final float[] queryVector;
+    private final VectorData queryVector;
     private Integer numCands;
     private final List<QueryBuilder> filterQueries = new ArrayList<>();
     private final Float vectorSimilarity;
 
     public KnnVectorQueryBuilder(String fieldName, float[] queryVector, Integer numCands, Float vectorSimilarity) {
+        this(fieldName, VectorData.fromFloats(queryVector), numCands, vectorSimilarity);
+    }
+
+    public KnnVectorQueryBuilder(String fieldName, byte[] queryVector, Integer numCands, Float vectorSimilarity) {
+        this(fieldName, VectorData.fromBytes(queryVector), numCands, vectorSimilarity);
+    }
+
+    public KnnVectorQueryBuilder(String fieldName, VectorData queryVector, Integer numCands, Float vectorSimilarity) {
         if (numCands != null && numCands > NUM_CANDS_LIMIT) {
             throw new IllegalArgumentException("[" + NUM_CANDS_FIELD.getPreferredName() + "] cannot exceed [" + NUM_CANDS_LIMIT + "]");
         }
@@ -121,12 +124,17 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
         } else {
             this.numCands = in.readVInt();
         }
-        if (in.getTransportVersion().before(TransportVersions.V_8_7_0) || in.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) {
-            this.queryVector = in.readFloatArray();
+        if (in.getTransportVersion().onOrAfter(TransportVersions.KNN_EXPLICIT_BYTE_QUERY_VECTOR_PARSING)) {
+            this.queryVector = in.readOptionalWriteable(VectorData::new);
         } else {
-            in.readBoolean();
-            this.queryVector = in.readFloatArray();
-            in.readBoolean(); // used for byteQueryVector, which was always null
+            if (in.getTransportVersion().before(TransportVersions.V_8_7_0)
+                || in.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) {
+                this.queryVector = VectorData.fromFloats(in.readFloatArray());
+            } else {
+                in.readBoolean();
+                this.queryVector = VectorData.fromFloats(in.readFloatArray());
+                in.readBoolean(); // used for byteQueryVector, which was always null
+            }
         }
         if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_2_0)) {
             this.filterQueries.addAll(readQueries(in));
@@ -143,7 +151,7 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
     }
 
     @Nullable
-    public float[] queryVector() {
+    public VectorData queryVector() {
         return queryVector;
     }
 
@@ -190,13 +198,17 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
                 out.writeVInt(numCands);
             }
         }
-        if (out.getTransportVersion().before(TransportVersions.V_8_7_0)
-            || out.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) {
-            out.writeFloatArray(queryVector);
+        if (out.getTransportVersion().onOrAfter(TransportVersions.KNN_EXPLICIT_BYTE_QUERY_VECTOR_PARSING)) {
+            out.writeOptionalWriteable(queryVector);
         } else {
-            out.writeBoolean(true);
-            out.writeFloatArray(queryVector);
-            out.writeBoolean(false); // used for byteQueryVector, which was always null
+            if (out.getTransportVersion().before(TransportVersions.V_8_7_0)
+                || out.getTransportVersion().onOrAfter(TransportVersions.V_8_12_0)) {
+                out.writeFloatArray(queryVector.asFloatVector());
+            } else {
+                out.writeBoolean(true);
+                out.writeFloatArray(queryVector.asFloatVector());
+                out.writeBoolean(false); // used for byteQueryVector, which was always null
+            }
         }
         if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_2_0)) {
             writeQueries(out, filterQueries);
@@ -326,13 +338,13 @@ public class KnnVectorQueryBuilder extends AbstractQueryBuilder<KnnVectorQueryBu
 
     @Override
     protected int doHashCode() {
-        return Objects.hash(fieldName, Arrays.hashCode(queryVector), numCands, filterQueries, vectorSimilarity);
+        return Objects.hash(fieldName, Objects.hashCode(queryVector), numCands, filterQueries, vectorSimilarity);
     }
 
     @Override
     protected boolean doEquals(KnnVectorQueryBuilder other) {
         return Objects.equals(fieldName, other.fieldName)
-            && Arrays.equals(queryVector, other.queryVector)
+            && Objects.equals(queryVector, other.queryVector)
             && Objects.equals(numCands, other.numCands)
             && Objects.equals(filterQueries, other.filterQueries)
             && Objects.equals(vectorSimilarity, other.vectorSimilarity);

+ 168 - 0
server/src/main/java/org/elasticsearch/search/vectors/VectorData.java

@@ -0,0 +1,168 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.search.vectors;
+
+import org.elasticsearch.common.ParsingException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
+import org.elasticsearch.xcontent.ToXContentFragment;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HexFormat;
+import java.util.List;
+import java.util.Objects;
+
+import static org.elasticsearch.common.Strings.format;
+
+public record VectorData(float[] floatVector, byte[] byteVector) implements Writeable, ToXContentFragment {
+
+    private VectorData(float[] floatVector) {
+        this(floatVector, null);
+    }
+
+    private VectorData(byte[] byteVector) {
+        this(null, byteVector);
+    }
+
+    public VectorData(StreamInput in) throws IOException {
+        this(in.readOptionalFloatArray(), in.readOptionalByteArray());
+    }
+
+    public VectorData {
+        if (false == (floatVector == null ^ byteVector == null)) {
+            throw new IllegalArgumentException("please supply exactly either a float or a byte vector");
+        }
+    }
+
+    public byte[] asByteVector() {
+        if (byteVector != null) {
+            return byteVector;
+        }
+        DenseVectorFieldMapper.ElementType.BYTE.checkVectorBounds(floatVector);
+        byte[] vec = new byte[floatVector.length];
+        for (int i = 0; i < floatVector.length; i++) {
+            vec[i] = (byte) floatVector[i];
+        }
+        return vec;
+    }
+
+    public float[] asFloatVector() {
+        if (floatVector != null) {
+            return floatVector;
+        }
+        float[] vec = new float[byteVector.length];
+        for (int i = 0; i < byteVector.length; i++) {
+            vec[i] = byteVector[i];
+        }
+        return vec;
+    }
+
+    public void addToBuffer(ByteBuffer byteBuffer) {
+        if (floatVector != null) {
+            for (float val : floatVector) {
+                byteBuffer.putFloat(val);
+            }
+        } else {
+            byteBuffer.put(byteVector);
+        }
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeOptionalFloatArray(floatVector);
+        out.writeOptionalByteArray(byteVector);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        if (floatVector != null) {
+            builder.startArray();
+            for (float v : floatVector) {
+                builder.value(v);
+            }
+            builder.endArray();
+        } else {
+            builder.value(HexFormat.of().formatHex(byteVector));
+        }
+        return builder;
+    }
+
+    @Override
+    public String toString() {
+        return floatVector != null ? Arrays.toString(floatVector) : Arrays.toString(byteVector);
+    }
+
+    @Override
+    public boolean equals(Object obj) {
+        if (this == obj) {
+            return true;
+        }
+        if (obj == null || getClass() != obj.getClass()) {
+            return false;
+        }
+        VectorData other = (VectorData) obj;
+        return Arrays.equals(floatVector, other.floatVector) && Arrays.equals(byteVector, other.byteVector);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(Arrays.hashCode(floatVector), Arrays.hashCode(byteVector));
+    }
+
+    public static VectorData parseXContent(XContentParser parser) throws IOException {
+        XContentParser.Token token = parser.currentToken();
+        return switch (token) {
+            case START_ARRAY -> parseQueryVectorArray(parser);
+            case VALUE_STRING -> parseHexEncodedVector(parser);
+            case VALUE_NUMBER -> parseNumberVector(parser);
+            default -> throw new ParsingException(parser.getTokenLocation(), format("Unknown type [%s] for parsing vector", token));
+        };
+    }
+
+    private static VectorData parseQueryVectorArray(XContentParser parser) throws IOException {
+        XContentParser.Token token;
+        List<Float> vectorArr = new ArrayList<>();
+        while ((token = parser.nextToken()) != XContentParser.Token.END_ARRAY) {
+            if (token == XContentParser.Token.VALUE_NUMBER || token == XContentParser.Token.VALUE_STRING) {
+                vectorArr.add(parser.floatValue());
+            } else {
+                throw new ParsingException(parser.getTokenLocation(), format("Type [%s] not supported for query vector", token));
+            }
+        }
+        float[] floatVector = new float[vectorArr.size()];
+        for (int i = 0; i < vectorArr.size(); i++) {
+            floatVector[i] = vectorArr.get(i);
+        }
+        return VectorData.fromFloats(floatVector);
+    }
+
+    private static VectorData parseHexEncodedVector(XContentParser parser) throws IOException {
+        return VectorData.fromBytes(HexFormat.of().parseHex(parser.text()));
+    }
+
+    private static VectorData parseNumberVector(XContentParser parser) throws IOException {
+        return VectorData.fromFloats(new float[] { parser.floatValue() });
+    }
+
+    public static VectorData fromFloats(float[] vec) {
+        return vec == null ? null : new VectorData(vec);
+    }
+
+    public static VectorData fromBytes(byte[] vec) {
+        return vec == null ? null : new VectorData(vec);
+    }
+
+}

+ 3 - 4
server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java

@@ -25,6 +25,7 @@ import org.elasticsearch.index.mapper.FieldTypeTestCase;
 import org.elasticsearch.index.mapper.MappedFieldType;
 import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.DenseVectorFieldType;
 import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.VectorSimilarity;
+import org.elasticsearch.search.vectors.VectorData;
 
 import java.io.IOException;
 import java.util.Collections;
@@ -179,7 +180,7 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
             for (int i = 0; i < dims; i++) {
                 queryVector[i] = randomFloat();
             }
-            Query query = field.createExactKnnQuery(queryVector);
+            Query query = field.createExactKnnQuery(VectorData.fromFloats(queryVector));
             assertTrue(query instanceof BooleanQuery);
             BooleanQuery booleanQuery = (BooleanQuery) query;
             boolean foundFunction = false;
@@ -202,12 +203,10 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
                 Collections.emptyMap()
             );
             byte[] queryVector = new byte[dims];
-            float[] floatQueryVector = new float[dims];
             for (int i = 0; i < dims; i++) {
                 queryVector[i] = randomByte();
-                floatQueryVector[i] = queryVector[i];
             }
-            Query query = field.createExactKnnQuery(floatQueryVector);
+            Query query = field.createExactKnnQuery(VectorData.fromBytes(queryVector));
             assertTrue(query instanceof BooleanQuery);
             BooleanQuery booleanQuery = (BooleanQuery) query;
             boolean foundFunction = false;

+ 24 - 29
server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java

@@ -43,10 +43,12 @@ import static org.hamcrest.Matchers.instanceOf;
 abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCase<KnnVectorQueryBuilder> {
     private static final String VECTOR_FIELD = "vector";
     private static final String VECTOR_ALIAS_FIELD = "vector_alias";
-    private static final int VECTOR_DIMENSION = 3;
+    static final int VECTOR_DIMENSION = 3;
 
     abstract DenseVectorFieldMapper.ElementType elementType();
 
+    abstract KnnVectorQueryBuilder createKnnVectorQueryBuilder(String fieldName, int numCands, Float similarity);
+
     @Override
     protected void initializeAdditionalMappings(MapperService mapperService) throws IOException {
         XContentBuilder builder = XContentFactory.jsonBuilder()
@@ -75,12 +77,9 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
     @Override
     protected KnnVectorQueryBuilder doCreateTestQueryBuilder() {
         String fieldName = randomBoolean() ? VECTOR_FIELD : VECTOR_ALIAS_FIELD;
-        float[] vector = new float[VECTOR_DIMENSION];
-        for (int i = 0; i < vector.length; i++) {
-            vector[i] = elementType().equals(DenseVectorFieldMapper.ElementType.BYTE) ? randomByte() : randomFloat();
-        }
         int numCands = randomIntBetween(DEFAULT_SIZE, 1000);
-        KnnVectorQueryBuilder queryBuilder = new KnnVectorQueryBuilder(fieldName, vector, numCands, randomBoolean() ? null : randomFloat());
+        KnnVectorQueryBuilder queryBuilder = createKnnVectorQueryBuilder(fieldName, numCands, randomBoolean() ? null : randomFloat());
+
         if (randomBoolean()) {
             List<QueryBuilder> filters = new ArrayList<>();
             int numFilters = randomIntBetween(1, 5);
@@ -120,11 +119,16 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
         Query knnVectorQueryBuilt = switch (elementType()) {
             case BYTE -> new ESKnnByteVectorQuery(
                 VECTOR_FIELD,
-                getByteQueryVector(queryBuilder.queryVector()),
+                queryBuilder.queryVector().asByteVector(),
+                queryBuilder.numCands(),
+                filterQuery
+            );
+            case FLOAT -> new ESKnnFloatVectorQuery(
+                VECTOR_FIELD,
+                queryBuilder.queryVector().asFloatVector(),
                 queryBuilder.numCands(),
                 filterQuery
             );
-            case FLOAT -> new ESKnnFloatVectorQuery(VECTOR_FIELD, queryBuilder.queryVector(), queryBuilder.numCands(), filterQuery);
         };
         if (query instanceof VectorSimilarityQuery vectorSimilarityQuery) {
             query = vectorSimilarityQuery.getInnerKnnQuery();
@@ -193,7 +197,8 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
 
     public void testBWCVersionSerializationFilters() throws IOException {
         KnnVectorQueryBuilder query = createTestQueryBuilder();
-        KnnVectorQueryBuilder queryNoFilters = new KnnVectorQueryBuilder(query.getFieldName(), query.queryVector(), query.numCands(), null)
+        VectorData vectorData = VectorData.fromFloats(query.queryVector().asFloatVector());
+        KnnVectorQueryBuilder queryNoFilters = new KnnVectorQueryBuilder(query.getFieldName(), vectorData, query.numCands(), null)
             .queryName(query.queryName())
             .boost(query.boost());
         TransportVersion beforeFilterVersion = TransportVersionUtils.randomVersionBetween(
@@ -206,12 +211,11 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
 
     public void testBWCVersionSerializationSimilarity() throws IOException {
         KnnVectorQueryBuilder query = createTestQueryBuilder();
-        KnnVectorQueryBuilder queryNoSimilarity = new KnnVectorQueryBuilder(
-            query.getFieldName(),
-            query.queryVector(),
-            query.numCands(),
-            null
-        ).queryName(query.queryName()).boost(query.boost()).addFilterQueries(query.filterQueries());
+        VectorData vectorData = VectorData.fromFloats(query.queryVector().asFloatVector());
+        KnnVectorQueryBuilder queryNoSimilarity = new KnnVectorQueryBuilder(query.getFieldName(), vectorData, query.numCands(), null)
+            .queryName(query.queryName())
+            .boost(query.boost())
+            .addFilterQueries(query.filterQueries());
         assertBWCSerialization(query, queryNoSimilarity, TransportVersions.V_8_7_0);
     }
 
@@ -223,12 +227,11 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
             TransportVersions.V_8_12_0
         );
         Float similarity = differentQueryVersion.before(TransportVersions.V_8_8_0) ? null : query.getVectorSimilarity();
-        KnnVectorQueryBuilder queryOlderVersion = new KnnVectorQueryBuilder(
-            query.getFieldName(),
-            query.queryVector(),
-            query.numCands(),
-            similarity
-        ).queryName(query.queryName()).boost(query.boost()).addFilterQueries(query.filterQueries());
+        VectorData vectorData = VectorData.fromFloats(query.queryVector().asFloatVector());
+        KnnVectorQueryBuilder queryOlderVersion = new KnnVectorQueryBuilder(query.getFieldName(), vectorData, query.numCands(), similarity)
+            .queryName(query.queryName())
+            .boost(query.boost())
+            .addFilterQueries(query.filterQueries());
         assertBWCSerialization(query, queryOlderVersion, differentQueryVersion);
     }
 
@@ -245,12 +248,4 @@ abstract class AbstractKnnVectorQueryBuilderTestCase extends AbstractQueryTestCa
             }
         }
     }
-
-    private static byte[] getByteQueryVector(float[] queryVector) {
-        byte[] byteQueryVector = new byte[queryVector.length];
-        for (int i = 0; i < queryVector.length; i++) {
-            byteQueryVector[i] = (byte) queryVector[i];
-        }
-        return byteQueryVector;
-    }
 }

+ 9 - 0
server/src/test/java/org/elasticsearch/search/vectors/KnnByteVectorQueryBuilderTests.java

@@ -15,4 +15,13 @@ public class KnnByteVectorQueryBuilderTests extends AbstractKnnVectorQueryBuilde
     DenseVectorFieldMapper.ElementType elementType() {
         return DenseVectorFieldMapper.ElementType.BYTE;
     }
+
+    @Override
+    protected KnnVectorQueryBuilder createKnnVectorQueryBuilder(String fieldName, int numCands, Float similarity) {
+        byte[] vector = new byte[VECTOR_DIMENSION];
+        for (int i = 0; i < vector.length; i++) {
+            vector[i] = randomByte();
+        }
+        return new KnnVectorQueryBuilder(fieldName, vector, numCands, similarity);
+    }
 }

+ 10 - 1
server/src/test/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilderTests.java → server/src/test/java/org/elasticsearch/search/vectors/KnnFloatVectorQueryBuilderTests.java

@@ -10,9 +10,18 @@ package org.elasticsearch.search.vectors;
 
 import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
 
-public class KnnVectorQueryBuilderTests extends AbstractKnnVectorQueryBuilderTestCase {
+public class KnnFloatVectorQueryBuilderTests extends AbstractKnnVectorQueryBuilderTestCase {
     @Override
     DenseVectorFieldMapper.ElementType elementType() {
         return DenseVectorFieldMapper.ElementType.FLOAT;
     }
+
+    @Override
+    KnnVectorQueryBuilder createKnnVectorQueryBuilder(String fieldName, int numCands, Float similarity) {
+        float[] vector = new float[VECTOR_DIMENSION];
+        for (int i = 0; i < vector.length; i++) {
+            vector[i] = randomFloat();
+        }
+        return new KnnVectorQueryBuilder(fieldName, vector, numCands, similarity);
+    }
 }

+ 2 - 2
server/src/test/java/org/elasticsearch/search/vectors/KnnSearchBuilderTests.java

@@ -106,7 +106,7 @@ public class KnnSearchBuilderTests extends AbstractXContentSerializingTestCase<K
                     instance.boost
                 );
             case 1:
-                float[] newVector = randomValueOtherThan(instance.queryVector, () -> randomVector(5));
+                float[] newVector = randomValueOtherThan(instance.queryVector.asFloatVector(), () -> randomVector(5));
                 return new KnnSearchBuilder(instance.field, newVector, instance.k, instance.numCands, instance.similarity).boost(
                     instance.boost
                 );
@@ -213,7 +213,7 @@ public class KnnSearchBuilderTests extends AbstractXContentSerializingTestCase<K
 
         assertThat(rewritten.field, equalTo(searchBuilder.field));
         assertThat(rewritten.boost, equalTo(searchBuilder.boost));
-        assertThat(rewritten.queryVector, equalTo(expectedArray));
+        assertThat(rewritten.queryVector.asFloatVector(), equalTo(expectedArray));
         assertThat(rewritten.queryVectorBuilder, nullValue());
         assertThat(rewritten.filterQueries, hasSize(1));
         assertThat(rewritten.similarity, equalTo(1f));

+ 199 - 0
server/src/test/java/org/elasticsearch/search/vectors/VectorDataTests.java

@@ -0,0 +1,199 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.search.vectors;
+
+import org.elasticsearch.common.ParsingException;
+import org.elasticsearch.common.bytes.BytesArray;
+import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xcontent.XContentParserConfiguration;
+import org.elasticsearch.xcontent.XContentType;
+
+import java.io.IOException;
+
+import static org.hamcrest.Matchers.containsString;
+
+public class VectorDataTests extends ESTestCase {
+
+    private static final float DELTA = 1e-5f;
+
+    public void testThrowsIfBothVectorsAreNull() {
+        IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> new VectorData(null, null));
+        assertThat(ex.getMessage(), containsString("please supply exactly either a float or a byte vector"));
+    }
+
+    public void testThrowsIfBothVectorsAreNonNull() {
+        IllegalArgumentException ex = expectThrows(
+            IllegalArgumentException.class,
+            () -> new VectorData(new float[] { 0f }, new byte[] { 1 })
+        );
+        assertThat(ex.getMessage(), containsString("please supply exactly either a float or a byte vector"));
+    }
+
+    public void testShouldCorrectlyConvertByteToFloatIfExplicitlyRequested() {
+        byte[] byteVector = new byte[] { 1, 2, -127 };
+        float[] expected = new float[] { 1f, 2f, -127f };
+
+        VectorData vectorData = new VectorData(null, byteVector);
+        float[] actual = vectorData.asFloatVector();
+        assertArrayEquals(expected, actual, DELTA);
+    }
+
+    public void testShouldThrowForDecimalsWhenConvertingToByte() {
+        float[] vec = new float[] { 1f, 2f, 3.1f };
+
+        VectorData vectorData = new VectorData(vec, null);
+        expectThrows(IllegalArgumentException.class, vectorData::asByteVector);
+    }
+
+    public void testShouldThrowForOutsideRangeWhenConvertingToByte() {
+        float[] vec = new float[] { 1f, 2f, 200f };
+
+        VectorData vectorData = new VectorData(vec, null);
+        expectThrows(IllegalArgumentException.class, vectorData::asByteVector);
+    }
+
+    public void testEqualsAndHashCode() {
+        VectorData v1 = new VectorData(new float[] { 1, 2, 3 }, null);
+        VectorData v2 = new VectorData(null, new byte[] { 1, 2, 3 });
+        assertNotEquals(v1, v2);
+        assertNotEquals(v1.hashCode(), v2.hashCode());
+
+        VectorData v3 = new VectorData(null, new byte[] { 1, 2, 3 });
+        assertEquals(v2, v3);
+        assertEquals(v2.hashCode(), v3.hashCode());
+    }
+
+    public void testParseHexCorrectly() throws IOException {
+        byte[] expected = new byte[] { 64, 10, -30, 10 };
+        String toParse = "\"400ae20a\"";
+        try (
+            XContentParser parser = XContentHelper.createParserNotCompressed(
+                XContentParserConfiguration.EMPTY,
+                new BytesArray(toParse),
+                XContentType.JSON
+            )
+        ) {
+            parser.nextToken();
+            VectorData parsed = VectorData.parseXContent(parser);
+            assertArrayEquals(expected, parsed.asByteVector());
+        }
+    }
+
+    public void testParseFloatArray() throws IOException {
+        float[] expected = new float[] { 1f, -1f, .1f };
+        String toParse = "[1.0, -1.0, 0.1]";
+        try (
+            XContentParser parser = XContentHelper.createParserNotCompressed(
+                XContentParserConfiguration.EMPTY,
+                new BytesArray(toParse),
+                XContentType.JSON
+            )
+        ) {
+            parser.nextToken();
+            VectorData parsed = VectorData.parseXContent(parser);
+            assertArrayEquals(expected, parsed.asFloatVector(), DELTA);
+        }
+    }
+
+    public void testParseByteArray() throws IOException {
+        byte[] expected = new byte[] { 64, 10, -30, 10 };
+        String toParse = "[64,10,-30,10]";
+        try (
+            XContentParser parser = XContentHelper.createParserNotCompressed(
+                XContentParserConfiguration.EMPTY,
+                new BytesArray(toParse),
+                XContentType.JSON
+            )
+        ) {
+            parser.nextToken();
+            VectorData parsed = VectorData.parseXContent(parser);
+            assertArrayEquals(expected, parsed.asByteVector());
+        }
+    }
+
+    public void testByteThrowsForOutsideRange() throws IOException {
+        String toParse = "[1000]";
+        try (
+            XContentParser parser = XContentHelper.createParserNotCompressed(
+                XContentParserConfiguration.EMPTY,
+                new BytesArray(toParse),
+                XContentType.JSON
+            )
+        ) {
+            parser.nextToken();
+            VectorData parsed = VectorData.parseXContent(parser);
+            IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, parsed::asByteVector);
+            assertThat(ex.getMessage(), containsString("vectors only support integers between [-128, 127]"));
+        }
+    }
+
+    public void testAsByteThrowsForDecimals() throws IOException {
+        String toParse = "[0.1]";
+        try (
+            XContentParser parser = XContentHelper.createParserNotCompressed(
+                XContentParserConfiguration.EMPTY,
+                new BytesArray(toParse),
+                XContentType.JSON
+            )
+        ) {
+            parser.nextToken();
+            VectorData parsed = VectorData.parseXContent(parser);
+            IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, parsed::asByteVector);
+            assertThat(ex.getMessage(), containsString("vectors only support non-decimal values but found decimal value"));
+        }
+    }
+
+    public void testParseSingleNumber() throws IOException {
+        float[] expected = new float[] { 0.1f };
+        String toParse = "0.1";
+        try (
+            XContentParser parser = XContentHelper.createParserNotCompressed(
+                XContentParserConfiguration.EMPTY,
+                new BytesArray(toParse),
+                XContentType.JSON
+            )
+        ) {
+            parser.nextToken();
+            VectorData parsed = VectorData.parseXContent(parser);
+            assertArrayEquals(expected, parsed.asFloatVector(), DELTA);
+        }
+    }
+
+    public void testParseThrowsForUnknown() throws IOException {
+        String unknown = "{\"foo\":\"bar\"}";
+        try (
+            XContentParser parser = XContentHelper.createParser(
+                XContentParserConfiguration.EMPTY,
+                new BytesArray(unknown),
+                XContentType.JSON
+            )
+        ) {
+            parser.nextToken();
+            ParsingException ex = expectThrows(ParsingException.class, () -> VectorData.parseXContent(parser));
+            assertThat(ex.getMessage(), containsString("Unknown type [" + XContentParser.Token.START_OBJECT + "] for parsing vector"));
+        }
+    }
+
+    public void testFailForUnknownArrayValue() throws IOException {
+        String toParse = "[0.1, true]";
+        try (
+            XContentParser parser = XContentHelper.createParserNotCompressed(
+                XContentParserConfiguration.EMPTY,
+                new BytesArray(toParse),
+                XContentType.JSON
+            )
+        ) {
+            parser.nextToken();
+            ParsingException ex = expectThrows(ParsingException.class, () -> VectorData.parseXContent(parser));
+            assertThat(ex.getMessage(), containsString("Type [" + XContentParser.Token.VALUE_BOOLEAN + "] not supported for query vector"));
+        }
+    }
+}

+ 1 - 1
test/framework/src/main/java/org/elasticsearch/test/AbstractQueryVectorBuilderTestCase.java

@@ -132,7 +132,7 @@ public abstract class AbstractQueryVectorBuilderTestCase<T extends QueryVectorBu
                 PlainActionFuture<KnnSearchBuilder> future = new PlainActionFuture<>();
                 Rewriteable.rewriteAndFetch(randomFrom(serialized, searchBuilder), context, future);
                 KnnSearchBuilder rewritten = future.get();
-                assertThat(rewritten.getQueryVector(), equalTo(expected));
+                assertThat(rewritten.getQueryVector().asFloatVector(), equalTo(expected));
                 assertThat(rewritten.getQueryVectorBuilder(), nullValue());
             }
         }