瀏覽代碼

ESQL: dense_vector cosine similarity function (#130641)

Carlos Delgado 3 月之前
父節點
當前提交
f1ddd4c312
共有 17 個文件被更改,包括 842 次插入15 次删除
  1. 1 0
      docs/reference/query-languages/esql/images/functions/v_cosine.svg
  2. 12 0
      docs/reference/query-languages/esql/kibana/definition/functions/v_cosine.json
  3. 11 0
      docs/reference/query-languages/esql/kibana/docs/functions/v_cosine.md
  4. 93 0
      x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-cosine-similarity.csv-spec
  5. 208 0
      x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java
  6. 6 1
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
  7. 12 4
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java
  8. 2 6
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java
  9. 3 1
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java
  10. 77 0
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/CosineSimilarity.java
  11. 174 0
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java
  12. 39 0
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java
  13. 44 2
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java
  14. 14 0
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java
  15. 102 0
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java
  16. 42 0
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/CosineSimilarityTests.java
  17. 2 1
      x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml

+ 1 - 0
docs/reference/query-languages/esql/images/functions/v_cosine.svg

@@ -0,0 +1 @@
+<svg version="1.1" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns="http://www.w3.org/2000/svg" width="420" height="46" viewbox="0 0 420 46"><defs><style type="text/css">.c{fill:none;stroke:#222222;}.k{fill:#000000;font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;font-size:20px;}.s{fill:#e4f4ff;stroke:#222222;}.syn{fill:#8D8D8D;font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace;font-size:20px;}</style></defs><path class="c" d="M0 31h5m116 0h10m32 0h10m68 0h10m32 0h10m80 0h10m32 0h5"/><rect class="s" x="5" y="5" width="116" height="36"/><text class="k" x="15" y="31">V_COSINE</text><rect class="s" x="131" y="5" width="32" height="36" rx="7"/><text class="syn" x="141" y="31">(</text><rect class="s" x="173" y="5" width="68" height="36" rx="7"/><text class="k" x="183" y="31">left</text><rect class="s" x="251" y="5" width="32" height="36" rx="7"/><text class="syn" x="261" y="31">,</text><rect class="s" x="293" y="5" width="80" height="36" rx="7"/><text class="k" x="303" y="31">right</text><rect class="s" x="383" y="5" width="32" height="36" rx="7"/><text class="syn" x="393" y="31">)</text></svg>

+ 12 - 0
docs/reference/query-languages/esql/kibana/definition/functions/v_cosine.json

@@ -0,0 +1,12 @@
+{
+  "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.",
+  "type" : "scalar",
+  "name" : "v_cosine",
+  "description" : "Calculates the cosine similarity between two dense_vectors.",
+  "signatures" : [ ],
+  "examples" : [
+    " from colors\n | where color != \"black\"\n | eval similarity = v_cosine(rgb_vector, [0, 255, 255])\n | sort similarity desc, color asc"
+  ],
+  "preview" : true,
+  "snapshot_only" : true
+}

+ 11 - 0
docs/reference/query-languages/esql/kibana/docs/functions/v_cosine.md

@@ -0,0 +1,11 @@
+% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.
+
+### V COSINE
+Calculates the cosine similarity between two dense_vectors.
+
+```esql
+ from colors
+ | where color != "black"
+ | eval similarity = v_cosine(rgb_vector, [0, 255, 255])
+ | sort similarity desc, color asc
+```

+ 93 - 0
x-pack/plugin/esql/qa/testFixtures/src/main/resources/vector-cosine-similarity.csv-spec

@@ -0,0 +1,93 @@
+ # Tests for cosine similarity function
+ 
+ similarityWithVectorField
+ required_capability: cosine_vector_similarity_function
+ 
+// tag::vector-cosine-similarity[]
+ from colors
+ | where color != "black" 
+ | eval similarity = v_cosine(rgb_vector, [0, 255, 255]) 
+ | sort similarity desc, color asc 
+// end::vector-cosine-similarity[]
+ | limit 10
+ | keep color, similarity
+ ;
+ 
+// tag::vector-cosine-similarity-result[]
+color:text     | similarity:double
+cyan           | 1.0
+teal           | 1.0
+turquoise      | 0.9890533685684204
+aqua marine    | 0.964962363243103
+azure          | 0.916246771812439
+lavender       | 0.9136701822280884
+mint cream     | 0.9122757911682129
+honeydew       | 0.9122424125671387
+gainsboro      | 0.9082483053207397
+gray           | 0.9082483053207397  
+// end::vector-cosine-similarity-result[] 
+;
+
+ similarityAsPartOfExpression
+ required_capability: cosine_vector_similarity_function
+ 
+ from colors
+ | where color != "black" 
+ | eval score = round((1 + v_cosine(rgb_vector, [0, 255, 255]) / 2), 3) 
+ | sort score desc, color asc 
+ | limit 10
+ | keep color, score
+ ;
+
+color:text   | score:double
+cyan         | 1.5
+teal         | 1.5
+turquoise    | 1.495
+aqua marine  | 1.482
+azure        | 1.458
+lavender     | 1.457
+honeydew     | 1.456
+mint cream   | 1.456
+gainsboro    | 1.454
+gray         | 1.454  
+;
+
+similarityWithLiteralVectors
+required_capability: cosine_vector_similarity_function
+ 
+row a = 1
+| eval similarity = round(v_cosine([1, 2, 3], [0, 1, 2]), 3) 
+| keep similarity
+;
+
+similarity:double
+0.978  
+;
+
+ similarityWithStats
+ required_capability: cosine_vector_similarity_function
+ 
+ from colors
+ | where color != "black" 
+ | eval similarity = round(v_cosine(rgb_vector, [0, 255, 255]), 3) 
+ | stats avg = round(avg(similarity), 3), min = min(similarity), max = max(similarity)
+ ;
+
+avg:double | min:double | max:double
+0.832      | 0.5        | 1.0
+;
+
+# TODO Need to implement a conversion function to convert a non-foldable row to a dense_vector
+similarityWithRow-Ignore
+required_capability: cosine_vector_similarity_function
+ 
+row vector = [1, 2, 3] 
+| eval similarity = round(v_cosine(vector, [0, 1, 2]), 3) 
+| sort similarity desc, color asc 
+| limit 10
+| keep color, similarity
+;
+
+similarity:double
+0.978  
+;

+ 208 - 0
x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/vector/VectorSimilarityFunctionsIT.java

@@ -0,0 +1,208 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.vector;
+
+import com.carrotsearch.randomizedtesting.annotations.Name;
+import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
+
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.elasticsearch.action.index.IndexRequestBuilder;
+import org.elasticsearch.cluster.metadata.IndexMetadata;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xpack.esql.EsqlClientException;
+import org.elasticsearch.xpack.esql.EsqlTestUtils;
+import org.elasticsearch.xpack.esql.action.AbstractEsqlIntegTestCase;
+import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Locale;
+
+import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
+
+public class VectorSimilarityFunctionsIT extends AbstractEsqlIntegTestCase {
+
+    @ParametersFactory
+    public static Iterable<Object[]> parameters() throws Exception {
+        List<Object[]> params = new ArrayList<>();
+
+        params.add(new Object[] { "v_cosine", VectorSimilarityFunction.COSINE });
+
+        return params;
+    }
+
+    private final String functionName;
+    private final VectorSimilarityFunction similarityFunction;
+    private int numDims;
+
+    public VectorSimilarityFunctionsIT(
+        @Name("functionName") String functionName,
+        @Name("similarityFunction") VectorSimilarityFunction similarityFunction
+    ) {
+        this.functionName = functionName;
+        this.similarityFunction = similarityFunction;
+    }
+
+    @SuppressWarnings("unchecked")
+    public void testSimilarityBetweenVectors() {
+        var query = String.format(Locale.ROOT, """
+                FROM test
+                | EVAL similarity = %s(left_vector, right_vector)
+                | KEEP left_vector, right_vector, similarity
+            """, functionName);
+
+        try (var resp = run(query)) {
+            List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
+            valuesList.forEach(values -> {
+                float[] left = readVector((List<Float>) values.get(0));
+                float[] right = readVector((List<Float>) values.get(1));
+                Double similarity = (Double) values.get(2);
+
+                assertNotNull(similarity);
+                float expectedSimilarity = similarityFunction.compare(left, right);
+                assertEquals(expectedSimilarity, similarity, 0.0001);
+            });
+        }
+    }
+
+    @SuppressWarnings("unchecked")
+    public void testSimilarityBetweenConstantVectorAndField() {
+        var randomVector = randomVectorArray();
+        var query = String.format(Locale.ROOT, """
+                FROM test
+                | EVAL similarity = %s(left_vector, %s)
+                | KEEP left_vector, similarity
+            """, functionName, Arrays.toString(randomVector));
+
+        try (var resp = run(query)) {
+            List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
+            valuesList.forEach(values -> {
+                float[] left = readVector((List<Float>) values.get(0));
+                Double similarity = (Double) values.get(1);
+
+                assertNotNull(similarity);
+                float expectedSimilarity = similarityFunction.compare(left, randomVector);
+                assertEquals(expectedSimilarity, similarity, 0.0001);
+            });
+        }
+    }
+
+    public void testDifferentDimensions() {
+        var randomVector = randomVectorArray(randomValueOtherThan(numDims, () -> randomIntBetween(32, 64) * 2));
+        var query = String.format(Locale.ROOT, """
+                FROM test
+                | EVAL similarity = %s(left_vector, %s)
+                | KEEP left_vector, similarity
+            """, functionName, Arrays.toString(randomVector));
+
+        EsqlClientException iae = expectThrows(EsqlClientException.class, () -> { run(query); });
+        assertTrue(iae.getMessage().contains("Vectors must have the same dimensions"));
+    }
+
+    @SuppressWarnings("unchecked")
+    public void testSimilarityBetweenConstantVectors() {
+        var vectorLeft = randomVectorArray();
+        var vectorRight = randomVectorArray();
+        var query = String.format(Locale.ROOT, """
+                ROW a = 1
+                | EVAL similarity = %s(%s, %s)
+                | KEEP similarity
+            """, functionName, Arrays.toString(vectorLeft), Arrays.toString(vectorRight));
+
+        try (var resp = run(query)) {
+            List<List<Object>> valuesList = EsqlTestUtils.getValuesList(resp);
+            assertEquals(1, valuesList.size());
+
+            Double similarity = (Double) valuesList.get(0).get(0);
+            assertNotNull(similarity);
+            float expectedSimilarity = similarityFunction.compare(vectorLeft, vectorRight);
+            assertEquals(expectedSimilarity, similarity, 0.0001);
+        }
+    }
+
+    private static float[] readVector(List<Float> leftVector) {
+        float[] leftScratch = new float[leftVector.size()];
+        for (int i = 0; i < leftVector.size(); i++) {
+            leftScratch[i] = leftVector.get(i);
+        }
+        return leftScratch;
+    }
+
+    @Before
+    public void setup() throws IOException {
+        assumeTrue("Dense vector type is disabled", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled());
+
+        createIndexWithDenseVector("test");
+
+        numDims = randomIntBetween(32, 64) * 2; // min 64, even number
+        int numDocs = randomIntBetween(10, 100);
+        IndexRequestBuilder[] docs = new IndexRequestBuilder[numDocs];
+        for (int i = 0; i < numDocs; i++) {
+            List<Float> leftVector = randomVector();
+            List<Float> rightVector = randomVector();
+            docs[i] = prepareIndex("test").setId("" + i)
+                .setSource("id", String.valueOf(i), "left_vector", leftVector, "right_vector", rightVector);
+        }
+
+        indexRandom(true, docs);
+    }
+
+    private List<Float> randomVector() {
+        assert numDims != 0 : "numDims must be set before calling randomVector()";
+        List<Float> vector = new ArrayList<>(numDims);
+        for (int j = 0; j < numDims; j++) {
+            vector.add(randomFloat());
+        }
+        return vector;
+    }
+
+    private float[] randomVectorArray() {
+        assert numDims != 0 : "numDims must be set before calling randomVectorArray()";
+        return randomVectorArray(numDims);
+    }
+
+    private static float[] randomVectorArray(int dimensions) {
+        float[] vector = new float[dimensions];
+        for (int j = 0; j < dimensions; j++) {
+            vector[j] = randomFloat();
+        }
+        return vector;
+    }
+
+    private void createIndexWithDenseVector(String indexName) throws IOException {
+        var client = client().admin().indices();
+        XContentBuilder mapping = XContentFactory.jsonBuilder()
+            .startObject()
+            .startObject("properties")
+            .startObject("id")
+            .field("type", "integer")
+            .endObject();
+        createDenseVectorField(mapping, "left_vector");
+        createDenseVectorField(mapping, "right_vector");
+        mapping.endObject().endObject();
+        Settings.Builder settingsBuilder = Settings.builder()
+            .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
+            .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, randomIntBetween(1, 5));
+
+        var CreateRequest = client.prepareCreate(indexName)
+            .setSettings(Settings.builder().put("index.number_of_shards", 1))
+            .setMapping(mapping)
+            .setSettings(settingsBuilder.build());
+        assertAcked(CreateRequest);
+    }
+
+    private void createDenseVectorField(XContentBuilder mapping, String fieldName) throws IOException {
+        mapping.startObject(fieldName).field("type", "dense_vector").field("similarity", "cosine");
+        mapping.endObject();
+    }
+}

+ 6 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

@@ -1254,7 +1254,12 @@ public class EsqlCapabilities {
          * Forbid usage of brackets in unquoted index and enrich policy names
          * https://github.com/elastic/elasticsearch/issues/130378
          */
-        NO_BRACKETS_IN_UNQUOTED_INDEX_NAMES;
+        NO_BRACKETS_IN_UNQUOTED_INDEX_NAMES,
+
+        /*
+         * Cosine vector similarity function
+         */
+        COSINE_VECTOR_SIMILARITY_FUNCTION(Build.current().isSnapshot());
 
         private final boolean enabled;
 

+ 12 - 4
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java

@@ -1400,15 +1400,15 @@ public class Analyzer extends ParameterizedRuleExecutor<LogicalPlan, AnalyzerCon
             if (f instanceof In in) {
                 return processIn(in);
             }
+            if (f instanceof VectorFunction) {
+                return processVectorFunction(f);
+            }
             if (f instanceof EsqlScalarFunction || f instanceof GroupingFunction) { // exclude AggregateFunction until it is needed
                 return processScalarOrGroupingFunction(f, registry);
             }
             if (f instanceof EsqlArithmeticOperation || f instanceof BinaryComparison) {
                 return processBinaryOperator((BinaryOperator) f);
             }
-            if (f instanceof VectorFunction vectorFunction) {
-                return processVectorFunction(f);
-            }
             return f;
         }
 
@@ -1613,6 +1613,7 @@ public class Analyzer extends ParameterizedRuleExecutor<LogicalPlan, AnalyzerCon
             }
         }
 
