Browse Source

Revert "Auto-normalize dot_product vectors at index & query (#98944)" (#99421)

This reverts commit 7b9c367aeb84493ec625c1125d0c50d63c61bc3a.
Benjamin Trent 2 years ago
parent
commit
83b70e37ef

+ 0 - 5
docs/changelog/98944.yaml

@@ -1,5 +0,0 @@
-pr: 98944
-summary: Auto-normalize `dot_product` vectors at index & query
-area: Vector Search
-type: enhancement
-issues: []

+ 2 - 2
docs/reference/how-to/knn-search.asciidoc

@@ -21,8 +21,8 @@ options.
 The `cosine` option accepts any float vector and computes the cosine
 The `cosine` option accepts any float vector and computes the cosine
 similarity. While this is convenient for testing, it's not the most efficient
 similarity. While this is convenient for testing, it's not the most efficient
 approach. Instead, we recommend using the `dot_product` option to compute the
 approach. Instead, we recommend using the `dot_product` option to compute the
-similarity. When using `dot_product`, all vectors are normalized during index to have
-a magnitude of 1. The `dot_product` option is significantly faster, since it
+similarity. To use `dot_product`, all vectors need to be normalized in advance
+to have length 1. The `dot_product` option is significantly faster, since it
 avoids performing extra vector length computations during the search.
 avoids performing extra vector length computations during the search.
 
 
 [discrete]
 [discrete]

+ 5 - 5
docs/reference/mapping/types/dense-vector.asciidoc

@@ -163,9 +163,9 @@ Computes the dot product of two vectors. This option provides an optimized way
 to perform cosine similarity. The constraints and computed score are defined
 to perform cosine similarity. The constraints and computed score are defined
 by `element_type`.
 by `element_type`.
 +
 +
-When `element_type` is `float`, all vectors are automatically converted to unit length, including both
-document and query vectors. Consequently, `dot_product` does not allow vectors with a zero magnitude.
-The document `_score` is computed as `(1 + dot_product(query, vector)) / 2`.
+When `element_type` is `float`, all vectors must be unit length, including both
+document and query vectors. The document `_score` is computed as
+`(1 + dot_product(query, vector)) / 2`.
 +
 +
 When `element_type` is `byte`, all vectors must have the same
 When `element_type` is `byte`, all vectors must have the same
 length including both document and query vectors or results will be inaccurate.
 length including both document and query vectors or results will be inaccurate.
@@ -175,9 +175,9 @@ where `dims` is the number of dimensions per vector.
 
 
 `cosine`:::
 `cosine`:::
 Computes the cosine similarity. Note that the most efficient way to perform
 Computes the cosine similarity. Note that the most efficient way to perform
-cosine similarity is to have all vectors normalized to unit length, and instead use
+cosine similarity is to normalize all vectors to unit length, and instead use
 `dot_product`. You should only use `cosine` if you need to preserve the
 `dot_product`. You should only use `cosine` if you need to preserve the
-original vectors and cannot allow Elasticsearch to normalize them. The document `_score`
+original vectors and cannot normalize them in advance. The document `_score`
 is computed as `(1 + cosine(query, vector)) / 2`. The `cosine` similarity does
 is computed as `(1 + cosine(query, vector)) / 2`. The `cosine` similarity does
 not allow vectors with zero magnitude, since cosine is not defined in this
 not allow vectors with zero magnitude, since cosine is not defined in this
 case.
 case.

+ 0 - 120
rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml

@@ -368,123 +368,3 @@ setup:
             filter: {"term": {"name": "cow.jpg"}}
             filter: {"term": {"name": "cow.jpg"}}
 
 
   - length: {hits.hits: 0}
   - length: {hits.hits: 0}
----
-"kNN search with dot-product auto-normalized":
-  - skip:
-      features: close_to
-      version: ' - 8.10.99'
-      reason: 'dense_vector auto-normalized dot_product in 8.11'
-
-  - do:
-      indices.create:
-        index: test_dot_product
-        body:
-          mappings:
-            properties:
-              name:
-                type: keyword
-              dot_product_vector:
-                type: dense_vector
-                dims: 5
-                index: true
-                similarity: dot_product
-              cosine_vector:
-                type: dense_vector
-                dims: 5
-                index: true
-                similarity: cosine
-
-  - do:
-      index:
-        index: test_dot_product
-        id: "1"
-        body:
-          name: cow.jpg
-          dot_product_vector: [ 230.0, 300.33, -34.8988, 15.555, -200.0 ]
-          cosine_vector: [ 230.0, 300.33, -34.8988, 15.555, -200.0 ]
-
-  - do:
-      index:
-        index: test_dot_product
-        id: "2"
-        body:
-          name: moose.jpg
-          dot_product_vector: [ -0.5, 100.0, -13, 14.8, -156.0 ]
-          cosine_vector: [ -0.5, 100.0, -13, 14.8, -156.0 ]
-
-  - do:
-      index:
-        index: test_dot_product
-        id: "3"
-        body:
-          name: rabbit.jpg
-          dot_product_vector: [ 0.5, 111.3, -13.0, 14.8, -156.0 ]
-          cosine_vector: [ 0.5, 111.3, -13.0, 14.8, -156.0 ]
-
-  - do:
-      indices.refresh: { }
-
-  - do:
-      search:
-        index: test_dot_product
-        body:
-          fields: [ "name" ]
-          knn:
-            field: dot_product_vector
-            query_vector: [-0.5, 90.0, -10, 14.8, -156.0]
-            k: 2
-            num_candidates: 3
-
-  - match: {hits.total.value: 2}
-  - match: {hits.hits.0._id: "2"}
-  - set: { hits.hits.0._score: score_0 }
-  - match: {hits.hits.0.fields.name.0: "moose.jpg"}
-  - match: {hits.hits.1._id: "3"}
-  - set: { hits.hits.1._score: score_1 }
-  - match: {hits.hits.1.fields.name.0: "rabbit.jpg"}
-
-  - do:
-      search:
-        index: test_dot_product
-        body:
-          fields: [ "name" ]
-          knn:
-            field: cosine_vector
-            query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ]
-            k: 2
-            num_candidates: 3
-
-  - match: {hits.total.value: 2}
-  - match: {hits.hits.0._id: "2"}
-  - close_to: { hits.hits.0._score: { value: $score_0, error: 0.00001 } }
-  - match: {hits.hits.0.fields.name.0: "moose.jpg"}
-  - match: {hits.hits.1._id: "3"}
-  - close_to: { hits.hits.1._score: { value: $score_1, error: 0.00001 } }
-  - match: {hits.hits.1.fields.name.0: "rabbit.jpg"}
----
-"kNN search fails with non-normalized dot-product in older versions":
-  - skip:
-      version: '8.10.99 - '
-      reason: 'dense_vector auto-normalized dot_product in 8.11'
-
-  - do:
-      indices.create:
-        index: test_failing_dot_product
-        body:
-          mappings:
-            properties:
-              dot_product_vector:
-                type: dense_vector
-                dims: 5
-                index: true
-                similarity: dot_product
-
-  - do:
-      catch: bad_request
-      index:
-        index: test_failing_dot_product
-        id: "1"
-        body:
-          name: cow.jpg
-          dot_product_vector: [ 230.0, 300.33, -34.8988, 15.555, -200.0 ]
-

