1
0
Эх сурвалжийг харах

ES|QL brute force l1_norm vector function (#131768)

Tommaso Teofili 2 сар өмнө
parent
commit
75ca87436b
16 өөрчлөгдсөн 323 нэмэгдсэн , 8 устгасан
  1. 6 0
      docs/reference/query-languages/esql/_snippets/functions/description/v_l1_norm.md
  2. 9 0
      docs/reference/query-languages/esql/_snippets/functions/examples/v_l1_norm.md
  3. 27 0
      docs/reference/query-languages/esql/_snippets/functions/layout/v_l1_norm.md
  4. 10 0
      docs/reference/query-languages/esql/_snippets/functions/parameters/v_l1_norm.md
  5. 1 0
      docs/reference/query-languages/esql/images/functions/v_l1_norm.svg
  6. 12 0
      docs/reference/query-languages/esql/kibana/definition/functions/v_l1_norm.json
  7. 8 0
      docs/reference/query-languages/esql/kibana/docs/functions/v_l1_norm.md
  8. 90 0
      x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-l1-norm.csv-spec
  9. 12 7
      x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java
  10. 5 0
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
  11. 3 1
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java
  12. 84 0
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/L1Norm.java
  13. 3 0
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java
  14. 7 0
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
  15. 4 0
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java
  16. 42 0
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/L1NormSimilarityTests.java

+ 6 - 0
docs/reference/query-languages/esql/_snippets/functions/description/v_l1_norm.md

@@ -0,0 +1,6 @@
+% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.
+
+**Description**
+
+Calculates the l1 norm between two dense_vectors.
+

+ 9 - 0
docs/reference/query-languages/esql/_snippets/functions/examples/v_l1_norm.md

@@ -0,0 +1,9 @@
+% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.
+
+**Example**
+
+```esql
+null
+```
+
+

+ 27 - 0
docs/reference/query-languages/esql/_snippets/functions/layout/v_l1_norm.md

@@ -0,0 +1,27 @@
+% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.
+
+## `V_L1_NORM` [esql-v_l1_norm]
+```{applies_to}
+stack: development
+serverless: preview
+```
+
+**Syntax**
+
+:::{image} ../../../images/functions/v_l1_norm.svg
+:alt: Embedded
+:class: text-center
+:::
+
+
+:::{include} ../parameters/v_l1_norm.md
+:::
+
+:::{include} ../description/v_l1_norm.md
+:::
+
+:::{include} ../types/v_l1_norm.md
+:::
+
+:::{include} ../examples/v_l1_norm.md
+:::

+ 10 - 0
docs/reference/query-languages/esql/_snippets/functions/parameters/v_l1_norm.md

@@ -0,0 +1,10 @@
+% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.
+
+**Parameters**
+
+`left`
+:   first dense_vector to calculate l1 norm similarity
+
+`right`
+:   second dense_vector to calculate l1 norm similarity
+

+ 1 - 0
docs/reference/query-languages/esql/images/functions/v_l1_norm.svg

@@ -0,0 +1 @@
+<svg version="1.1" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns="http://www.w3.org/2000/svg" width="432" height="46" viewbox="0 0 432 46"><defs><style type="text/css">.c{fill:none;stroke:#222222;}.k{fill:#000000;font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;font-size:20px;}.s{fill:#e4f4ff;stroke:#222222;}.syn{fill:#8D8D8D;font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;font-size:20px;}</style></defs><path class="c" d="M0 31h5m128 0h10m32 0h10m68 0h10m32 0h10m80 0h10m32 0h5"/><rect class="s" x="5" y="5" width="128" height="36"/><text class="k" x="15" y="31">V_L1_NORM</text><rect class="s" x="143" y="5" width="32" height="36" rx="7"/><text class="syn" x="153" y="31">(</text><rect class="s" x="185" y="5" width="68" height="36" rx="7"/><text class="k" x="195" y="31">left</text><rect class="s" x="263" y="5" width="32" height="36" rx="7"/><text class="syn" x="273" y="31">,</text><rect class="s" x="305" y="5" width="80" height="36" rx="7"/><text class="k" x="315" y="31">right</text><rect class="s" x="395" y="5" width="32" height="36" rx="7"/><text class="syn" x="405" y="31">)</text></svg>

+ 12 - 0
docs/reference/query-languages/esql/kibana/definition/functions/v_l1_norm.json

@@ -0,0 +1,12 @@
+{
+  "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.",
+  "type" : "scalar",
+  "name" : "v_l1_norm",
+  "description" : "Calculates the l1 norm between two dense_vectors.",
+  "signatures" : [ ],
+  "examples" : [
+    null
+  ],
+  "preview" : true,
+  "snapshot_only" : true
+}

+ 8 - 0
docs/reference/query-languages/esql/kibana/docs/functions/v_l1_norm.md

@@ -0,0 +1,8 @@
+% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.
+
+### V L1 NORM
+Calculates the l1 norm between two dense_vectors.
+
+```esql
+null
+```

+ 90 - 0
x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-l1-norm.csv-spec

@@ -0,0 +1,90 @@
+ # Tests for l1_norm similarity function
+ 
+ similarityWithVectorField
+ required_capability: l1_norm_vector_similarity_function
+ 
+// tag::vector-l1-norm-similarity[]
+ from colors
+ | eval similarity = v_l1_norm(rgb_vector, [0, 255, 255]) 
+ | sort similarity desc, color asc 
+// end::vector-l1-norm-similarity[]
+ | limit 10
+ | keep color, similarity
+ ;
+ 
+// tag::vector-l1-norm-similarity-result[]
+color:text | similarity:double
+red        | 765.0
+crimson    | 650.0
+maroon     | 638.0
+firebrick  | 620.0
+orange     | 600.0
+tomato     | 595.0
+brown      | 591.0
+chocolate  | 585.0
+coral      | 558.0
+gold       | 550.0
+// end::vector-l1-norm-similarity-result[] 
+;
+
+ similarityAsPartOfExpression
+ required_capability: l1_norm_vector_similarity_function
+ 
+ from colors
+ | eval score = round((1 + v_l1_norm(rgb_vector, [0, 255, 255]) / 2), 3) 
+ | sort score desc, color asc 
+ | limit 10
+ | keep color, score
+ ;
+
+color:text | score:double
+red        | 383.5
+crimson    | 326.0
+maroon     | 320.0
+firebrick  | 311.0
+orange     | 301.0
+tomato     | 298.5
+brown      | 296.5
+chocolate  | 293.5
+coral      | 280.0
+gold       | 276.0
+;
+
+similarityWithLiteralVectors
+required_capability: l1_norm_vector_similarity_function
+ 
+row a = 1
+| eval similarity = round(v_l1_norm([1, 2, 3], [0, 1, 2]), 3) 
+| keep similarity
+;
+
+similarity:double
+3.0
+;
+
+ similarityWithStats
+ required_capability: l1_norm_vector_similarity_function
+ 
+ from colors
+ | eval similarity = round(v_l1_norm(rgb_vector, [0, 255, 255]), 3) 
+ | stats avg = round(avg(similarity), 3), min = min(similarity), max = max(similarity)
+ ;
+
+avg:double | min:double | max:double
+391.254    | 0.0        | 765.0
+;
+
+# TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector
+similarityWithRow-Ignore
+required_capability: l1_norm_vector_similarity_function
+ 
+row vector = [1, 2, 3] 
+| eval similarity = round(v_l1_norm(vector, [0, 1, 2]), 3) 
+| sort similarity desc, color asc 
+| limit 10
+| keep color, similarity
+;
+
+similarity:double
+0.978  
+;

+ 12 - 7
x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java

@@ -20,6 +20,8 @@ import org.elasticsearch.xpack.esql.EsqlClientException;
 import org.elasticsearch.xpack.esql.EsqlTestUtils;
 import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase;
 import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
+import org.elasticsearch.xpack.esql.expression.function.vector.L1Norm;
+import org.elasticsearch.xpack.esql.expression.function.vector.VectorSimilarityFunction.SimilarityEvaluatorFunction;
 import org.junit.Before;
 
 import java.io.IOException;
@@ -37,22 +39,25 @@ public class VectorSimilarityFunctionsIT extends AbstractEsqlIntegTestCase {
         List<Object[]> params = new ArrayList<>();
 
         if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
-            params.add(new Object[] { "v_cosine", VectorSimilarityFunction.COSINE });
+            params.add(new Object[] { "v_cosine", (SimilarityEvaluatorFunction) VectorSimilarityFunction.COSINE::compare });
         }
         if (EsqlCapabilities.Cap.DOT_PRODUCT_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
-            params.add(new Object[] { "v_dot_product", VectorSimilarityFunction.DOT_PRODUCT });
+            params.add(new Object[] { "v_dot_product", (SimilarityEvaluatorFunction) VectorSimilarityFunction.DOT_PRODUCT::compare });
+        }
+        if (EsqlCapabilities.Cap.L1_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
+            params.add(new Object[] { "v_l1_norm", (SimilarityEvaluatorFunction) L1Norm::calculateSimilarity });
         }
 
         return params;
     }
 
     private final String functionName;
-    private final VectorSimilarityFunction similarityFunction;
+    private final SimilarityEvaluatorFunction similarityFunction;
     private int numDims;
 
     public VectorSimilarityFunctionsIT(
         @Name("functionName") String functionName,
-        @Name("similarityFunction") VectorSimilarityFunction similarityFunction
+        @Name("similarityFunction") SimilarityEvaluatorFunction similarityFunction
     ) {
         this.functionName = functionName;
         this.similarityFunction = similarityFunction;
@@ -74,7 +79,7 @@ public class VectorSimilarityFunctionsIT extends AbstractEsqlIntegTestCase {
                 Double similarity = (Double) values.get(2);
 
                 assertNotNull(similarity);
-                float expectedSimilarity = similarityFunction.compare(left, right);
+                float expectedSimilarity = similarityFunction.calculateSimilarity(left, right);
                 assertEquals(expectedSimilarity, similarity, 0.0001);
             });
         }
