소스 검색

Revert "Auto-normalize dot_product vectors at index & query (#98944)" (#99421)

This reverts commit 7b9c367aeb84493ec625c1125d0c50d63c61bc3a.
Benjamin Trent 2 년 전
부모
커밋
83b70e37ef

+ 0 - 5
docs/changelog/98944.yaml

@@ -1,5 +0,0 @@
-pr: 98944
-summary: Auto-normalize `dot_product` vectors at index & query
-area: Vector Search
-type: enhancement
-issues: []

+ 2 - 2
docs/reference/how-to/knn-search.asciidoc

@@ -21,8 +21,8 @@ options.
 The `cosine` option accepts any float vector and computes the cosine
 similarity. While this is convenient for testing, it's not the most efficient
 approach. Instead, we recommend using the `dot_product` option to compute the
-similarity. When using `dot_product`, all vectors are normalized during index to have
-a magnitude of 1. The `dot_product` option is significantly faster, since it
+similarity. To use `dot_product`, all vectors need to be normalized in advance
+to have length 1. The `dot_product` option is significantly faster, since it
 avoids performing extra vector length computations during the search.
 
 [discrete]

+ 5 - 5
docs/reference/mapping/types/dense-vector.asciidoc

@@ -163,9 +163,9 @@ Computes the dot product of two vectors. This option provides an optimized way
 to perform cosine similarity. The constraints and computed score are defined
 by `element_type`.
 +
-When `element_type` is `float`, all vectors are automatically converted to unit length, including both
-document and query vectors. Consequently, `dot_product` does not allow vectors with a zero magnitude.
-The document `_score` is computed as `(1 + dot_product(query, vector)) / 2`.
+When `element_type` is `float`, all vectors must be unit length, including both
+document and query vectors. The document `_score` is computed as
+`(1 + dot_product(query, vector)) / 2`.
 +
 When `element_type` is `byte`, all vectors must have the same
 length including both document and query vectors or results will be inaccurate.
@@ -175,9 +175,9 @@ where `dims` is the number of dimensions per vector.
 
 `cosine`:::
 Computes the cosine similarity. Note that the most efficient way to perform
-cosine similarity is to have all vectors normalized to unit length, and instead use
+cosine similarity is to normalize all vectors to unit length, and instead use
 `dot_product`. You should only use `cosine` if you need to preserve the
-original vectors and cannot allow Elasticsearch to normalize them. The document `_score`
+original vectors and cannot normalize them in advance. The document `_score`
 is computed as `(1 + cosine(query, vector)) / 2`. The `cosine` similarity does
 not allow vectors with zero magnitude, since cosine is not defined in this
 case.

+ 0 - 120
rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml

@@ -368,123 +368,3 @@ setup:
             filter: {"term": {"name": "cow.jpg"}}
 
   - length: {hits.hits: 0}
----
-"kNN search with dot-product auto-normalized":
-  - skip:
-      features: close_to
-      version: ' - 8.10.99'
-      reason: 'dense_vector auto-normalized dot_product in 8.11'
-
-  - do:
-      indices.create:
-        index: test_dot_product
-        body:
-          mappings:
-            properties:
-              name:
-                type: keyword
-              dot_product_vector:
-                type: dense_vector
-                dims: 5
-                index: true
-                similarity: dot_product
-              cosine_vector:
-                type: dense_vector
-                dims: 5
-                index: true
-                similarity: cosine
-
-  - do:
-      index:
-        index: test_dot_product
-        id: "1"
-        body:
-          name: cow.jpg
-          dot_product_vector: [ 230.0, 300.33, -34.8988, 15.555, -200.0 ]
-          cosine_vector: [ 230.0, 300.33, -34.8988, 15.555, -200.0 ]
-
-  - do:
-      index:
-        index: test_dot_product
-        id: "2"
-        body:
-          name: moose.jpg
-          dot_product_vector: [ -0.5, 100.0, -13, 14.8, -156.0 ]
-          cosine_vector: [ -0.5, 100.0, -13, 14.8, -156.0 ]
-
-  - do:
-      index:
-        index: test_dot_product
-        id: "3"
-        body:
-          name: rabbit.jpg
-          dot_product_vector: [ 0.5, 111.3, -13.0, 14.8, -156.0 ]
-          cosine_vector: [ 0.5, 111.3, -13.0, 14.8, -156.0 ]
-
-  - do:
-      indices.refresh: { }
-
-  - do:
-      search:
-        index: test_dot_product
-        body:
-          fields: [ "name" ]
-          knn:
-            field: dot_product_vector
-            query_vector: [-0.5, 90.0, -10, 14.8, -156.0]
-            k: 2
-            num_candidates: 3
-
-  - match: {hits.total.value: 2}
-  - match: {hits.hits.0._id: "2"}
-  - set: { hits.hits.0._score: score_0 }
-  - match: {hits.hits.0.fields.name.0: "moose.jpg"}
-  - match: {hits.hits.1._id: "3"}
-  - set: { hits.hits.1._score: score_1 }
-  - match: {hits.hits.1.fields.name.0: "rabbit.jpg"}
-
-  - do:
-      search:
-        index: test_dot_product
-        body:
-          fields: [ "name" ]
-          knn:
-            field: cosine_vector
-            query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ]
-            k: 2
-            num_candidates: 3
-
-  - match: {hits.total.value: 2}
-  - match: {hits.hits.0._id: "2"}
-  - close_to: { hits.hits.0._score: { value: $score_0, error: 0.00001 } }
-  - match: {hits.hits.0.fields.name.0: "moose.jpg"}
-  - match: {hits.hits.1._id: "3"}
-  - close_to: { hits.hits.1._score: { value: $score_1, error: 0.00001 } }
-  - match: {hits.hits.1.fields.name.0: "rabbit.jpg"}
----
-"kNN search fails with non-normalized dot-product in older versions":
-  - skip:
-      version: '8.10.99 - '
-      reason: 'dense_vector auto-normalized dot_product in 8.11'
-
-  - do:
-      indices.create:
-        index: test_failing_dot_product
-        body:
-          mappings:
-            properties:
-              dot_product_vector:
-                type: dense_vector
-                dims: 5
-                index: true
-                similarity: dot_product
-
-  - do:
-      catch: bad_request
-      index:
-        index: test_failing_dot_product
-        id: "1"
-        body:
-          name: cow.jpg
-          dot_product_vector: [ 230.0, 300.33, -34.8988, 15.555, -200.0 ]
-

