Browse Source

ESQL: Add `MV_PSERIES_WEIGHTED_SUM` for score calculations used by security solution (#109017)

* Create MV_RIEMANN_ZETA scalar multivalue function



---------

Co-authored-by: Nik Everett <nik9000@gmail.com>
Pablo Machado 1 year ago
parent
commit
f79c62157d
24 changed files with 636 additions and 5 deletions
  1. 6 0
      docs/changelog/109017.yaml
  2. 5 0
      docs/reference/esql/functions/description/mv_pseries_weighted_sum.asciidoc
  3. 13 0
      docs/reference/esql/functions/examples/mv_pseries_weighted_sum.asciidoc
  4. 29 0
      docs/reference/esql/functions/kibana/definition/mv_pseries_weighted_sum.json
  5. 12 0
      docs/reference/esql/functions/kibana/docs/mv_pseries_weighted_sum.md
  6. 15 0
      docs/reference/esql/functions/layout/mv_pseries_weighted_sum.asciidoc
  7. 2 0
      docs/reference/esql/functions/mv-functions.asciidoc
  8. 9 0
      docs/reference/esql/functions/parameters/mv_pseries_weighted_sum.asciidoc
  9. 1 0
      docs/reference/esql/functions/signature/mv_pseries_weighted_sum.svg
  10. 9 0
      docs/reference/esql/functions/types/mv_pseries_weighted_sum.asciidoc
  11. 2 0
      x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java
  12. 11 0
      x-pack/plugin/esql/qa/testFixtures/src/main/resources/alerts.csv
  13. 10 0
      x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-alerts.json
  14. 5 1
      x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec
  15. 89 0
      x-pack/plugin/esql/qa/testFixtures/src/main/resources/mv_pseries_weighted_sum.csv-spec
  16. 105 0
      x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumDoubleEvaluator.java
  17. 9 0
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java
  18. 3 0
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java
  19. 1 0
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/AbstractMultivalueFunction.java
  20. 174 0
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSum.java
  21. 10 0
      x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/package-info.java
  22. 6 4
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java
  23. 39 0
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumSerializationTests.java
  24. 71 0
      x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumTests.java

+ 6 - 0
docs/changelog/109017.yaml

@@ -0,0 +1,6 @@
+pr: 109017
+summary: "ESQL: Add `MV_PSERIES_WEIGHTED_SUM` for score calculations used by security\
+  \ solution"
+area: ES|QL
+type: "feature"
+issues: [ ]

+ 5 - 0
docs/reference/esql/functions/description/mv_pseries_weighted_sum.asciidoc

@@ -0,0 +1,5 @@
+// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.
+
+*Description*
+
+Converts a multivalued expression into a single-valued column by multiplying every element on the input list by its corresponding term in P-Series and computing the sum.

+ 13 - 0
docs/reference/esql/functions/examples/mv_pseries_weighted_sum.asciidoc

@@ -0,0 +1,13 @@
+// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.
+
+*Example*
+
+[source.merge.styled,esql]
+----
+include::{esql-specs}/mv_pseries_weighted_sum.csv-spec[tag=example]
+----
+[%header.monospaced.styled,format=dsv,separator=|]
+|===
+include::{esql-specs}/mv_pseries_weighted_sum.csv-spec[tag=example-result]
+|===
+

+ 29 - 0
docs/reference/esql/functions/kibana/definition/mv_pseries_weighted_sum.json

@@ -0,0 +1,29 @@
+{
+  "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.",
+  "type" : "eval",
+  "name" : "mv_pseries_weighted_sum",
+  "description" : "Converts a multivalued expression into a single-valued column by multiplying every element on the input list by its corresponding term in P-Series and computing the sum.",
+  "signatures" : [
+    {
+      "params" : [
+        {
+          "name" : "number",
+          "type" : "double",
+          "optional" : false,
+          "description" : "Multivalue expression."
+        },
+        {
+          "name" : "p",
+          "type" : "double",
+          "optional" : false,
+          "description" : "It is a constant number that represents the 'p' parameter in the P-Series. It impacts every element's contribution to the weighted sum."
+        }
+      ],
+      "variadic" : false,
+      "returnType" : "double"
+    }
+  ],
+  "examples" : [
+    "ROW a = [70.0, 45.0, 21.0, 21.0, 21.0]\n| EVAL sum = MV_PSERIES_WEIGHTED_SUM(a, 1.5)\n| KEEP sum"
+  ]
+}

+ 12 - 0
docs/reference/esql/functions/kibana/docs/mv_pseries_weighted_sum.md

@@ -0,0 +1,12 @@
+<!--
+This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.
+-->
+
+### MV_PSERIES_WEIGHTED_SUM
+Converts a multivalued expression into a single-valued column by multiplying every element on the input list by its corresponding term in P-Series and computing the sum.
+
+```
+ROW a = [70.0, 45.0, 21.0, 21.0, 21.0]
+| EVAL sum = MV_PSERIES_WEIGHTED_SUM(a, 1.5)
+| KEEP sum
+```

+ 15 - 0
docs/reference/esql/functions/layout/mv_pseries_weighted_sum.asciidoc

@@ -0,0 +1,15 @@
+// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.
+
+[discrete]
+[[esql-mv_pseries_weighted_sum]]
+=== `MV_PSERIES_WEIGHTED_SUM`
+
+*Syntax*
+
+[.text-center]
+image::esql/functions/signature/mv_pseries_weighted_sum.svg[Embedded,opts=inline]
+
+include::../parameters/mv_pseries_weighted_sum.asciidoc[]
+include::../description/mv_pseries_weighted_sum.asciidoc[]
+include::../types/mv_pseries_weighted_sum.asciidoc[]
+include::../examples/mv_pseries_weighted_sum.asciidoc[]

+ 2 - 0
docs/reference/esql/functions/mv-functions.asciidoc

@@ -18,6 +18,7 @@
 * <<esql-mv_max>>
 * <<esql-mv_median>>
 * <<esql-mv_min>>
+* <<esql-mv_pseries_weighted_sum>>
 * <<esql-mv_sort>>
 * <<esql-mv_slice>>
 * <<esql-mv_sum>>
@@ -34,6 +35,7 @@ include::layout/mv_last.asciidoc[]
 include::layout/mv_max.asciidoc[]
 include::layout/mv_median.asciidoc[]
 include::layout/mv_min.asciidoc[]
+include::layout/mv_pseries_weighted_sum.asciidoc[]
 include::layout/mv_slice.asciidoc[]
 include::layout/mv_sort.asciidoc[]
 include::layout/mv_sum.asciidoc[]

+ 9 - 0
docs/reference/esql/functions/parameters/mv_pseries_weighted_sum.asciidoc

@@ -0,0 +1,9 @@
+// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.
+
+*Parameters*
+
+`number`::
+Multivalue expression.
+
+`p`::
+It is a constant number that represents the 'p' parameter in the P-Series. It impacts every element's contribution to the weighted sum.

+ 1 - 0
docs/reference/esql/functions/signature/mv_pseries_weighted_sum.svg

@@ -0,0 +1 @@
+<svg version="1.1" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns="http://www.w3.org/2000/svg" width="576" height="46" viewbox="0 0 576 46"><defs><style type="text/css">#guide .c{fill:none;stroke:#222222;}#guide .k{fill:#000000;font-family:Roboto Mono,Sans-serif;font-size:20px;}#guide .s{fill:#e4f4ff;stroke:#222222;}#guide .syn{fill:#8D8D8D;font-family:Roboto Mono,Sans-serif;font-size:20px;}</style></defs><path class="c" d="M0 31h5m296 0h10m32 0h10m92 0h10m32 0h10m32 0h10m32 0h5"/><rect class="s" x="5" y="5" width="296" height="36"/><text class="k" x="15" y="31">MV_PSERIES_WEIGHTED_SUM</text><rect class="s" x="311" y="5" width="32" height="36" rx="7"/><text class="syn" x="321" y="31">(</text><rect class="s" x="353" y="5" width="92" height="36" rx="7"/><text class="k" x="363" y="31">number</text><rect class="s" x="455" y="5" width="32" height="36" rx="7"/><text class="syn" x="465" y="31">,</text><rect class="s" x="497" y="5" width="32" height="36" rx="7"/><text class="k" x="507" y="31">p</text><rect class="s" x="539" y="5" width="32" height="36" rx="7"/><text class="syn" x="549" y="31">)</text></svg>

+ 9 - 0
docs/reference/esql/functions/types/mv_pseries_weighted_sum.asciidoc

@@ -0,0 +1,9 @@
+// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.
+
+*Supported types*
+
+[%header.monospaced.styled,format=dsv,separator=|]
+|===
+number | p | result
+double | double | double
+|===

+ 2 - 0
x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java

@@ -55,6 +55,7 @@ public class CsvTestsDataLoader {
     private static final TestsDataset HOSTS = new TestsDataset("hosts", "mapping-hosts.json", "hosts.csv");
     private static final TestsDataset APPS = new TestsDataset("apps", "mapping-apps.json", "apps.csv");
     private static final TestsDataset LANGUAGES = new TestsDataset("languages", "mapping-languages.json", "languages.csv");
+    private static final TestsDataset ALERTS = new TestsDataset("alerts", "mapping-alerts.json", "alerts.csv");
     private static final TestsDataset UL_LOGS = new TestsDataset("ul_logs", "mapping-ul_logs.json", "ul_logs.csv");
     private static final TestsDataset SAMPLE_DATA = new TestsDataset("sample_data", "mapping-sample_data.json", "sample_data.csv");
     private static final TestsDataset SAMPLE_DATA_STR = new TestsDataset(
@@ -106,6 +107,7 @@ public class CsvTestsDataLoader {
         Map.entry(LANGUAGES.indexName, LANGUAGES),
         Map.entry(UL_LOGS.indexName, UL_LOGS),
         Map.entry(SAMPLE_DATA.indexName, SAMPLE_DATA),
+        Map.entry(ALERTS.indexName, ALERTS),
         Map.entry(SAMPLE_DATA_STR.indexName, SAMPLE_DATA_STR),
         Map.entry(SAMPLE_DATA_TS_LONG.indexName, SAMPLE_DATA_TS_LONG),
         Map.entry(CLIENT_IPS.indexName, CLIENT_IPS),

+ 11 - 0
x-pack/plugin/esql/qa/testFixtures/src/main/resources/alerts.csv

@@ -0,0 +1,11 @@
+host.name:keyword,kibana.alert.risk_score:double
+test-host-1,21.0
+test-host-2,17.0
+test-host-2,23.0
+test-host-1,45.0
+test-host-2,12.0
+test-host-2,16.0
+test-host-1,21.0
+test-host-1,70.0
+test-host-1,21.0
+test-host-2,5.0

+ 10 - 0
x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-alerts.json

@@ -0,0 +1,10 @@
+{
+  "properties": {
+    "host.name": {
+      "type": "keyword"
+    },
+    "kibana.alert.risk_score": {
+      "type": "double"
+    }
+  }
+}

+ 5 - 1
x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec

@@ -53,6 +53,7 @@ double e()
 "boolean|date|double|integer|ip|keyword|long|text|unsigned_long|version mv_max(field:boolean|date|double|integer|ip|keyword|long|text|unsigned_long|version)"
 "double|integer|long|unsigned_long mv_median(number:double|integer|long|unsigned_long)"
 "boolean|date|double|integer|ip|keyword|long|text|unsigned_long|version mv_min(field:boolean|date|double|integer|ip|keyword|long|text|unsigned_long|version)"
+"double mv_pseries_weighted_sum(number:double, p:double)"
 "boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version mv_slice(field:boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version, start:integer, ?end:integer)"
 "boolean|date|double|integer|ip|keyword|long|text|version mv_sort(field:boolean|date|double|integer|ip|keyword|long|text|version, ?order:keyword)"
 "double|integer|long|unsigned_long mv_sum(number:double|integer|long|unsigned_long)"
@@ -174,6 +175,7 @@ mv_last       |field                               |"boolean|cartesian_point|car
 mv_max        |field                               |"boolean|date|double|integer|ip|keyword|long|text|unsigned_long|version"                                                          |Multivalue expression.
 mv_median     |number                              |"double|integer|long|unsigned_long"                                                                                               |Multivalue expression.
 mv_min        |field                               |"boolean|date|double|integer|ip|keyword|long|text|unsigned_long|version"                                                          |Multivalue expression.
+mv_pseries_wei|[number, p]                         |[double, double]                                                                                                                  |[Multivalue expression., It is a constant number that represents the 'p' parameter in the P-Series. It impacts every element's contribution to the weighted sum.]
 mv_slice      |[field, start, end]                 |["boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version", integer, integer]|[Multivalue expression. If `null`\, the function returns `null`., Start position. If `null`\, the function returns `null`. The start argument can be negative. An index of -1 is used to specify the last value in the list., End position(included). Optional; if omitted\, the position at `start` is returned. The end argument can be negative. An index of -1 is used to specify the last value in the list.]
 mv_sort       |[field, order]                      |["boolean|date|double|integer|ip|keyword|long|text|version", keyword]                                                             |[Multivalue expression. If `null`\, the function returns `null`., Sort order. The valid options are ASC and DESC\, the default is ASC.]
 mv_sum        |number                              |"double|integer|long|unsigned_long"                                                                                               |Multivalue expression.
@@ -296,6 +298,7 @@ mv_last       |Converts a multivalue expression into a single valued column cont
 mv_max        |Converts a multivalued expression into a single valued column containing the maximum value.
 mv_median     |Converts a multivalued field into a single valued field containing the median value.
 mv_min        |Converts a multivalued expression into a single valued column containing the minimum value.
+mv_pseries_wei|Converts a multivalued expression into a single-valued column by multiplying every element on the input list by its corresponding term in P-Series and computing the sum.
 mv_slice      |Returns a subset of the multivalued field using the start and end index values.
 mv_sort       |Sorts a multivalued field in lexicographical order.
 mv_sum        |Converts a multivalued field into a single valued field containing the sum of all of the values.
@@ -419,6 +422,7 @@ mv_last       |"boolean|cartesian_point|cartesian_shape|date|double|geo_point|ge
 mv_max        |"boolean|date|double|integer|ip|keyword|long|text|unsigned_long|version"                                                    |false                       |false           |false
 mv_median     |"double|integer|long|unsigned_long"                                                                                         |false                       |false           |false
 mv_min        |"boolean|date|double|integer|ip|keyword|long|text|unsigned_long|version"                                                    |false                       |false           |false
+mv_pseries_wei|"double"                                                                                                                    |[false, false]              |false           |false 
 mv_slice      |"boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version"              |[false, false, true]        |false           |false
 mv_sort       |"boolean|date|double|integer|ip|keyword|long|text|version"                                                                  |[false, true]               |false           |false
 mv_sum        |"double|integer|long|unsigned_long"                                                                                         |false                       |false           |false
@@ -497,5 +501,5 @@ countFunctions#[skip:-8.15.99]
 meta functions |  stats  a = count(*), b = count(*), c = count(*) |  mv_expand c;
 
 a:long | b:long | c:long
-113    | 113    | 113
+114    | 114    | 114
 ;

+ 89 - 0
x-pack/plugin/esql/qa/testFixtures/src/main/resources/mv_pseries_weighted_sum.csv-spec

@@ -0,0 +1,89 @@
+default
+required_capability: mv_pseries_weighted_sum
+
+// tag::example[]
+ROW a = [70.0, 45.0, 21.0, 21.0, 21.0]
+| EVAL sum = MV_PSERIES_WEIGHTED_SUM(a, 1.5)
+| KEEP sum
+// end::example[]
+;
+
+// tag::example-result[]
+sum:double
+94.45465156212452
+// end::example-result[]
+;
+
+oneElement
+required_capability: mv_pseries_weighted_sum
+
+ROW data = [3.0]
+| EVAL score = MV_PSERIES_WEIGHTED_SUM(data, 9999.9)
+| KEEP score;
+
+score:double
+3.0
+;
+
+zeroP
+required_capability: mv_pseries_weighted_sum
+
+ROW data = [3.0, 10.0, 15.0]
+| EVAL score = MV_PSERIES_WEIGHTED_SUM(data, 0.0)
+| KEEP score;
+
+score:double
+28.0
+;
+
+negativeP
+required_capability: mv_pseries_weighted_sum
+
+ROW data = [10.0, 5.0, 3.0]
+| EVAL score = MV_PSERIES_WEIGHTED_SUM(data, -2.0)
+| KEEP score;
+
+score:double
+57.0
+;
+
+composed
+required_capability: mv_pseries_weighted_sum
+
+ROW data = [21.0, 45.0, 21.0, 70.0, 21.0]
+| EVAL sorted = MV_SORT(data, "desc")
+| EVAL score = MV_PSERIES_WEIGHTED_SUM(sorted, 1.5)
+| EVAL normalized_score = ROUND(100 * score / 261.2, 2)
+| KEEP normalized_score, score;
+
+normalized_score:double|score:double
+36.16                  |94.45465156212452
+;
+
+multivalueAggregation
+required_capability: mv_pseries_weighted_sum
+
+FROM alerts
+| WHERE host.name is not null
+| SORT host.name, kibana.alert.risk_score
+| STATS score=MV_PSERIES_WEIGHTED_SUM(
+    TOP(kibana.alert.risk_score, 10000, "desc"), 1.5
+) BY host.name
+| EVAL normalized_score = ROUND(100 * score / 261.2, 2)
+| KEEP host.name, normalized_score, score;
+
+host.name:keyword|normalized_score:double|score:double
+test-host-1      |36.16                  |94.45465156212452
+test-host-2      |13.03                  |34.036822671263614
+;
+
+asArgument
+required_capability: mv_pseries_weighted_sum
+
+ROW data = [70.0, 45.0, 21.0, 21.0, 21.0]
+| EVAL score = ROUND(MV_PSERIES_WEIGHTED_SUM(data, 1.5), 1)
+| KEEP score;
+
+score:double
+94.5
+;

+ 105 - 0
x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumDoubleEvaluator.java

@@ -0,0 +1,105 @@
+// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+// or more contributor license agreements. Licensed under the Elastic License
+// 2.0; you may not use this file except in compliance with the Elastic License
+// 2.0.
+package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue;
+
+import java.lang.Override;
+import java.lang.String;
+import java.util.function.Function;
+import org.elasticsearch.compute.data.Block;
+import org.elasticsearch.compute.data.DoubleBlock;
+import org.elasticsearch.compute.data.Page;
+import org.elasticsearch.compute.operator.DriverContext;
+import org.elasticsearch.compute.operator.EvalOperator;
+import org.elasticsearch.core.Releasables;
+import org.elasticsearch.search.aggregations.metrics.CompensatedSum;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.expression.function.Warnings;
+
+/**
+ * {@link EvalOperator.ExpressionEvaluator} implementation for {@link MvPSeriesWeightedSum}.
+ * This class is generated. Do not edit it.
+ */
+public final class MvPSeriesWeightedSumDoubleEvaluator implements EvalOperator.ExpressionEvaluator {
+  private final Warnings warnings;
+
+  private final EvalOperator.ExpressionEvaluator block;
+
+  private final CompensatedSum sum;
+
+  private final double p;
+
+  private final DriverContext driverContext;
+
+  public MvPSeriesWeightedSumDoubleEvaluator(Source source, EvalOperator.ExpressionEvaluator block,
+      CompensatedSum sum, double p, DriverContext driverContext) {
+    this.block = block;
+    this.sum = sum;
+    this.p = p;
+    this.driverContext = driverContext;
+    this.warnings = Warnings.createWarnings(driverContext.warningsMode(), source);
+  }
+
+  @Override
+  public Block eval(Page page) {
+    try (DoubleBlock blockBlock = (DoubleBlock) block.eval(page)) {
+      return eval(page.getPositionCount(), blockBlock);
+    }
+  }
+
+  public DoubleBlock eval(int positionCount, DoubleBlock blockBlock) {
+    try(DoubleBlock.Builder result = driverContext.blockFactory().newDoubleBlockBuilder(positionCount)) {
+      position: for (int p = 0; p < positionCount; p++) {
+        boolean allBlocksAreNulls = true;
+        if (!blockBlock.isNull(p)) {
+          allBlocksAreNulls = false;
+        }
+        if (allBlocksAreNulls) {
+          result.appendNull();
+          continue position;
+        }
+        MvPSeriesWeightedSum.process(result, p, blockBlock, this.sum, this.p);
+      }
+      return result.build();
+    }
+  }
+
+  @Override
+  public String toString() {
+    return "MvPSeriesWeightedSumDoubleEvaluator[" + "block=" + block + ", p=" + p + "]";
+  }
+
+  @Override
+  public void close() {
+    Releasables.closeExpectNoException(block);
+  }
+
+  static class Factory implements EvalOperator.ExpressionEvaluator.Factory {
+    private final Source source;
+
+    private final EvalOperator.ExpressionEvaluator.Factory block;
+
+    private final Function<DriverContext, CompensatedSum> sum;
+
+    private final double p;
+
+    public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory block,
+        Function<DriverContext, CompensatedSum> sum, double p) {
+      this.source = source;
+      this.block = block;
+      this.sum = sum;
+      this.p = p;
+    }
+
+    @Override
+    public MvPSeriesWeightedSumDoubleEvaluator get(DriverContext context) {
+      return new MvPSeriesWeightedSumDoubleEvaluator(source, block.get(context), sum.apply(context), p, context);
+    }
+
+    @Override
+    public String toString() {
+      return "MvPSeriesWeightedSumDoubleEvaluator[" + "block=" + block + ", p=" + p + "]";
+    }
+  }
+}

+ 9 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java

@@ -183,6 +183,15 @@ public class EsqlCapabilities {
          */
         FIXED_PUSHDOWN_PAST_PROJECT,
 
+        /**
+         * Adds the {@code MV_PSERIES_WEIGHTED_SUM} function for converting sorted lists of numbers into
+         * a bounded score. This is a generalization of the
+         * <a href="https://en.wikipedia.org/wiki/Riemann_zeta_function">riemann zeta function</a> but we
+         * don't name it that because we don't support complex numbers and don't want to make folks think
+         * of mystical number theory things. This is just a weighted sum that is adjacent to magic.
+         */
+        MV_PSERIES_WEIGHTED_SUM,
+
         /**
          * Support for match operator
          */

+ 3 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java

@@ -95,6 +95,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvLast
 import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMax;
 import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMedian;
 import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMin;
+import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvPSeriesWeightedSum;
 import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvSlice;
 import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvSort;
 import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvSum;
@@ -360,11 +361,13 @@ public class EsqlFunctionRegistry {
                 def(MvMax.class, MvMax::new, "mv_max"),
                 def(MvMedian.class, MvMedian::new, "mv_median"),
                 def(MvMin.class, MvMin::new, "mv_min"),
+                def(MvPSeriesWeightedSum.class, MvPSeriesWeightedSum::new, "mv_pseries_weighted_sum"),
                 def(MvSort.class, MvSort::new, "mv_sort"),
                 def(MvSlice.class, MvSlice::new, "mv_slice"),
                 def(MvZip.class, MvZip::new, "mv_zip"),
                 def(MvSum.class, MvSum::new, "mv_sum"),
                 def(Split.class, Split::new, "split") } };
+
     }
 
     private static FunctionDefinition[][] snapshotFunctions() {

+ 1 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/AbstractMultivalueFunction.java

@@ -44,6 +44,7 @@ public abstract class AbstractMultivalueFunction extends UnaryScalarFunction {
             MvMax.ENTRY,
             MvMedian.ENTRY,
             MvMin.ENTRY,
+            MvPSeriesWeightedSum.ENTRY,
             MvSlice.ENTRY,
             MvSort.ENTRY,
             MvSum.ENTRY,

+ 174 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSum.java

@@ -0,0 +1,174 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue;
+
+import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.compute.ann.Evaluator;
+import org.elasticsearch.compute.ann.Fixed;
+import org.elasticsearch.compute.data.DoubleBlock;
+import org.elasticsearch.compute.operator.EvalOperator;
+import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator;
+import org.elasticsearch.search.aggregations.metrics.CompensatedSum;
+import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException;
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.expression.TypeResolutions;
+import org.elasticsearch.xpack.esql.core.tree.NodeInfo;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.core.type.DataType;
+import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper;
+import org.elasticsearch.xpack.esql.expression.function.Example;
+import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
+import org.elasticsearch.xpack.esql.expression.function.Param;
+import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction;
+import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput;
+import org.elasticsearch.xpack.esql.planner.PlannerUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+import java.util.function.Function;
+
+import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST;
+import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND;
+import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNullAndFoldable;
+import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE;
+
+/**
+ * Reduce a multivalued field to a single valued field containing the weighted sum of all element applying the P series function.
+ */
+public class MvPSeriesWeightedSum extends EsqlScalarFunction implements EvaluatorMapper {
+    public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(
+        Expression.class,
+        "MvPSeriesWeightedSum",
+        MvPSeriesWeightedSum::new
+    );
+
+    private final Expression field, p;
+
+    @FunctionInfo(
+        returnType = { "double" },
+
+        description = "Converts a multivalued expression into a single-valued column by multiplying every "
+            + "element on the input list by its corresponding term in P-Series and computing the sum.",
+        examples = @Example(file = "mv_pseries_weighted_sum", tag = "example")
+    )
+    public MvPSeriesWeightedSum(
+        Source source,
+        @Param(name = "number", type = { "double" }, description = "Multivalue expression.") Expression field,
+        @Param(
+            name = "p",
+            type = { "double" },
+            description = "It is a constant number that represents the 'p' parameter in the P-Series. "
+                + "It impacts every element's contribution to the weighted sum."
+        ) Expression p
+    ) {
+        super(source, Arrays.asList(field, p));
+        this.field = field;
+        this.p = p;
+    }
+
+    private MvPSeriesWeightedSum(StreamInput in) throws IOException {
+        this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class), in.readNamedWriteable(Expression.class));
+    }
+
+    @Override
+    protected TypeResolution resolveType() {
+        if (childrenResolved() == false) {
+            return new TypeResolution("Unresolved children");
+        }
+
+        TypeResolution resolution = TypeResolutions.isType(field, dt -> dt == DOUBLE, sourceText(), FIRST, "double");
+        if (resolution.unresolved()) {
+            return resolution;
+        }
+
+        resolution = TypeResolutions.isType(p, dt -> dt == DOUBLE, sourceText(), SECOND, "double")
+            .and(isNotNullAndFoldable(p, sourceText(), SECOND));
+
+        if (resolution.unresolved()) {
+            return resolution;
+        }
+
+        return resolution;
+    }
+
+    @Override
+    public boolean foldable() {
+        return field.foldable() && p.foldable();
+    }
+
+    @Override
+    public EvalOperator.ExpressionEvaluator.Factory toEvaluator(Function<Expression, ExpressionEvaluator.Factory> toEvaluator) {
+        return switch (PlannerUtils.toElementType(field.dataType())) {
+            case DOUBLE -> new MvPSeriesWeightedSumDoubleEvaluator.Factory(
+                source(),
+                toEvaluator.apply(field),
+                ctx -> new CompensatedSum(),
+                (Double) p.fold()
+            );
+            case NULL -> EvalOperator.CONSTANT_NULL_FACTORY;
+            default -> throw EsqlIllegalArgumentException.illegalDataType(field.dataType());
+        };
+    }
+
+    @Override
+    public Expression replaceChildren(List<Expression> newChildren) {
+        return new MvPSeriesWeightedSum(source(), newChildren.get(0), newChildren.get(1));
+    }
+
+    @Override
+    protected NodeInfo<? extends Expression> info() {
+        return NodeInfo.create(this, MvPSeriesWeightedSum::new, field, p);
+    }
+
+    @Override
+    public DataType dataType() {
+        return field.dataType();
+    }
+
+    @Evaluator(extraName = "Double")
+    static void process(
+        DoubleBlock.Builder builder,
+        int position,
+        DoubleBlock block,
+        @Fixed(includeInToString = false, build = true) CompensatedSum sum,
+        @Fixed double p
+    ) {
+        sum.reset(0, 0);
+        int start = block.getFirstValueIndex(position);
+        int end = block.getValueCount(position) + start;
+
+        for (int i = start; i < end; i++) {
+            double current_score = block.getDouble(i) / Math.pow(i - start + 1, p);
+            sum.add(current_score);
+        }
+        builder.appendDouble(sum.value());
+    }
+
+    @Override
+    public String getWriteableName() {
+        return ENTRY.name;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        source().writeTo(out);
+        out.writeNamedWriteable(field);
+        out.writeNamedWriteable(p);
+    }
+
+    Expression field() {
+        return field;
+    }
+
+    Expression p() {
+        return p;
+    }
+}

+ 10 - 0
x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/package-info.java

@@ -184,6 +184,16 @@
  *         looks ok.
  *     </li>
  *     <li>
+ *         Let's finish up the code by making the tests backwards compatible. Since this is a new
+ *         feature we just have to convince the tests not to run in a cluster that includes older
+ *         versions of Elasticsearch. We do that with a {@link org.elasticsearch.rest.RestHandler#supportedCapabilities capability}
+ *         on the REST handler. ESQL has a <strong>ton</strong> of capabilities so we list them
+ *         all in {@link org.elasticsearch.xpack.esql.action.EsqlCapabilities}. Add a new one
+ *         for your function. Now add something like {@code required_capability: my_function}
+ *         to all of your csv-spec tests. Run those csv-spec tests as integration tests to double
+ *         check that they run on the main branch.
+ *     </li>
+ *     <li>
  *         Open the PR. The subject and description of the PR are important because those'll turn
  *         into the commit message we see in the commit history. Good PR descriptions make me very
  *         happy. But functions don't need an essay.

+ 6 - 4
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java

@@ -1305,11 +1305,11 @@ public record TestCaseSupplier(String name, List<DataType> types, Supplier<TestC
         private final Class<? extends Throwable> foldingExceptionClass;
         private final String foldingExceptionMessage;
 
-        public TestCase(List<TypedData> data, String evaluatorToString, DataType expectedType, Matcher<Object> matcher) {
+        public TestCase(List<TypedData> data, String evaluatorToString, DataType expectedType, Matcher<?> matcher) {
             this(data, equalTo(evaluatorToString), expectedType, matcher);
         }
 
-        public TestCase(List<TypedData> data, Matcher<String> evaluatorToString, DataType expectedType, Matcher<Object> matcher) {
+        public TestCase(List<TypedData> data, Matcher<String> evaluatorToString, DataType expectedType, Matcher<?> matcher) {
             this(data, evaluatorToString, expectedType, matcher, null, null, null, null);
         }
 
@@ -1321,7 +1321,7 @@ public record TestCaseSupplier(String name, List<DataType> types, Supplier<TestC
             List<TypedData> data,
             Matcher<String> evaluatorToString,
             DataType expectedType,
-            Matcher<Object> matcher,
+            Matcher<?> matcher,
             String[] expectedWarnings,
             String expectedTypeError,
             Class<? extends Throwable> foldingExceptionClass,
@@ -1331,7 +1331,9 @@ public record TestCaseSupplier(String name, List<DataType> types, Supplier<TestC
             this.data = data;
             this.evaluatorToString = evaluatorToString;
             this.expectedType = expectedType;
-            this.matcher = matcher;
+            @SuppressWarnings("unchecked")
+            Matcher<Object> downcast = (Matcher<Object>) matcher;
+            this.matcher = downcast;
             this.expectedWarnings = expectedWarnings;
             this.expectedTypeError = expectedTypeError;
             this.canBuildEvaluator = data.stream().allMatch(d -> d.forceLiteral || DataType.isRepresentable(d.type));

+ 39 - 0
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumSerializationTests.java

@@ -0,0 +1,39 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue;
+
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests;
+
+import java.io.IOException;
+
+public class MvPSeriesWeightedSumSerializationTests extends AbstractExpressionSerializationTests<MvPSeriesWeightedSum> {
+    @Override
+    protected MvPSeriesWeightedSum createTestInstance() {
+        Source source = randomSource();
+        Expression field = randomChild();
+        Expression p = randomChild();
+
+        return new MvPSeriesWeightedSum(source, field, p);
+    }
+
+    @Override
+    protected MvPSeriesWeightedSum mutateInstance(MvPSeriesWeightedSum instance) throws IOException {
+        Source source = instance.source();
+        Expression field = instance.field();
+        Expression p = instance.p();
+
+        switch (between(0, 1)) {
+            case 0 -> field = randomValueOtherThan(field, AbstractExpressionSerializationTests::randomChild);
+            case 1 -> p = randomValueOtherThan(p, AbstractExpressionSerializationTests::randomChild);
+
+        }
+        return new MvPSeriesWeightedSum(source, field, p);
+    }
+}

+ 71 - 0
x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumTests.java

@@ -0,0 +1,71 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue;
+
+import com.carrotsearch.randomizedtesting.annotations.Name;
+import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
+
+import org.elasticsearch.xpack.esql.core.expression.Expression;
+import org.elasticsearch.xpack.esql.core.tree.Source;
+import org.elasticsearch.xpack.esql.core.type.DataType;
+import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase;
+import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.function.Supplier;
+
+import static org.hamcrest.Matchers.closeTo;
+
+public class MvPSeriesWeightedSumTests extends AbstractScalarFunctionTestCase {
+    public MvPSeriesWeightedSumTests(@Name("TestCase") Supplier<TestCaseSupplier.TestCase> testCaseSupplier) {
+        this.testCase = testCaseSupplier.get();
+    }
+
+    @ParametersFactory
+    public static Iterable<Object[]> parameters() {
+        List<TestCaseSupplier> cases = new ArrayList<>();
+
+        doubles(cases);
+
+        // TODO use parameterSuppliersFromTypedDataWithDefaultChecks instead of parameterSuppliersFromTypedData and fix errors
+        return parameterSuppliersFromTypedData(cases);
+    }
+
+    @Override
+    protected Expression build(Source source, List<Expression> args) {
+        return new MvPSeriesWeightedSum(source, args.get(0), args.get(1));
+    }
+
+    private static void doubles(List<TestCaseSupplier> cases) {
+
+        cases.add(new TestCaseSupplier(List.of(DataType.DOUBLE, DataType.DOUBLE), () -> {
+            List<Double> field = randomList(1, 10, () -> randomDouble());
+            double p = randomDoubleBetween(-100.0, 100.0, true);
+
+            return new TestCaseSupplier.TestCase(
+                List.of(
+                    new TestCaseSupplier.TypedData(field, DataType.DOUBLE, "field"),
+                    new TestCaseSupplier.TypedData(p, DataType.DOUBLE, "p").forceLiteral()
+                ),
+                "MvPSeriesWeightedSumDoubleEvaluator[block=Attribute[channel=0], p=" + p + "]",
+                DataType.DOUBLE,
+                closeTo(calcPSeriesWeightedSum(field, p), 0.00000001)
+            );
+        }));
+    }
+
+    private static double calcPSeriesWeightedSum(List<Double> field, double p) {
+        double sum = 0;
+        for (int i = 0; i < field.size(); i++) {
+            double current = field.get(i) / Math.pow(i + 1, p);
+            sum += current;
+        }
+        return sum;
+    }
+}