+        @SuppressWarnings("unchecked")
         private static Expression processVectorFunction(org.elasticsearch.xpack.esql.core.expression.function.Function vectorFunction) {
             List<Expression> args = vectorFunction.arguments();
             List<Expression> newArgs = new ArrayList<>();
@@ -1620,7 +1621,14 @@ public class Analyzer extends ParameterizedRuleExecutor<LogicalPlan, AnalyzerCon
                 if (arg.resolved() && arg.dataType().isNumeric() && arg.foldable()) {
                     Object folded = arg.fold(FoldContext.small() /* TODO remove me */);
                     if (folded instanceof List) {
-                        Literal denseVector = new Literal(arg.source(), folded, DataType.DENSE_VECTOR);
+                        // Convert to floats so blocks are created accordingly
+                        List<Float> floatVector;
+                        if (arg.dataType() == FLOAT) {
+                            floatVector = (List<Float>) folded;
+                        } else {
+                            floatVector = ((List<Number>) folded).stream().map(Number::floatValue).collect(Collectors.toList());
+                        }
+                        Literal denseVector = new Literal(arg.source(), floatVector, DataType.DENSE_VECTOR);
                         newArgs.add(denseVector);
                         continue;
                     }

+ 2 - 6
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/ExpressionWritables.java

@@ -8,7 +8,6 @@
 package org.elasticsearch.xpack.esql.expression;
 
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
-import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
 import org.elasticsearch.xpack.esql.core.expression.ExpressionCoreWritables;
 import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute;
 import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateWritables;
@@ -85,7 +84,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.RLik
 import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.WildcardLike;
 import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.WildcardLikeList;
 import org.elasticsearch.xpack.esql.expression.function.scalar.util.Delay;
-import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
+import org.elasticsearch.xpack.esql.expression.function.vector.VectorWritables;
 import org.elasticsearch.xpack.esql.expression.predicate.logical.Not;
 import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNotNull;
 import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNull;
@@ -259,9 +258,6 @@ public class ExpressionWritables {
     }
 
     private static List<NamedWriteableRegistry.Entry> vector() {
-        if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
-            return List.of(Knn.ENTRY);
-        }
-        return List.of();
+        return VectorWritables.getNamedWritables();
     }
 }