+ 19 - 79
server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java

@@ -26,7 +26,6 @@ import org.apache.lucene.search.KnnByteVectorQuery;
 import org.apache.lucene.search.KnnFloatVectorQuery;
 import org.apache.lucene.search.KnnFloatVectorQuery;
 import org.apache.lucene.search.Query;
 import org.apache.lucene.search.Query;
 import org.apache.lucene.util.BytesRef;
 import org.apache.lucene.util.BytesRef;
-import org.apache.lucene.util.VectorUtil;
 import org.elasticsearch.common.xcontent.support.XContentMapValues;
 import org.elasticsearch.common.xcontent.support.XContentMapValues;
 import org.elasticsearch.index.IndexVersion;
 import org.elasticsearch.index.IndexVersion;
 import org.elasticsearch.index.fielddata.FieldDataContext;
 import org.elasticsearch.index.fielddata.FieldDataContext;
@@ -57,7 +56,6 @@ import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.nio.ByteBuffer;
 import java.nio.ByteOrder;
 import java.nio.ByteOrder;
 import java.time.ZoneId;
 import java.time.ZoneId;
-import java.util.Arrays;
 import java.util.Locale;
 import java.util.Locale;
 import java.util.Map;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Objects;
@@ -71,10 +69,8 @@ import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpect
  * A {@link FieldMapper} for indexing a dense vector of floats.
  * A {@link FieldMapper} for indexing a dense vector of floats.
  */
  */
 public class DenseVectorFieldMapper extends FieldMapper {
 public class DenseVectorFieldMapper extends FieldMapper {
-    private static final float EPS = 1e-4f;
     public static final IndexVersion MAGNITUDE_STORED_INDEX_VERSION = IndexVersion.V_7_5_0;
     public static final IndexVersion MAGNITUDE_STORED_INDEX_VERSION = IndexVersion.V_7_5_0;
     public static final IndexVersion INDEXED_BY_DEFAULT_INDEX_VERSION = IndexVersion.V_8_11_0;
     public static final IndexVersion INDEXED_BY_DEFAULT_INDEX_VERSION = IndexVersion.V_8_11_0;
-    public static final IndexVersion DOT_PRODUCT_AUTO_NORMALIZED = IndexVersion.V_8_11_0;
     public static final IndexVersion LITTLE_ENDIAN_FLOAT_STORED_INDEX_VERSION = IndexVersion.V_8_9_0;
     public static final IndexVersion LITTLE_ENDIAN_FLOAT_STORED_INDEX_VERSION = IndexVersion.V_8_9_0;
 
 
     public static final String CONTENT_TYPE = "dense_vector";
     public static final String CONTENT_TYPE = "dense_vector";
@@ -325,7 +321,6 @@ public class DenseVectorFieldMapper extends FieldMapper {
 
 
             @Override
             @Override
             void checkVectorMagnitude(
             void checkVectorMagnitude(
-                IndexVersion indexVersion,
                 VectorSimilarity similarity,
                 VectorSimilarity similarity,
                 Function<StringBuilder, StringBuilder> appender,
                 Function<StringBuilder, StringBuilder> appender,
                 float squaredMagnitude
                 float squaredMagnitude
@@ -388,12 +383,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
                     squaredMagnitude += value * value;
                     squaredMagnitude += value * value;
                 }
                 }
                 fieldMapper.checkDimensionMatches(index, context);
                 fieldMapper.checkDimensionMatches(index, context);
-                checkVectorMagnitude(
-                    fieldMapper.indexCreatedVersion,
-                    fieldMapper.similarity,
-                    errorByteElementsAppender(vector),
-                    squaredMagnitude
-                );
+                checkVectorMagnitude(fieldMapper.similarity, errorByteElementsAppender(vector), squaredMagnitude);
                 return createKnnVectorField(fieldMapper.fieldType().name(), vector, fieldMapper.similarity.function);
                 return createKnnVectorField(fieldMapper.fieldType().name(), vector, fieldMapper.similarity.function);
             }
             }
 
 
@@ -485,31 +475,20 @@ public class DenseVectorFieldMapper extends FieldMapper {
 
 
             @Override
             @Override
             void checkVectorMagnitude(
             void checkVectorMagnitude(
-                IndexVersion indexVersion,
                 VectorSimilarity similarity,
                 VectorSimilarity similarity,
                 Function<StringBuilder, StringBuilder> appender,
                 Function<StringBuilder, StringBuilder> appender,
                 float squaredMagnitude
                 float squaredMagnitude
             ) {
             ) {
                 StringBuilder errorBuilder = null;
                 StringBuilder errorBuilder = null;
 
 
-                if (indexVersion.before(DOT_PRODUCT_AUTO_NORMALIZED)) {
-                    if (similarity == VectorSimilarity.DOT_PRODUCT && Math.abs(squaredMagnitude - 1.0f) > EPS) {
-                        errorBuilder = new StringBuilder(
-                            "The [" + VectorSimilarity.DOT_PRODUCT + "] similarity can only be used with unit-length vectors."
-                        );
-                    }
-                    if (similarity == VectorSimilarity.COSINE && Math.sqrt(squaredMagnitude) == 0.0f) {
-                        errorBuilder = new StringBuilder(
-                            "The [" + similarity + "] similarity does not support vectors with zero magnitude."
-                        );
-                    }
-                } else {
-                    if ((similarity == VectorSimilarity.COSINE || similarity == VectorSimilarity.DOT_PRODUCT)
-                        && Math.sqrt(squaredMagnitude) == 0.0f) {
-                        errorBuilder = new StringBuilder(
-                            "The [" + similarity + "] similarity does not support vectors with zero magnitude."
-                        );
-                    }
+                if (similarity == VectorSimilarity.DOT_PRODUCT && Math.abs(squaredMagnitude - 1.0f) > 1e-4f) {
+                    errorBuilder = new StringBuilder(
+                        "The [" + VectorSimilarity.DOT_PRODUCT + "] similarity can only be used with unit-length vectors."
+                    );
+                } else if (similarity == VectorSimilarity.COSINE && Math.sqrt(squaredMagnitude) == 0.0f) {
+                    errorBuilder = new StringBuilder(
+                        "The [" + VectorSimilarity.COSINE + "] similarity does not support vectors with zero magnitude."
+                    );
                 }
                 }
 
 
                 if (errorBuilder != null) {
                 if (errorBuilder != null) {
@@ -532,15 +511,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
                 }
                 }
                 fieldMapper.checkDimensionMatches(index, context);
                 fieldMapper.checkDimensionMatches(index, context);
                 checkVectorBounds(vector);
                 checkVectorBounds(vector);
-                checkVectorMagnitude(
-                    fieldMapper.indexCreatedVersion,
-                    fieldMapper.similarity,
-                    errorFloatElementsAppender(vector),
-                    squaredMagnitude
-                );
-                if (fieldMapper.indexCreatedVersion.onOrAfter(DOT_PRODUCT_AUTO_NORMALIZED)) {
-                    fieldMapper.similarity.floatPreprocessing(vector, squaredMagnitude);
-                }
+                checkVectorMagnitude(fieldMapper.similarity, errorFloatElementsAppender(vector), squaredMagnitude);
                 return createKnnVectorField(fieldMapper.fieldType().name(), vector, fieldMapper.similarity.function);
                 return createKnnVectorField(fieldMapper.fieldType().name(), vector, fieldMapper.similarity.function);
             }
             }
 
 
@@ -598,7 +569,6 @@ public class DenseVectorFieldMapper extends FieldMapper {
         public abstract void checkVectorBounds(float[] vector);
         public abstract void checkVectorBounds(float[] vector);
 
 
         abstract void checkVectorMagnitude(
         abstract void checkVectorMagnitude(
-            IndexVersion indexVersion,
             VectorSimilarity similarity,
             VectorSimilarity similarity,
             Function<StringBuilder, StringBuilder> errorElementsAppender,
             Function<StringBuilder, StringBuilder> errorElementsAppender,
             float squaredMagnitude
             float squaredMagnitude
@@ -717,21 +687,6 @@ public class DenseVectorFieldMapper extends FieldMapper {
                     case FLOAT -> (1 + similarity) / 2f;
                     case FLOAT -> (1 + similarity) / 2f;
                 };
                 };
             }
             }
-
-            @Override
-            void floatPreprocessing(float[] vector, float squareSum) {
-                if (squareSum == 0) {
-                    throw new IllegalArgumentException("Cannot normalize a zero-length vector");
-                }
-                // Vector already has a magnitude have `1`
-                if (Math.abs(squareSum - 1.0f) < EPS) {
-                    return;
-                }
-                float length = (float) Math.sqrt(squareSum);
-                for (int i = 0; i < vector.length; i++) {
-                    vector[i] /= length;
-                }
-            }
         };
         };
 
 
         public final VectorSimilarityFunction function;
         public final VectorSimilarityFunction function;
