浏览代码

ES|QL brute force l2_norm vector function (#132025)

Tommaso Teofili 2 月之前
父节点
当前提交
edc3a64aaf
共有 16 个文件被更改,包括 329 次插入1 次删除
  1. 6 0
      docs/reference/query-languages/esql/_snippets/functions/description/v_l2_norm.md
  2. 24 0
      docs/reference/query-languages/esql/_snippets/functions/examples/v_l2_norm.md
  3. 27 0
      docs/reference/query-languages/esql/_snippets/functions/layout/v_l2_norm.md
  4. 10 0
      docs/reference/query-languages/esql/_snippets/functions/parameters/v_l2_norm.md
  5. 1 0
      docs/reference/query-languages/esql/images/functions/v_l2_norm.svg
  6. 12 0
      docs/reference/query-languages/esql/kibana/definition/functions/v_l2_norm.json
  7. 10 0
      docs/reference/query-languages/esql/kibana/docs/functions/v_l2_norm.md
  8. 90 0
      x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-l2-norm.csv-spec
  9. 4 0
      x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java
  10. 5 0
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
  11. 3 1
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java
  12. 81 0
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/L2Norm.java
  13. 3 0
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java
  14. 7 0
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
  15. 4 0
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java
  16. 42 0
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/L2NormSimilarityTests.java

+ 6 - 0
docs/reference/query-languages/esql/_snippets/functions/description/v_l2_norm.md

@@ -0,0 +1,6 @@
+% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.
+
+**Description**
+
+Calculates the l2 norm between two dense_vectors.
+

+ 24 - 0
docs/reference/query-languages/esql/_snippets/functions/examples/v_l2_norm.md

@@ -0,0 +1,24 @@
+% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.
+
+**Example**
+
+```esql
+ from colors
+ | eval similarity = v_l2_norm(rgb_vector, [0, 255, 255])
+ | sort similarity desc, color asc
+```
+
+| color:text | similarity:double |
+| --- | --- |
+| red | 441.6729431152344 |
+| maroon | 382.6669616699219 |
+| crimson | 376.36419677734375 |
+| orange | 371.68536376953125 |
+| gold | 362.8360595703125 |
+| black | 360.62445068359375 |
+| magenta | 360.62445068359375 |
+| yellow | 360.62445068359375 |
+| firebrick | 359.67486572265625 |
+| tomato | 351.0227966308594 |
+
+

+ 27 - 0
docs/reference/query-languages/esql/_snippets/functions/layout/v_l2_norm.md

@@ -0,0 +1,27 @@
+% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.
+
+## `V_L2_NORM` [esql-v_l2_norm]
+```{applies_to}
+stack: development
+serverless: preview
+```
+
+**Syntax**
+
+:::{image} ../../../images/functions/v_l2_norm.svg
+:alt: Embedded
+:class: text-center
+:::
+
+
+:::{include} ../parameters/v_l2_norm.md
+:::
+
+:::{include} ../description/v_l2_norm.md
+:::
+
+:::{include} ../types/v_l2_norm.md
+:::
+
+:::{include} ../examples/v_l2_norm.md
+:::

+ 10 - 0
docs/reference/query-languages/esql/_snippets/functions/parameters/v_l2_norm.md

@@ -0,0 +1,10 @@
+% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.
+
+**Parameters**
+
+`left`
+:   first dense_vector to calculate l2 norm similarity
+
+`right`
+:   second dense_vector to calculate l2 norm similarity
+

+ 1 - 0
docs/reference/query-languages/esql/images/functions/v_l2_norm.svg

@@ -0,0 +1 @@
+<svg version="1.1" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns="http://www.w3.org/2000/svg" width="432" height="46" viewbox="0 0 432 46"><defs><style type="text/css">.c{fill:none;stroke:#222222;}.k{fill:#000000;font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;font-size:20px;}.s{fill:#e4f4ff;stroke:#222222;}.syn{fill:#8D8D8D;font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;font-size:20px;}</style></defs><path class="c" d="M0 31h5m128 0h10m32 0h10m68 0h10m32 0h10m80 0h10m32 0h5"/><rect class="s" x="5" y="5" width="128" height="36"/><text class="k" x="15" y="31">V_L2_NORM</text><rect class="s" x="143" y="5" width="32" height="36" rx="7"/><text class="syn" x="153" y="31">(</text><rect class="s" x="185" y="5" width="68" height="36" rx="7"/><text class="k" x="195" y="31">left</text><rect class="s" x="263" y="5" width="32" height="36" rx="7"/><text class="syn" x="273" y="31">,</text><rect class="s" x="305" y="5" width="80" height="36" rx="7"/><text class="k" x="315" y="31">right</text><rect class="s" x="395" y="5" width="32" height="36" rx="7"/><text class="syn" x="405" y="31">)</text></svg>

+ 12 - 0
docs/reference/query-languages/esql/kibana/definition/functions/v_l2_norm.json

@@ -0,0 +1,12 @@
+{
+  "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.",
+  "type" : "scalar",
+  "name" : "v_l2_norm",
+  "description" : "Calculates the l2 norm between two dense_vectors.",
+  "signatures" : [ ],
+  "examples" : [
+    " from colors\n | eval similarity = v_l2_norm(rgb_vector, [0, 255, 255])\n | sort similarity desc, color asc"
+  ],
+  "preview" : true,
+  "snapshot_only" : true
+}

+ 10 - 0
docs/reference/query-languages/esql/kibana/docs/functions/v_l2_norm.md

@@ -0,0 +1,10 @@
+% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.
+
+### V L2 NORM
+Calculates the l2 norm between two dense_vectors.
+
+```esql
+ from colors
+ | eval similarity = v_l2_norm(rgb_vector, [0, 255, 255])
+ | sort similarity desc, color asc
+```

+ 90 - 0
x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-l2-norm.csv-spec

@@ -0,0 +1,90 @@
+ # Tests for l2_norm similarity function
+ 
+ similarityWithVectorField
+ required_capability: l2_norm_vector_similarity_function
+ 
+// tag::vector-l2-norm[]
+ from colors
+ | eval similarity = v_l2_norm(rgb_vector, [0, 255, 255]) 
+ | sort similarity desc, color asc 
+// end::vector-l2-norm[]
+ | limit 10
+ | keep color, similarity
+ ;
+ 
+// tag::vector-l2-norm-result[]
+color:text | similarity:double
+red        | 441.6729431152344
+maroon     | 382.6669616699219
+crimson    | 376.36419677734375
+orange     | 371.68536376953125
+gold       | 362.8360595703125
+black      | 360.62445068359375
+magenta    | 360.62445068359375
+yellow     | 360.62445068359375
+firebrick  | 359.67486572265625
+tomato     | 351.0227966308594
+// end::vector-l2-norm-result[] 
+;
+
+ similarityAsPartOfExpression
+ required_capability: l2_norm_vector_similarity_function
+ 
+ from colors
+ | eval score = round((1 + v_l2_norm(rgb_vector, [0, 255, 255]) / 2), 3) 
+ | sort score desc, color asc 
+ | limit 10
+ | keep color, score
+ ;
+
+color:text | score:double
+red        | 221.836
+maroon     | 192.333
+crimson    | 189.182
+orange     | 186.843
+gold       | 182.418
+black      | 181.312
+magenta    | 181.312
+yellow     | 181.312
+firebrick  | 180.837
+tomato     | 176.511
+;
+
+similarityWithLiteralVectors
+required_capability: l2_norm_vector_similarity_function
+ 
+row a = 1
+| eval similarity = round(v_l2_norm([1, 2, 3], [0, 1, 2]), 3) 
+| keep similarity
+;
+
+similarity:double
+1.732
+;
+
+ similarityWithStats
+ required_capability: l2_norm_vector_similarity_function
+ 
+ from colors
+ | eval similarity = round(v_l2_norm(rgb_vector, [0, 255, 255]), 3) 
+ | stats avg = round(avg(similarity), 3), min = min(similarity), max = max(similarity)
+ ;
+
+avg:double | min:double | max:double
+274.974    | 0.0        | 441.673
+;
+
+# TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector
+similarityWithRow-Ignore
+required_capability: l2_norm_vector_similarity_function
+ 
+row vector = [1, 2, 3] 
+| eval similarity = round(v_l2_norm(vector, [0, 1, 2]), 3) 
+| sort similarity desc, color asc 
+| limit 10
+| keep color, similarity
+;
+
+similarity:double
+0.978  
+;

+ 4 - 0
x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java

@@ -21,6 +21,7 @@ import org.elasticsearch.xpack.esql.EsqlTestUtils;
 import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase;
 import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
 import org.elasticsearch.xpack.esql.expression.function.vector.L1Norm;
+import org.elasticsearch.xpack.esql.expression.function.vector.L2Norm;
 import org.elasticsearch.xpack.esql.expression.function.vector.VectorSimilarityFunction.SimilarityEvaluatorFunction;
 import org.junit.Before;
 
@@ -47,6 +48,9 @@ public class VectorSimilarityFunctionsIT extends AbstractEsqlIntegTestCase {
         if (EsqlCapabilities.Cap.L1_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
             params.add(new Object[] { "v_l1_norm", (SimilarityEvaluatorFunction) L1Norm::calculateSimilarity });
         }
+        if (EsqlCapabilities.Cap.L2_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
+            params.add(new Object[] { "v_l2_norm", (SimilarityEvaluatorFunction) L2Norm::calculateSimilarity });
+        }
 
         return params;
     }

+ 5 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

@@ -1322,6 +1322,11 @@ public class EsqlCapabilities {
          */
         L1_NORM_VECTOR_SIMILARITY_FUNCTION(Build.current().isSnapshot()),
 
+        /**
+         * l2 norm vector similarity function
+         */
+        L2_NORM_VECTOR_SIMILARITY_FUNCTION(Build.current().isSnapshot()),
+
         /**
          * Support for the options field of CATEGORIZE.
          */

+ 3 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java

@@ -184,6 +184,7 @@ import org.elasticsearch.xpack.esql.expression.function.vector.CosineSimilarity;
 import org.elasticsearch.xpack.esql.expression.function.vector.DotProduct;
 import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
 import org.elasticsearch.xpack.esql.expression.function.vector.L1Norm;
+import org.elasticsearch.xpack.esql.expression.function.vector.L2Norm;
 import org.elasticsearch.xpack.esql.parser.ParsingException;
 import org.elasticsearch.xpack.esql.session.Configuration;
 
@@ -495,7 +496,8 @@ public class EsqlFunctionRegistry {
                 def(StGeohexToString.class, StGeohexToString::new, "st_geohex_to_string"),
                 def(CosineSimilarity.class, CosineSimilarity::new, "v_cosine"),
                 def(DotProduct.class, DotProduct::new, "v_dot_product"),
-                def(L1Norm.class, L1Norm::new, "v_l1_norm") } };
+                def(L1Norm.class, L1Norm::new, "v_l1_norm"),
+                def(L2Norm.class, L2Norm::new, "v_l2_norm") } };
     }
 
     public EsqlFunctionRegistry snapshotRegistry() {

+ 81 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/L2Norm.java

@@ -0,0 +1,81 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.vector;
+
+import org.apache.lucene.util.VectorUtil;
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.function.scalar.BinaryScalarFunction;
+import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.expression.function.Example;
+import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo;
+import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle;
+import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
+import org.elasticsearch.xpack.esql.expression.function.Param;
+
+import java.io.IOException;
+
+public class L2Norm extends VectorSimilarityFunction {
+
+    public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "L2Norm", L2Norm::new);
+    static final SimilarityEvaluatorFunction SIMILARITY_FUNCTION = L2Norm::calculateSimilarity;
+
+    @FunctionInfo(
+        returnType = "double",
+        preview = true,
+        description = "Calculates the l2 norm between two dense_vectors.",
+        examples = { @Example(file = "vector-l2-norm", tag = "vector-l2-norm") },
+        appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) }
+    )
+    public L2Norm(
+        Source source,
+        @Param(
+            name = "left",
+            type = { "dense_vector" },
+            description = "first dense_vector to calculate l2 norm similarity"
+        ) Expression left,
+        @Param(
+            name = "right",
+            type = { "dense_vector" },
+            description = "second dense_vector to calculate l2 norm similarity"
+        ) Expression right
+    ) {
+        super(source, left, right);
+    }
+
+    private L2Norm(StreamInput in) throws IOException {
+        super(in);
+    }
+
+    @Override
+    protected BinaryScalarFunction replaceChildren(Expression newLeft, Expression newRight) {
+        return new L2Norm(source(), newLeft, newRight);
+    }
+
+    @Override
+    protected SimilarityEvaluatorFunction getSimilarityFunction() {
+        return SIMILARITY_FUNCTION;
+    }
+
+    @Override
+    protected NodeInfo<? extends Expression> info() {
+        return NodeInfo.create(this, L2Norm::new, left(), right());
+    }
+
+    @Override
+    public String getWriteableName() {
+        return ENTRY.name;
+    }
+
+    public static float calculateSimilarity(float[] leftScratch, float[] rightScratch) {
+        return (float) Math.sqrt(VectorUtil.squareDistance(leftScratch, rightScratch));
+    }
+
+}