+ 3 - 1
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java

@@ -180,6 +180,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.string.ToLower;
 import org.elasticsearch.xpack.esql.expression.function.scalar.string.ToUpper;
 import org.elasticsearch.xpack.esql.expression.function.scalar.string.Trim;
 import org.elasticsearch.xpack.esql.expression.function.scalar.util.Delay;
+import org.elasticsearch.xpack.esql.expression.function.vector.CosineSimilarity;
 import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
 import org.elasticsearch.xpack.esql.parser.ParsingException;
 import org.elasticsearch.xpack.esql.session.Configuration;
@@ -489,7 +490,8 @@ public class EsqlFunctionRegistry {
                 def(StGeotileToString.class, StGeotileToString::new, "st_geotile_to_string"),
                 def(StGeohex.class, StGeohex::new, "st_geohex"),
                 def(StGeohexToLong.class, StGeohexToLong::new, "st_geohex_to_long"),
-                def(StGeohexToString.class, StGeohexToString::new, "st_geohex_to_string") } };
+                def(StGeohexToString.class, StGeohexToString::new, "st_geohex_to_string"),
+                def(CosineSimilarity.class, CosineSimilarity::new, "v_cosine") } };
     }
 
     public EsqlFunctionRegistry snapshotRegistry() {

+ 77 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/CosineSimilarity.java

@@ -0,0 +1,77 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.vector;
+
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.function.scalar.BinaryScalarFunction;
+import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.expression.function.Example;
+import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo;
+import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle;
+import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
+import org.elasticsearch.xpack.esql.expression.function.Param;
+
+import java.io.IOException;
+
+import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
+
+public class CosineSimilarity extends VectorSimilarityFunction {
+
+    public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
+        Expression.class,
+        "CosineSimilarity",
+        CosineSimilarity::new
+    );
+    static final SimilarityEvaluatorFunction SIMILARITY_FUNCTION = COSINE::compare;
+
+    @FunctionInfo(
+        returnType = "double",
+        preview = true,
+        description = "Calculates the cosine similarity between two dense_vectors.",
+        examples = { @Example(file = "vector-cosine-similarity", tag = "vector-cosine-similarity") },
+        appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.DEVELOPMENT) }
+    )
+    public CosineSimilarity(
+        Source source,
+        @Param(name = "left", type = { "dense_vector" }, description = "first dense_vector to calculate cosine similarity") Expression left,
+        @Param(
+            name = "right",
+            type = { "dense_vector" },
+            description = "second dense_vector to calculate cosine similarity"
+        ) Expression right
+    ) {
+        super(source, left, right);
+    }
+
+    private CosineSimilarity(StreamInput in) throws IOException {
+        super(in);
+    }
+
+    @Override
+    protected BinaryScalarFunction replaceChildren(Expression newLeft, Expression newRight) {
+        return new CosineSimilarity(source(), newLeft, newRight);
+    }
+
+    @Override
+    protected SimilarityEvaluatorFunction getSimilarityFunction() {
+        return SIMILARITY_FUNCTION;
+    }
+
+    @Override
+    protected NodeInfo<? extends Expression> info() {
+        return NodeInfo.create(this, CosineSimilarity::new, left(), right());
+    }
+
+    @Override
+    public String getWriteableName() {
+        return ENTRY.name;
+    }
+}