@@ -96,7 +101,7 @@ public class VectorSimilarityFunctionsIT extends AbstractEsqlIntegTestCase {
                 Double similarity = (Double) values.get(1);
 
                 assertNotNull(similarity);
-                float expectedSimilarity = similarityFunction.compare(left, randomVector);
+                float expectedSimilarity = similarityFunction.calculateSimilarity(left, randomVector);
                 assertEquals(expectedSimilarity, similarity, 0.0001);
             });
         }
@@ -130,7 +135,7 @@ public class VectorSimilarityFunctionsIT extends AbstractEsqlIntegTestCase {
 
             Double similarity = (Double) valuesList.get(0).get(0);
             assertNotNull(similarity);
-            float expectedSimilarity = similarityFunction.compare(vectorLeft, vectorRight);
+            float expectedSimilarity = similarityFunction.calculateSimilarity(vectorLeft, vectorRight);
             assertEquals(expectedSimilarity, similarity, 0.0001);
         }
     }

+ 5 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

@@ -1296,6 +1296,11 @@ public class EsqlCapabilities {
          */
         DOT_PRODUCT_VECTOR_SIMILARITY_FUNCTION(Build.current().isSnapshot()),
 
+        /**
+         * l1 norm vector similarity function
+         */
+        L1_NORM_VECTOR_SIMILARITY_FUNCTION(Build.current().isSnapshot()),
+
         /**
          * Support for the options field of CATEGORIZE.
          */

+ 3 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java

@@ -183,6 +183,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.util.Delay;
 import org.elasticsearch.xpack.esql.expression.function.vector.CosineSimilarity;
 import org.elasticsearch.xpack.esql.expression.function.vector.DotProduct;
 import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
+import org.elasticsearch.xpack.esql.expression.function.vector.L1Norm;
 import org.elasticsearch.xpack.esql.parser.ParsingException;
 import org.elasticsearch.xpack.esql.session.Configuration;
 
@@ -493,7 +494,8 @@ public class EsqlFunctionRegistry {
                 def(StGeohexToLong.class, StGeohexToLong::new, "st_geohex_to_long"),
                 def(StGeohexToString.class, StGeohexToString::new, "st_geohex_to_string"),
                 def(CosineSimilarity.class, CosineSimilarity::new, "v_cosine"),
-                def(DotProduct.class, DotProduct::new, "v_dot_product") } };
+                def(DotProduct.class, DotProduct::new, "v_dot_product"),
+                def(L1Norm.class, L1Norm::new, "v_l1_norm") } };
     }
 
     public EsqlFunctionRegistry snapshotRegistry() {

+ 84 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/L1Norm.java

@@ -0,0 +1,84 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.vector;
+
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.function.scalar.BinaryScalarFunction;
+import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.expression.function.Example;
+import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo;
+import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle;
+import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
+import org.elasticsearch.xpack.esql.expression.function.Param;
+
+import java.io.IOException;
+
+public class L1Norm extends VectorSimilarityFunction {
+
+    public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "L1Norm", L1Norm::new);
+    static final SimilarityEvaluatorFunction SIMILARITY_FUNCTION = L1Norm::calculateSimilarity;
+
+    @FunctionInfo(
+        returnType = "double",
+        preview = true,
+        description = "Calculates the l1 norm between two dense_vectors.",
+        examples = { @Example(file = "vector-l1-norm", tag = "vector-l1-norm") },
+        appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) }
+    )
+    public L1Norm(
+        Source source,
+        @Param(
+            name = "left",
+            type = { "dense_vector" },
+            description = "first dense_vector to calculate l1 norm similarity"
+        ) Expression left,
+        @Param(
+            name = "right",
+            type = { "dense_vector" },
+            description = "second dense_vector to calculate l1 norm similarity"
+        ) Expression right
+    ) {
+        super(source, left, right);
+    }
+
+    private L1Norm(StreamInput in) throws IOException {
+        super(in);
+    }
+
+    @Override
+    protected BinaryScalarFunction replaceChildren(Expression newLeft, Expression newRight) {
+        return new L1Norm(source(), newLeft, newRight);
+    }
+
+    @Override
+    protected SimilarityEvaluatorFunction getSimilarityFunction() {
+        return SIMILARITY_FUNCTION;
+    }
+
+    @Override
+    protected NodeInfo<? extends Expression> info() {
+        return NodeInfo.create(this, L1Norm::new, left(), right());
+    }
+
+    @Override
+    public String getWriteableName() {
+        return ENTRY.name;
+    }
+
+    public static float calculateSimilarity(float[] leftScratch, float[] rightScratch) {
+        float result = 0f;
+        for (int i = 0; i < leftScratch.length; i++) {
+            result += Math.abs(leftScratch[i] - rightScratch[i]);
+        }
+        return result;
+    }
+
+}