+ 3 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java

@@ -39,6 +39,9 @@ public final class VectorWritables {
         if (EsqlCapabilities.Cap.L1_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
             entries.add(L1Norm.ENTRY);
         }
+        if (EsqlCapabilities.Cap.L2_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
+            entries.add(L2Norm.ENTRY);
+        }
 
         return Collections.unmodifiableList(entries);
     }

+ 7 - 0
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java

@@ -2369,6 +2369,10 @@ public class AnalyzerTests extends ESTestCase {
             checkDenseVectorImplicitCastingSimilarityFunction("v_l1_norm(vector, [0.342, 0.164, 0.234])", List.of(0.342f, 0.164f, 0.234f));
             checkDenseVectorImplicitCastingSimilarityFunction("v_l1_norm(vector, [1, 2, 3])", List.of(1f, 2f, 3f));
         }
+        if (EsqlCapabilities.Cap.L2_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
+            checkDenseVectorImplicitCastingSimilarityFunction("v_l2_norm(vector, [0.342, 0.164, 0.234])", List.of(0.342f, 0.164f, 0.234f));
+            checkDenseVectorImplicitCastingSimilarityFunction("v_l2_norm(vector, [1, 2, 3])", List.of(1f, 2f, 3f));
+        }
     }
 
     private void checkDenseVectorImplicitCastingSimilarityFunction(String similarityFunction, List<Number> expectedElems) {
@@ -2398,6 +2402,9 @@ public class AnalyzerTests extends ESTestCase {
         if (EsqlCapabilities.Cap.L1_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
             checkNoDenseVectorFailsSimilarityFunction("v_l1_norm([0, 1, 2], 0.342)");
         }
+        if (EsqlCapabilities.Cap.L2_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
+            checkNoDenseVectorFailsSimilarityFunction("v_l2_norm([0, 1, 2], 0.342)");
+        }
     }
 
     private void checkNoDenseVectorFailsSimilarityFunction(String similarityFunction) {

+ 4 - 0
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java

@@ -2495,6 +2495,10 @@ public class VerifierTests extends ESTestCase {
             checkVectorSimilarityFunctionsNullArgs("v_l1_norm(null, vector)", "first");
             checkVectorSimilarityFunctionsNullArgs("v_l1_norm(vector, null)", "second");
         }
+        if (EsqlCapabilities.Cap.L2_NORM_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
+            checkVectorSimilarityFunctionsNullArgs("v_l2_norm(null, vector)", "first");
+            checkVectorSimilarityFunctionsNullArgs("v_l2_norm(vector, null)", "second");
+        }
     }
 
     private void checkVectorSimilarityFunctionsNullArgs(String functionInvocation, String argOrdinal) throws Exception {

+ 42 - 0
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/L2NormSimilarityTests.java

@@ -0,0 +1,42 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.vector;
+
+import com.carrotsearch.randomizedtesting.annotations.Name;
+import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
+
+import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.expression.function.FunctionName;
+import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
+
+import java.util.List;
+import java.util.function.Supplier;
+
+@FunctionName("v_l2_norm")
+public class L2NormSimilarityTests extends AbstractVectorSimilarityFunctionTestCase {
+
+    public L2NormSimilarityTests(@Name("TestCase") Supplier<TestCaseSupplier.TestCase> testCaseSupplier) {
+        super(testCaseSupplier);
+    }
+
+    @ParametersFactory
+    public static Iterable<Object[]> parameters() {
+        return similarityParameters(L2Norm.class.getSimpleName(), L2Norm.SIMILARITY_FUNCTION);
+    }
+
+    protected EsqlCapabilities.Cap capability() {
+        return EsqlCapabilities.Cap.L2_NORM_VECTOR_SIMILARITY_FUNCTION;
+    }
+
+    @Override
+    protected Expression build(Source source, List<Expression> args) {
+        return new L2Norm(source, args.get(0), args.get(1));
+    }
+}