Browse Source

Add access to dense_vector values (#71313)

Allow direct access to a dense_vector' values in script
through the following functions:

- getVectorValue – returns a vector's value as an array of floats
- getMagnitude – returns a vector's magnitude

Closes #51964
Mayya Sharipova 4 years ago
parent
commit
853e68dfdf
18 changed files with 374 additions and 91 deletions
  1. 3 4
      docs/reference/mapping/types/dense-vector.asciidoc
  2. 58 0
      docs/reference/vectors/vector-functions.asciidoc
  3. 1 5
      server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreFunction.java
  4. 0 1
      server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreQuery.java
  5. 1 2
      server/src/main/java/org/elasticsearch/index/query/functionscore/ScriptScoreFunctionBuilder.java
  6. 0 22
      server/src/main/java/org/elasticsearch/script/ScoreScript.java
  7. 65 0
      x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/vectors/30_dense_vector_script_access.yml
  8. 6 4
      x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapper.java
  9. 45 2
      x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoder.java
  10. 32 2
      x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/DenseVectorScriptDocValues.java
  11. 8 25
      x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java
  12. 7 2
      x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/VectorDVLeafFieldData.java
  13. 16 5
      x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/VectorIndexFieldData.java
  14. 2 0
      x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/whitelist.txt
  15. 1 1
      x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapperTests.java
  16. 12 6
      x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldTypeTests.java
  17. 12 10
      x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/DenseVectorFunctionTests.java
  18. 105 0
      x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/DenseVectorScriptDocValuesTests.java

+ 3 - 4
docs/reference/mapping/types/dense-vector.asciidoc

@@ -10,9 +10,8 @@ A `dense_vector` field stores dense vectors of float values.
 The maximum number of dimensions that can be in a vector should
 not exceed 2048. A `dense_vector` field is a single-valued field.
 
-These vectors can be used for <<vector-functions,document scoring>>.
-For example, a document score can represent a distance between
-a given query vector and the indexed document vector.
+`dense_vector` fields do not support querying, sorting or aggregating. They can
+only be accessed in scripts through the dedicated <<vector-functions,vector functions>>.
 
 You index a dense vector as an array of floats.
 
@@ -47,4 +46,4 @@ PUT my-index-000001/_doc/2
 
 --------------------------------------------------
 
-<1> dimsthe number of dimensions in the vector, required parameter.
+<1> dimsthe number of dimensions in the vector, required parameter.

+ 58 - 0
docs/reference/vectors/vector-functions.asciidoc

@@ -8,6 +8,16 @@ linearly scanned. Thus, expect the query time grow linearly
 with the number of matched documents. For this reason, we recommend
 to limit the number of matched documents with a `query` parameter.
 
+This is the list of available vector functions and vector access methods:
+
+1. `cosineSimilarity` – calculates cosine similarity
+2. `dotProduct` – calculates dot product
+3. `l1norm` – calculates L^1^ distance
+4. `l2norm` - calculates L^2^ distance
+5. `doc[<field>].vectorValue` – returns a vector's value as an array of floats
+6. `doc[<field>].magnitude` – returns a vector's magnitude
+
+
 Let's create an index with a `dense_vector` mapping and index a couple
 of documents into it.
 
@@ -195,3 +205,51 @@ You can check if a document has a value for the field `my_vector` by
 "source": "doc['my_vector'].size() == 0 ? 0 : cosineSimilarity(params.queryVector, 'my_vector')"
 --------------------------------------------------
 // NOTCONSOLE
+
+The recommended way to access dense vectors is through `cosineSimilarity`,
+`dotProduct`, `l1norm` or `l2norm` functions. But for custom use cases,
+you can access dense vectors's values directly through the following functions:
+
+- `doc[<field>].vectorValue` – returns a vector's value as an array of floats
+
+- `doc[<field>].magnitude` – returns a vector's magnitude as a float
+(for vectors created prior to version 7.5 the magnitude is not stored.
+So this function calculates it anew every time it is called).
+
+For example, the script below implements a cosine similarity using these
+two functions:
+
+[source,console]
+--------------------------------------------------
+GET my-index-000001/_search
+{
+  "query": {
+    "script_score": {
+      "query" : {
+        "bool" : {
+          "filter" : {
+            "term" : {
+              "status" : "published"
+            }
+          }
+        }
+      },
+      "script": {
+        "source": """
+          float[] v = doc['my_dense_vector'].vectorValue;
+          float vm = doc['my_dense_vector'].magnitude;
+          float dotProduct = 0;
+          for (int i = 0; i < v.length; i++) {
+            dotProduct += v[i] * params.queryVector[i];
+          }
+          return dotProduct / (vm * (float) params.queryVectorMag);
+        """,
+        "params": {
+          "queryVector": [4, 3.4, -0.2],
+          "queryVectorMag": 5.25357
+        }
+      }
+    }
+  }
+}
+--------------------------------------------------

+ 1 - 5
server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreFunction.java

@@ -14,7 +14,6 @@ import org.apache.lucene.search.Scorable;
 import org.elasticsearch.script.ExplainableScoreScript;
 import org.elasticsearch.script.ScoreScript;
 import org.elasticsearch.script.Script;
-import org.elasticsearch.Version;
 
 import java.io.IOException;
 import java.util.Objects;
@@ -42,15 +41,13 @@ public class ScriptScoreFunction extends ScoreFunction {
 
     private final int shardId;
     private final String indexName;
-    private final Version indexVersion;
 
-    public ScriptScoreFunction(Script sScript, ScoreScript.LeafFactory script, String indexName, int shardId, Version indexVersion) {
+    public ScriptScoreFunction(Script sScript, ScoreScript.LeafFactory script, String indexName, int shardId) {
         super(CombineFunction.REPLACE);
         this.sScript = sScript;
         this.script = script;
         this.indexName = indexName;
         this.shardId = shardId;
-        this.indexVersion = indexVersion;
     }
 
     @Override
@@ -60,7 +57,6 @@ public class ScriptScoreFunction extends ScoreFunction {
         leafScript.setScorer(scorer);
         leafScript._setIndexName(indexName);
         leafScript._setShard(shardId);
-        leafScript._setIndexVersion(indexVersion);
         return new LeafScoreFunction() {
             @Override
             public double score(int docId, float subQueryScore) throws IOException {

+ 0 - 1
server/src/main/java/org/elasticsearch/common/lucene/search/function/ScriptScoreQuery.java

@@ -146,7 +146,6 @@ public class ScriptScoreQuery extends Query {
                 final ScoreScript scoreScript = scriptBuilder.newInstance(context);
                 scoreScript._setIndexName(indexName);
                 scoreScript._setShard(shardId);
-                scoreScript._setIndexVersion(indexVersion);
                 return scoreScript;
             }
 

+ 1 - 2
server/src/main/java/org/elasticsearch/index/query/functionscore/ScriptScoreFunctionBuilder.java

@@ -83,8 +83,7 @@ public class ScriptScoreFunctionBuilder extends ScoreFunctionBuilder<ScriptScore
         try {
             ScoreScript.Factory factory = context.compile(script, ScoreScript.CONTEXT);
             ScoreScript.LeafFactory searchScript = factory.newFactory(script.getParams(), context.lookup());
-            return new ScriptScoreFunction(script, searchScript,
-                context.index().getName(), context.getShardId(), context.indexVersionCreated());
+            return new ScriptScoreFunction(script, searchScript, context.index().getName(), context.getShardId());
         } catch (Exception e) {
             throw new QueryShardException(context, "script_score: the script could not be loaded", e);
         }

+ 0 - 22
server/src/main/java/org/elasticsearch/script/ScoreScript.java

@@ -10,7 +10,6 @@ package org.elasticsearch.script;
 import org.apache.lucene.index.LeafReaderContext;
 import org.apache.lucene.search.Explanation;
 import org.apache.lucene.search.Scorable;
-import org.elasticsearch.Version;
 import org.elasticsearch.common.logging.DeprecationCategory;
 import org.elasticsearch.common.logging.DeprecationLogger;
 import org.elasticsearch.index.fielddata.ScriptDocValues;
@@ -85,7 +84,6 @@ public abstract class ScoreScript {
     private int docId;
     private int shardId = -1;
     private String indexName = null;
-    private Version indexVersion = null;
 
     public ScoreScript(Map<String, Object> params, SearchLookup lookup, LeafReaderContext leafContext) {
         // null check needed b/c of expression engine subclass
@@ -185,19 +183,6 @@ public abstract class ScoreScript {
         }
     }
 
-    /**
-     *  Starting a name with underscore, so that the user cannot access this function directly through a script
-     *  It is only used within predefined painless functions.
-     * @return index version or throws an exception if the index version is not set up for this script instance
-     */
-    public Version _getIndexVersion() {
-        if (indexVersion != null) {
-            return indexVersion;
-        } else {
-            throw new IllegalArgumentException("index version can not be looked up!");
-        }
-    }
-
     /**
      *  Starting a name with underscore, so that the user cannot access this function directly through a script
      */
@@ -212,13 +197,6 @@ public abstract class ScoreScript {
         this.indexName = indexName;
     }
 
-    /**
-     *  Starting a name with underscore, so that the user cannot access this function directly through a script
-     */
-    public void _setIndexVersion(Version indexVersion) {
-        this.indexVersion = indexVersion;
-    }
-
 
     /** A factory to construct {@link ScoreScript} instances. */
     public interface LeafFactory {

+ 65 - 0
x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/vectors/30_dense_vector_script_access.yml

@@ -0,0 +1,65 @@
+---
+"Access to values of dense_vector in script":
+  - skip:
+      version: " - 7.12.99"
+      reason: "Access to values of dense_vector in script was added in 7.13"
+  - do:
+      indices.create:
+        index: test-index
+        body:
+          mappings:
+            properties:
+              v:
+                type: dense_vector
+                dims: 3
+
+  - do:
+      bulk:
+        index: test-index
+        refresh: true
+        body:
+          - '{"index": {"_id": "1"}}'
+          - '{"v": [1, 1, 1]}'
+          - '{"index": {"_id": "2"}}'
+          - '{"v": [1, 1, 2]}'
+          - '{"index": {"_id": "3"}}'
+          - '{"v": [1, 1, 3]}'
+          - '{"index": {"_id": "missing_vector"}}'
+          - '{}'
+
+  # vector functions in loop – return the index of the closest parameter vector based on cosine similarity
+  - do:
+      search:
+        body:
+          query:
+            script_score:
+              query: { "exists": { "field": "v" } }
+              script:
+                source: |
+                  float[] v = doc['v'].vectorValue;
+                  float vm = doc['v'].magnitude;
+
+                  int closestPv = 0;
+                  float maxCosSim = -1;
+                  for (int i = 0; i < params.pvs.length; i++) {
+                    float dotProduct = 0;
+                    for (int j = 0; j < v.length; j++) {
+                      dotProduct += v[j] * params.pvs[i][j];
+                    }
+                    float cosSim = dotProduct / (vm * (float) params.pvs_magnts[i]);
+                    if (maxCosSim < cosSim) {
+                      maxCosSim = cosSim;
+                      closestPv = i;
+                    }
+                  }
+                  closestPv;
+                params:
+                  pvs: [ [ 1, 1, 1 ], [ 1, 1, 2 ], [ 1, 1, 3 ] ]
+                  pvs_magnts: [1.7320, 2.4495, 3.3166]
+
+  - match: { hits.hits.0._id: "3" }
+  - match: { hits.hits.0._score: 2 }
+  - match: { hits.hits.1._id: "2" }
+  - match: { hits.hits.1._score: 1 }
+  - match: { hits.hits.2._id: "1" }
+  - match: { hits.hits.2._score: 0 }

+ 6 - 4
x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapper.java

@@ -82,7 +82,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
         public DenseVectorFieldMapper build(ContentPath contentPath) {
             return new DenseVectorFieldMapper(
                 name,
-                new DenseVectorFieldType(buildFullName(contentPath), dims.getValue(), meta.getValue()),
+                new DenseVectorFieldType(buildFullName(contentPath), indexVersionCreated, dims.getValue(), meta.getValue()),
                 dims.getValue(),
                 indexVersionCreated,
                 multiFieldsBuilder.build(this, contentPath),
@@ -94,10 +94,12 @@ public class DenseVectorFieldMapper extends FieldMapper {
 
     public static final class DenseVectorFieldType extends MappedFieldType {
         private final int dims;
+        private final Version indexVersionCreated;
 
-        public DenseVectorFieldType(String name, int dims, Map<String, String> meta) {
+        public DenseVectorFieldType(String name, Version indexVersionCreated, int dims, Map<String, String> meta) {
             super(name, false, false, true, TextSearchInfo.NONE, meta);
             this.dims = dims;
+            this.indexVersionCreated = indexVersionCreated;
         }
 
         int dims() {
@@ -124,7 +126,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
 
         @Override
         public DocValueFormat docValueFormat(String format, ZoneId timeZone) {
-            throw new UnsupportedOperationException(
+            throw new IllegalArgumentException(
                 "Field [" + name() + "] of type [" + typeName() + "] doesn't support docvalue_fields or aggregations");
         }
 
@@ -135,7 +137,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
 
         @Override
         public IndexFieldData.Builder fielddataBuilder(String fullyQualifiedIndexName, Supplier<SearchLookup> searchLookup) {
-            return new VectorIndexFieldData.Builder(name(), CoreValuesSourceType.KEYWORD);
+            return new VectorIndexFieldData.Builder(name(), CoreValuesSourceType.KEYWORD, indexVersionCreated, dims);
         }
 
         @Override

+ 45 - 2
x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/mapper/VectorEncoderDecoder.java

@@ -13,6 +13,7 @@ import org.elasticsearch.Version;
 
 import java.nio.ByteBuffer;
 
+
 public final class VectorEncoderDecoder {
     public static final byte INT_BYTES = 4;
 
@@ -29,9 +30,51 @@ public final class VectorEncoderDecoder {
      * NOTE: this function can only be called on vectors from an index version greater than or
      * equal to 7.5.0, since vectors created prior to that do not store the magnitude.
      */
-    public static float decodeVectorMagnitude(Version indexVersion, BytesRef vectorBR) {
+    public static float decodeMagnitude(Version indexVersion, BytesRef vectorBR) {
         assert indexVersion.onOrAfter(Version.V_7_5_0);
         ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
-        return byteBuffer.getFloat(vectorBR.offset + vectorBR.length - 4);
+        return byteBuffer.getFloat(vectorBR.offset + vectorBR.length - INT_BYTES);
+    }
+
+    /**
+     * Calculates vector magnitude
+     */
+    private static float calculateMagnitude(Version indexVersion, BytesRef vectorBR) {
+        final int length = denseVectorLength(indexVersion, vectorBR);
+        ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
+        double magnitude = 0.0f;
+        for (int i = 0; i < length; i++) {
+            float value = byteBuffer.getFloat();
+            magnitude += value * value;
+        }
+        magnitude = Math.sqrt(magnitude);
+        return (float) magnitude;
+    }
+
+    public static float getMagnitude(Version indexVersion, BytesRef vectorBR) {
+        if (vectorBR == null) {
+            throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
+        }
+        if (indexVersion.onOrAfter(Version.V_7_5_0)) {
+            return decodeMagnitude(indexVersion, vectorBR);
+        } else {
+            return calculateMagnitude(indexVersion, vectorBR);
+        }
+    }
+
+    /**
+     * Decodes a BytesRef into the provided array of floats
+     * @param vectorBR - dense vector encoded in BytesRef
+     * @param vector - array of floats where the decoded vector should be stored
+     */
+    public static void decodeDenseVector(BytesRef vectorBR, float[] vector) {
+        if (vectorBR == null) {
+            throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
+        }
+        ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
+        for (int dim = 0; dim < vector.length; dim++) {
+            vector[dim] = byteBuffer.getFloat();
+        }
     }
+
 }

+ 32 - 2
x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/DenseVectorScriptDocValues.java

@@ -10,17 +10,26 @@ package org.elasticsearch.xpack.vectors.query;
 
 import org.apache.lucene.index.BinaryDocValues;
 import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.Version;
 import org.elasticsearch.index.fielddata.ScriptDocValues;
+import org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder;
 
 import java.io.IOException;
 
 public class DenseVectorScriptDocValues extends ScriptDocValues<BytesRef> {
 
     private final BinaryDocValues in;
+    private final Version indexVersion;
+    private final int dims;
+    private final float[] vector;
     private BytesRef value;
 
-    DenseVectorScriptDocValues(BinaryDocValues in) {
+
+    DenseVectorScriptDocValues(BinaryDocValues in, Version indexVersion, int dims) {
         this.in = in;
+        this.indexVersion = indexVersion;
+        this.dims = dims;
+        this.vector = new float[dims];
     }
 
     @Override
@@ -37,9 +46,30 @@ public class DenseVectorScriptDocValues extends ScriptDocValues<BytesRef> {
         return value;
     }
 
+    // package private access only for {@link ScoreScriptUtils}
+    int dims() {
+        return dims;
+    }
+
     @Override
     public BytesRef get(int index) {
-        throw new UnsupportedOperationException("accessing a vector field's value through 'get' or 'value' is not supported");
+        throw new UnsupportedOperationException("accessing a vector field's value through 'get' or 'value' is not supported!" +
+            "Use 'vectorValue' or 'magnitude' instead!'");
+    }
+
+    /**
+     * Get dense vector's value as an array of floats
+     */
+    public float[] getVectorValue() {
+        VectorEncoderDecoder.decodeDenseVector(value, vector);
+        return vector;
+    }
+
+    /**
+     * Get dense vector's magnitude
+     */
+    public float getMagnitude() {
+        return VectorEncoderDecoder.getMagnitude(indexVersion, value);
     }
 
     @Override

+ 8 - 25
x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/ScoreScriptUtils.java

@@ -10,9 +10,7 @@ package org.elasticsearch.xpack.vectors.query;
 
 import org.apache.lucene.util.BytesRef;
 import org.elasticsearch.ExceptionsHelper;
-import org.elasticsearch.Version;
 import org.elasticsearch.script.ScoreScript;
-import org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder;
 
 import java.io.IOException;
 import java.nio.ByteBuffer;
@@ -45,6 +43,11 @@ public class ScoreScriptUtils {
             this.scoreScript = scoreScript;
             this.docValues = (DenseVectorScriptDocValues) scoreScript.getDoc().get(field);
 
+            if (docValues.dims() != queryVector.size()){
+                throw new IllegalArgumentException("The query vector has a different number of dimensions [" +
+                    queryVector.size() + "] than the document vectors [" + docValues.dims() + "].");
+            }
+
             this.queryVector = new float[queryVector.size()];
             double queryMagnitude = 0.0;
             for (int i = 0; i < queryVector.size(); i++) {
@@ -67,18 +70,10 @@ public class ScoreScriptUtils {
             } catch (IOException e) {
                 throw ExceptionsHelper.convertToElastic(e);
             }
-
-            // Validate the encoded vector's length.
             BytesRef vector = docValues.getEncodedValue();
             if (vector == null) {
                 throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
             }
-
-            int vectorLength = VectorEncoderDecoder.denseVectorLength(scoreScript._getIndexVersion(), vector);
-            if (queryVector.length != vectorLength) {
-                throw new IllegalArgumentException("The query vector has a different number of dimensions [" +
-                    queryVector.length + "] than the document vectors [" + vectorLength + "].");
-            }
             return vector;
         }
     }
@@ -152,23 +147,11 @@ public class ScoreScriptUtils {
         public double cosineSimilarity() {
             BytesRef vector = getEncodedVector();
             ByteBuffer byteBuffer = ByteBuffer.wrap(vector.bytes, vector.offset, vector.length);
-
             double dotProduct = 0.0;
-            double vectorMagnitude = 0.0f;
-            if (scoreScript._getIndexVersion().onOrAfter(Version.V_7_5_0)) {
-                for (float queryValue : queryVector) {
-                    dotProduct += queryValue * byteBuffer.getFloat();
-                }
-                vectorMagnitude = VectorEncoderDecoder.decodeVectorMagnitude(scoreScript._getIndexVersion(), vector);
-            } else {
-                for (float queryValue : queryVector) {
-                    float docValue = byteBuffer.getFloat();
-                    dotProduct += queryValue * docValue;
-                    vectorMagnitude += docValue * docValue;
-                }
-                vectorMagnitude = (float) Math.sqrt(vectorMagnitude);
+            for (float queryValue : queryVector) {
+                dotProduct += queryValue * byteBuffer.getFloat();
             }
-            return dotProduct / vectorMagnitude;
+            return dotProduct / docValues.getMagnitude();
         }
     }
 }

+ 7 - 2
x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/VectorDVLeafFieldData.java

@@ -13,6 +13,7 @@ import org.apache.lucene.index.DocValues;
 import org.apache.lucene.index.LeafReader;
 import org.apache.lucene.util.Accountable;
 import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.Version;
 import org.elasticsearch.index.fielddata.LeafFieldData;
 import org.elasticsearch.index.fielddata.ScriptDocValues;
 import org.elasticsearch.index.fielddata.SortedBinaryDocValues;
@@ -25,10 +26,14 @@ final class VectorDVLeafFieldData implements LeafFieldData {
 
     private final LeafReader reader;
     private final String field;
+    private final Version indexVersion;
+    private final int dims;
 
-    VectorDVLeafFieldData(LeafReader reader, String field) {
+    VectorDVLeafFieldData(LeafReader reader, String field, Version indexVersion, int dims) {
         this.reader = reader;
         this.field = field;
+        this.indexVersion = indexVersion;
+        this.dims = dims;
     }
 
     @Override
@@ -50,7 +55,7 @@ final class VectorDVLeafFieldData implements LeafFieldData {
     public ScriptDocValues<BytesRef> getScriptValues() {
         try {
             final BinaryDocValues values = DocValues.getBinary(reader, field);
-            return new DenseVectorScriptDocValues(values);
+            return new DenseVectorScriptDocValues(values, indexVersion, dims);
         } catch (IOException e) {
             throw new IllegalStateException("Cannot load doc values for vector field!", e);
         }

+ 16 - 5
x-pack/plugin/vectors/src/main/java/org/elasticsearch/xpack/vectors/query/VectorIndexFieldData.java

@@ -10,6 +10,7 @@ package org.elasticsearch.xpack.vectors.query;
 
 import org.apache.lucene.index.LeafReaderContext;
 import org.apache.lucene.search.SortField;
+import org.elasticsearch.Version;
 import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.util.BigArrays;
 import org.elasticsearch.index.fielddata.IndexFieldData;
@@ -21,16 +22,21 @@ import org.elasticsearch.search.MultiValueMode;
 import org.elasticsearch.search.aggregations.support.ValuesSourceType;
 import org.elasticsearch.search.sort.BucketedSort;
 import org.elasticsearch.search.sort.SortOrder;
+import org.elasticsearch.xpack.vectors.mapper.DenseVectorFieldMapper;
 
 
 public class VectorIndexFieldData implements IndexFieldData<VectorDVLeafFieldData> {
 
     protected final String fieldName;
     protected final ValuesSourceType valuesSourceType;
+    private final Version indexVersion;
+    private final int dims;
 
-    public VectorIndexFieldData(String fieldName, ValuesSourceType valuesSourceType) {
+    public VectorIndexFieldData(String fieldName, ValuesSourceType valuesSourceType, Version indexVersion, int dims) {
         this.fieldName = fieldName;
         this.valuesSourceType = valuesSourceType;
+        this.indexVersion = indexVersion;
+        this.dims = dims;
     }
 
     @Override
@@ -45,7 +51,8 @@ public class VectorIndexFieldData implements IndexFieldData<VectorDVLeafFieldDat
 
     @Override
     public SortField sortField(@Nullable Object missingValue, MultiValueMode sortMode, Nested nested, boolean reverse) {
-        throw new IllegalArgumentException("can't sort on the vector field");
+        throw new IllegalArgumentException("Field [" + fieldName + "] of type [" +
+            DenseVectorFieldMapper.CONTENT_TYPE + "] doesn't support sort");
     }
 
     @Override
@@ -56,7 +63,7 @@ public class VectorIndexFieldData implements IndexFieldData<VectorDVLeafFieldDat
 
     @Override
     public VectorDVLeafFieldData load(LeafReaderContext context) {
-        return new VectorDVLeafFieldData(context.reader(), fieldName);
+        return new VectorDVLeafFieldData(context.reader(), fieldName, indexVersion, dims);
     }
 
     @Override
@@ -67,15 +74,19 @@ public class VectorIndexFieldData implements IndexFieldData<VectorDVLeafFieldDat
     public static class Builder implements IndexFieldData.Builder {
         private final String name;
         private final ValuesSourceType valuesSourceType;
+        private final Version indexVersion;
+        private final int dims;
 
-        public Builder(String name, ValuesSourceType valuesSourceType) {
+        public Builder(String name, ValuesSourceType valuesSourceType, Version indexVersion, int dims) {
             this.name = name;
             this.valuesSourceType = valuesSourceType;
+            this.indexVersion = indexVersion;
+            this.dims = dims;
         }
 
         @Override
         public IndexFieldData<?> build(IndexFieldDataCache cache, CircuitBreakerService breakerService) {
-            return new VectorIndexFieldData(name, valuesSourceType);
+            return new VectorIndexFieldData(name, valuesSourceType, indexVersion, dims);
         }
 
     }

+ 2 - 0
x-pack/plugin/vectors/src/main/resources/org/elasticsearch/xpack/vectors/query/whitelist.txt

@@ -5,6 +5,8 @@
 # 2.0.
 #
 class org.elasticsearch.xpack.vectors.query.DenseVectorScriptDocValues {
+    float[] getVectorValue()
+    float getMagnitude()
 }
 class org.elasticsearch.script.ScoreScript @no_import {
 }

+ 1 - 1
x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldMapperTests.java

@@ -95,7 +95,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
         // assert that after decoding the indexed value is equal to expected
         BytesRef vectorBR = fields[0].binaryValue();
         float[] decodedValues = decodeDenseVector(Version.CURRENT, vectorBR);
-        float decodedMagnitude = VectorEncoderDecoder.decodeVectorMagnitude(Version.CURRENT, vectorBR);
+        float decodedMagnitude = VectorEncoderDecoder.decodeMagnitude(Version.CURRENT, vectorBR);
         assertEquals(expectedMagnitude, decodedMagnitude, 0.001f);
         assertArrayEquals(
             "Decoded dense vector values is not equal to the indexed one.",

+ 12 - 6
x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/mapper/DenseVectorFieldTypeTests.java

@@ -7,6 +7,7 @@
 
 package org.elasticsearch.xpack.vectors.mapper;
 
+import org.elasticsearch.Version;
 import org.elasticsearch.index.mapper.FieldTypeTestCase;
 
 import java.io.IOException;
@@ -16,29 +17,34 @@ import java.util.List;
 public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
 
     public void testHasDocValues() {
-        DenseVectorFieldMapper.DenseVectorFieldType ft = new DenseVectorFieldMapper.DenseVectorFieldType("f", 1, Collections.emptyMap());
+        DenseVectorFieldMapper.DenseVectorFieldType ft = new DenseVectorFieldMapper.DenseVectorFieldType(
+            "f", Version.CURRENT, 1, Collections.emptyMap());
         assertTrue(ft.hasDocValues());
     }
 
     public void testIsAggregatable() {
-        DenseVectorFieldMapper.DenseVectorFieldType ft = new DenseVectorFieldMapper.DenseVectorFieldType("f", 1, Collections.emptyMap());
+        DenseVectorFieldMapper.DenseVectorFieldType ft = new DenseVectorFieldMapper.DenseVectorFieldType(
+            "f", Version.CURRENT,1, Collections.emptyMap());
         assertFalse(ft.isAggregatable());
     }
 
     public void testFielddataBuilder() {
-        DenseVectorFieldMapper.DenseVectorFieldType ft = new DenseVectorFieldMapper.DenseVectorFieldType("f", 1, Collections.emptyMap());
+        DenseVectorFieldMapper.DenseVectorFieldType ft = new DenseVectorFieldMapper.DenseVectorFieldType(
+            "f", Version.CURRENT,1, Collections.emptyMap());
         assertNotNull(ft.fielddataBuilder("index", () -> {
             throw new UnsupportedOperationException();
         }));
     }
 
     public void testDocValueFormat() {
-        DenseVectorFieldMapper.DenseVectorFieldType ft = new DenseVectorFieldMapper.DenseVectorFieldType("f", 1, Collections.emptyMap());
-        expectThrows(UnsupportedOperationException.class, () -> ft.docValueFormat(null, null));
+        DenseVectorFieldMapper.DenseVectorFieldType ft = new DenseVectorFieldMapper.DenseVectorFieldType(
+            "f", Version.CURRENT,1, Collections.emptyMap());
+        expectThrows(IllegalArgumentException.class, () -> ft.docValueFormat(null, null));
     }
 
     public void testFetchSourceValue() throws IOException {
-        DenseVectorFieldMapper.DenseVectorFieldType ft = new DenseVectorFieldMapper.DenseVectorFieldType("f", 5, Collections.emptyMap());
+        DenseVectorFieldMapper.DenseVectorFieldType ft = new DenseVectorFieldMapper.DenseVectorFieldType(
+            "f", Version.CURRENT, 5, Collections.emptyMap());
         List<Double> vector = List.of(0.0, 1.0, 2.0, 3.0, 4.0);
         assertEquals(vector, fetchSourceValue(ft, vector));
     }

+ 12 - 10
x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/DenseVectorFunctionTests.java

@@ -44,11 +44,14 @@ public class DenseVectorFunctionTests extends ESTestCase {
     public void testVectorFunctions() {
         for (Version indexVersion : Arrays.asList(Version.V_7_4_0, Version.CURRENT)) {
             BytesRef encodedDocVector = mockEncodeDenseVector(docVector, indexVersion);
+            float magnitude = VectorEncoderDecoder.getMagnitude(indexVersion, encodedDocVector);
+
             DenseVectorScriptDocValues docValues = mock(DenseVectorScriptDocValues.class);
             when(docValues.getEncodedValue()).thenReturn(encodedDocVector);
+            when(docValues.getMagnitude()).thenReturn(magnitude);
+            when(docValues.dims()).thenReturn(docVector.length);
 
             ScoreScript scoreScript = mock(ScoreScript.class);
-            when(scoreScript._getIndexVersion()).thenReturn(indexVersion);
             when(scoreScript.getDoc()).thenReturn(Collections.singletonMap(field, docValues));
 
             testDotProduct(scoreScript);
@@ -63,8 +66,8 @@ public class DenseVectorFunctionTests extends ESTestCase {
         double result = function.dotProduct();
         assertEquals("dotProduct result is not equal to the expected value!", 65425.624, result, 0.001);
 
-        DotProduct invalidFunction = new DotProduct(scoreScript, invalidQueryVector, field);
-        IllegalArgumentException e = expectThrows(IllegalArgumentException.class, invalidFunction::dotProduct);
+        IllegalArgumentException e =
+            expectThrows(IllegalArgumentException.class, () -> new DotProduct(scoreScript, invalidQueryVector, field));
         assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
     }
 
@@ -73,8 +76,8 @@ public class DenseVectorFunctionTests extends ESTestCase {
         double result = function.cosineSimilarity();
         assertEquals("cosineSimilarity result is not equal to the expected value!", 0.790, result, 0.001);
 
-        CosineSimilarity invalidFunction = new CosineSimilarity(scoreScript, invalidQueryVector, field);
-        IllegalArgumentException e = expectThrows(IllegalArgumentException.class, invalidFunction::cosineSimilarity);
+        IllegalArgumentException e =
+            expectThrows(IllegalArgumentException.class, () -> new CosineSimilarity(scoreScript, invalidQueryVector, field));
         assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
     }
 
@@ -83,8 +86,8 @@ public class DenseVectorFunctionTests extends ESTestCase {
         double result = function.l1norm();
         assertEquals("l1norm result is not equal to the expected value!", 485.184, result, 0.001);
 
-        L1Norm invalidFunction = new L1Norm(scoreScript, invalidQueryVector, field);
-        IllegalArgumentException e = expectThrows(IllegalArgumentException.class, invalidFunction::l1norm);
+        IllegalArgumentException e =
+            expectThrows(IllegalArgumentException.class,  () -> new L1Norm(scoreScript, invalidQueryVector, field));
         assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
     }
 
@@ -93,12 +96,11 @@ public class DenseVectorFunctionTests extends ESTestCase {
         double result = function.l2norm();
         assertEquals("l2norm result is not equal to the expected value!", 301.361, result, 0.001);
 
-        L2Norm invalidFunction = new L2Norm(scoreScript, invalidQueryVector, field);
-        IllegalArgumentException e = expectThrows(IllegalArgumentException.class, invalidFunction::l2norm);
+        IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> new L2Norm(scoreScript, invalidQueryVector, field));
         assertThat(e.getMessage(), containsString("query vector has a different number of dimensions [2] than the document vectors [5]"));
     }
 
-    private static BytesRef mockEncodeDenseVector(float[] values, Version indexVersion) {
+    static BytesRef mockEncodeDenseVector(float[] values, Version indexVersion) {
         byte[] bytes = indexVersion.onOrAfter(Version.V_7_5_0)
             ? new byte[VectorEncoderDecoder.INT_BYTES * values.length + VectorEncoderDecoder.INT_BYTES]
             : new byte[VectorEncoderDecoder.INT_BYTES * values.length];

+ 105 - 0
x-pack/plugin/vectors/src/test/java/org/elasticsearch/xpack/vectors/query/DenseVectorScriptDocValuesTests.java

@@ -0,0 +1,105 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.vectors.query;
+
+import org.apache.lucene.index.BinaryDocValues;
+import org.apache.lucene.util.BytesRef;
+import org.elasticsearch.Version;
+import org.elasticsearch.test.ESTestCase;
+
+import java.io.IOException;
+import java.util.Arrays;
+
+import static org.hamcrest.Matchers.containsString;
+
+public class DenseVectorScriptDocValuesTests extends ESTestCase {
+
+    private static BinaryDocValues wrap(float[][] vectors, Version indexVersion) {
+        return new BinaryDocValues() {
+            int idx = -1;
+            int maxIdx = vectors.length;
+            @Override
+            public BytesRef binaryValue() {
+                if (idx >= maxIdx) {
+                    throw new IllegalStateException("max index exceeded");
+                }
+                return DenseVectorFunctionTests.mockEncodeDenseVector(vectors[idx], indexVersion);
+            }
+
+            @Override
+            public boolean advanceExact(int target) {
+                idx = target;
+                if (target < maxIdx) {
+                    return true;
+                }
+                return false;
+            }
+
+            @Override
+            public int docID() {
+                return idx;
+            }
+
+            @Override
+            public int nextDoc() {
+                return idx++;
+            }
+
+            @Override
+            public int advance(int target) {
+                throw new IllegalArgumentException("not defined!");
+            }
+
+            @Override
+            public long cost() {
+                throw new IllegalArgumentException("not defined!");
+            }
+        };
+    }
+
+    public void testGetVectorValueAndGetMagnitude() throws IOException {
+        final int dims = 3;
+        float[][] vectors = {{ 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } };
+        float[] expectedMagnitudes = { 1.7320f, 2.4495f, 3.3166f };
+
+        for (Version indexVersion : Arrays.asList(Version.V_7_4_0, Version.CURRENT)) {
+            BinaryDocValues docValues = wrap(vectors, indexVersion);
+            final DenseVectorScriptDocValues scriptDocValues = new DenseVectorScriptDocValues(docValues, indexVersion, dims);
+            for (int i = 0; i < vectors.length; i++) {
+                scriptDocValues.setNextDocId(i);
+                assertArrayEquals(vectors[i], scriptDocValues.getVectorValue(), 0.0001f);
+                assertEquals(expectedMagnitudes[i], scriptDocValues.getMagnitude(), 0.0001f);
+            }
+        }
+    }
+
+    public void testMissingValues() throws IOException {
+        final int dims = 3;
+        float[][] vectors = {{ 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } };
+        BinaryDocValues docValues = wrap(vectors, Version.CURRENT);
+        final DenseVectorScriptDocValues scriptDocValues = new DenseVectorScriptDocValues(docValues, Version.CURRENT, dims);
+
+        scriptDocValues.setNextDocId(3);
+        Exception e = expectThrows(IllegalArgumentException.class, () -> scriptDocValues.getVectorValue());
+        assertEquals("A document doesn't have a value for a vector field!", e.getMessage());
+
+        e = expectThrows(IllegalArgumentException.class, () -> scriptDocValues.getMagnitude());
+        assertEquals("A document doesn't have a value for a vector field!", e.getMessage());
+    }
+
+    public void testGetFunctionIsNotAccessible() throws IOException {
+        final int dims = 3;
+        float[][] vectors = {{ 1, 1, 1 }, { 1, 1, 2 }, { 1, 1, 3 } };
+        BinaryDocValues docValues = wrap(vectors, Version.CURRENT);
+        final DenseVectorScriptDocValues scriptDocValues = new DenseVectorScriptDocValues(docValues, Version.CURRENT, dims);
+
+        scriptDocValues.setNextDocId(0);
+        Exception e = expectThrows(UnsupportedOperationException.class, () -> scriptDocValues.get(0));
+        assertThat(e.getMessage(), containsString("accessing a vector field's value through 'get' or 'value' is not supported!"));
+    }
+}