+ 19 - 79
server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

@@ -26,7 +26,6 @@ import org.apache.lucene.search.KnnByteVectorQuery;
 import org.apache.lucene.search.KnnFloatVectorQuery;
 import org.apache.lucene.search.Query;
 import org.apache.lucene.util.BytesRef;
-import org.apache.lucene.util.VectorUtil;
 import org.elasticsearch.common.xcontent.support.XContentMapValues;
 import org.elasticsearch.index.IndexVersion;
 import org.elasticsearch.index.fielddata.FieldDataContext;
@@ -57,7 +56,6 @@ import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.nio.ByteOrder;
 import java.time.ZoneId;
-import java.util.Arrays;
 import java.util.Locale;
 import java.util.Map;
 import java.util.Objects;
@@ -71,10 +69,8 @@ import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpect
  * A {@link FieldMapper} for indexing a dense vector of floats.
  */
 public class DenseVectorFieldMapper extends FieldMapper {
-    private static final float EPS = 1e-4f;
     public static final IndexVersion MAGNITUDE_STORED_INDEX_VERSION = IndexVersion.V_7_5_0;
     public static final IndexVersion INDEXED_BY_DEFAULT_INDEX_VERSION = IndexVersion.V_8_11_0;
-    public static final IndexVersion DOT_PRODUCT_AUTO_NORMALIZED = IndexVersion.V_8_11_0;
     public static final IndexVersion LITTLE_ENDIAN_FLOAT_STORED_INDEX_VERSION = IndexVersion.V_8_9_0;
 
     public static final String CONTENT_TYPE = "dense_vector";
@@ -325,7 +321,6 @@ public class DenseVectorFieldMapper extends FieldMapper {
 
             @Override
             void checkVectorMagnitude(
-                IndexVersion indexVersion,
                 VectorSimilarity similarity,
                 Function<StringBuilder, StringBuilder> appender,
                 float squaredMagnitude
@@ -388,12 +383,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
                     squaredMagnitude += value * value;
                 }
                 fieldMapper.checkDimensionMatches(index, context);
-                checkVectorMagnitude(
-                    fieldMapper.indexCreatedVersion,
-                    fieldMapper.similarity,
-                    errorByteElementsAppender(vector),
-                    squaredMagnitude
-                );
+                checkVectorMagnitude(fieldMapper.similarity, errorByteElementsAppender(vector), squaredMagnitude);
                 return createKnnVectorField(fieldMapper.fieldType().name(), vector, fieldMapper.similarity.function);
             }
 
@@ -485,31 +475,20 @@ public class DenseVectorFieldMapper extends FieldMapper {
 
             @Override
             void checkVectorMagnitude(
-                IndexVersion indexVersion,
                 VectorSimilarity similarity,
                 Function<StringBuilder, StringBuilder> appender,
                 float squaredMagnitude
             ) {
                 StringBuilder errorBuilder = null;
 
-                if (indexVersion.before(DOT_PRODUCT_AUTO_NORMALIZED)) {
-                    if (similarity == VectorSimilarity.DOT_PRODUCT && Math.abs(squaredMagnitude - 1.0f) > EPS) {
-                        errorBuilder = new StringBuilder(
-                            "The [" + VectorSimilarity.DOT_PRODUCT + "] similarity can only be used with unit-length vectors."
-                        );
-                    }
-                    if (similarity == VectorSimilarity.COSINE && Math.sqrt(squaredMagnitude) == 0.0f) {
-                        errorBuilder = new StringBuilder(
-                            "The [" + similarity + "] similarity does not support vectors with zero magnitude."
-                        );
-                    }
-                } else {
-                    if ((similarity == VectorSimilarity.COSINE || similarity == VectorSimilarity.DOT_PRODUCT)
-                        && Math.sqrt(squaredMagnitude) == 0.0f) {
-                        errorBuilder = new StringBuilder(
-                            "The [" + similarity + "] similarity does not support vectors with zero magnitude."
-                        );
-                    }
+                if (similarity == VectorSimilarity.DOT_PRODUCT && Math.abs(squaredMagnitude - 1.0f) > 1e-4f) {
+                    errorBuilder = new StringBuilder(
+                        "The [" + VectorSimilarity.DOT_PRODUCT + "] similarity can only be used with unit-length vectors."
+                    );
+                } else if (similarity == VectorSimilarity.COSINE && Math.sqrt(squaredMagnitude) == 0.0f) {
+                    errorBuilder = new StringBuilder(
+                        "The [" + VectorSimilarity.COSINE + "] similarity does not support vectors with zero magnitude."
+                    );
                 }
 
                 if (errorBuilder != null) {
@@ -532,15 +511,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
                 }
                 fieldMapper.checkDimensionMatches(index, context);
                 checkVectorBounds(vector);
-                checkVectorMagnitude(
-                    fieldMapper.indexCreatedVersion,
-                    fieldMapper.similarity,
-                    errorFloatElementsAppender(vector),
-                    squaredMagnitude
-                );
-                if (fieldMapper.indexCreatedVersion.onOrAfter(DOT_PRODUCT_AUTO_NORMALIZED)) {
-                    fieldMapper.similarity.floatPreprocessing(vector, squaredMagnitude);
-                }
+                checkVectorMagnitude(fieldMapper.similarity, errorFloatElementsAppender(vector), squaredMagnitude);
                 return createKnnVectorField(fieldMapper.fieldType().name(), vector, fieldMapper.similarity.function);
             }
 
@@ -598,7 +569,6 @@ public class DenseVectorFieldMapper extends FieldMapper {
         public abstract void checkVectorBounds(float[] vector);
 
         abstract void checkVectorMagnitude(
-            IndexVersion indexVersion,
             VectorSimilarity similarity,
             Function<StringBuilder, StringBuilder> errorElementsAppender,
             float squaredMagnitude
@@ -717,21 +687,6 @@ public class DenseVectorFieldMapper extends FieldMapper {
                     case FLOAT -> (1 + similarity) / 2f;
                 };
             }
-
-            @Override
-            void floatPreprocessing(float[] vector, float squareSum) {
-                if (squareSum == 0) {
-                    throw new IllegalArgumentException("Cannot normalize a zero-length vector");
-                }
-                // Vector already has a magnitude have `1`
-                if (Math.abs(squareSum - 1.0f) < EPS) {
-                    return;
-                }
-                float length = (float) Math.sqrt(squareSum);
-                for (int i = 0; i < vector.length; i++) {
-                    vector[i] /= length;
-                }
-            }
         };
 
         public final VectorSimilarityFunction function;
@@ -746,8 +701,6 @@ public class DenseVectorFieldMapper extends FieldMapper {
         }
 
         abstract float score(float similarity, ElementType elementType, int dim);
-
-        void floatPreprocessing(float[] vector, float squareSum) {}
     }
 
     private abstract static class IndexOptions implements ToXContent {
@@ -906,13 +859,11 @@ public class DenseVectorFieldMapper extends FieldMapper {
             }
 
             if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
-                int squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
-                elementType.checkVectorMagnitude(
-                    indexVersionCreated,
-                    similarity,
-                    elementType.errorByteElementsAppender(queryVector),
-                    squaredMagnitude
-                );
+                float squaredMagnitude = 0.0f;
+                for (byte b : queryVector) {
+                    squaredMagnitude += b * b;
+                }
+                elementType.checkVectorMagnitude(similarity, elementType.errorByteElementsAppender(queryVector), squaredMagnitude);
             }
             Query knnQuery = new KnnByteVectorQuery(name(), queryVector, numCands, filter);
             if (similarityThreshold != null) {
@@ -940,22 +891,11 @@ public class DenseVectorFieldMapper extends FieldMapper {
             elementType.checkVectorBounds(queryVector);
 
             if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
-                float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
-                elementType.checkVectorMagnitude(
-                    indexVersionCreated,
-                    similarity,
-                    elementType.errorFloatElementsAppender(queryVector),
-                    squaredMagnitude
-                );
-                // We don't want to normalize the original query vector.
-                // It mutates it in place and might cause down stream weirdness
-                // Instead we copy the value and then normalize that copy
-                if (similarity == VectorSimilarity.DOT_PRODUCT
-                    && elementType == ElementType.FLOAT
-                    && indexVersionCreated.onOrAfter(DOT_PRODUCT_AUTO_NORMALIZED)) {
-                    queryVector = Arrays.copyOf(queryVector, queryVector.length);
-                    similarity.floatPreprocessing(queryVector, squaredMagnitude);
+                float squaredMagnitude = 0.0f;
+                for (float e : queryVector) {
+                    squaredMagnitude += e * e;
                 }
+                elementType.checkVectorMagnitude(similarity, elementType.errorFloatElementsAppender(queryVector), squaredMagnitude);
             }
             Query knnQuery = switch (elementType) {
                 case BYTE -> {

+ 19 - 3
server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java

@@ -320,7 +320,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
                 b -> b.field("type", "dense_vector").field("dims", 3).field("index", true).field("similarity", VectorSimilarity.DOT_PRODUCT)
             )
         );
-        float[] vector = { 0f, 0f, 0f };
+        float[] vector = { -12.1f, 2.7f, -4 };
         DocumentParsingException e = expectThrows(
             DocumentParsingException.class,
             () -> mapper.parse(source(b -> b.array("field", vector)))
@@ -329,7 +329,23 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
         assertThat(
             e.getCause().getMessage(),
             containsString(
-                "The [dot_product] similarity does not support vectors with zero magnitude. Preview of invalid vector: [0.0, 0.0, 0.0]"
+                "The [dot_product] similarity can only be used with unit-length vectors. Preview of invalid vector: [-12.1, 2.7, -4.0]"
+            )
+        );
+
+        DocumentMapper mapperWithLargerDim = createDocumentMapper(
+            fieldMapping(
+                b -> b.field("type", "dense_vector").field("dims", 6).field("index", true).field("similarity", VectorSimilarity.DOT_PRODUCT)
+            )
+        );
+        float[] largerVector = { -12.1f, 2.7f, -4, 1.05f, 10.0f, 29.9f };
+        e = expectThrows(DocumentParsingException.class, () -> mapperWithLargerDim.parse(source(b -> b.array("field", largerVector))));
+        assertNotNull(e.getCause());
+        assertThat(
+            e.getCause().getMessage(),
+            containsString(
+                "The [dot_product] similarity can only be used with unit-length vectors. "
+                    + "Preview of invalid vector: [-12.1, 2.7, -4.0, 1.05, 10.0, ...]"
             )
         );
     }
@@ -499,7 +515,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
         assertNull(denseVectorFieldType.getSimilarity());
     }
 
-    public void testParamsBeforeIndexByDefault() throws Exception {
+    public void testtParamsBeforeIndexByDefault() throws Exception {
         DocumentMapper documentMapper = createDocumentMapper(INDEXED_BY_DEFAULT_PREVIOUS_INDEX_VERSION, fieldMapping(b -> {
             b.field("type", "dense_vector").field("dims", 3).field("index", true).field("similarity", "dot_product");
         }));

+ 2 - 2
server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java

@@ -137,9 +137,9 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
         );
         e = expectThrows(
             IllegalArgumentException.class,
-            () -> dotProductField.createKnnQuery(new float[] { 0.0f, 0.0f, 0.0f }, 10, null, null)
+            () -> dotProductField.createKnnQuery(new float[] { 0.3f, 0.1f, 1.0f }, 10, null, null)
         );
-        assertThat(e.getMessage(), containsString("The [dot_product] similarity does not support vectors with zero magnitude."));
+        assertThat(e.getMessage(), containsString("The [dot_product] similarity can only be used with unit-length vectors."));
 
         DenseVectorFieldType cosineField = new DenseVectorFieldType(
             "f",

+ 6 - 6
test/yaml-rest-runner/src/main/java/org/elasticsearch/test/rest/yaml/section/CloseToAssertion.java

@@ -36,14 +36,14 @@ public class CloseToAssertion extends Assertion {
                 throw new IllegalArgumentException("expected a map with value and error but got a map with " + map.size() + " fields");
             }
             Object valObj = map.get("value");
-            if (valObj == null) {
-                throw new IllegalArgumentException("value is missing");
+            if (valObj instanceof Number == false) {
+                throw new IllegalArgumentException("value is missing or not a number");
             }
             Object errObj = map.get("error");
             if (errObj instanceof Number == false) {
                 throw new IllegalArgumentException("error is missing or not a number");
             }
-            return new CloseToAssertion(location, fieldValueTuple.v1(), valObj, ((Number) errObj).doubleValue());
+            return new CloseToAssertion(location, fieldValueTuple.v1(), ((Number) valObj).doubleValue(), ((Number) errObj).doubleValue());
         } else {
             throw new IllegalArgumentException(
                 "expected a map with value and error but got " + fieldValueTuple.v2().getClass().getSimpleName()
@@ -56,7 +56,7 @@ public class CloseToAssertion extends Assertion {
 
     private final double error;
 
-    public CloseToAssertion(XContentLocation location, String field, Object expectedValue, Double error) {
+    public CloseToAssertion(XContentLocation location, String field, Double expectedValue, Double error) {
         super(location, field, expectedValue);
         this.error = error;
     }
@@ -69,9 +69,9 @@ public class CloseToAssertion extends Assertion {
     protected void doAssert(Object actualValue, Object expectedValue) {
         logger.trace("assert that [{}] is close to [{}] with error [{}] (field [{}])", actualValue, expectedValue, error, getField());
         if (actualValue instanceof Number actualValueNumber) {
-            assertThat(actualValueNumber.doubleValue(), closeTo(((Number) expectedValue).doubleValue(), error));
+            assertThat(actualValueNumber.doubleValue(), closeTo((Double) expectedValue, error));
         } else {
-            throw new AssertionError("expected a value close to " + expectedValue + " but got " + actualValue + ", which is not a number");
+            throw new AssertionError("excpected a value close to " + expectedValue + " but got " + actualValue + ", which is not a number");
         }
     }
 }

+ 1 - 1
test/yaml-rest-runner/src/test/java/org/elasticsearch/test/rest/yaml/section/AssertionTests.java

@@ -169,7 +169,7 @@ public class AssertionTests extends AbstractClientYamlTestFragmentParserTestCase
 
         parser = createParser(YamlXContent.yamlXContent, "{ field: { foo: 13, bar: 15 } }");
         exception = expectThrows(IllegalArgumentException.class, () -> CloseToAssertion.parse(parser));
-        assertThat(exception.getMessage(), equalTo("value is missing"));
+        assertThat(exception.getMessage(), equalTo("value is missing or not a number"));
     }
 
     public void testExists() throws IOException {