@@ -746,8 +701,6 @@ public class DenseVectorFieldMapper extends FieldMapper {
         }
         }
 
 
         abstract float score(float similarity, ElementType elementType, int dim);
         abstract float score(float similarity, ElementType elementType, int dim);
-
-        void floatPreprocessing(float[] vector, float squareSum) {}
     }
     }
 
 
     private abstract static class IndexOptions implements ToXContent {
     private abstract static class IndexOptions implements ToXContent {
@@ -906,13 +859,11 @@ public class DenseVectorFieldMapper extends FieldMapper {
             }
             }
 
 
             if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
             if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
-                int squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
-                elementType.checkVectorMagnitude(
-                    indexVersionCreated,
-                    similarity,
-                    elementType.errorByteElementsAppender(queryVector),
-                    squaredMagnitude
-                );
+                float squaredMagnitude = 0.0f;
+                for (byte b : queryVector) {
+                    squaredMagnitude += b * b;
+                }
+                elementType.checkVectorMagnitude(similarity, elementType.errorByteElementsAppender(queryVector), squaredMagnitude);
             }
             }
             Query knnQuery = new KnnByteVectorQuery(name(), queryVector, numCands, filter);
             Query knnQuery = new KnnByteVectorQuery(name(), queryVector, numCands, filter);
             if (similarityThreshold != null) {
             if (similarityThreshold != null) {
@@ -940,22 +891,11 @@ public class DenseVectorFieldMapper extends FieldMapper {
             elementType.checkVectorBounds(queryVector);
             elementType.checkVectorBounds(queryVector);
 
 
             if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
             if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
-                float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
-                elementType.checkVectorMagnitude(
-                    indexVersionCreated,
-                    similarity,
-                    elementType.errorFloatElementsAppender(queryVector),
-                    squaredMagnitude
-                );
-                // We don't want to normalize the original query vector.
-                // It mutates it in place and might cause down stream weirdness
-                // Instead we copy the value and then normalize that copy
-                if (similarity == VectorSimilarity.DOT_PRODUCT
-                    && elementType == ElementType.FLOAT
-                    && indexVersionCreated.onOrAfter(DOT_PRODUCT_AUTO_NORMALIZED)) {
-                    queryVector = Arrays.copyOf(queryVector, queryVector.length);
-                    similarity.floatPreprocessing(queryVector, squaredMagnitude);
+                float squaredMagnitude = 0.0f;
+                for (float e : queryVector) {
+                    squaredMagnitude += e * e;
                 }
                 }
+                elementType.checkVectorMagnitude(similarity, elementType.errorFloatElementsAppender(queryVector), squaredMagnitude);
             }
             }
             Query knnQuery = switch (elementType) {
             Query knnQuery = switch (elementType) {
                 case BYTE -> {
                 case BYTE -> {

+ 19 - 3
server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapperTests.java

@@ -320,7 +320,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
                 b -> b.field("type", "dense_vector").field("dims", 3).field("index", true).field("similarity", VectorSimilarity.DOT_PRODUCT)
                 b -> b.field("type", "dense_vector").field("dims", 3).field("index", true).field("similarity", VectorSimilarity.DOT_PRODUCT)
             )
             )
         );
         );
-        float[] vector = { 0f, 0f, 0f };
+        float[] vector = { -12.1f, 2.7f, -4 };
         DocumentParsingException e = expectThrows(
         DocumentParsingException e = expectThrows(
             DocumentParsingException.class,
             DocumentParsingException.class,
             () -> mapper.parse(source(b -> b.array("field", vector)))
             () -> mapper.parse(source(b -> b.array("field", vector)))
@@ -329,7 +329,23 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
         assertThat(
         assertThat(
             e.getCause().getMessage(),
             e.getCause().getMessage(),
             containsString(
             containsString(
-                "The [dot_product] similarity does not support vectors with zero magnitude. Preview of invalid vector: [0.0, 0.0, 0.0]"
+                "The [dot_product] similarity can only be used with unit-length vectors. Preview of invalid vector: [-12.1, 2.7, -4.0]"
+            )
+        );
+
+        DocumentMapper mapperWithLargerDim = createDocumentMapper(
+            fieldMapping(
+                b -> b.field("type", "dense_vector").field("dims", 6).field("index", true).field("similarity", VectorSimilarity.DOT_PRODUCT)
+            )
+        );
+        float[] largerVector = { -12.1f, 2.7f, -4, 1.05f, 10.0f, 29.9f };
+        e = expectThrows(DocumentParsingException.class, () -> mapperWithLargerDim.parse(source(b -> b.array("field", largerVector))));
+        assertNotNull(e.getCause());
+        assertThat(
+            e.getCause().getMessage(),
+            containsString(
+                "The [dot_product] similarity can only be used with unit-length vectors. "
+                    + "Preview of invalid vector: [-12.1, 2.7, -4.0, 1.05, 10.0, ...]"
             )
             )
         );
         );
     }
     }
@@ -499,7 +515,7 @@ public class DenseVectorFieldMapperTests extends MapperTestCase {
         assertNull(denseVectorFieldType.getSimilarity());
         assertNull(denseVectorFieldType.getSimilarity());
     }
     }
 
 