+ 174 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorSimilarityFunction.java

@@ -0,0 +1,174 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.vector;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.DoubleVector;
+import org.elasticsearch.compute.data.FloatBlock;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.compute.operator.DriverContext;
+import org.elasticsearch.compute.operator.EvalOperator;
+import org.elasticsearch.xpack.esql.EsqlClientException;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.FoldContext;
+import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
+import org.elasticsearch.xpack.esql.core.expression.function.scalar.BinaryScalarFunction;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.core.type.DataType;
+import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;
+
+import java.io.IOException;
+
+import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
+import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND;
+import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNull;
+import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;
+import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
+
+/**
+ * Base class for vector similarity functions, which compute a similarity score between two dense vectors
+ */
+public abstract class VectorSimilarityFunction extends BinaryScalarFunction implements EvaluatorMapper, VectorFunction {
+
+    protected VectorSimilarityFunction(Source source, Expression left, Expression right) {
+        super(source, left, right);
+    }
+
+    protected VectorSimilarityFunction(StreamInput in) throws IOException {
+        super(in);
+    }
+
+    @Override
+    public DataType dataType() {
+        return DataType.DOUBLE;
+    }
+
+    @Override
+    protected TypeResolution resolveType() {
+        if (childrenResolved() == false) {
+            return new TypeResolution("Unresolved children");
+        }
+
+        return checkDenseVectorParam(left(), FIRST).and(checkDenseVectorParam(right(), SECOND));
+    }
+
+    private TypeResolution checkDenseVectorParam(Expression param, TypeResolutions.ParamOrdinal paramOrdinal) {
+        return isNotNull(param, sourceText(), paramOrdinal).and(
+            isType(param, dt -> dt == DENSE_VECTOR, sourceText(), paramOrdinal, "dense_vector")
+        );
+    }
+
+    /**
+     * Functional interface for evaluating the similarity between two float arrays
+     */
+    @FunctionalInterface
+    public interface SimilarityEvaluatorFunction {
+        float calculateSimilarity(float[] leftScratch, float[] rightScratch);
+    }
+
+    @Override
+    public Object fold(FoldContext ctx) {
+        return EvaluatorMapper.super.fold(source(), ctx);
+    }
+
+    @Override
+    public final EvalOperator.ExpressionEvaluator.Factory toEvaluator(EvaluatorMapper.ToEvaluator toEvaluator) {
+        return new SimilarityEvaluatorFactory(
+            toEvaluator.apply(left()),
+            toEvaluator.apply(right()),
+            getSimilarityFunction(),
+            getClass().getSimpleName() + "Evaluator"
+        );
+    }
+
+    /**
+     * Returns the similarity function to be used for evaluating the similarity between two vectors.
+     */
+    protected abstract SimilarityEvaluatorFunction getSimilarityFunction();
+
+    private record SimilarityEvaluatorFactory(
+        EvalOperator.ExpressionEvaluator.Factory left,
+        EvalOperator.ExpressionEvaluator.Factory right,
+        SimilarityEvaluatorFunction similarityFunction,
+        String evaluatorName
+    ) implements EvalOperator.ExpressionEvaluator.Factory {
+
+        @Override
+        public EvalOperator.ExpressionEvaluator get(DriverContext context) {
+            // TODO check whether to use this custom evaluator or reuse / define an existing one
+            return new EvalOperator.ExpressionEvaluator() {
+                @Override
+                public Block eval(Page page) {
+                    try (
+                        FloatBlock leftBlock = (FloatBlock) left.get(context).eval(page);
+                        FloatBlock rightBlock = (FloatBlock) right.get(context).eval(page)
+                    ) {
+                        int positionCount = page.getPositionCount();
+                        int dimensions = 0;
+                        // Get the first non-empty vector to calculate the dimension
+                        for (int p = 0; p < positionCount; p++) {
+                            if (leftBlock.getValueCount(p) != 0) {
+                                dimensions = leftBlock.getValueCount(p);
+                                break;
+                            }
+                        }
+                        if (dimensions == 0) {
+                            return context.blockFactory().newConstantFloatBlockWith(0F, 0);
+                        }
+
+                        float[] leftScratch = new float[dimensions];
+                        float[] rightScratch = new float[dimensions];
+                        try (DoubleVector.Builder builder = context.blockFactory().newDoubleVectorBuilder(positionCount * dimensions)) {
+                            for (int p = 0; p < positionCount; p++) {
+                                int dimsLeft = leftBlock.getValueCount(p);
+                                int dimsRight = rightBlock.getValueCount(p);
+
+                                if (dimsLeft == 0 || dimsRight == 0) {
+                                    // A null value on the left or right vector. Similarity is 0
+                                    builder.appendDouble(0.0);
+                                    continue;
+                                } else if (dimsLeft != dimsRight) {
+                                    throw new EsqlClientException(
+                                        "Vectors must have the same dimensions; first vector has {}, and second has {}",
+                                        dimsLeft,
+                                        dimsRight
+                                    );
+                                }
+                                readFloatArray(leftBlock, leftBlock.getFirstValueIndex(p), dimensions, leftScratch);
+                                readFloatArray(rightBlock, rightBlock.getFirstValueIndex(p), dimensions, rightScratch);
+                                float result = similarityFunction.calculateSimilarity(leftScratch, rightScratch);
+                                builder.appendDouble(result);
+                            }
+                            return builder.build().asBlock();
+                        }
+                    }
+                }
+
+                @Override
+                public String toString() {
+                    return evaluatorName() + "[left=" + left + ", right=" + right + "]";
+                }
+
+                @Override
+                public void close() {}
+            };
+        }
+
+        private static void readFloatArray(FloatBlock block, int position, int dimensions, float[] scratch) {
+            for (int i = 0; i < dimensions; i++) {
+                scratch[i] = block.getFloat(position + i);
+            }
+        }
+
+        @Override
+        public String toString() {
+            return evaluatorName() + "[left=" + left + ", right=" + right + "]";
+        }
+    }
+}

