Explorar o código

Script: Fields API for Dense Vector (#83550)

Adds the fields API for `dense_vector` field mapper.

Adds a `DenseVector` interface for the value type.

Implemented by:
 * `KnnDenseVector` which wraps a decoded float array from `VectorValues`
 * `BinaryDenseVector` which lazily decodes a `BytesRef` from `BinaryDocValues`

The vector operations have moved into those implements from `BinaryDenseVectorScriptDocValues.java` and  `KnnDenseVectorScriptDocValues.java`, respectively.

The `DenseVector` API is:
```
float getMagnitude();
double dotProduct(float[] | List);
double l1Norm(float[] | List);
double l2Norm(float[] | List);
float[] getVector();
int dims();

boolean isEmpty(); // does the value exist
int size();        // 0 if isEmpty(), 1 otherwise
Iterator<Float> iterator()
```

`dotProduct`, `l1Norm` and `l2Norm` take a `float[]` or a `List` via the
a delegating `default` method on the `DenseVector` interface.

The `DenseVectorDocValuesField` abstract class contains two getter APIS.
It is implemented by  `KnnDenseVectorDocValuesField` and
`BinaryDenseVectorDocValuesField`.

```
DenseVector get()
DenseVector get(DenseVector defaultValue)
```

The `get()` method is included because there isn't a good default dense vector,
so that API returns an empty `DenseVector` which throws an
`IllegalArgumentException` for all method calls other than `isEmpty()`,
`size()` and `iterator()`.

The empty dense vector will always be `DenseVector.EMPTY` in case users want
to use equality checks.

Refs: #79105
Stuart Tettemer %!s(int64=3) %!d(string=hai) anos
pai
achega
b44fcfbb3a
Modificáronse 19 ficheiros con 1883 adicións e 390 borrados
  1. 5 0
      docs/changelog/83550.yaml
  2. 848 0
      x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/vectors/60_knn_and_binary_dv_fields_api.yml
  3. 141 0
      x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/BinaryDenseVector.java
  4. 70 0
      x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/BinaryDenseVectorDocValuesField.java
  5. 0 119
      x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/BinaryDenseVectorScriptDocValues.java
  6. 227 0
      x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/DenseVector.java
  7. 51 0
      x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/DenseVectorDocValuesField.java
  8. 40 50
      x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/DenseVectorScriptDocValues.java
  9. 4 1
      x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/DocValuesWhitelistExtension.java
  10. 109 0
      x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/KnnDenseVector.java
  11. 79 0
      x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/KnnDenseVectorDocValuesField.java
  12. 0 122
      x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/KnnDenseVectorScriptDocValues.java
  13. 13 21
      x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java
  14. 6 28
      x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/VectorDVLeafFieldData.java
  15. 37 0
      x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/org.elasticsearch.xpack.vectors.txt
  16. 51 15
      x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/BinaryDenseVectorScriptDocValuesTests.java
  17. 55 19
      x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/DenseVectorFunctionTests.java
  18. 84 0
      x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/DenseVectorTests.java
  19. 63 15
      x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/KnnDenseVectorScriptDocValuesTests.java

+ 5 - 0
docs/changelog/83550.yaml

@@ -0,0 +1,5 @@
+pr: 83550
+summary: "Script: Fields API for Dense Vector"
+area: Infra/Scripting
+type: enhancement
+issues: []

+ 848 - 0
x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/vectors/60_knn_and_binary_dv_fields_api.yml

@@ -0,0 +1,848 @@
+---
+"size and isEmpty code works for any vector, including empty":
+  - skip:
+      version: " - 8.1.99"
+      reason: "Fields API for dense vector added in 8.2"
+
+  - do:
+      indices.create:
+        index: test-index
+        body:
+          mappings:
+            properties:
+              bdv:
+                type: dense_vector
+                dims: 3
+              knn:
+                type: dense_vector
+                dims: 3
+                index: true
+                similarity: l2_norm
+  - do:
+      bulk:
+        index: test-index
+        refresh: true
+        body:
+          - '{"index": {"_id": "1"}}'
+          - '{"bdv": [1, 1, 1], "knn": [1, 1, 1]}'
+          - '{"index": {"_id": "2"}}'
+          - '{"bdv": [1, 1, 2], "knn": [1, 1, 2]}'
+          - '{"index": {"_id": "3"}}'
+          - '{"bdv": [1, 1, 3], "knn": [1, 1, 3]}'
+          - '{"index": {"_id": "missing_vector"}}'
+          - '{}'
+
+  - do:
+      search:
+        body:
+          query:
+            script_score:
+              query: { match_all: {} }
+              script:
+                source: |
+                  def dv = field(params.field).get();
+                  if (dv.isEmpty()) {
+                    return dv.size();
+                  }
+                  return dv.vector[2] * dv.size()
+                params:
+                  field: bdv
+
+  - match: { hits.hits.0._id: "3" }
+  - match: { hits.hits.0._score: 3 }
+  - match: { hits.hits.1._id: "2" }
+  - match: { hits.hits.1._score: 2 }
+  - match: { hits.hits.2._id: "1" }
+  - match: { hits.hits.2._score: 1 }
+  - match: { hits.hits.3._id: "missing_vector" }
+  - match: { hits.hits.3._score: 0 }
+
+  - do:
+      search:
+        body:
+          query:
+            script_score:
+              query: { match_all: {} }
+              script:
+                source: |
+                  def dv = field(params.field).get();
+                  if (dv.isEmpty()) {
+                    return dv.size();
+                  }
+                  return dv.vector[2] * dv.size()
+                params:
+                  field: knn
+
+  - match: { hits.hits.0._id: "3" }
+  - match: { hits.hits.0._score: 3 }
+  - match: { hits.hits.1._id: "2" }
+  - match: { hits.hits.1._score: 2 }
+  - match: { hits.hits.2._id: "1" }
+  - match: { hits.hits.2._score: 1 }
+  - match: { hits.hits.3._id: "missing_vector" }
+  - match: { hits.hits.3._score: 0 }
+
+---
+"null can be used for default value":
+  - skip:
+      version: " - 8.1.99"
+      reason: "Fields API for dense vector added in 8.2"
+
+  - do:
+      indices.create:
+        index: test-index
+        body:
+          mappings:
+            properties:
+              bdv:
+                type: dense_vector
+                dims: 3
+              knn:
+                type: dense_vector
+                dims: 3
+                index: true
+                similarity: l2_norm
+  - do:
+      bulk:
+        index: test-index
+        refresh: true
+        body:
+          - '{"index": {"_id": "1"}}'
+          - '{"bdv": [1, 1, 1], "knn": [1, 1, 1]}'
+          - '{"index": {"_id": "2"}}'
+          - '{"bdv": [1, 1, 2], "knn": [1, 1, 2]}'
+          - '{"index": {"_id": "3"}}'
+          - '{"bdv": [1, 1, 3], "knn": [1, 1, 3]}'
+          - '{"index": {"_id": "missing_vector"}}'
+          - '{}'
+
+  - do:
+      search:
+        body:
+          query:
+            script_score:
+              query: { match_all: {} }
+              script:
+                source: |
+                  DenseVector dv = field(params.field).get(null);
+                  if (dv == null) {
+                    return 1;
+                  }
+                  return dv.vector[2];
+                params:
+                  field: bdv
+
+  - match: { hits.hits.0._id: "3" }
+  - match: { hits.hits.0._score: 3 }
+  - match: { hits.hits.1._id: "2" }
+  - match: { hits.hits.1._score: 2 }
+  - match: { hits.hits.2._id: "1" }
+  - match: { hits.hits.2._score: 1 }
+  - match: { hits.hits.3._id: "missing_vector" }
+  - match: { hits.hits.3._score: 1 }
+
+  - do:
+      search:
+        body:
+          query:
+            script_score:
+              query: { match_all: {} }
+              script:
+                source: |
+                  DenseVector dv = field(params.field).get(null);
+                  if (dv == null) {
+                    return 1;
+                  }
+                  return dv.vector[2];
+                params:
+                  field: knn
+
+  - match: { hits.hits.0._id: "3" }
+  - match: { hits.hits.0._score: 3 }
+  - match: { hits.hits.1._id: "2" }
+  - match: { hits.hits.1._score: 2 }
+  - match: { hits.hits.2._id: "1" }
+  - match: { hits.hits.2._score: 1 }
+  - match: { hits.hits.3._id: "missing_vector" }
+  - match: { hits.hits.3._score: 1 }
+
+---
+"empty dense vector throws for vector accesses":
+  - skip:
+      version: " - 8.1.99"
+      reason: "Fields API for dense vector added in 8.2"
+
+  - do:
+      indices.create:
+        index: test-index
+        body:
+          mappings:
+            properties:
+              bdv:
+                type: dense_vector
+                dims: 3
+              knn:
+                type: dense_vector
+                dims: 3
+                index: true
+                similarity: l2_norm
+  - do:
+      bulk:
+        index: test-index
+        refresh: true
+        body:
+          - '{"index": {"_id": "1"}}'
+          - '{"bdv": [1, 1, 1], "knn": [1, 1, 1]}'
+          - '{"index": {"_id": "2"}}'
+          - '{"bdv": [1, 1, 2], "knn": [1, 1, 2]}'
+          - '{"index": {"_id": "3"}}'
+          - '{"bdv": [1, 1, 3], "knn": [1, 1, 3]}'
+          - '{"index": {"_id": "missing_vector"}}'
+          - '{}'
+
+  - do:
+      catch: bad_request
+      search:
+        body:
+          query:
+            script_score:
+              query: { "bool": { "must_not": { "exists": { "field": "bdv" } } } }
+              script:
+                source: |
+                  field(params.field).get().vector[2]
+                params:
+                  field: bdv
+
+  - match: { error.failed_shards.0.reason.caused_by.type: "illegal_argument_exception" }
+  - match: { error.failed_shards.0.reason.caused_by.reason: "Dense vector value missing for a field, use isEmpty() to check for a missing vector value" }
+
+  - do:
+      catch: bad_request
+      search:
+        body:
+          query:
+            script_score:
+              query: { "bool": { "must_not": { "exists": { "field": "bdv" } } } }
+              script:
+                source: |
+                  field(params.field).get().vector[2]
+                params:
+                  field: knn
+
+  - match: { error.failed_shards.0.reason.caused_by.type: "illegal_argument_exception" }
+  - match: { error.failed_shards.0.reason.caused_by.reason: "Dense vector value missing for a field, use isEmpty() to check for a missing vector value" }
+
+  - do:
+      search:
+        body:
+          query:
+            script_score:
+              query: { "bool": { "must_not": { "exists": { "field": "bdv" } } } }
+              script:
+                source: |
+                  float[] q = new float[1];
+                  q[0] = 3;
+                  DenseVector dv = field(params.field).get();
+                  float score = 0;
+                  try { score += dv.magnitude } catch (IllegalArgumentException e) { score += 10; }
+                  try { score += dv.dotProduct(q) } catch (IllegalArgumentException e) { score += 200; }
+                  try { score += dv.l1Norm(q) } catch (IllegalArgumentException e) { score += 3000; }
+                  try { score += dv.l2Norm(q) } catch (IllegalArgumentException e) { score += 40000; }
+                  try { score += dv.vector[0] } catch (IllegalArgumentException e) { score += 500000; }
+                  try { score += dv.dims } catch (IllegalArgumentException e) { score += 6000000; }
+                  return score;
+                params:
+                  field: bdv
+
+  - match: { hits.hits.0._id: "missing_vector" }
+  - match: { hits.hits.0._score: 6543210 }
+
+  - do:
+      search:
+        body:
+          query:
+            script_score:
+              query: { "bool": { "must_not": { "exists": { "field": "bdv" } } } }
+              script:
+                source: |
+                  float[] q = new float[1];
+                  q[0] = 3;
+                  DenseVector dv = field(params.field).get();
+                  float score = 0;
+                  try { score += dv.magnitude } catch (IllegalArgumentException e) { score += 10; }
+                  try { score += dv.dotProduct(q) } catch (IllegalArgumentException e) { score += 200; }
+                  try { score += dv.l1Norm(q) } catch (IllegalArgumentException e) { score += 3000; }
+                  try { score += dv.l2Norm(q) } catch (IllegalArgumentException e) { score += 40000; }
+                  try { score += dv.cosineSimilarity(q) } catch (IllegalArgumentException e) { score += 200000; }
+                  try { score += dv.vector[0] } catch (IllegalArgumentException e) { score += 500000; }
+                  try { score += dv.dims } catch (IllegalArgumentException e) { score += 6000000; }
+                  return score;
+                params:
+                  field: knn
+
+  - match: { hits.hits.0._id: "missing_vector" }
+  - match: { hits.hits.0._score: 6743210 }
+
+---
+"dot product works on dense vectors":
+  - skip:
+      version: " - 8.1.99"
+      reason: "Fields API for dense vector added in 8.2"
+
+  - do:
+      indices.create:
+        index: test-index
+        body:
+          mappings:
+            properties:
+              bdv:
+                type: dense_vector
+                dims: 3
+              knn:
+                type: dense_vector
+                dims: 3
+                index: true
+                similarity: l2_norm
+  - do:
+      bulk:
+        index: test-index
+        refresh: true
+        body:
+          - '{"index": {"_id": "1"}}'
+          - '{"bdv": [1, 1, 1], "knn": [1, 1, 1]}'
+          - '{"index": {"_id": "2"}}'
+          - '{"bdv": [1, 1, 2], "knn": [1, 1, 2]}'
+          - '{"index": {"_id": "3"}}'
+          - '{"bdv": [1, 1, 3], "knn": [1, 1, 3]}'
+          - '{"index": {"_id": "missing_vector"}}'
+          - '{}'
+
+  - do:
+      search:
+        body:
+          query:
+            script_score:
+              query: { "exists": { "field": "bdv" } }
+              script:
+                source: |
+                  field(params.field).get().dotProduct(params.query)
+                params:
+                  query: [4, 5, 6]
+                  field: bdv
+
+  - match: { hits.hits.0._id: "3" }
+  - match: { hits.hits.0._score: 27 }
+  - match: { hits.hits.1._id: "2" }
+  - match: { hits.hits.1._score: 21 }
+  - match: { hits.hits.2._id: "1" }
+  - match: { hits.hits.2._score: 15 }
+
+  - do:
+      search:
+        body:
+          query:
+            script_score:
+              query: { "exists": { "field": "bdv" } }
+              script:
+                source: |
+                  float[] query = new float[3];
+                  query[0] = 4; query[1] = 5; query[2] = 6;
+                  field(params.field).get().dotProduct(query)
+                params:
+                  field: bdv
+
+  - match: { hits.hits.0._id: "3" }
+  - match: { hits.hits.0._score: 27 }
+  - match: { hits.hits.1._id: "2" }
+  - match: { hits.hits.1._score: 21 }
+  - match: { hits.hits.2._id: "1" }
+  - match: { hits.hits.2._score: 15 }
+
+  - do:
+      search:
+        body:
+          query:
+            script_score:
+              query: { "exists": { "field": "bdv" } }
+              script:
+                source: |
+                  field(params.field).get().dotProduct(params.query)
+                params:
+                  query: [4, 5, 6]
+                  field: knn
+
+  - match: { hits.hits.0._id: "3" }
+  - match: { hits.hits.0._score: 27 }
+  - match: { hits.hits.1._id: "2" }
+  - match: { hits.hits.1._score: 21 }
+  - match: { hits.hits.2._id: "1" }
+  - match: { hits.hits.2._score: 15 }
+
+  - do:
+      search:
+        body:
+          query:
+            script_score:
+              query: { "exists": { "field": "bdv" } }
+              script:
+                source: |
+                  float[] query = new float[3];
+                  query[0] = 4; query[1] = 5; query[2] = 6;
+                  field(params.field).get().dotProduct(query)
+                params:
+                  field: knn
+
+  - match: { hits.hits.0._id: "3" }
+  - match: { hits.hits.0._score: 27 }
+  - match: { hits.hits.1._id: "2" }
+  - match: { hits.hits.1._score: 21 }
+  - match: { hits.hits.2._id: "1" }
+  - match: { hits.hits.2._score: 15 }
+
+---
+"iterator over dense vector values":
+  - skip:
+      version: " - 8.1.99"
+      reason: "Fields API for dense vector added in 8.2"
+
+  - do:
+      indices.create:
+        index: test-index
+        body:
+          mappings:
+            properties:
+              bdv:
+                type: dense_vector
+                dims: 3
+              knn:
+                type: dense_vector
+                dims: 3
+                index: true
+                similarity: l2_norm
+  - do:
+      bulk:
+        index: test-index
+        refresh: true
+        body:
+          - '{"index": {"_id": "1"}}'
+          - '{"bdv": [1, 1, 1], "knn": [1, 1, 1]}'
+          - '{"index": {"_id": "2"}}'
+          - '{"bdv": [1, 1, 2], "knn": [1, 1, 2]}'
+          - '{"index": {"_id": "3"}}'
+          - '{"bdv": [1, 1, 3], "knn": [1, 1, 3]}'
+          - '{"index": {"_id": "missing_vector"}}'
+          - '{}'
+
+  - do:
+      catch: bad_request
+      search:
+        body:
+          query:
+            script_score:
+              query: { match_all: {} }
+              script:
+                source: |
+                  float sum = 0.0f;
+                  for (def v : field(params.field)) {
+                    sum += v;
+                  }
+                  return sum;
+                params:
+                  field: bdv
+
+  - match: { error.failed_shards.0.reason.caused_by.type: "unsupported_operation_exception" }
+  - match: { error.failed_shards.0.reason.caused_by.reason: "Cannot iterate over single valued dense_vector field, use get() instead" }
+
+  - do:
+      catch: bad_request
+      search:
+        body:
+          query:
+            script_score:
+              query: { match_all: {} }
+              script:
+                source: |
+                  float sum = 0.0f;
+                  for (def v : field(params.field)) {
+                    sum += v;
+                  }
+                  return sum;
+                params:
+                  field: knn
+
+  - match: { error.failed_shards.0.reason.caused_by.type: "unsupported_operation_exception" }
+  - match: { error.failed_shards.0.reason.caused_by.reason: "Cannot iterate over single valued dense_vector field, use get() instead"}
+
+---
+"l1Norm works on dense vectors":
+  - skip:
+      version: " - 8.1.99"
+      reason: "Fields API for dense vector added in 8.2"
+
+  - do:
+      indices.create:
+        index: test-index
+        body:
+          mappings:
+            properties:
+              bdv:
+                type: dense_vector
+                dims: 3
+              knn:
+                type: dense_vector
+                dims: 3
+                index: true
+                similarity: l2_norm
+  - do:
+      bulk:
+        index: test-index
+        refresh: true
+        body:
+          - '{"index": {"_id": "1"}}'
+          - '{"bdv": [1, 1, 1], "knn": [1, 1, 1]}'
+          - '{"index": {"_id": "2"}}'
+          - '{"bdv": [1, 1, 2], "knn": [1, 1, 2]}'
+          - '{"index": {"_id": "3"}}'
+          - '{"bdv": [1, 1, 3], "knn": [1, 1, 3]}'
+          - '{"index": {"_id": "missing_vector"}}'
+          - '{}'
+
+  - do:
+      search:
+        body:
+          query:
+            script_score:
+              query: { "exists": { "field": "bdv" } }
+              script:
+                source: |
+                  field(params.field).get().l1Norm(params.query)
+                params:
+                  query: [4, 5, 6]
+                  field: bdv
+
+  - match: { hits.hits.0._id: "1" }
+  - match: { hits.hits.0._score: 12 }
+
+  - do:
+      search:
+        body:
+          query:
+            script_score:
+              query: { "exists": { "field": "bdv" } }
+              script:
+                source: |
+                  float[] query = new float[3];
+                  query[0] = 4; query[1] = 5; query[2] = 6;
+                  field(params.field).get().l1Norm(query)
+                params:
+                  field: bdv
+
+  - match: { hits.hits.0._id: "1" }
+  - match: { hits.hits.0._score: 12 }
+
+  - do:
+      search:
+        body:
+          query:
+            script_score:
+              query: { "exists": { "field": "bdv" } }
+              script:
+                source: |
+                  field(params.field).get().l1Norm(params.query)
+                params:
+                  query: [4, 5, 6]
+                  field: knn
+
+  - match: { hits.hits.0._id: "1" }
+  - match: { hits.hits.0._score: 12 }
+
+  - do:
+      search:
+        body:
+          query:
+            script_score:
+              query: { "exists": { "field": "bdv" } }
+              script:
+                source: |
+                  float[] query = new float[3];
+                  query[0] = 4; query[1] = 5; query[2] = 6;
+                  field(params.field).get().l1Norm(query)
+                params:
+                  field: knn
+
+  - match: { hits.hits.0._id: "1" }
+  - match: { hits.hits.0._score: 12 }
+
+---
+"l2Norm works on dense vectors":
+  - skip:
+      version: " - 8.1.99"
+      reason: "Fields API for dense vector added in 8.2"
+
+  - do:
+      indices.create:
+        index: test-index
+        body:
+          mappings:
+            properties:
+              bdv:
+                type: dense_vector
+                dims: 3
+              knn:
+                type: dense_vector
+                dims: 3
+                index: true
+                similarity: l2_norm
+  - do:
+      bulk:
+        index: test-index
+        refresh: true
+        body:
+          - '{"index": {"_id": "1"}}'
+          - '{"bdv": [1, 1, 1], "knn": [1, 1, 1]}'
+          - '{"index": {"_id": "2"}}'
+          - '{"bdv": [1, 1, 2], "knn": [1, 1, 2]}'
+          - '{"index": {"_id": "3"}}'
+          - '{"bdv": [1, 1, 3], "knn": [1, 1, 3]}'
+          - '{"index": {"_id": "missing_vector"}}'
+          - '{}'
+
+  - do:
+      search:
+        body:
+          query:
+            script_score:
+              query: { "exists": { "field": "bdv" } }
+              script:
+                source: |
+                  (int) field(params.field).get().l2Norm(params.query)
+                params:
+                  query: [4, 5, 6]
+                  field: bdv
+
+  - match: { hits.hits.0._id: "1" }
+  - match: { hits.hits.0._score: 7 }
+  - match: { hits.hits.1._id: "2" }
+  - match: { hits.hits.1._score: 6 }
+  - match: { hits.hits.2._id: "3" }
+  - match: { hits.hits.2._score: 5 }
+
+  - do:
+      search:
+        body:
+          query:
+            script_score:
+              query: { "exists": { "field": "bdv" } }
+              script:
+                source: |
+                  float[] query = new float[3];
+                  query[0] = 4; query[1] = 5; query[2] = 6;
+                  (int) field(params.field).get().l2Norm(query)
+                params:
+                  field: bdv
+
+  - match: { hits.hits.0._id: "1" }
+  - match: { hits.hits.0._score: 7 }
+  - match: { hits.hits.1._id: "2" }
+  - match: { hits.hits.1._score: 6 }
+  - match: { hits.hits.2._id: "3" }
+  - match: { hits.hits.2._score: 5 }
+
+  - do:
+      search:
+        body:
+          query:
+            script_score:
+              query: { "exists": { "field": "bdv" } }
+              script:
+                source: |
+                  (int) field(params.field).get().l2Norm(params.query)
+                params:
+                  query: [4, 5, 6]
+                  field: knn
+
+  - match: { hits.hits.0._id: "1" }
+  - match: { hits.hits.0._score: 7 }
+  - match: { hits.hits.1._id: "2" }
+  - match: { hits.hits.1._score: 6 }
+  - match: { hits.hits.2._id: "3" }
+  - match: { hits.hits.2._score: 5 }
+
+  - do:
+      search:
+        body:
+          query:
+            script_score:
+              query: { "exists": { "field": "bdv" } }
+              script:
+                source: |
+                  float[] query = new float[3];
+                  query[0] = 4; query[1] = 5; query[2] = 6;
+                  (int) field(params.field).get().l2Norm(query)
+                params:
+                  field: knn
+
+  - match: { hits.hits.0._id: "1" }
+  - match: { hits.hits.0._score: 7 }
+  - match: { hits.hits.1._id: "2" }
+  - match: { hits.hits.1._score: 6 }
+  - match: { hits.hits.2._id: "3" }
+  - match: { hits.hits.2._score: 5 }
+
+---
+"cosineSimilarity works on dense vectors":
+  - skip:
+      version: " - 8.1.99"
+      reason: "Fields API for dense vector added in 8.2"
+
+  - do:
+      indices.create:
+        index: test-index
+        body:
+          mappings:
+            properties:
+              bdv:
+                type: dense_vector
+                dims: 3
+              knn:
+                type: dense_vector
+                dims: 3
+                index: true
+                similarity: l2_norm
+
+  - do:
+      bulk:
+        index: test-index
+        refresh: true
+        body:
+          - '{"index": {"_id": "1"}}'
+          - '{"bdv": [1, 1, 1], "knn": [1, 1, 1]}'
+          - '{"index": {"_id": "2"}}'
+          - '{"bdv": [1, 1, 2], "knn": [1, 1, 2]}'
+          - '{"index": {"_id": "3"}}'
+          - '{"bdv": [1, 1, 3], "knn": [1, 1, 3]}'
+          - '{"index": {"_id": "missing_vector"}}'
+          - '{}'
+
+  - do:
+      search:
+        body:
+          query:
+            script_score:
+              query: { "exists": { "field": "bdv" } }
+              script:
+                source: |
+                  float[] query = new float[3];
+                  query[0] = 4; query[1] = 5; query[2] = 6;
+                  (int) (field(params.field).get().cosineSimilarity(query) * 100.0f)
+                params:
+                  field: bdv
+
+  - match: { hits.hits.0._id: "1" }
+  - match: { hits.hits.0._score: 98 }
+  - match: { hits.hits.1._id: "2" }
+  - match: { hits.hits.1._score: 97 }
+  - match: { hits.hits.2._id: "3" }
+  - match: { hits.hits.2._score: 92 }
+
+  - do:
+      search:
+        body:
+          query:
+            script_score:
+              query: { "exists": { "field": "bdv" } }
+              script:
+                source: |
+                  (int) (field(params.field).get().cosineSimilarity(params.query) * 100.0f)
+                params:
+                  query: [4, 5, 6]
+                  field: knn
+
+  - match: { hits.hits.0._id: "1" }
+  - match: { hits.hits.0._score: 98 }
+  - match: { hits.hits.1._id: "2" }
+  - match: { hits.hits.1._score: 97 }
+  - match: { hits.hits.2._id: "3" }
+  - match: { hits.hits.2._score: 92 }
+
+  - do:
+      search:
+        body:
+          query:
+            script_score:
+              query: { "exists": { "field": "bdv" } }
+              script:
+                source: |
+                  (int) (field(params.field).get().cosineSimilarity(params.query) * 100.0f)
+                params:
+                  query: [4, 5, 6]
+                  field: bdv
+
+  - match: { hits.hits.0._id: "1" }
+  - match: { hits.hits.0._score: 98 }
+  - match: { hits.hits.1._id: "2" }
+  - match: { hits.hits.1._score: 97 }
+  - match: { hits.hits.2._id: "3" }
+  - match: { hits.hits.2._score: 92 }
+
+---
+"query vector of wrong type errors":
+  - skip:
+      version: " - 8.0.99"
+      reason: "Fields API for dense vector added in 8.2"
+
+  - do:
+      indices.create:
+        index: test-index
+        body:
+          mappings:
+            properties:
+              bdv:
+                type: dense_vector
+                dims: 3
+              knn:
+                type: dense_vector
+                dims: 3
+                index: true
+                similarity: l2_norm
+  - do:
+      bulk:
+        index: test-index
+        refresh: true
+        body:
+          - '{"index": {"_id": "1"}}'
+          - '{"bdv": [1, 1, 1], "knn": [1, 1, 1]}'
+
+  - do:
+      catch: bad_request
+      search:
+        body:
+          query:
+            script_score:
+              query: { "exists": { "field": "bdv" } }
+              script:
+                source: |
+                  (int) field(params.field).get().l2Norm(params.query)
+                params:
+                  query: "one, two, three"
+                  field: bdv
+
+  - match: { error.failed_shards.0.reason.caused_by.type: "illegal_argument_exception" }
+  - match: { error.failed_shards.0.reason.caused_by.reason: "Cannot use vector [one, two, three] with class [java.lang.String] as query vector" }
+
+  - do:
+      catch: bad_request
+      search:
+        body:
+          query:
+            script_score:
+              query: { "exists": { "field": "bdv" } }
+              script:
+                source: |
+                  (int) field(params.field).get().l2Norm(params.query)
+                params:
+                  query: "one, two, three"
+                  field: knn
+
+  - match: { error.failed_shards.0.reason.caused_by.type: "illegal_argument_exception" }
+  - match: { error.failed_shards.0.reason.caused_by.reason: "Cannot use vector [one, two, three] with class [java.lang.String] as query vector" }

+ 141 - 0
x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/BinaryDenseVector.java

@@ -0,0 +1,141 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.vectors.query;
+
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.Version;
+import org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder;
+
+import java.nio.ByteBuffer;
+import java.util.List;
+
+public class BinaryDenseVector implements DenseVector {
+    protected final BytesRef docVector;
+    protected final int dims;
+    protected final Version indexVersion;
+
+    protected float[] decodedDocVector;
+
+    public BinaryDenseVector(BytesRef docVector, int dims, Version indexVersion) {
+        this.docVector = docVector;
+        this.indexVersion = indexVersion;
+        this.dims = dims;
+    }
+
+    @Override
+    public float[] getVector() {
+        if (decodedDocVector == null) {
+            decodedDocVector = new float[dims];
+            VectorEncoderDecoder.decodeDenseVector(docVector, decodedDocVector);
+        }
+        return decodedDocVector;
+    }
+
+    @Override
+    public float getMagnitude() {
+        return VectorEncoderDecoder.getMagnitude(indexVersion, docVector);
+    }
+
+    @Override
+    public double dotProduct(float[] queryVector) {
+        ByteBuffer byteBuffer = wrap(docVector);
+
+        double dotProduct = 0;
+        for (float v : queryVector) {
+            dotProduct += byteBuffer.getFloat() * v;
+        }
+        return dotProduct;
+    }
+
+    @Override
+    public double dotProduct(List<Number> queryVector) {
+        ByteBuffer byteBuffer = wrap(docVector);
+
+        double dotProduct = 0;
+        for (int i = 0; i < queryVector.size(); i++) {
+            dotProduct += byteBuffer.getFloat() * queryVector.get(i).floatValue();
+        }
+        return dotProduct;
+    }
+
+    @Override
+    public double l1Norm(float[] queryVector) {
+        ByteBuffer byteBuffer = wrap(docVector);
+
+        double l1norm = 0;
+        for (float v : queryVector) {
+            l1norm += Math.abs(v - byteBuffer.getFloat());
+        }
+        return l1norm;
+    }
+
+    @Override
+    public double l1Norm(List<Number> queryVector) {
+        ByteBuffer byteBuffer = wrap(docVector);
+
+        double l1norm = 0;
+        for (int i = 0; i < queryVector.size(); i++) {
+            l1norm += Math.abs(queryVector.get(i).floatValue() - byteBuffer.getFloat());
+        }
+        return l1norm;
+    }
+
+    @Override
+    public double l2Norm(float[] queryVector) {
+        ByteBuffer byteBuffer = wrap(docVector);
+        double l2norm = 0;
+        for (float queryValue : queryVector) {
+            double diff = byteBuffer.getFloat() - queryValue;
+            l2norm += diff * diff;
+        }
+        return Math.sqrt(l2norm);
+    }
+
+    @Override
+    public double l2Norm(List<Number> queryVector) {
+        ByteBuffer byteBuffer = wrap(docVector);
+        double l2norm = 0;
+        for (Number number : queryVector) {
+            double diff = byteBuffer.getFloat() - number.floatValue();
+            l2norm += diff * diff;
+        }
+        return Math.sqrt(l2norm);
+    }
+
+    @Override
+    public double cosineSimilarity(float[] queryVector, boolean normalizeQueryVector) {
+        if (normalizeQueryVector) {
+            return dotProduct(queryVector) / (DenseVector.getMagnitude(queryVector) * getMagnitude());
+        }
+        return dotProduct(queryVector) / getMagnitude();
+    }
+
+    @Override
+    public double cosineSimilarity(List<Number> queryVector) {
+        return dotProduct(queryVector) / (DenseVector.getMagnitude(queryVector) * getMagnitude());
+    }
+
+    @Override
+    public int size() {
+        return 1;
+    }
+
+    @Override
+    public boolean isEmpty() {
+        return false;
+    }
+
+    @Override
+    public int getDims() {
+        return dims;
+    }
+
+    private static ByteBuffer wrap(BytesRef dv) {
+        return ByteBuffer.wrap(dv.bytes, dv.offset, dv.length);
+    }
+}

+ 70 - 0
x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/BinaryDenseVectorDocValuesField.java

@@ -0,0 +1,70 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.vectors.query;
+
+import org.apache.lucene.index.BinaryDocValues;
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.Version;
+
+import java.io.IOException;
+
+public class BinaryDenseVectorDocValuesField extends DenseVectorDocValuesField {
+
+    protected final BinaryDocValues input;
+    protected final Version indexVersion;
+    protected final int dims;
+    protected BytesRef value;
+
+    public BinaryDenseVectorDocValuesField(BinaryDocValues input, String name, int dims, Version indexVersion) {
+        super(name);
+        this.input = input;
+        this.indexVersion = indexVersion;
+        this.dims = dims;
+    }
+
+    @Override
+    public void setNextDocId(int docId) throws IOException {
+        if (input.advanceExact(docId)) {
+            value = input.binaryValue();
+        } else {
+            value = null;
+        }
+    }
+
+    @Override
+    public DenseVectorScriptDocValues getScriptDocValues() {
+        return new DenseVectorScriptDocValues(this, dims);
+    }
+
+    @Override
+    public boolean isEmpty() {
+        return value == null;
+    }
+
+    @Override
+    public DenseVector get() {
+        if (isEmpty()) {
+            return DenseVector.EMPTY;
+        }
+
+        return new BinaryDenseVector(value, dims, indexVersion);
+    }
+
+    @Override
+    public DenseVector get(DenseVector defaultValue) {
+        if (isEmpty()) {
+            return defaultValue;
+        }
+        return new BinaryDenseVector(value, dims, indexVersion);
+    }
+
+    @Override
+    public DenseVector getInternal() {
+        return get(null);
+    }
+}

+ 0 - 119
x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/BinaryDenseVectorScriptDocValues.java

@@ -1,119 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.vectors.query;
-
-import org.apache.lucene.index.BinaryDocValues;
-import org.apache.lucene.util.BytesRef;
-import org.elasticsearch.Version;
-import org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder;
-
-import java.io.IOException;
-import java.nio.ByteBuffer;
-
-public class BinaryDenseVectorScriptDocValues extends DenseVectorScriptDocValues {
-
-    public static class BinaryDenseVectorSupplier implements DenseVectorSupplier<BytesRef> {
-
-        private final BinaryDocValues in;
-        private BytesRef value;
-
-        public BinaryDenseVectorSupplier(BinaryDocValues in) {
-            this.in = in;
-        }
-
-        @Override
-        public void setNextDocId(int docId) throws IOException {
-            if (in.advanceExact(docId)) {
-                value = in.binaryValue();
-            } else {
-                value = null;
-            }
-        }
-
-        @Override
-        public BytesRef getInternal(int index) {
-            throw new UnsupportedOperationException();
-        }
-
-        public BytesRef getInternal() {
-            return value;
-        }
-
-        @Override
-        public int size() {
-            if (value == null) {
-                return 0;
-            } else {
-                return 1;
-            }
-        }
-    }
-
-    private final BinaryDenseVectorSupplier bdvSupplier;
-    private final Version indexVersion;
-    private final float[] vector;
-
-    BinaryDenseVectorScriptDocValues(BinaryDenseVectorSupplier supplier, Version indexVersion, int dims) {
-        super(supplier, dims);
-        this.bdvSupplier = supplier;
-        this.indexVersion = indexVersion;
-        this.vector = new float[dims];
-    }
-
-    @Override
-    public int size() {
-        return supplier.size();
-    }
-
-    @Override
-    public float[] getVectorValue() {
-        VectorEncoderDecoder.decodeDenseVector(bdvSupplier.getInternal(), vector);
-        return vector;
-    }
-
-    @Override
-    public float getMagnitude() {
-        return VectorEncoderDecoder.getMagnitude(indexVersion, bdvSupplier.getInternal());
-    }
-
-    @Override
-    public double dotProduct(float[] queryVector) {
-        BytesRef value = bdvSupplier.getInternal();
-        ByteBuffer byteBuffer = ByteBuffer.wrap(value.bytes, value.offset, value.length);
-
-        double dotProduct = 0;
-        for (float queryValue : queryVector) {
-            dotProduct += queryValue * byteBuffer.getFloat();
-        }
-        return (float) dotProduct;
-    }
-
-    @Override
-    public double l1Norm(float[] queryVector) {
-        BytesRef value = bdvSupplier.getInternal();
-        ByteBuffer byteBuffer = ByteBuffer.wrap(value.bytes, value.offset, value.length);
-
-        double l1norm = 0;
-        for (float queryValue : queryVector) {
-            l1norm += Math.abs(queryValue - byteBuffer.getFloat());
-        }
-        return l1norm;
-    }
-
-    @Override
-    public double l2Norm(float[] queryVector) {
-        BytesRef value = bdvSupplier.getInternal();
-        ByteBuffer byteBuffer = ByteBuffer.wrap(value.bytes, value.offset, value.length);
-        double l2norm = 0;
-        for (float queryValue : queryVector) {
-            double diff = queryValue - byteBuffer.getFloat();
-            l2norm += diff * diff;
-        }
-        return Math.sqrt(l2norm);
-    }
-}

+ 227 - 0
x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/DenseVector.java

@@ -0,0 +1,227 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.vectors.query;
+
+import java.util.List;
+
+/**
+ * DenseVector value type for the painless.
+ */
+/* dotProduct, l1Norm, l2Norm, cosineSimilarity have three flavors depending on the type of the queryVector
+ * 1) float[], this is for the ScoreScriptUtils class bindings which have converted a List based query vector into an array
+ * 2) List, A painless script will typically use Lists since they are easy to pass as params and have an easy
+ *      literal syntax.  Working with Lists directly, instead of converting to a float[], trades off runtime operations against
+ *      memory pressure.  Dense Vectors may have high dimensionality, up to 2048.  Allocating a float[] per doc per script API
+ *      call is prohibitively expensive.
+ * 3) Object, the whitelisted method for the painless API.  Calls into the float[] or List version based on the
+        class of the argument and checks dimensionality.
+ */
+public interface DenseVector {
+    float[] getVector();
+
+    float getMagnitude();
+
+    double dotProduct(float[] queryVector);
+
+    double dotProduct(List<Number> queryVector);
+
+    @SuppressWarnings("unchecked")
+    default double dotProduct(Object queryVector) {
+        if (queryVector instanceof float[] array) {
+            checkDimensions(getDims(), array.length);
+            return dotProduct(array);
+
+        } else if (queryVector instanceof List<?> list) {
+            checkDimensions(getDims(), list.size());
+            return dotProduct((List<Number>) list);
+        }
+
+        throw new IllegalArgumentException(badQueryVectorType(queryVector));
+    }
+
+    double l1Norm(float[] queryVector);
+
+    double l1Norm(List<Number> queryVector);
+
+    @SuppressWarnings("unchecked")
+    default double l1Norm(Object queryVector) {
+        if (queryVector instanceof float[] array) {
+            checkDimensions(getDims(), array.length);
+            return l1Norm(array);
+
+        } else if (queryVector instanceof List<?> list) {
+            checkDimensions(getDims(), list.size());
+            return l1Norm((List<Number>) list);
+        }
+
+        throw new IllegalArgumentException(badQueryVectorType(queryVector));
+    }
+
+    double l2Norm(float[] queryVector);
+
+    double l2Norm(List<Number> queryVector);
+
+    @SuppressWarnings("unchecked")
+    default double l2Norm(Object queryVector) {
+        if (queryVector instanceof float[] array) {
+            checkDimensions(getDims(), array.length);
+            return l2Norm(array);
+
+        } else if (queryVector instanceof List<?> list) {
+            checkDimensions(getDims(), list.size());
+            return l2Norm((List<Number>) list);
+        }
+
+        throw new IllegalArgumentException(badQueryVectorType(queryVector));
+    }
+
+    /**
+     * Get the cosine similarity with the un-normalized query vector
+     */
+    default double cosineSimilarity(float[] queryVector) {
+        return cosineSimilarity(queryVector, true);
+    }
+
+    /**
+     * Get the cosine similarity with the query vector
+     * @param normalizeQueryVector - normalize the query vector, does not change the contents of passed in query vector
+     */
+    double cosineSimilarity(float[] queryVector, boolean normalizeQueryVector);
+
+    /**
+     * Get the cosine similarity with the un-normalized query vector
+     */
+    double cosineSimilarity(List<Number> queryVector);
+
+    /**
+     * Get the cosine similarity with the un-normalized query vector.  Handles queryVectors of type float[] and List.
+     */
+    @SuppressWarnings("unchecked")
+    default double cosineSimilarity(Object queryVector) {
+        if (queryVector instanceof float[] array) {
+            checkDimensions(getDims(), array.length);
+            return cosineSimilarity(array);
+
+        } else if (queryVector instanceof List<?> list) {
+            checkDimensions(getDims(), list.size());
+            return cosineSimilarity((List<Number>) list);
+        }
+
+        throw new IllegalArgumentException(badQueryVectorType(queryVector));
+    }
+
+    boolean isEmpty();
+
+    int getDims();
+
+    int size();
+
+    static float getMagnitude(float[] vector) {
+        double mag = 0.0f;
+        for (float elem : vector) {
+            mag += elem * elem;
+        }
+        return (float) Math.sqrt(mag);
+    }
+
+    static float getMagnitude(List<Number> vector) {
+        double mag = 0.0f;
+        for (Number number : vector) {
+            float elem = number.floatValue();
+            mag += elem * elem;
+        }
+        return (float) Math.sqrt(mag);
+    }
+
+    static void checkDimensions(int dvDims, int qvDims) {
+        if (dvDims != qvDims) {
+            throw new IllegalArgumentException(
+                "The query vector has a different number of dimensions [" + qvDims + "] than the document vectors [" + dvDims + "]."
+            );
+        }
+    }
+
+    private static String badQueryVectorType(Object queryVector) {
+        return "Cannot use vector [" + queryVector + "] with class [" + queryVector.getClass().getName() + "] as query vector";
+    }
+
+    DenseVector EMPTY = new DenseVector() {
+        public static final String MISSING_VECTOR_FIELD_MESSAGE = "Dense vector value missing for a field,"
+            + " use isEmpty() to check for a missing vector value";
+
+        @Override
+        public float getMagnitude() {
+            throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
+        }
+
+        @Override
+        public double dotProduct(float[] queryVector) {
+            throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
+        }
+
+        @Override
+        public double dotProduct(List<Number> queryVector) {
+            throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
+        }
+
+        @Override
+        public double l1Norm(List<Number> queryVector) {
+            throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
+        }
+
+        @Override
+        public double l1Norm(float[] queryVector) {
+            throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
+        }
+
+        @Override
+        public double l2Norm(List<Number> queryVector) {
+            throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
+        }
+
+        @Override
+        public double l2Norm(float[] queryVector) {
+            throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
+        }
+
+        @Override
+        public double cosineSimilarity(float[] queryVector) {
+            throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
+        }
+
+        @Override
+        public double cosineSimilarity(float[] queryVector, boolean normalizeQueryVector) {
+            throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
+        }
+
+        @Override
+        public double cosineSimilarity(List<Number> queryVector) {
+            throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
+        }
+
+        @Override
+        public float[] getVector() {
+            throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
+        }
+
+        @Override
+        public boolean isEmpty() {
+            return true;
+        }
+
+        @Override
+        public int getDims() {
+            throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
+        }
+
+        @Override
+        public int size() {
+            return 0;
+        }
+    };
+}

+ 51 - 0
x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/DenseVectorDocValuesField.java

@@ -0,0 +1,51 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.vectors.query;
+
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.script.field.DocValuesField;
+
+import java.util.Iterator;
+
+public abstract class DenseVectorDocValuesField implements DocValuesField<DenseVector>, DenseVectorScriptDocValues.DenseVectorSupplier {
+    protected final String name;
+
+    public DenseVectorDocValuesField(String name) {
+        this.name = name;
+    }
+
+    @Override
+    public String getName() {
+        return name;
+    }
+
+    @Override
+    public int size() {
+        return isEmpty() ? 0 : 1;
+    }
+
+    @Override
+    public BytesRef getInternal(int index) {
+        throw new UnsupportedOperationException();
+    }
+
+    /**
+     * Get the DenseVector for a document if one exists, DenseVector.EMPTY otherwise
+     */
+    public abstract DenseVector get();
+
+    public abstract DenseVector get(DenseVector defaultValue);
+
+    public abstract DenseVectorScriptDocValues getScriptDocValues();
+
+    // DenseVector fields are single valued, so Iterable does not make sense.
+    @Override
+    public Iterator<DenseVector> iterator() {
+        throw new UnsupportedOperationException("Cannot iterate over single valued dense_vector field, use get() instead");
+    }
+}

+ 40 - 50
x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/DenseVectorScriptDocValues.java

@@ -10,24 +10,16 @@ package org.elasticsearch.xpack.vectors.query;
 import org.apache.lucene.util.BytesRef;
 import org.elasticsearch.index.fielddata.ScriptDocValues;
 
-public abstract class DenseVectorScriptDocValues extends ScriptDocValues<BytesRef> {
-
-    public interface DenseVectorSupplier<T> extends Supplier<BytesRef> {
-
-        @Override
-        default BytesRef getInternal(int index) {
-            throw new UnsupportedOperationException();
-        }
-
-        T getInternal();
-    }
+public class DenseVectorScriptDocValues extends ScriptDocValues<BytesRef> {
 
     public static final String MISSING_VECTOR_FIELD_MESSAGE = "A document doesn't have a value for a vector field!";
 
     private final int dims;
+    protected final DenseVectorSupplier dvSupplier;
 
-    public DenseVectorScriptDocValues(DenseVectorSupplier<?> supplier, int dims) {
+    public DenseVectorScriptDocValues(DenseVectorSupplier supplier, int dims) {
         super(supplier);
+        this.dvSupplier = supplier;
         this.dims = dims;
     }
 
@@ -35,60 +27,58 @@ public abstract class DenseVectorScriptDocValues extends ScriptDocValues<BytesRe
         return dims;
     }
 
+    private DenseVector getCheckedVector() {
+        DenseVector vector = dvSupplier.getInternal();
+        if (vector == null) {
+            throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
+        }
+        return vector;
+    }
+
     /**
      * Get dense vector's value as an array of floats
      */
-    public abstract float[] getVectorValue();
+    public float[] getVectorValue() {
+        return getCheckedVector().getVector();
+    }
 
     /**
      * Get dense vector's magnitude
      */
-    public abstract float getMagnitude();
+    public float getMagnitude() {
+        return getCheckedVector().getMagnitude();
+    }
 
-    public abstract double dotProduct(float[] queryVector);
+    public double dotProduct(float[] queryVector) {
+        return getCheckedVector().dotProduct(queryVector);
+    }
 
-    public abstract double l1Norm(float[] queryVector);
+    public double l1Norm(float[] queryVector) {
+        return getCheckedVector().l1Norm(queryVector);
+    }
 
-    public abstract double l2Norm(float[] queryVector);
+    public double l2Norm(float[] queryVector) {
+        return getCheckedVector().l2Norm(queryVector);
+    }
 
     @Override
     public BytesRef get(int index) {
         throw new UnsupportedOperationException(
-            "accessing a vector field's value through 'get' or 'value' is not supported!" + "Use 'vectorValue' or 'magnitude' instead!'"
+            "accessing a vector field's value through 'get' or 'value' is not supported, use 'vectorValue' or 'magnitude' instead."
         );
     }
 
-    public static DenseVectorScriptDocValues empty(DenseVectorSupplier<?> supplier, int dims) {
-        return new DenseVectorScriptDocValues(supplier, dims) {
-            @Override
-            public float[] getVectorValue() {
-                throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
-            }
-
-            @Override
-            public float getMagnitude() {
-                throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
-            }
-
-            @Override
-            public double dotProduct(float[] queryVector) {
-                throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
-            }
-
-            @Override
-            public double l1Norm(float[] queryVector) {
-                throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
-            }
-
-            @Override
-            public double l2Norm(float[] queryVector) {
-                throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
-            }
-
-            @Override
-            public int size() {
-                return supplier.size();
-            }
-        };
+    @Override
+    public int size() {
+        return dvSupplier.getInternal() == null ? 0 : 1;
+    }
+
+    public interface DenseVectorSupplier extends Supplier<BytesRef> {
+        @Override
+        default BytesRef getInternal(int index) {
+            throw new UnsupportedOperationException();
+        }
+
+        DenseVector getInternal();
     }
 }

+ 4 - 1
x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/DocValuesWhitelistExtension.java

@@ -19,7 +19,10 @@ import java.util.Map;
 
 public class DocValuesWhitelistExtension implements PainlessExtension {
 
-    private static final Whitelist WHITELIST = WhitelistLoader.loadFromResourceFiles(DocValuesWhitelistExtension.class, "whitelist.txt");
+    private static final Whitelist WHITELIST = WhitelistLoader.loadFromResourceFiles(
+        DocValuesWhitelistExtension.class,
+        "org.elasticsearch.xpack.vectors.txt"
+    );
 
     @Override
     public Map<ScriptContext<?>, List<Whitelist>> getContextWhitelists() {

+ 109 - 0
x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/KnnDenseVector.java

@@ -0,0 +1,109 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.vectors.query;
+
+import org.apache.lucene.util.VectorUtil;
+
+import java.util.Arrays;
+import java.util.List;
+
+public class KnnDenseVector implements DenseVector {
+    protected final float[] docVector;
+
+    public KnnDenseVector(float[] docVector) {
+        this.docVector = docVector;
+    }
+
+    @Override
+    public float[] getVector() {
+        // we need to copy the value, since {@link VectorValues} can reuse
+        // the underlying array across documents
+        return Arrays.copyOf(docVector, docVector.length);
+    }
+
+    @Override
+    public float getMagnitude() {
+        return DenseVector.getMagnitude(docVector);
+    }
+
+    @Override
+    public double dotProduct(float[] queryVector) {
+        return VectorUtil.dotProduct(docVector, queryVector);
+    }
+
+    @Override
+    public double dotProduct(List<Number> queryVector) {
+        double dotProduct = 0;
+        for (int i = 0; i < docVector.length; i++) {
+            dotProduct += docVector[i] * queryVector.get(i).floatValue();
+        }
+        return dotProduct;
+    }
+
+    @Override
+    public double l1Norm(float[] queryVector) {
+        double result = 0.0;
+        for (int i = 0; i < docVector.length; i++) {
+            result += Math.abs(docVector[i] - queryVector[i]);
+        }
+        return result;
+    }
+
+    @Override
+    public double l1Norm(List<Number> queryVector) {
+        double result = 0.0;
+        for (int i = 0; i < docVector.length; i++) {
+            result += Math.abs(docVector[i] - queryVector.get(i).floatValue());
+        }
+        return result;
+    }
+
+    @Override
+    public double l2Norm(float[] queryVector) {
+        return Math.sqrt(VectorUtil.squareDistance(docVector, queryVector));
+    }
+
+    @Override
+    public double l2Norm(List<Number> queryVector) {
+        double l2norm = 0;
+        for (int i = 0; i < docVector.length; i++) {
+            double diff = docVector[i] - queryVector.get(i).floatValue();
+            l2norm += diff * diff;
+        }
+        return Math.sqrt(l2norm);
+    }
+
+    @Override
+    public double cosineSimilarity(float[] queryVector, boolean normalizeQueryVector) {
+        if (normalizeQueryVector) {
+            return dotProduct(queryVector) / (DenseVector.getMagnitude(queryVector) * getMagnitude());
+        }
+
+        return dotProduct(queryVector) / getMagnitude();
+    }
+
+    @Override
+    public double cosineSimilarity(List<Number> queryVector) {
+        return dotProduct(queryVector) / (DenseVector.getMagnitude(queryVector) * getMagnitude());
+    }
+
+    @Override
+    public boolean isEmpty() {
+        return false;
+    }
+
+    @Override
+    public int getDims() {
+        return docVector.length;
+    }
+
+    @Override
+    public int size() {
+        return 1;
+    }
+}

+ 79 - 0
x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/KnnDenseVectorDocValuesField.java

@@ -0,0 +1,79 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.vectors.query;
+
+import org.apache.lucene.index.VectorValues;
+import org.elasticsearch.core.Nullable;
+
+import java.io.IOException;
+
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+
+public class KnnDenseVectorDocValuesField extends DenseVectorDocValuesField {
+    protected VectorValues input; // null if no vectors
+    protected float[] vector;
+    protected final int dims;
+
+    public KnnDenseVectorDocValuesField(@Nullable VectorValues input, String name, int dims) {
+        super(name);
+        this.dims = dims;
+        this.input = input;
+    }
+
+    @Override
+    public void setNextDocId(int docId) throws IOException {
+        if (input == null) {
+            return;
+        }
+        int currentDoc = input.docID();
+        if (currentDoc == NO_MORE_DOCS || docId < currentDoc) {
+            vector = null;
+        } else if (docId == currentDoc) {
+            vector = input.vectorValue();
+        } else {
+            currentDoc = input.advance(docId);
+            if (currentDoc == docId) {
+                vector = input.vectorValue();
+            } else {
+                vector = null;
+            }
+        }
+    }
+
+    @Override
+    public DenseVectorScriptDocValues getScriptDocValues() {
+        return new DenseVectorScriptDocValues(this, dims);
+    }
+
+    public boolean isEmpty() {
+        return vector == null;
+    }
+
+    @Override
+    public DenseVector get() {
+        if (isEmpty()) {
+            return DenseVector.EMPTY;
+        }
+
+        return new KnnDenseVector(vector);
+    }
+
+    @Override
+    public DenseVector get(DenseVector defaultValue) {
+        if (isEmpty()) {
+            return defaultValue;
+        }
+
+        return new KnnDenseVector(vector);
+    }
+
+    @Override
+    public DenseVector getInternal() {
+        return get(null);
+    }
+}

+ 0 - 122
x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/KnnDenseVectorScriptDocValues.java

@@ -1,122 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.vectors.query;
-
-import org.apache.lucene.index.VectorValues;
-import org.apache.lucene.util.BytesRef;
-import org.apache.lucene.util.VectorUtil;
-
-import java.io.IOException;
-import java.util.Arrays;
-
-import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
-
-public class KnnDenseVectorScriptDocValues extends DenseVectorScriptDocValues {
-
-    public static class KnnDenseVectorSupplier implements DenseVectorSupplier<float[]> {
-
-        private final VectorValues in;
-        private float[] vector;
-
-        public KnnDenseVectorSupplier(VectorValues in) {
-            this.in = in;
-        }
-
-        @Override
-        public void setNextDocId(int docId) throws IOException {
-            int currentDoc = in.docID();
-            if (currentDoc == NO_MORE_DOCS || docId < currentDoc) {
-                vector = null;
-            } else if (docId == currentDoc) {
-                vector = in.vectorValue();
-            } else {
-                currentDoc = in.advance(docId);
-                if (currentDoc == docId) {
-                    vector = in.vectorValue();
-                } else {
-                    vector = null;
-                }
-            }
-        }
-
-        @Override
-        public BytesRef getInternal(int index) {
-            throw new UnsupportedOperationException();
-        }
-
-        public float[] getInternal() {
-            return vector;
-        }
-
-        @Override
-        public int size() {
-            if (vector == null) {
-                return 0;
-            } else {
-                return 1;
-            }
-        }
-    }
-
-    private final KnnDenseVectorSupplier kdvSupplier;
-
-    KnnDenseVectorScriptDocValues(KnnDenseVectorSupplier supplier, int dims) {
-        super(supplier, dims);
-        this.kdvSupplier = supplier;
-    }
-
-    private float[] getVectorChecked() {
-        if (kdvSupplier.getInternal() == null) {
-            throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
-        }
-        return kdvSupplier.getInternal();
-    }
-
-    @Override
-    public float[] getVectorValue() {
-        float[] vector = getVectorChecked();
-        // we need to copy the value, since {@link VectorValues} can reuse
-        // the underlying array across documents
-        return Arrays.copyOf(vector, vector.length);
-    }
-
-    @Override
-    public float getMagnitude() {
-        float[] vector = getVectorChecked();
-        double magnitude = 0.0f;
-        for (float elem : vector) {
-            magnitude += elem * elem;
-        }
-        return (float) Math.sqrt(magnitude);
-    }
-
-    @Override
-    public double dotProduct(float[] queryVector) {
-        return VectorUtil.dotProduct(getVectorChecked(), queryVector);
-    }
-
-    @Override
-    public double l1Norm(float[] queryVector) {
-        float[] vectorValue = getVectorChecked();
-        double result = 0.0;
-        for (int i = 0; i < queryVector.length; i++) {
-            result += Math.abs(vectorValue[i] - queryVector[i]);
-        }
-        return result;
-    }
-
-    @Override
-    public double l2Norm(float[] queryVector) {
-        return Math.sqrt(VectorUtil.squareDistance(getVectorValue(), queryVector));
-    }
-
-    @Override
-    public int size() {
-        return supplier.size();
-    }
-}

+ 13 - 21
x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java

@@ -18,10 +18,10 @@ public class ScoreScriptUtils {
     public static class DenseVectorFunction {
         final ScoreScript scoreScript;
         final float[] queryVector;
-        final DenseVectorScriptDocValues docValues;
+        final DenseVectorDocValuesField field;
 
-        public DenseVectorFunction(ScoreScript scoreScript, List<Number> queryVector, String field) {
-            this(scoreScript, queryVector, field, false);
+        public DenseVectorFunction(ScoreScript scoreScript, List<Number> queryVector, String fieldName) {
+            this(scoreScript, queryVector, fieldName, false);
         }
 
         /**
@@ -31,19 +31,10 @@ public class ScoreScriptUtils {
          * @param queryVector The query vector.
          * @param normalizeQuery Whether the provided query should be normalized to unit length.
          */
-        public DenseVectorFunction(ScoreScript scoreScript, List<Number> queryVector, String field, boolean normalizeQuery) {
+        public DenseVectorFunction(ScoreScript scoreScript, List<Number> queryVector, String fieldName, boolean normalizeQuery) {
             this.scoreScript = scoreScript;
-            this.docValues = (DenseVectorScriptDocValues) scoreScript.getDoc().get(field);
-
-            if (docValues.dims() != queryVector.size()) {
-                throw new IllegalArgumentException(
-                    "The query vector has a different number of dimensions ["
-                        + queryVector.size()
-                        + "] than the document vectors ["
-                        + docValues.dims()
-                        + "]."
-                );
-            }
+            this.field = (DenseVectorDocValuesField) scoreScript.field(fieldName);
+            DenseVector.checkDimensions(field.get().getDims(), queryVector.size());
 
             this.queryVector = new float[queryVector.size()];
             double queryMagnitude = 0.0;
@@ -63,11 +54,11 @@ public class ScoreScriptUtils {
 
         void setNextVector() {
             try {
-                docValues.getSupplier().setNextDocId(scoreScript._getDocId());
+                field.setNextDocId(scoreScript._getDocId());
             } catch (IOException e) {
                 throw ExceptionsHelper.convertToElastic(e);
             }
-            if (docValues.size() == 0) {
+            if (field.isEmpty()) {
                 throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
             }
         }
@@ -82,7 +73,7 @@ public class ScoreScriptUtils {
 
         public double l1norm() {
             setNextVector();
-            return docValues.l1Norm(queryVector);
+            return field.get().l1Norm(queryVector);
         }
     }
 
@@ -95,7 +86,7 @@ public class ScoreScriptUtils {
 
         public double l2norm() {
             setNextVector();
-            return docValues.l2Norm(queryVector);
+            return field.get().l2Norm(queryVector);
         }
     }
 
@@ -108,7 +99,7 @@ public class ScoreScriptUtils {
 
         public double dotProduct() {
             setNextVector();
-            return docValues.dotProduct(queryVector);
+            return field.get().dotProduct(queryVector);
         }
     }
 
@@ -121,7 +112,8 @@ public class ScoreScriptUtils {
 
         public double cosineSimilarity() {
             setNextVector();
-            return docValues.dotProduct(queryVector) / docValues.getMagnitude();
+            // query vector normalized in constructor
+            return field.get().cosineSimilarity(queryVector, false);
         }
     }
 }

+ 6 - 28
x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/VectorDVLeafFieldData.java

@@ -15,18 +15,12 @@ import org.apache.lucene.util.Accountable;
 import org.elasticsearch.Version;
 import org.elasticsearch.index.fielddata.LeafFieldData;
 import org.elasticsearch.index.fielddata.SortedBinaryDocValues;
-import org.elasticsearch.script.field.DelegateDocValuesField;
 import org.elasticsearch.script.field.DocValuesField;
-import org.elasticsearch.xpack.vectors.query.BinaryDenseVectorScriptDocValues.BinaryDenseVectorSupplier;
-import org.elasticsearch.xpack.vectors.query.DenseVectorScriptDocValues.DenseVectorSupplier;
-import org.elasticsearch.xpack.vectors.query.KnnDenseVectorScriptDocValues.KnnDenseVectorSupplier;
 
 import java.io.IOException;
 import java.util.Collection;
 import java.util.Collections;
 
-import static org.elasticsearch.xpack.vectors.query.DenseVectorScriptDocValues.MISSING_VECTOR_FIELD_MESSAGE;
-
 final class VectorDVLeafFieldData implements LeafFieldData {
 
     private final LeafReader reader;
@@ -63,31 +57,15 @@ final class VectorDVLeafFieldData implements LeafFieldData {
         try {
             if (indexed) {
                 VectorValues values = reader.getVectorValues(field);
-                if (values == null || values == VectorValues.EMPTY) {
-                    return new DelegateDocValuesField(DenseVectorScriptDocValues.empty(new DenseVectorSupplier<float[]>() {
-                        @Override
-                        public float[] getInternal() {
-                            throw new IllegalArgumentException(MISSING_VECTOR_FIELD_MESSAGE);
-                        }
-
-                        @Override
-                        public void setNextDocId(int docId) throws IOException {
-                            // do nothing
-                        }
-
-                        @Override
-                        public int size() {
-                            return 0;
-                        }
-                    }, dims), name);
+                if (values == VectorValues.EMPTY) {
+                    // There's no way for KnnDenseVectorDocValuesField to reliably differentiate between VectorValues.EMPTY and
+                    // values that can be iterated through. Since VectorValues.EMPTY throws on docID(), pass a null instead.
+                    values = null;
                 }
-                return new DelegateDocValuesField(new KnnDenseVectorScriptDocValues(new KnnDenseVectorSupplier(values), dims), name);
+                return new KnnDenseVectorDocValuesField(values, name, dims);
             } else {
                 BinaryDocValues values = DocValues.getBinary(reader, field);
-                return new DelegateDocValuesField(
-                    new BinaryDenseVectorScriptDocValues(new BinaryDenseVectorSupplier(values), indexVersion, dims),
-                    name
-                );
+                return new BinaryDenseVectorDocValuesField(values, name, dims, indexVersion);
             }
         } catch (IOException e) {
             throw new IllegalStateException("Cannot load doc values for vector field!", e);

+ 37 - 0
x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/whitelist.txt → x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/org.elasticsearch.xpack.vectors.txt

@@ -11,6 +11,43 @@ class org.elasticsearch.xpack.vectors.query.DenseVectorScriptDocValues {
 class org.elasticsearch.script.ScoreScript @no_import {
 }
 
+class org.elasticsearch.xpack.vectors.query.DenseVector {
+    DenseVector EMPTY
+    float getMagnitude()
+
+    # handle List<Number> and float[] arguments
+    double dotProduct(Object)
+    double l1Norm(Object)
+    double l2Norm(Object)
+    double cosineSimilarity(Object)
+
+    float[] getVector()
+    boolean isEmpty()
+    int getDims()
+    int size()
+}
+
+# implementation of DenseVector
+class org.elasticsearch.xpack.vectors.query.BinaryDenseVector {
+}
+
+# implementation of DenseVector
+class org.elasticsearch.xpack.vectors.query.KnnDenseVector {
+}
+
+class org.elasticsearch.xpack.vectors.query.DenseVectorDocValuesField {
+    DenseVector get()
+    DenseVector get(DenseVector)
+}
+
+# implementation of DenseVectorDocValuesField
+class org.elasticsearch.xpack.vectors.query.KnnDenseVectorDocValuesField {
+}
+
+# implementation of DenseVectorDocValuesField
+class org.elasticsearch.xpack.vectors.query.BinaryDenseVectorDocValuesField {
+}
+
 static_import {
     double l1norm(org.elasticsearch.script.ScoreScript, List, String) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L1Norm
     double l2norm(org.elasticsearch.script.ScoreScript, List, String) bound_to org.elasticsearch.xpack.vectors.query.ScoreScriptUtils$L2Norm

+ 51 - 15
x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/BinaryDenseVectorScriptDocValuesTests.java

@@ -12,7 +12,6 @@ import org.apache.lucene.util.BytesRef;
 import org.elasticsearch.Version;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder;
-import org.elasticsearch.xpack.vectors.query.BinaryDenseVectorScriptDocValues.BinaryDenseVectorSupplier;
 
 import java.io.IOException;
 import java.nio.ByteBuffer;
@@ -29,24 +28,56 @@ public class BinaryDenseVectorScriptDocValuesTests extends ESTestCase {
 
         for (Version indexVersion : Arrays.asList(Version.V_7_4_0, Version.CURRENT)) {
             BinaryDocValues docValues = wrap(vectors, indexVersion);
-            BinaryDenseVectorSupplier supplier = new BinaryDenseVectorSupplier(docValues);
-            DenseVectorScriptDocValues scriptDocValues = new BinaryDenseVectorScriptDocValues(supplier, indexVersion, dims);
+            BinaryDenseVectorDocValuesField field = new BinaryDenseVectorDocValuesField(docValues, "test", dims, indexVersion);
+            DenseVectorScriptDocValues scriptDocValues = field.getScriptDocValues();
             for (int i = 0; i < vectors.length; i++) {
-                supplier.setNextDocId(i);
+                field.setNextDocId(i);
+                assertEquals(1, field.size());
+                assertEquals(dims, scriptDocValues.dims());
                 assertArrayEquals(vectors[i], scriptDocValues.getVectorValue(), 0.0001f);
                 assertEquals(expectedMagnitudes[i], scriptDocValues.getMagnitude(), 0.0001f);
             }
         }
     }
 
+    public void testMetadataAndIterator() throws IOException {
+        int dims = 3;
+        Version indexVersion = Version.CURRENT;
+        float[][] vectors = fill(new float[randomIntBetween(1, 5)][dims]);
+        BinaryDocValues docValues = wrap(vectors, indexVersion);
+        BinaryDenseVectorDocValuesField field = new BinaryDenseVectorDocValuesField(docValues, "test", dims, indexVersion);
+        for (int i = 0; i < vectors.length; i++) {
+            field.setNextDocId(i);
+            DenseVector dv = field.get();
+            assertEquals(1, dv.size());
+            assertFalse(dv.isEmpty());
+            assertEquals(dims, dv.getDims());
+            UnsupportedOperationException e = expectThrows(UnsupportedOperationException.class, field::iterator);
+            assertEquals("Cannot iterate over single valued dense_vector field, use get() instead", e.getMessage());
+        }
+        field.setNextDocId(vectors.length);
+        DenseVector dv = field.get();
+        assertEquals(dv, DenseVector.EMPTY);
+    }
+
+    protected float[][] fill(float[][] vectors) {
+        for (float[] vector : vectors) {
+            for (int i = 0; i < vector.length; i++) {
+                vector[i] = randomFloat();
+            }
+        }
+        return vectors;
+    }
+
     public void testMissingValues() throws IOException {
         int dims = 3;
         float[][] vectors = { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } };
         BinaryDocValues docValues = wrap(vectors, Version.CURRENT);
-        BinaryDenseVectorSupplier supplier = new BinaryDenseVectorSupplier(docValues);
-        DenseVectorScriptDocValues scriptDocValues = new BinaryDenseVectorScriptDocValues(supplier, Version.CURRENT, dims);
+        BinaryDenseVectorDocValuesField field = new BinaryDenseVectorDocValuesField(docValues, "test", dims, Version.CURRENT);
+        DenseVectorScriptDocValues scriptDocValues = field.getScriptDocValues();
 
-        supplier.setNextDocId(3);
+        field.setNextDocId(3);
+        assertEquals(0, field.size());
         Exception e = expectThrows(IllegalArgumentException.class, scriptDocValues::getVectorValue);
         assertEquals("A document doesn't have a value for a vector field!", e.getMessage());
 
@@ -58,12 +89,17 @@ public class BinaryDenseVectorScriptDocValuesTests extends ESTestCase {
         int dims = 3;
         float[][] vectors = { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } };
         BinaryDocValues docValues = wrap(vectors, Version.CURRENT);
-        BinaryDenseVectorSupplier supplier = new BinaryDenseVectorSupplier(docValues);
-        DenseVectorScriptDocValues scriptDocValues = new BinaryDenseVectorScriptDocValues(supplier, Version.CURRENT, dims);
+        BinaryDenseVectorDocValuesField field = new BinaryDenseVectorDocValuesField(docValues, "test", dims, Version.CURRENT);
+        DenseVectorScriptDocValues scriptDocValues = field.getScriptDocValues();
 
-        supplier.setNextDocId(0);
+        field.setNextDocId(0);
         Exception e = expectThrows(UnsupportedOperationException.class, () -> scriptDocValues.get(0));
-        assertThat(e.getMessage(), containsString("accessing a vector field's value through 'get' or 'value' is not supported!"));
+        assertThat(
+            e.getMessage(),
+            containsString(
+                "accessing a vector field's value through 'get' or 'value' is not supported, use 'vectorValue' or 'magnitude' instead."
+            )
+        );
     }
 
     public void testSimilarityFunctions() throws IOException {
@@ -73,10 +109,10 @@ public class BinaryDenseVectorScriptDocValuesTests extends ESTestCase {
 
         for (Version indexVersion : Arrays.asList(Version.V_7_4_0, Version.CURRENT)) {
             BinaryDocValues docValues = wrap(new float[][] { docVector }, indexVersion);
-            BinaryDenseVectorSupplier supplier = new BinaryDenseVectorSupplier(docValues);
-            DenseVectorScriptDocValues scriptDocValues = new BinaryDenseVectorScriptDocValues(supplier, Version.CURRENT, dims);
+            BinaryDenseVectorDocValuesField field = new BinaryDenseVectorDocValuesField(docValues, "test", dims, indexVersion);
+            DenseVectorScriptDocValues scriptDocValues = field.getScriptDocValues();
 
-            supplier.setNextDocId(0);
+            field.setNextDocId(0);
 
             assertEquals(
                 "dotProduct result is not equal to the expected value!",
@@ -133,7 +169,7 @@ public class BinaryDenseVectorScriptDocValuesTests extends ESTestCase {
         };
     }
 
-    private static BytesRef mockEncodeDenseVector(float[] values, Version indexVersion) {
+    static BytesRef mockEncodeDenseVector(float[] values, Version indexVersion) {
         byte[] bytes = indexVersion.onOrAfter(Version.V_7_5_0)
             ? new byte[VectorEncoderDecoder.INT_BYTES * values.length + VectorEncoderDecoder.INT_BYTES]
             : new byte[VectorEncoderDecoder.INT_BYTES * values.length];

+ 55 - 19
x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/DenseVectorFunctionTests.java

@@ -7,18 +7,16 @@
 
 package org.elasticsearch.xpack.vectors.query;
 
-import org.apache.lucene.index.BinaryDocValues;
 import org.elasticsearch.Version;
 import org.elasticsearch.script.ScoreScript;
 import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.xpack.vectors.query.BinaryDenseVectorScriptDocValues.BinaryDenseVectorSupplier;
 import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.CosineSimilarity;
 import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.DotProduct;
 import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L1Norm;
 import org.elasticsearch.xpack.vectors.query.ScoreScriptUtils.L2Norm;
 
+import java.io.IOException;
 import java.util.Arrays;
-import java.util.Collections;
 import java.util.List;
 import java.util.function.Supplier;
 
@@ -28,34 +26,72 @@ import static org.mockito.Mockito.when;
 
 public class DenseVectorFunctionTests extends ESTestCase {
 
-    public void testVectorFunctions() {
-        String field = "vector";
+    public void testVectorClassBindings() throws IOException {
+        String fieldName = "vector";
         int dims = 5;
         float[] docVector = new float[] { 230.0f, 300.33f, -34.8988f, 15.555f, -200.0f };
         List<Number> queryVector = Arrays.asList(0.5f, 111.3f, -13.0f, 14.8f, -156.0f);
         List<Number> invalidQueryVector = Arrays.asList(0.5, 111.3);
 
-        for (Version indexVersion : Arrays.asList(Version.V_7_4_0, Version.CURRENT)) {
-            BinaryDocValues docValues = BinaryDenseVectorScriptDocValuesTests.wrap(new float[][] { docVector }, indexVersion);
-            DenseVectorScriptDocValues scriptDocValues = new BinaryDenseVectorScriptDocValues(
-                new BinaryDenseVectorSupplier(docValues),
-                indexVersion,
-                dims
-            );
+        List<DenseVectorDocValuesField> fields = List.of(
+            new BinaryDenseVectorDocValuesField(
+                BinaryDenseVectorScriptDocValuesTests.wrap(new float[][] { docVector }, Version.V_7_4_0),
+                "test",
+                dims,
+                Version.V_7_4_0
+            ),
+            new BinaryDenseVectorDocValuesField(
+                BinaryDenseVectorScriptDocValuesTests.wrap(new float[][] { docVector }, Version.CURRENT),
+                "test",
+                dims,
+                Version.CURRENT
+            ),
+            new KnnDenseVectorDocValuesField(KnnDenseVectorScriptDocValuesTests.wrap(new float[][] { docVector }), "test", dims)
+        );
+        for (DenseVectorDocValuesField field : fields) {
+            field.setNextDocId(0);
 
             ScoreScript scoreScript = mock(ScoreScript.class);
-            when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, scriptDocValues));
+            when(scoreScript.field("vector")).thenAnswer(mock -> field);
 
             // Test cosine similarity explicitly, as it must perform special logic on top of the doc values
-            CosineSimilarity function = new CosineSimilarity(scoreScript, queryVector, field);
-            assertEquals("cosineSimilarity result is not equal to the expected value!", 0.790, function.cosineSimilarity(), 0.001);
+            CosineSimilarity function = new CosineSimilarity(scoreScript, queryVector, fieldName);
+            float cosineSimilarityExpected = 0.790f;
+            assertEquals(
+                "cosineSimilarity result is not equal to the expected value!",
+                cosineSimilarityExpected,
+                function.cosineSimilarity(),
+                0.001
+            );
+
+            // Test normalization for cosineSimilarity
+            float[] queryVectorArray = new float[queryVector.size()];
+            for (int i = 0; i < queryVectorArray.length; i++) {
+                queryVectorArray[i] = queryVector.get(i).floatValue();
+            }
+            assertEquals(
+                "cosineSimilarity result is not equal to the expected value!",
+                cosineSimilarityExpected,
+                field.getInternal().cosineSimilarity(queryVectorArray, true),
+                0.001
+            );
 
             // Check each function rejects query vectors with the wrong dimension
-            assertDimensionMismatch(() -> new DotProduct(scoreScript, invalidQueryVector, field));
-            assertDimensionMismatch(() -> new CosineSimilarity(scoreScript, invalidQueryVector, field));
-            assertDimensionMismatch(() -> new L1Norm(scoreScript, invalidQueryVector, field));
-            assertDimensionMismatch(() -> new L2Norm(scoreScript, invalidQueryVector, field));
+            assertDimensionMismatch(() -> new DotProduct(scoreScript, invalidQueryVector, fieldName));
+            assertDimensionMismatch(() -> new CosineSimilarity(scoreScript, invalidQueryVector, fieldName));
+            assertDimensionMismatch(() -> new L1Norm(scoreScript, invalidQueryVector, fieldName));
+            assertDimensionMismatch(() -> new L2Norm(scoreScript, invalidQueryVector, fieldName));
+
+            // Check scripting infrastructure integration
+            DotProduct dotProduct = new DotProduct(scoreScript, queryVector, fieldName);
+            assertEquals(65425.6249, dotProduct.dotProduct(), 0.001);
+            assertEquals(485.1837, new L1Norm(scoreScript, queryVector, fieldName).l1norm(), 0.001);
+            assertEquals(301.3614, new L2Norm(scoreScript, queryVector, fieldName).l2norm(), 0.001);
+            when(scoreScript._getDocId()).thenReturn(1);
+            IllegalArgumentException e = expectThrows(IllegalArgumentException.class, dotProduct::dotProduct);
+            assertEquals("A document doesn't have a value for a vector field!", e.getMessage());
         }
+
     }
 
     private void assertDimensionMismatch(Supplier<ScoreScriptUtils.DenseVectorFunction> supplier) {

+ 84 - 0
x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/DenseVectorTests.java

@@ -0,0 +1,84 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.vectors.query;
+
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.Version;
+import org.elasticsearch.test.ESTestCase;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+
+import static org.hamcrest.Matchers.containsString;
+
+public class DenseVectorTests extends ESTestCase {
+    public void testBadVectorType() {
+        DenseVector knn = new KnnDenseVector(new float[] { 1.0f, 2.0f, 3.5f });
+        IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> knn.dotProduct(new HashMap<>()));
+        assertThat(e.getMessage(), containsString("Cannot use vector ["));
+        assertThat(e.getMessage(), containsString("] with class [java.util.HashMap] as query vector"));
+
+        e = expectThrows(IllegalArgumentException.class, () -> knn.l1Norm(new HashMap<>()));
+        assertThat(e.getMessage(), containsString("Cannot use vector ["));
+        assertThat(e.getMessage(), containsString("] with class [java.util.HashMap] as query vector"));
+
+        e = expectThrows(IllegalArgumentException.class, () -> knn.l2Norm(new HashMap<>()));
+        assertThat(e.getMessage(), containsString("Cannot use vector ["));
+        assertThat(e.getMessage(), containsString("] with class [java.util.HashMap] as query vector"));
+
+        e = expectThrows(IllegalArgumentException.class, () -> knn.cosineSimilarity(new HashMap<>()));
+        assertThat(e.getMessage(), containsString("Cannot use vector ["));
+        assertThat(e.getMessage(), containsString("] with class [java.util.HashMap] as query vector"));
+    }
+
+    public void testFloatVsListQueryVector() {
+        int dims = randomIntBetween(1, 16);
+        float[] docVector = new float[dims];
+        float[] arrayQV = new float[dims];
+        List<Number> listQV = new ArrayList<>(dims);
+        for (int i = 0; i < docVector.length; i++) {
+            docVector[i] = randomFloat();
+            float q = randomFloat();
+            arrayQV[i] = q;
+            listQV.add(q);
+        }
+
+        KnnDenseVector knn = new KnnDenseVector(docVector);
+        assertEquals(knn.dotProduct(arrayQV), knn.dotProduct(listQV), 0.001f);
+        assertEquals(knn.dotProduct((Object) listQV), knn.dotProduct((Object) arrayQV), 0.001f);
+
+        assertEquals(knn.l1Norm(arrayQV), knn.l1Norm(listQV), 0.001f);
+        assertEquals(knn.l1Norm((Object) listQV), knn.l1Norm((Object) arrayQV), 0.001f);
+
+        assertEquals(knn.l2Norm(arrayQV), knn.l2Norm(listQV), 0.001f);
+        assertEquals(knn.l2Norm((Object) listQV), knn.l2Norm((Object) arrayQV), 0.001f);
+
+        assertEquals(knn.cosineSimilarity(arrayQV), knn.cosineSimilarity(listQV), 0.001f);
+        assertEquals(knn.cosineSimilarity((Object) listQV), knn.cosineSimilarity((Object) arrayQV), 0.001f);
+
+        for (Version indexVersion : Arrays.asList(Version.V_7_4_0, Version.CURRENT)) {
+            BytesRef value = BinaryDenseVectorScriptDocValuesTests.mockEncodeDenseVector(docVector, indexVersion);
+            BinaryDenseVector bdv = new BinaryDenseVector(value, dims, indexVersion);
+
+            assertEquals(bdv.dotProduct(arrayQV), bdv.dotProduct(listQV), 0.001f);
+            assertEquals(bdv.dotProduct((Object) listQV), bdv.dotProduct((Object) arrayQV), 0.001f);
+
+            assertEquals(bdv.l1Norm(arrayQV), bdv.l1Norm(listQV), 0.001f);
+            assertEquals(bdv.l1Norm((Object) listQV), bdv.l1Norm((Object) arrayQV), 0.001f);
+
+            assertEquals(bdv.l2Norm(arrayQV), bdv.l2Norm(listQV), 0.001f);
+            assertEquals(bdv.l2Norm((Object) listQV), bdv.l2Norm((Object) arrayQV), 0.001f);
+
+            assertEquals(bdv.cosineSimilarity(arrayQV), bdv.cosineSimilarity(listQV), 0.001f);
+            assertEquals(bdv.cosineSimilarity((Object) listQV), bdv.cosineSimilarity((Object) arrayQV), 0.001f);
+        }
+    }
+
+}

+ 63 - 15
x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/KnnDenseVectorScriptDocValuesTests.java

@@ -10,7 +10,6 @@ package org.elasticsearch.xpack.vectors.query;
 import org.apache.lucene.index.VectorValues;
 import org.apache.lucene.util.BytesRef;
 import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.xpack.vectors.query.KnnDenseVectorScriptDocValues.KnnDenseVectorSupplier;
 
 import java.io.IOException;
 
@@ -23,22 +22,52 @@ public class KnnDenseVectorScriptDocValuesTests extends ESTestCase {
         float[][] vectors = { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } };
         float[] expectedMagnitudes = { 1.7320f, 2.4495f, 3.3166f };
 
-        KnnDenseVectorSupplier supplier = new KnnDenseVectorSupplier(wrap(vectors));
-        DenseVectorScriptDocValues scriptDocValues = new KnnDenseVectorScriptDocValues(supplier, dims);
+        DenseVectorDocValuesField field = new KnnDenseVectorDocValuesField(wrap(vectors), "test", dims);
+        DenseVectorScriptDocValues scriptDocValues = field.getScriptDocValues();
         for (int i = 0; i < vectors.length; i++) {
-            supplier.setNextDocId(i);
+            field.setNextDocId(i);
+            assertEquals(1, field.size());
+            assertEquals(dims, scriptDocValues.dims());
             assertArrayEquals(vectors[i], scriptDocValues.getVectorValue(), 0.0001f);
             assertEquals(expectedMagnitudes[i], scriptDocValues.getMagnitude(), 0.0001f);
         }
     }
 
+    public void testMetadataAndIterator() throws IOException {
+        int dims = 3;
+        float[][] vectors = fill(new float[randomIntBetween(1, 5)][dims]);
+        KnnDenseVectorDocValuesField field = new KnnDenseVectorDocValuesField(wrap(vectors), "test", dims);
+        for (int i = 0; i < vectors.length; i++) {
+            field.setNextDocId(i);
+            DenseVector dv = field.get();
+            assertEquals(1, dv.size());
+            assertFalse(dv.isEmpty());
+            assertEquals(dims, dv.getDims());
+            UnsupportedOperationException e = expectThrows(UnsupportedOperationException.class, field::iterator);
+            assertEquals("Cannot iterate over single valued dense_vector field, use get() instead", e.getMessage());
+        }
+        assertEquals(1, field.size());
+        field.setNextDocId(vectors.length);
+        DenseVector dv = field.get();
+        assertEquals(dv, DenseVector.EMPTY);
+    }
+
+    protected float[][] fill(float[][] vectors) {
+        for (float[] vector : vectors) {
+            for (int i = 0; i < vector.length; i++) {
+                vector[i] = randomFloat();
+            }
+        }
+        return vectors;
+    }
+
     public void testMissingValues() throws IOException {
         int dims = 3;
         float[][] vectors = { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } };
-        KnnDenseVectorSupplier supplier = new KnnDenseVectorSupplier(wrap(vectors));
-        DenseVectorScriptDocValues scriptDocValues = new KnnDenseVectorScriptDocValues(supplier, dims);
+        DenseVectorDocValuesField field = new KnnDenseVectorDocValuesField(wrap(vectors), "test", dims);
+        DenseVectorScriptDocValues scriptDocValues = field.getScriptDocValues();
 
-        supplier.setNextDocId(3);
+        field.setNextDocId(3);
         Exception e = expectThrows(IllegalArgumentException.class, () -> scriptDocValues.getVectorValue());
         assertEquals("A document doesn't have a value for a vector field!", e.getMessage());
 
@@ -49,12 +78,17 @@ public class KnnDenseVectorScriptDocValuesTests extends ESTestCase {
     public void testGetFunctionIsNotAccessible() throws IOException {
         int dims = 3;
         float[][] vectors = { { 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } };
-        KnnDenseVectorSupplier supplier = new KnnDenseVectorSupplier(wrap(vectors));
-        DenseVectorScriptDocValues scriptDocValues = new KnnDenseVectorScriptDocValues(supplier, dims);
+        DenseVectorDocValuesField field = new KnnDenseVectorDocValuesField(wrap(vectors), "test", dims);
+        DenseVectorScriptDocValues scriptDocValues = field.getScriptDocValues();
 
-        supplier.setNextDocId(0);
+        field.setNextDocId(0);
         Exception e = expectThrows(UnsupportedOperationException.class, () -> scriptDocValues.get(0));
-        assertThat(e.getMessage(), containsString("accessing a vector field's value through 'get' or 'value' is not supported!"));
+        assertThat(
+            e.getMessage(),
+            containsString(
+                "accessing a vector field's value through 'get' or 'value' is not supported, use 'vectorValue' or 'magnitude' instead."
+            )
+        );
     }
 
     public void testSimilarityFunctions() throws IOException {
@@ -62,16 +96,30 @@ public class KnnDenseVectorScriptDocValuesTests extends ESTestCase {
         float[] docVector = new float[] { 230.0f, 300.33f, -34.8988f, 15.555f, -200.0f };
         float[] queryVector = new float[] { 0.5f, 111.3f, -13.0f, 14.8f, -156.0f };
 
-        KnnDenseVectorSupplier supplier = new KnnDenseVectorSupplier(wrap(new float[][] { docVector }));
-        DenseVectorScriptDocValues scriptDocValues = new KnnDenseVectorScriptDocValues(supplier, dims);
-        supplier.setNextDocId(0);
+        DenseVectorDocValuesField field = new KnnDenseVectorDocValuesField(wrap(new float[][] { docVector }), "test", dims);
+        DenseVectorScriptDocValues scriptDocValues = field.getScriptDocValues();
+        field.setNextDocId(0);
 
         assertEquals("dotProduct result is not equal to the expected value!", 65425.624, scriptDocValues.dotProduct(queryVector), 0.001);
         assertEquals("l1norm result is not equal to the expected value!", 485.184, scriptDocValues.l1Norm(queryVector), 0.001);
         assertEquals("l2norm result is not equal to the expected value!", 301.361, scriptDocValues.l2Norm(queryVector), 0.001);
     }
 
-    private static VectorValues wrap(float[][] vectors) {
+    public void testMissingVectorValues() throws IOException {
+        int dims = 7;
+        KnnDenseVectorDocValuesField emptyKnn = new KnnDenseVectorDocValuesField(null, "test", dims);
+
+        emptyKnn.setNextDocId(0);
+        assertEquals(0, emptyKnn.getScriptDocValues().size());
+        assertTrue(emptyKnn.getScriptDocValues().isEmpty());
+        assertEquals(DenseVector.EMPTY, emptyKnn.get());
+        assertNull(emptyKnn.get(null));
+        assertNull(emptyKnn.getInternal());
+        UnsupportedOperationException e = expectThrows(UnsupportedOperationException.class, emptyKnn::iterator);
+        assertEquals("Cannot iterate over single valued dense_vector field, use get() instead", e.getMessage());
+    }
+
+    static VectorValues wrap(float[][] vectors) {
         return new VectorValues() {
             int index = 0;