+ 3 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java

@@ -36,6 +36,9 @@ public final class VectorWritables {
         if (EsqlCapabilities.Cap.DOT_PRODUCT_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
             entries.add(DotProduct.ENTRY);
         }
+        if (EsqlCapabilities.Cap.L1_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
+            entries.add(L1Norm.ENTRY);
+        }
 
         return Collections.unmodifiableList(entries);
     }

+ 7 - 0
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java

@@ -2365,6 +2365,10 @@ public class AnalyzerTests extends ESTestCase {
             );
             checkDenseVectorImplicitCastingSimilarityFunction("v_dot_product(vector, [1, 2, 3])", List.of(1f, 2f, 3f));
         }
+        if (EsqlCapabilities.Cap.L1_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
+            checkDenseVectorImplicitCastingSimilarityFunction("v_l1_norm(vector, [0.342, 0.164, 0.234])", List.of(0.342f, 0.164f, 0.234f));
+            checkDenseVectorImplicitCastingSimilarityFunction("v_l1_norm(vector, [1, 2, 3])", List.of(1f, 2f, 3f));
+        }
     }
 
     private void checkDenseVectorImplicitCastingSimilarityFunction(String similarityFunction, List<Number> expectedElems) {
@@ -2391,6 +2395,9 @@ public class AnalyzerTests extends ESTestCase {
         if (EsqlCapabilities.Cap.DOT_PRODUCT_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
             checkNoDenseVectorFailsSimilarityFunction("v_dot_product([0, 1, 2], 0.342)");
         }
+        if (EsqlCapabilities.Cap.L1_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
+            checkNoDenseVectorFailsSimilarityFunction("v_l1_norm([0, 1, 2], 0.342)");
+        }
     }
 
     private void checkNoDenseVectorFailsSimilarityFunction(String similarityFunction) {

+ 4 - 0
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java

@@ -2508,6 +2508,10 @@ public class VerifierTests extends ESTestCase {
             checkVectorSimilarityFunctionsNullArgs("v_dot_product(null, vector)", "first");
             checkVectorSimilarityFunctionsNullArgs("v_dot_product(vector, null)", "second");
         }
+        if (EsqlCapabilities.Cap.L1_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
+            checkVectorSimilarityFunctionsNullArgs("v_l1_norm(null, vector)", "first");
+            checkVectorSimilarityFunctionsNullArgs("v_l1_norm(vector, null)", "second");
+        }
     }
 
     private void checkVectorSimilarityFunctionsNullArgs(String functionInvocation, String argOrdinal) throws Exception {

+ 42 - 0
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/L1NormSimilarityTests.java

@@ -0,0 +1,42 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.vector;
+
+import com.carrotsearch.randomizedtesting.annotations.Name;
+import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
+
+import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.expression.function.FunctionName;
+import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
+
+import java.util.List;
+import java.util.function.Supplier;
+
+@FunctionName("v_l1_norm")
+public class L1NormSimilarityTests extends AbstractVectorSimilarityFunctionTestCase {
+
+    public L1NormSimilarityTests(@Name("TestCase") Supplier<TestCaseSupplier.TestCase> testCaseSupplier) {
+        super(testCaseSupplier);
+    }
+
+    @ParametersFactory
+    public static Iterable<Object[]> parameters() {
+        return similarityParameters(L1Norm.class.getSimpleName(), L1Norm.SIMILARITY_FUNCTION);
+    }
+
+    protected EsqlCapabilities.Cap capability() {
+        return EsqlCapabilities.Cap.L1_NORM_VECTOR_SIMILARITY_FUNCTION;
+    }
+
+    @Override
+    protected Expression build(Source source, List<Expression> args) {
+        return new L1Norm(source, args.get(0), args.get(1));
+    }
+}