+ 39 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/vector/VectorWritables.java

@@ -0,0 +1,39 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.vector;
+
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+/**
+ * Defines the named writables for vector functions in ESQL.
+ */
+public final class VectorWritables {
+
+    private VectorWritables() {
+        // Utility class
+        throw new UnsupportedOperationException();
+    }
+
+    public static List<NamedWriteableRegistry.Entry> getNamedWritables() {
+        List<NamedWriteableRegistry.Entry> entries = new ArrayList<>();
+
+        if (EsqlCapabilities.Cap.KNN_FUNCTION_V2.isEnabled()) {
+            entries.add(Knn.ENTRY);
+        }
+        if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
+            entries.add(CosineSimilarity.ENTRY);
+        }
+
+        return Collections.unmodifiableList(entries);
+    }
+}

+ 44 - 2
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java

@@ -57,6 +57,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong;
 import org.elasticsearch.xpack.esql.expression.function.scalar.string.Concat;
 import org.elasticsearch.xpack.esql.expression.function.scalar.string.Substring;
 import org.elasticsearch.xpack.esql.expression.function.vector.Knn;
+import org.elasticsearch.xpack.esql.expression.function.vector.VectorSimilarityFunction;
 import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add;
 import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals;
 import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan;
@@ -92,6 +93,7 @@ import java.io.IOException;
 import java.time.Period;
 import java.util.ArrayList;
 import java.util.List;
