|
@@ -27,9 +27,6 @@ import static org.elasticsearch.xpack.vectors.mapper.VectorEncoderDecoder.sortSp
|
|
|
|
|
|
public class ScoreScriptUtils {
|
|
|
private static final DeprecationLogger deprecationLogger = new DeprecationLogger(LogManager.getLogger(ScoreScriptUtils.class));
|
|
|
- static final String DEPRECATION_MESSAGE = "The vector functions of the form function(query, doc['field']) are deprecated, and " +
|
|
|
- "the form function(query, 'field') should be used instead. For example, cosineSimilarity(query, doc['field']) is replaced by " +
|
|
|
- "cosineSimilarity(query, 'field').";
|
|
|
|
|
|
//**************FUNCTIONS FOR DENSE VECTORS
|
|
|
// Functions are implemented as classes to accept a hidden parameter scoreScript that contains some index settings.
|
|
@@ -43,7 +40,7 @@ public class ScoreScriptUtils {
|
|
|
|
|
|
public DenseVectorFunction(ScoreScript scoreScript,
|
|
|
List<Number> queryVector,
|
|
|
- Object field) {
|
|
|
+ String field) {
|
|
|
this(scoreScript, queryVector, field, false);
|
|
|
}
|
|
|
|
|
@@ -56,9 +53,10 @@ public class ScoreScriptUtils {
|
|
|
*/
|
|
|
public DenseVectorFunction(ScoreScript scoreScript,
|
|
|
List<Number> queryVector,
|
|
|
- Object field,
|
|
|
+ String field,
|
|
|
boolean normalizeQuery) {
|
|
|
this.scoreScript = scoreScript;
|
|
|
+ this.docValues = (DenseVectorScriptDocValues) scoreScript.getDoc().get(field);
|
|
|
|
|
|
this.queryVector = new float[queryVector.size()];
|
|
|
double queryMagnitude = 0.0;
|
|
@@ -74,17 +72,6 @@ public class ScoreScriptUtils {
|
|
|
this.queryVector[dim] /= queryMagnitude;
|
|
|
}
|
|
|
}
|
|
|
-
|
|
|
- if (field instanceof String) {
|
|
|
- String fieldName = (String) field;
|
|
|
- docValues = (DenseVectorScriptDocValues) scoreScript.getDoc().get(fieldName);
|
|
|
- } else if (field instanceof DenseVectorScriptDocValues) {
|
|
|
- docValues = (DenseVectorScriptDocValues) field;
|
|
|
- deprecationLogger.deprecatedAndMaybeLog("vector_function_signature", DEPRECATION_MESSAGE);
|
|
|
- } else {
|
|
|
- throw new IllegalArgumentException("For vector functions, the 'field' argument must be of type String or " +
|
|
|
- "VectorScriptDocValues");
|
|
|
- }
|
|
|
}
|
|
|
|
|
|
BytesRef getEncodedVector() {
|
|
@@ -112,7 +99,7 @@ public class ScoreScriptUtils {
|
|
|
// Calculate l1 norm (Manhattan distance) between a query's dense vector and documents' dense vectors
|
|
|
public static final class L1Norm extends DenseVectorFunction {
|
|
|
|
|
|
- public L1Norm(ScoreScript scoreScript, List<Number> queryVector, Object field) {
|
|
|
+ public L1Norm(ScoreScript scoreScript, List<Number> queryVector, String field) {
|
|
|
super(scoreScript, queryVector, field);
|
|
|
}
|
|
|
|
|
@@ -132,7 +119,7 @@ public class ScoreScriptUtils {
|
|
|
// Calculate l2 norm (Euclidean distance) between a query's dense vector and documents' dense vectors
|
|
|
public static final class L2Norm extends DenseVectorFunction {
|
|
|
|
|
|
- public L2Norm(ScoreScript scoreScript, List<Number> queryVector, Object field) {
|
|
|
+ public L2Norm(ScoreScript scoreScript, List<Number> queryVector, String field) {
|
|
|
super(scoreScript, queryVector, field);
|
|
|
}
|
|
|
|
|
@@ -152,7 +139,7 @@ public class ScoreScriptUtils {
|
|
|
// Calculate a dot product between a query's dense vector and documents' dense vectors
|
|
|
public static final class DotProduct extends DenseVectorFunction {
|
|
|
|
|
|
- public DotProduct(ScoreScript scoreScript, List<Number> queryVector, Object field) {
|
|
|
+ public DotProduct(ScoreScript scoreScript, List<Number> queryVector, String field) {
|
|
|
super(scoreScript, queryVector, field);
|
|
|
}
|
|
|
|
|
@@ -171,7 +158,7 @@ public class ScoreScriptUtils {
|
|
|
// Calculate cosine similarity between a query's dense vector and documents' dense vectors
|
|
|
public static final class CosineSimilarity extends DenseVectorFunction {
|
|
|
|
|
|
- public CosineSimilarity(ScoreScript scoreScript, List<Number> queryVector, Object field) {
|
|
|
+ public CosineSimilarity(ScoreScript scoreScript, List<Number> queryVector, String field) {
|
|
|
super(scoreScript, queryVector, field, true);
|
|
|
}
|
|
|
|
|
@@ -214,8 +201,10 @@ public class ScoreScriptUtils {
|
|
|
// queryVector represents a map of dimensions to values
|
|
|
public SparseVectorFunction(ScoreScript scoreScript,
|
|
|
Map<String, Number> queryVector,
|
|
|
- Object field) {
|
|
|
+ String field) {
|
|
|
this.scoreScript = scoreScript;
|
|
|
+ this.docValues = (SparseVectorScriptDocValues) scoreScript.getDoc().get(field);
|
|
|
+
|
|
|
//break vector into two arrays dims and values
|
|
|
int n = queryVector.size();
|
|
|
queryValues = new float[n];
|
|
@@ -232,18 +221,6 @@ public class ScoreScriptUtils {
|
|
|
}
|
|
|
// Sort dimensions in the ascending order and sort values in the same order as their corresponding dimensions
|
|
|
sortSparseDimsFloatValues(queryDims, queryValues, n);
|
|
|
-
|
|
|
- if (field instanceof String) {
|
|
|
- String fieldName = (String) field;
|
|
|
- docValues = (SparseVectorScriptDocValues) scoreScript.getDoc().get(fieldName);
|
|
|
- } else if (field instanceof SparseVectorScriptDocValues) {
|
|
|
- docValues = (SparseVectorScriptDocValues) field;
|
|
|
- deprecationLogger.deprecatedAndMaybeLog("vector_function_signature", DEPRECATION_MESSAGE);
|
|
|
- } else {
|
|
|
- throw new IllegalArgumentException("For vector functions, the 'field' argument must be of type String or " +
|
|
|
- "VectorScriptDocValues");
|
|
|
- }
|
|
|
-
|
|
|
deprecationLogger.deprecatedAndMaybeLog("sparse_vector_function", SparseVectorFieldMapper.DEPRECATION_MESSAGE);
|
|
|
}
|
|
|
|
|
@@ -264,8 +241,8 @@ public class ScoreScriptUtils {
|
|
|
|
|
|
// Calculate l1 norm (Manhattan distance) between a query's sparse vector and documents' sparse vectors
|
|
|
public static final class L1NormSparse extends SparseVectorFunction {
|
|
|
- public L1NormSparse(ScoreScript scoreScript,Map<String, Number> queryVector, Object docVector) {
|
|
|
- super(scoreScript, queryVector, docVector);
|
|
|
+ public L1NormSparse(ScoreScript scoreScript,Map<String, Number> queryVector, String field) {
|
|
|
+ super(scoreScript, queryVector, field);
|
|
|
}
|
|
|
|
|
|
public double l1normSparse() {
|
|
@@ -303,8 +280,8 @@ public class ScoreScriptUtils {
|
|
|
|
|
|
// Calculate l2 norm (Euclidean distance) between a query's sparse vector and documents' sparse vectors
|
|
|
public static final class L2NormSparse extends SparseVectorFunction {
|
|
|
- public L2NormSparse(ScoreScript scoreScript, Map<String, Number> queryVector, Object docVector) {
|
|
|
- super(scoreScript, queryVector, docVector);
|
|
|
+ public L2NormSparse(ScoreScript scoreScript, Map<String, Number> queryVector, String field) {
|
|
|
+ super(scoreScript, queryVector, field);
|
|
|
}
|
|
|
|
|
|
public double l2normSparse() {
|
|
@@ -345,8 +322,8 @@ public class ScoreScriptUtils {
|
|
|
|
|
|
// Calculate a dot product between a query's sparse vector and documents' sparse vectors
|
|
|
public static final class DotProductSparse extends SparseVectorFunction {
|
|
|
- public DotProductSparse(ScoreScript scoreScript, Map<String, Number> queryVector, Object docVector) {
|
|
|
- super(scoreScript, queryVector, docVector);
|
|
|
+ public DotProductSparse(ScoreScript scoreScript, Map<String, Number> queryVector, String field) {
|
|
|
+ super(scoreScript, queryVector, field);
|
|
|
}
|
|
|
|
|
|
public double dotProductSparse() {
|
|
@@ -362,8 +339,8 @@ public class ScoreScriptUtils {
|
|
|
public static final class CosineSimilaritySparse extends SparseVectorFunction {
|
|
|
final double queryVectorMagnitude;
|
|
|
|
|
|
- public CosineSimilaritySparse(ScoreScript scoreScript, Map<String, Number> queryVector, Object docVector) {
|
|
|
- super(scoreScript, queryVector, docVector);
|
|
|
+ public CosineSimilaritySparse(ScoreScript scoreScript, Map<String, Number> queryVector, String field) {
|
|
|
+ super(scoreScript, queryVector, field);
|
|
|
double dotProduct = 0;
|
|
|
for (int i = 0; i< queryDims.length; i++) {
|
|
|
dotProduct += queryValues[i] * queryValues[i];
|