|
@@ -26,7 +26,6 @@ import org.apache.lucene.search.KnnByteVectorQuery;
|
|
|
import org.apache.lucene.search.KnnFloatVectorQuery;
|
|
|
import org.apache.lucene.search.Query;
|
|
|
import org.apache.lucene.util.BytesRef;
|
|
|
-import org.apache.lucene.util.VectorUtil;
|
|
|
import org.elasticsearch.common.xcontent.support.XContentMapValues;
|
|
|
import org.elasticsearch.index.IndexVersion;
|
|
|
import org.elasticsearch.index.fielddata.FieldDataContext;
|
|
@@ -57,7 +56,6 @@ import java.io.IOException;
|
|
|
import java.nio.ByteBuffer;
|
|
|
import java.nio.ByteOrder;
|
|
|
import java.time.ZoneId;
|
|
|
-import java.util.Arrays;
|
|
|
import java.util.Locale;
|
|
|
import java.util.Map;
|
|
|
import java.util.Objects;
|
|
@@ -71,10 +69,8 @@ import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpect
|
|
|
* A {@link FieldMapper} for indexing a dense vector of floats.
|
|
|
*/
|
|
|
public class DenseVectorFieldMapper extends FieldMapper {
|
|
|
- private static final float EPS = 1e-4f;
|
|
|
public static final IndexVersion MAGNITUDE_STORED_INDEX_VERSION = IndexVersion.V_7_5_0;
|
|
|
public static final IndexVersion INDEXED_BY_DEFAULT_INDEX_VERSION = IndexVersion.V_8_11_0;
|
|
|
- public static final IndexVersion DOT_PRODUCT_AUTO_NORMALIZED = IndexVersion.V_8_11_0;
|
|
|
public static final IndexVersion LITTLE_ENDIAN_FLOAT_STORED_INDEX_VERSION = IndexVersion.V_8_9_0;
|
|
|
|
|
|
public static final String CONTENT_TYPE = "dense_vector";
|
|
@@ -325,7 +321,6 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|
|
|
|
|
@Override
|
|
|
void checkVectorMagnitude(
|
|
|
- IndexVersion indexVersion,
|
|
|
VectorSimilarity similarity,
|
|
|
Function<StringBuilder, StringBuilder> appender,
|
|
|
float squaredMagnitude
|
|
@@ -388,12 +383,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|
|
squaredMagnitude += value * value;
|
|
|
}
|
|
|
fieldMapper.checkDimensionMatches(index, context);
|
|
|
- checkVectorMagnitude(
|
|
|
- fieldMapper.indexCreatedVersion,
|
|
|
- fieldMapper.similarity,
|
|
|
- errorByteElementsAppender(vector),
|
|
|
- squaredMagnitude
|
|
|
- );
|
|
|
+ checkVectorMagnitude(fieldMapper.similarity, errorByteElementsAppender(vector), squaredMagnitude);
|
|
|
return createKnnVectorField(fieldMapper.fieldType().name(), vector, fieldMapper.similarity.function);
|
|
|
}
|
|
|
|
|
@@ -485,31 +475,20 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|
|
|
|
|
@Override
|
|
|
void checkVectorMagnitude(
|
|
|
- IndexVersion indexVersion,
|
|
|
VectorSimilarity similarity,
|
|
|
Function<StringBuilder, StringBuilder> appender,
|
|
|
float squaredMagnitude
|
|
|
) {
|
|
|
StringBuilder errorBuilder = null;
|
|
|
|
|
|
- if (indexVersion.before(DOT_PRODUCT_AUTO_NORMALIZED)) {
|
|
|
- if (similarity == VectorSimilarity.DOT_PRODUCT && Math.abs(squaredMagnitude - 1.0f) > EPS) {
|
|
|
- errorBuilder = new StringBuilder(
|
|
|
- "The [" + VectorSimilarity.DOT_PRODUCT + "] similarity can only be used with unit-length vectors."
|
|
|
- );
|
|
|
- }
|
|
|
- if (similarity == VectorSimilarity.COSINE && Math.sqrt(squaredMagnitude) == 0.0f) {
|
|
|
- errorBuilder = new StringBuilder(
|
|
|
- "The [" + similarity + "] similarity does not support vectors with zero magnitude."
|
|
|
- );
|
|
|
- }
|
|
|
- } else {
|
|
|
- if ((similarity == VectorSimilarity.COSINE || similarity == VectorSimilarity.DOT_PRODUCT)
|
|
|
- && Math.sqrt(squaredMagnitude) == 0.0f) {
|
|
|
- errorBuilder = new StringBuilder(
|
|
|
- "The [" + similarity + "] similarity does not support vectors with zero magnitude."
|
|
|
- );
|
|
|
- }
|
|
|
+ if (similarity == VectorSimilarity.DOT_PRODUCT && Math.abs(squaredMagnitude - 1.0f) > 1e-4f) {
|
|
|
+ errorBuilder = new StringBuilder(
|
|
|
+ "The [" + VectorSimilarity.DOT_PRODUCT + "] similarity can only be used with unit-length vectors."
|
|
|
+ );
|
|
|
+ } else if (similarity == VectorSimilarity.COSINE && Math.sqrt(squaredMagnitude) == 0.0f) {
|
|
|
+ errorBuilder = new StringBuilder(
|
|
|
+ "The [" + VectorSimilarity.COSINE + "] similarity does not support vectors with zero magnitude."
|
|
|
+ );
|
|
|
}
|
|
|
|
|
|
if (errorBuilder != null) {
|
|
@@ -532,15 +511,7 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|
|
}
|
|
|
fieldMapper.checkDimensionMatches(index, context);
|
|
|
checkVectorBounds(vector);
|
|
|
- checkVectorMagnitude(
|
|
|
- fieldMapper.indexCreatedVersion,
|
|
|
- fieldMapper.similarity,
|
|
|
- errorFloatElementsAppender(vector),
|
|
|
- squaredMagnitude
|
|
|
- );
|
|
|
- if (fieldMapper.indexCreatedVersion.onOrAfter(DOT_PRODUCT_AUTO_NORMALIZED)) {
|
|
|
- fieldMapper.similarity.floatPreprocessing(vector, squaredMagnitude);
|
|
|
- }
|
|
|
+ checkVectorMagnitude(fieldMapper.similarity, errorFloatElementsAppender(vector), squaredMagnitude);
|
|
|
return createKnnVectorField(fieldMapper.fieldType().name(), vector, fieldMapper.similarity.function);
|
|
|
}
|
|
|
|
|
@@ -598,7 +569,6 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|
|
public abstract void checkVectorBounds(float[] vector);
|
|
|
|
|
|
abstract void checkVectorMagnitude(
|
|
|
- IndexVersion indexVersion,
|
|
|
VectorSimilarity similarity,
|
|
|
Function<StringBuilder, StringBuilder> errorElementsAppender,
|
|
|
float squaredMagnitude
|
|
@@ -717,21 +687,6 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|
|
case FLOAT -> (1 + similarity) / 2f;
|
|
|
};
|
|
|
}
|
|
|
-
|
|
|
- @Override
|
|
|
- void floatPreprocessing(float[] vector, float squareSum) {
|
|
|
- if (squareSum == 0) {
|
|
|
- throw new IllegalArgumentException("Cannot normalize a zero-length vector");
|
|
|
- }
|
|
|
- // Vector already has a magnitude have `1`
|
|
|
- if (Math.abs(squareSum - 1.0f) < EPS) {
|
|
|
- return;
|
|
|
- }
|
|
|
- float length = (float) Math.sqrt(squareSum);
|
|
|
- for (int i = 0; i < vector.length; i++) {
|
|
|
- vector[i] /= length;
|
|
|
- }
|
|
|
- }
|
|
|
};
|
|
|
|
|
|
public final VectorSimilarityFunction function;
|
|
@@ -746,8 +701,6 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|
|
}
|
|
|
|
|
|
abstract float score(float similarity, ElementType elementType, int dim);
|
|
|
-
|
|
|
- void floatPreprocessing(float[] vector, float squareSum) {}
|
|
|
}
|
|
|
|
|
|
private abstract static class IndexOptions implements ToXContent {
|
|
@@ -906,13 +859,11 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|
|
}
|
|
|
|
|
|
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
|
|
|
- int squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
|
|
|
- elementType.checkVectorMagnitude(
|
|
|
- indexVersionCreated,
|
|
|
- similarity,
|
|
|
- elementType.errorByteElementsAppender(queryVector),
|
|
|
- squaredMagnitude
|
|
|
- );
|
|
|
+ float squaredMagnitude = 0.0f;
|
|
|
+ for (byte b : queryVector) {
|
|
|
+ squaredMagnitude += b * b;
|
|
|
+ }
|
|
|
+ elementType.checkVectorMagnitude(similarity, elementType.errorByteElementsAppender(queryVector), squaredMagnitude);
|
|
|
}
|
|
|
Query knnQuery = new KnnByteVectorQuery(name(), queryVector, numCands, filter);
|
|
|
if (similarityThreshold != null) {
|
|
@@ -940,22 +891,11 @@ public class DenseVectorFieldMapper extends FieldMapper {
|
|
|
elementType.checkVectorBounds(queryVector);
|
|
|
|
|
|
if (similarity == VectorSimilarity.DOT_PRODUCT || similarity == VectorSimilarity.COSINE) {
|
|
|
- float squaredMagnitude = VectorUtil.dotProduct(queryVector, queryVector);
|
|
|
- elementType.checkVectorMagnitude(
|
|
|
- indexVersionCreated,
|
|
|
- similarity,
|
|
|
- elementType.errorFloatElementsAppender(queryVector),
|
|
|
- squaredMagnitude
|
|
|
- );
|
|
|
- // We don't want to normalize the original query vector.
|
|
|
- // It mutates it in place and might cause down stream weirdness
|
|
|
- // Instead we copy the value and then normalize that copy
|
|
|
- if (similarity == VectorSimilarity.DOT_PRODUCT
|
|
|
- && elementType == ElementType.FLOAT
|
|
|
- && indexVersionCreated.onOrAfter(DOT_PRODUCT_AUTO_NORMALIZED)) {
|
|
|
- queryVector = Arrays.copyOf(queryVector, queryVector.length);
|
|
|
- similarity.floatPreprocessing(queryVector, squaredMagnitude);
|
|
|
+ float squaredMagnitude = 0.0f;
|
|
|
+ for (float e : queryVector) {
|
|
|
+ squaredMagnitude += e * e;
|
|
|
}
|
|
|
+ elementType.checkVectorMagnitude(similarity, elementType.errorFloatElementsAppender(queryVector), squaredMagnitude);
|
|
|
}
|
|
|
Query knnQuery = switch (elementType) {
|
|
|
case BYTE -> {
|