+import java.util.Locale;
 import java.util.Map;
 import java.util.Set;
 import java.util.function.Function;
@@ -123,6 +125,7 @@ import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY;
 import static org.elasticsearch.xpack.esql.core.type.DataType.DATETIME;
 import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_NANOS;
 import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_PERIOD;
+import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
 import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE;
 import static org.elasticsearch.xpack.esql.core.type.DataType.LONG;
 import static org.elasticsearch.xpack.esql.core.type.DataType.UNSUPPORTED;
@@ -2337,7 +2340,7 @@ public class AnalyzerTests extends ESTestCase {
         assertThat(e.getMessage(), containsString("[+] has arguments with incompatible types [datetime] and [datetime]"));
     }
 
-    public void testDenseVectorImplicitCasting() {
+    public void testDenseVectorImplicitCastingKnn() {
         assumeTrue("dense_vector capability not available", EsqlCapabilities.Cap.DENSE_VECTOR_FIELD_TYPE.isEnabled());
         Analyzer analyzer = analyzer(loadMapping("mapping-dense_vector.json", "vectors"));
 
@@ -2351,7 +2354,46 @@ public class AnalyzerTests extends ESTestCase {
         var field = knn.field();
         var queryVector = as(knn.query(), Literal.class);
         assertEquals(DataType.DENSE_VECTOR, queryVector.dataType());
-        assertThat(queryVector.value(), equalTo(List.of(0.342, 0.164, 0.234)));
+        assertThat(queryVector.value(), equalTo(List.of(0.342f, 0.164f, 0.234f)));
+    }
+
+    public void testDenseVectorImplicitCastingSimilarityFunctions() {
+        if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
+            checkDenseVectorImplicitCastingSimilarityFunction("v_cosine(vector, [0.342, 0.164, 0.234])", List.of(0.342f, 0.164f, 0.234f));
+            checkDenseVectorImplicitCastingSimilarityFunction("v_cosine(vector, [1, 2, 3])", List.of(1f, 2f, 3f));
+        }
+    }
+
+    private void checkDenseVectorImplicitCastingSimilarityFunction(String similarityFunction, List<Number> expectedElems) {
+        var plan = analyze(String.format(Locale.ROOT, """
+            from test | eval similarity = %s
+            """, similarityFunction), "mapping-dense_vector.json");
+
+        var limit = as(plan, Limit.class);
+        var eval = as(limit.child(), Eval.class);
+        var alias = as(eval.fields().get(0), Alias.class);
+        assertEquals("similarity", alias.name());
+        var similarity = as(alias.child(), VectorSimilarityFunction.class);
+        var left = as(similarity.left(), FieldAttribute.class);
+        assertEquals("vector", left.name());
+        var right = as(similarity.right(), Literal.class);
+        assertThat(right.dataType(), is(DENSE_VECTOR));
+        assertThat(right.value(), equalTo(expectedElems));
+    }
+
+    public void testNoDenseVectorFailsSimilarityFunction() {
+        if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
+            checkNoDenseVectorFailsSimilarityFunction("v_cosine([0, 1, 2], 0.342)");
+        }
+    }
+
+    private void checkNoDenseVectorFailsSimilarityFunction(String similarityFunction) {
+        var query = String.format(Locale.ROOT, "row a = 1 |  eval similarity = %s", similarityFunction);
+        VerificationException error = expectThrows(VerificationException.class, () -> analyze(query));
+        assertThat(
+            error.getMessage(),
+            containsString("second argument of [" + similarityFunction + "] must be" + " [dense_vector], found value [0.342] type [double]")
+        );
     }
 
     public void testRateRequiresCounterTypes() {

+ 14 - 0
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java

@@ -2300,6 +2300,20 @@ public class VerifierTests extends ESTestCase {
         );
     }
 
