|
@@ -20,6 +20,8 @@ import org.elasticsearch.xpack.esql.EsqlClientException;
|
|
|
import org.elasticsearch.xpack.esql.EsqlTestUtils;
|
|
|
import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase;
|
|
|
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
|
|
|
+import org.elasticsearch.xpack.esql.expression.function.vector.L1Norm;
|
|
|
+import org.elasticsearch.xpack.esql.expression.function.vector.VectorSimilarityFunction.SimilarityEvaluatorFunction;
|
|
|
import org.junit.Before;
|
|
|
|
|
|
import java.io.IOException;
|
|
@@ -37,22 +39,25 @@ public class VectorSimilarityFunctionsIT extends AbstractEsqlIntegTestCase {
|
|
|
List<Object[]> params = new ArrayList<>();
|
|
|
|
|
|
if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
|
|
|
- params.add(new Object[] { "v_cosine", VectorSimilarityFunction.COSINE });
|
|
|
+ params.add(new Object[] { "v_cosine", (SimilarityEvaluatorFunction) VectorSimilarityFunction.COSINE::compare });
|
|
|
}
|
|
|
if (EsqlCapabilities.Cap.DOT_PRODUCT_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
|
|
|
- params.add(new Object[] { "v_dot_product", VectorSimilarityFunction.DOT_PRODUCT });
|
|
|
+ params.add(new Object[] { "v_dot_product", (SimilarityEvaluatorFunction) VectorSimilarityFunction.DOT_PRODUCT::compare });
|
|
|
+ }
|
|
|
+ if (EsqlCapabilities.Cap.L1_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
|
|
|
+ params.add(new Object[] { "v_l1_norm", (SimilarityEvaluatorFunction) L1Norm::calculateSimilarity });
|
|
|
}
|
|
|
|
|
|
return params;
|
|
|
}
|
|
|
|
|
|
private final String functionName;
|
|
|
- private final VectorSimilarityFunction similarityFunction;
|
|
|
+ private final SimilarityEvaluatorFunction similarityFunction;
|
|
|
private int numDims;
|
|
|
|
|
|
public VectorSimilarityFunctionsIT(
|
|
|
@Name("functionName") String functionName,
|
|
|
- @Name("similarityFunction") VectorSimilarityFunction similarityFunction
|
|
|
+ @Name("similarityFunction") SimilarityEvaluatorFunction similarityFunction
|
|
|
) {
|
|
|
this.functionName = functionName;
|
|
|
this.similarityFunction = similarityFunction;
|
|
@@ -74,7 +79,7 @@ public class VectorSimilarityFunctionsIT extends AbstractEsqlIntegTestCase {
|
|
|
Double similarity = (Double) values.get(2);
|
|
|
|
|
|
assertNotNull(similarity);
|
|
|
- float expectedSimilarity = similarityFunction.compare(left, right);
|
|
|
+ float expectedSimilarity = similarityFunction.calculateSimilarity(left, right);
|
|
|
assertEquals(expectedSimilarity, similarity, 0.0001);
|
|
|
});
|
|
|
}
|
|
@@ -96,7 +101,7 @@ public class VectorSimilarityFunctionsIT extends AbstractEsqlIntegTestCase {
|
|
|
Double similarity = (Double) values.get(1);
|
|
|
|
|
|
assertNotNull(similarity);
|
|
|
- float expectedSimilarity = similarityFunction.compare(left, randomVector);
|
|
|
+ float expectedSimilarity = similarityFunction.calculateSimilarity(left, randomVector);
|
|
|
assertEquals(expectedSimilarity, similarity, 0.0001);
|
|
|
});
|
|
|
}
|
|
@@ -130,7 +135,7 @@ public class VectorSimilarityFunctionsIT extends AbstractEsqlIntegTestCase {
|
|
|
|
|
|
Double similarity = (Double) valuesList.get(0).get(0);
|
|
|
assertNotNull(similarity);
|
|
|
- float expectedSimilarity = similarityFunction.compare(vectorLeft, vectorRight);
|
|
|
+ float expectedSimilarity = similarityFunction.calculateSimilarity(vectorLeft, vectorRight);
|
|
|
assertEquals(expectedSimilarity, similarity, 0.0001);
|
|
|
}
|
|
|
}
|