-    public void testParamsBeforeIndexByDefault() throws Exception {
+    public void testtParamsBeforeIndexByDefault() throws Exception {
         DocumentMapper documentMapper = createDocumentMapper(INDEXED_BY_DEFAULT_PREVIOUS_INDEX_VERSION, fieldMapping(b -> {
         DocumentMapper documentMapper = createDocumentMapper(INDEXED_BY_DEFAULT_PREVIOUS_INDEX_VERSION, fieldMapping(b -> {
             b.field("type", "dense_vector").field("dims", 3).field("index", true).field("similarity", "dot_product");
             b.field("type", "dense_vector").field("dims", 3).field("index", true).field("similarity", "dot_product");
         }));
         }));

+ 2 - 2
server/src/test/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldTypeTests.java

@@ -137,9 +137,9 @@ public class DenseVectorFieldTypeTests extends FieldTypeTestCase {
         );
         );
         e = expectThrows(
         e = expectThrows(
             IllegalArgumentException.class,
             IllegalArgumentException.class,
-            () -> dotProductField.createKnnQuery(new float[] { 0.0f, 0.0f, 0.0f }, 10, null, null)
+            () -> dotProductField.createKnnQuery(new float[] { 0.3f, 0.1f, 1.0f }, 10, null, null)
         );
         );
-        assertThat(e.getMessage(), containsString("The [dot_product] similarity does not support vectors with zero magnitude."));
+        assertThat(e.getMessage(), containsString("The [dot_product] similarity can only be used with unit-length vectors."));
 
 
         DenseVectorFieldType cosineField = new DenseVectorFieldType(
         DenseVectorFieldType cosineField = new DenseVectorFieldType(
             "f",
             "f",

+ 6 - 6
test/yaml-rest-runner/src/main/java/org/elasticsearch/test/rest/yaml/section/CloseToAssertion.java

@@ -36,14 +36,14 @@ public class CloseToAssertion extends Assertion {
                 throw new IllegalArgumentException("expected a map with value and error but got a map with " + map.size() + " fields");
                 throw new IllegalArgumentException("expected a map with value and error but got a map with " + map.size() + " fields");
             }
             }
             Object valObj = map.get("value");
             Object valObj = map.get("value");
-            if (valObj == null) {
-                throw new IllegalArgumentException("value is missing");
+            if (valObj instanceof Number == false) {
+                throw new IllegalArgumentException("value is missing or not a number");
             }
             }
             Object errObj = map.get("error");
             Object errObj = map.get("error");
             if (errObj instanceof Number == false) {
             if (errObj instanceof Number == false) {
                 throw new IllegalArgumentException("error is missing or not a number");
                 throw new IllegalArgumentException("error is missing or not a number");
             }
             }
-            return new CloseToAssertion(location, fieldValueTuple.v1(), valObj, ((Number) errObj).doubleValue());
+            return new CloseToAssertion(location, fieldValueTuple.v1(), ((Number) valObj).doubleValue(), ((Number) errObj).doubleValue());
         } else {
         } else {
             throw new IllegalArgumentException(
             throw new IllegalArgumentException(
                 "expected a map with value and error but got " + fieldValueTuple.v2().getClass().getSimpleName()
                 "expected a map with value and error but got " + fieldValueTuple.v2().getClass().getSimpleName()
@@ -56,7 +56,7 @@ public class CloseToAssertion extends Assertion {
 
 
     private final double error;
     private final double error;
 
 
-    public CloseToAssertion(XContentLocation location, String field, Object expectedValue, Double error) {
+    public CloseToAssertion(XContentLocation location, String field, Double expectedValue, Double error) {
         super(location, field, expectedValue);
         super(location, field, expectedValue);
         this.error = error;
         this.error = error;
     }
     }
@@ -69,9 +69,9 @@ public class CloseToAssertion extends Assertion {
     protected void doAssert(Object actualValue, Object expectedValue) {
     protected void doAssert(Object actualValue, Object expectedValue) {
         logger.trace("assert that [{}] is close to [{}] with error [{}] (field [{}])", actualValue, expectedValue, error, getField());
         logger.trace("assert that [{}] is close to [{}] with error [{}] (field [{}])", actualValue, expectedValue, error, getField());
         if (actualValue instanceof Number actualValueNumber) {
         if (actualValue instanceof Number actualValueNumber) {
-            assertThat(actualValueNumber.doubleValue(), closeTo(((Number) expectedValue).doubleValue(), error));
+            assertThat(actualValueNumber.doubleValue(), closeTo((Double) expectedValue, error));
         } else {
         } else {
-            throw new AssertionError("expected a value close to " + expectedValue + " but got " + actualValue + ", which is not a number");
+            throw new AssertionError("excpected a value close to " + expectedValue + " but got " + actualValue + ", which is not a number");
         }
         }
     }
     }
 }
 }

+ 1 - 1
test/yaml-rest-runner/src/test/java/org/elasticsearch/test/rest/yaml/section/AssertionTests.java

@@ -169,7 +169,7 @@ public class AssertionTests extends AbstractClientYamlTestFragmentParserTestCase
 
 
         parser = createParser(YamlXContent.yamlXContent, "{ field: { foo: 13, bar: 15 } }");
         parser = createParser(YamlXContent.yamlXContent, "{ field: { foo: 13, bar: 15 } }");
         exception = expectThrows(IllegalArgumentException.class, () -> CloseToAssertion.parse(parser));
         exception = expectThrows(IllegalArgumentException.class, () -> CloseToAssertion.parse(parser));
-        assertThat(exception.getMessage(), equalTo("value is missing"));
+        assertThat(exception.getMessage(), equalTo("value is missing or not a number"));
     }
     }
 
 
     public void testExists() throws IOException {
     public void testExists() throws IOException {