+    public void testVectorSimilarityFunctionsNullArgs() throws Exception {
+        if (EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION.isEnabled()) {
+            checkVectorSimilarityFunctionsNullArgs("v_cosine(null, vector)", "first");
+            checkVectorSimilarityFunctionsNullArgs("v_cosine(vector, null)", "second");
+        }
+    }
+
+    private void checkVectorSimilarityFunctionsNullArgs(String functionInvocation, String argOrdinal) throws Exception {
+        assertThat(
+            error("from test | eval similarity = " + functionInvocation, fullTextAnalyzer),
+            containsString(argOrdinal + " argument of [" + functionInvocation + "] cannot be null, received [null]")
+        );
+    }
+
     private void query(String query) {
         query(query, defaultAnalyzer);
     }

+ 102 - 0
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/AbstractVectorSimilarityFunctionTestCase.java

@@ -0,0 +1,102 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.vector;
+
+import com.carrotsearch.randomizedtesting.annotations.Name;
+
+import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
+import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase;
+import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
+import org.hamcrest.Matcher;
+import org.junit.Before;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.function.Supplier;
+
+import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR;
+import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE;
+import static org.hamcrest.Matchers.equalTo;
+
+public abstract class AbstractVectorSimilarityFunctionTestCase extends AbstractScalarFunctionTestCase {
+
+    protected AbstractVectorSimilarityFunctionTestCase(@Name("TestCase") Supplier<TestCaseSupplier.TestCase> testCaseSupplier) {
+        this.testCase = testCaseSupplier.get();
+    }
+
+    @Before
+    public void checkCapability() {
+        assumeTrue("Similarity function is not enabled", capability().isEnabled());
+    }
+
+    /**
+     * Get the capability of the vector similarity function to check
+     */
+    protected abstract EsqlCapabilities.Cap capability();
+
+    protected static Iterable<Object[]> similarityParameters(
+        String className,
+        VectorSimilarityFunction.SimilarityEvaluatorFunction similarityFunction
+    ) {
+
+        final String evaluatorName = className + "Evaluator" + "[left=Attribute[channel=0], right=Attribute[channel=1]]";
+
+        List<TestCaseSupplier> suppliers = new ArrayList<>();
+
+        // Basic test with two dense vectors
+        suppliers.add(new TestCaseSupplier(List.of(DENSE_VECTOR, DENSE_VECTOR), () -> {
+            int dimensions = between(64, 128);
+            List<Float> left = randomDenseVector(dimensions);
+            List<Float> right = randomDenseVector(dimensions);
+            float[] leftArray = listToFloatArray(left);
+            float[] rightArray = listToFloatArray(right);
+            double expected = similarityFunction.calculateSimilarity(leftArray, rightArray);
+            return new TestCaseSupplier.TestCase(
+                List.of(
+                    new TestCaseSupplier.TypedData(left, DENSE_VECTOR, "vector1"),
+                    new TestCaseSupplier.TypedData(right, DENSE_VECTOR, "vector2")
+                ),
+                evaluatorName,
+                DOUBLE,
+                equalTo(expected) // Random vectors should have cosine similarity close to 0
+            );
+        }));
+
+        return parameterSuppliersFromTypedData(suppliers);
+    }
+
+    private static float[] listToFloatArray(List<Float> floatList) {
+        float[] floatArray = new float[floatList.size()];
+        for (int i = 0; i < floatList.size(); i++) {
+            floatArray[i] = floatList.get(i);
+        }
+        return floatArray;
+    }
+
+    protected double calculateSimilarity(List<Float> left, List<Float> right) {
+        return 0;
+    }
+
+    /**
+     * @return A random dense vector for testing
+     * @param dimensions
+     */
+    private static List<Float> randomDenseVector(int dimensions) {
+        List<Float> vector = new ArrayList<>();
+        for (int i = 0; i < dimensions; i++) {
+            vector.add(randomFloat());
+        }
+        return vector;
+    }
+
+    @Override
+    protected Matcher<Object> allNullsMatcher() {
+        // A null value on the left or right vector. Similarity is 0
+        return equalTo(0.0);
+    }
+}

+ 42 - 0
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/vector/CosineSimilarityTests.java

@@ -0,0 +1,42 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.vector;
+
+import com.carrotsearch.randomizedtesting.annotations.Name;
+import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
+
+import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.expression.function.FunctionName;
+import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
+
+import java.util.List;
+import java.util.function.Supplier;
+
+@FunctionName("v_cosine")
+public class CosineSimilarityTests extends AbstractVectorSimilarityFunctionTestCase {
+
+    public CosineSimilarityTests(@Name("TestCase") Supplier<TestCaseSupplier.TestCase> testCaseSupplier) {
+        super(testCaseSupplier);
+    }
+
+    @ParametersFactory
+    public static Iterable<Object[]> parameters() {
+        return similarityParameters(CosineSimilarity.class.getSimpleName(), CosineSimilarity.SIMILARITY_FUNCTION);
+    }
+
+    protected EsqlCapabilities.Cap capability() {
+        return EsqlCapabilities.Cap.COSINE_VECTOR_SIMILARITY_FUNCTION;
+    }
+
+    @Override
+    protected Expression build(Source source, List<Expression> args) {
+        return new CosineSimilarity(source, args.get(0), args.get(1));
+    }
+}

+ 2 - 1
x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/60_usage.yml

@@ -41,6 +41,7 @@ setup:
             - sum_over_time
             - count_over_time
             - distinct_over_time
+            - cosine_vector_similarity_function
       reason: "Test that should only be executed on snapshot versions"
 
   - do: {xpack.usage: {}}
@@ -130,7 +131,7 @@ setup:
   - match: {esql.functions.coalesce: $functions_coalesce}
   - gt: {esql.functions.categorize: $functions_categorize}
   # Testing for the entire function set isn't feasible, so we just check that we return the correct count as an approximation.
-  - length: {esql.functions: 156} # check the "sister" test below for a likely update to the same esql.functions length check
+  - length: {esql.functions: 157} # check the "sister" test below for a likely update to the same esql.functions length check
 
 ---
 "Basic ESQL usage output (telemetry) non-snapshot version":