1
0
Эх сурвалжийг харах

SQL: Implement scripting inside aggs (#55241)

Implement the use of scalar functions inside aggregate functions.
This allows for complex expressions inside aggregations, with or without
GROUBY as well as with or without a HAVING clause. e.g.:

```
SELECT MAX(CASE WHEN a IS NULL then -1 ELSE abs(a * 10) + 1 END) AS max, b
FROM test
GROUP BY b
HAVING MAX(CASE WHEN a IS NULL then -1 ELSE abs(a * 10) + 1 END) > 5
```

Scalar functions are still not allowed for `KURTOSIS` and `SKEWNESS` as
this is currently not implemented on the ElasticSearch side.

Fixes: #29980
Fixes: #36865
Fixes: #37271
Marios Trivyzas 5 жил өмнө
parent
commit
506d1beea7
30 өөрчлөгдсөн 904 нэмэгдсэн , 187 устгасан
  1. 85 0
      docs/reference/sql/functions/aggs.asciidoc
  2. 0 6
      docs/reference/sql/limitations.asciidoc
  3. 193 0
      x-pack/plugin/sql/qa/src/main/resources/agg.csv-spec
  4. 169 3
      x-pack/plugin/sql/qa/src/main/resources/docs/docs.csv-spec
  5. 18 0
      x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/analysis/analyzer/Verifier.java
  6. 51 40
      x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/planner/QueryTranslator.java
  7. 10 9
      x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/Agg.java
  8. 70 0
      x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/AggSource.java
  9. 2 3
      x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/Aggs.java
  10. 3 3
      x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/AvgAgg.java
  11. 3 3
      x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/CardinalityAgg.java
  12. 3 3
      x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/ExtendedStatsAgg.java
  13. 20 3
      x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/FilterExistsAgg.java
  14. 9 9
      x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/GroupByDateHistogram.java
  15. 23 17
      x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/GroupByKey.java
  16. 6 6
      x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/GroupByNumericHistogram.java
  17. 6 6
      x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/GroupByValue.java
  18. 8 2
      x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/LeafAgg.java
  19. 1 1
      x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/MatrixStatsAgg.java
  20. 3 3
      x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/MaxAgg.java
  21. 3 3
      x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/MedianAbsoluteDeviationAgg.java
  22. 3 3
      x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/MinAgg.java
  23. 3 8
      x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/PercentileRanksAgg.java
  24. 5 9
      x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/PercentilesAgg.java
  25. 3 3
      x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/StatsAgg.java
  26. 5 4
      x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/SumAgg.java
  27. 62 27
      x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/TopHitsAgg.java
  28. 11 4
      x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/analysis/analyzer/VerifierErrorMessagesTests.java
  29. 2 1
      x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/execution/search/SourceGeneratorTests.java
  30. 124 8
      x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java

+ 85 - 0
docs/reference/sql/functions/aggs.asciidoc

@@ -32,6 +32,11 @@ AVG(numeric_field) <1>
 include-tagged::{sql-specs}/docs/docs.csv-spec[aggAvg]
 --------------------------------------------------
 
+["source","sql",subs="attributes,macros"]
+--------------------------------------------------
+include-tagged::{sql-specs}/docs/docs.csv-spec[aggAvgScalars]
+--------------------------------------------------
+
 [[sql-functions-aggs-count]]
 ==== `COUNT`
 
@@ -82,6 +87,10 @@ COUNT(ALL field_name) <1>
 include-tagged::{sql-specs}/docs/docs.csv-spec[aggCountAll]
 --------------------------------------------------
 
+["source","sql",subs="attributes,macros"]
+--------------------------------------------------
+include-tagged::{sql-specs}/docs/docs.csv-spec[aggCountAllScalars]
+--------------------------------------------------
 
 [[sql-functions-aggs-count-distinct]]
 ==== `COUNT(DISTINCT)`
@@ -105,6 +114,11 @@ COUNT(DISTINCT field_name) <1>
 include-tagged::{sql-specs}/docs/docs.csv-spec[aggCountDistinct]
 --------------------------------------------------
 
+["source","sql",subs="attributes,macros"]
+--------------------------------------------------
+include-tagged::{sql-specs}/docs/docs.csv-spec[aggCountDistinctScalars]
+--------------------------------------------------
+
 [[sql-functions-aggs-first]]
 ==== `FIRST/FIRST_VALUE`
 
@@ -194,6 +208,11 @@ include-tagged::{sql-specs}/docs/docs.csv-spec[firstWithTwoArgsAndGroupBy]
 include-tagged::{sql-specs}/docs/docs.csv-spec[firstValueWithTwoArgsAndGroupBy]
 --------------------------------------------------------------------------
 
+["source","sql",subs="attributes,macros"]
+--------------------------------------------------------------------------
+include-tagged::{sql-specs}/docs/docs.csv-spec[firstValueWithTwoArgsAndGroupByScalars]
+--------------------------------------------------------------------------
+
 [NOTE]
 `FIRST` cannot be used in a HAVING clause.
 [NOTE]
@@ -289,6 +308,11 @@ include-tagged::{sql-specs}/docs/docs.csv-spec[lastWithTwoArgsAndGroupBy]
 include-tagged::{sql-specs}/docs/docs.csv-spec[lastValueWithTwoArgsAndGroupBy]
 -------------------------------------------------------------------------
 
+["source","sql",subs="attributes,macros"]
+-------------------------------------------------------------------------
+include-tagged::{sql-specs}/docs/docs.csv-spec[lastValueWithTwoArgsAndGroupByScalars]
+-------------------------------------------------------------------------
+
 [NOTE]
 `LAST` cannot be used in `HAVING` clause.
 [NOTE]
@@ -317,6 +341,11 @@ MAX(field_name) <1>
 include-tagged::{sql-specs}/docs/docs.csv-spec[aggMax]
 --------------------------------------------------
 
+["source","sql",subs="attributes,macros"]
+--------------------------------------------------
+include-tagged::{sql-specs}/docs/docs.csv-spec[aggMaxScalars]
+--------------------------------------------------
+
 [NOTE]
 `MAX` on a field of type <<text, `text`>> or <<keyword, `keyword`>> is translated into
 <<sql-functions-aggs-last>> and therefore, it cannot be used in `HAVING` clause.
@@ -369,6 +398,11 @@ SUM(field_name) <1>
 include-tagged::{sql-specs}/docs/docs.csv-spec[aggSum]
 --------------------------------------------------
 
+["source","sql",subs="attributes,macros"]
+--------------------------------------------------
+include-tagged::{sql-specs}/docs/docs.csv-spec[aggSumScalars]
+--------------------------------------------------
+
 [[sql-functions-aggs-statistics]]
 [float]
 === Statistics
@@ -397,6 +431,16 @@ https://en.wikipedia.org/wiki/Kurtosis[Quantify] the shape of the distribution o
 include-tagged::{sql-specs}/docs/docs.csv-spec[aggKurtosis]
 --------------------------------------------------
 
+[NOTE]
+====
+`KURTOSIS` cannot be used on top of scalar functions or operators but only directly on a field. So, for example,
+the following is not allowed and an error is returned:
+[source, sql]
+---------------------------------------
+ SELECT KURTOSIS(salary / 12.0), gender FROM emp GROUP BY gender
+---------------------------------------
+====
+
 [[sql-functions-aggs-mad]]
 ==== `MAD`
 
@@ -421,6 +465,11 @@ https://en.wikipedia.org/wiki/Median_absolute_deviation[Measure] the variability
 include-tagged::{sql-specs}/docs/docs.csv-spec[aggMad]
 --------------------------------------------------
 
+["source","sql",subs="attributes,macros"]
+--------------------------------------------------
+include-tagged::{sql-specs}/docs/docs.csv-spec[aggMadScalars]
+--------------------------------------------------
+
 [[sql-functions-aggs-percentile]]
 ==== `PERCENTILE`
 
@@ -449,6 +498,11 @@ of input values in the field `field_name`.
 include-tagged::{sql-specs}/docs/docs.csv-spec[aggPercentile]
 --------------------------------------------------
 
+["source","sql",subs="attributes,macros"]
+--------------------------------------------------
+include-tagged::{sql-specs}/docs/docs.csv-spec[aggPercentileScalars]
+--------------------------------------------------
+
 [[sql-functions-aggs-percentile-rank]]
 ==== `PERCENTILE_RANK`
 
@@ -477,6 +531,11 @@ of input values in the field `field_name`.
 include-tagged::{sql-specs}/docs/docs.csv-spec[aggPercentileRank]
 --------------------------------------------------
 
+["source","sql",subs="attributes,macros"]
+--------------------------------------------------
+include-tagged::{sql-specs}/docs/docs.csv-spec[aggPercentileRankScalars]
+--------------------------------------------------
+
 [[sql-functions-aggs-skewness]]
 ==== `SKEWNESS`
 
@@ -501,6 +560,16 @@ https://en.wikipedia.org/wiki/Skewness[Quantify] the asymmetric distribution of
 include-tagged::{sql-specs}/docs/docs.csv-spec[aggSkewness]
 --------------------------------------------------
 
+[NOTE]
+====
+`SKEWNESS` cannot be used on top of scalar functions but only directly on a field. So, for example, the following is
+not allowed and an error is returned:
+[source, sql]
+---------------------------------------
+ SELECT SKEWNESS(ROUND(salary / 12.0, 2), gender FROM emp GROUP BY gender
+---------------------------------------
+====
+
 [[sql-functions-aggs-stddev-pop]]
 ==== `STDDEV_POP`
 
@@ -525,6 +594,11 @@ Returns the https://en.wikipedia.org/wiki/Standard_deviations[population standar
 include-tagged::{sql-specs}/docs/docs.csv-spec[aggStddevPop]
 --------------------------------------------------
 
+["source","sql",subs="attributes,macros"]
+--------------------------------------------------
+include-tagged::{sql-specs}/docs/docs.csv-spec[aggStddevPopScalars]
+--------------------------------------------------
+
 [[sql-functions-aggs-sum-squares]]
 ==== `SUM_OF_SQUARES`
 
@@ -549,6 +623,11 @@ Returns the https://en.wikipedia.org/wiki/Total_sum_of_squares[sum of squares] o
 include-tagged::{sql-specs}/docs/docs.csv-spec[aggSumOfSquares]
 --------------------------------------------------
 
+["source","sql",subs="attributes,macros"]
+--------------------------------------------------
+include-tagged::{sql-specs}/docs/docs.csv-spec[aggSumOfSquaresScalars]
+--------------------------------------------------
+
 [[sql-functions-aggs-var-pop]]
 ==== `VAR_POP`
 
@@ -572,3 +651,9 @@ Returns the https://en.wikipedia.org/wiki/Variance[population variance] of input
 --------------------------------------------------
 include-tagged::{sql-specs}/docs/docs.csv-spec[aggVarPop]
 --------------------------------------------------
+
+
+["source","sql",subs="attributes,macros"]
+--------------------------------------------------
+include-tagged::{sql-specs}/docs/docs.csv-spec[aggVarPopScalars]
+--------------------------------------------------

+ 0 - 6
docs/reference/sql/limitations.asciidoc

@@ -129,12 +129,6 @@ SELECT age, ROUND(AVG(salary)) AS avg FROM test GROUP BY age ORDER BY avg;
 SELECT age, MAX(salary) - MIN(salary) AS diff FROM test GROUP BY age ORDER BY diff;
 --------------------------------------------------
 
-[float]
-=== Using aggregation functions on top of scalar functions
-
-Aggregation functions like <<sql-functions-aggs-min,`MIN`>>, <<sql-functions-aggs-max,`MAX`>>, etc. can only be used
-directly on fields, and so queries like `SELECT MAX(abs(age)) FROM test` are not possible.
-
 [float]
 === Using a sub-select
 

+ 193 - 0
x-pack/plugin/sql/qa/src/main/resources/agg.csv-spec

@@ -910,3 +910,196 @@ SELECT gender, MAD(salary) AS mad FROM test_emp GROUP BY gender HAVING mad > 100
 null           |10789.0        
 F              |12719.0         
 ;
+
+
+// aggregates with scalars
+aggregateFunctionsWithScalars
+SELECT MAX(CASE WHEN (salary - 10) > 70000 THEN (salary + 12345) * 1.2 ELSE (salary - 12345) * 2.7 END) AS "max",
+MIN(CASE WHEN (salary - 20) > 50000 THEN (salary * 1.2) - 1234 ELSE (salary - 20) * 0.93 END) AS "min",
+AVG(cos(salary * 1.2) + 100 * (salary / 5)) AS "avg",
+SUM(-salary / 0.765 + sin((salary + 12345) / 12)) AS "sum",
+MAD(abs(salary / -0.813) / 2 + (12345 * (salary % 10))) AS "mad"
+FROM test_emp;
+
+       max        |      min      |       avg       |       sum        |       mad
+------------------+---------------+-----------------+------------------+-----------------
+155409.30000000002|23532.72       |964937.9295477575|-6307004.517507723|30811.76199261993
+;
+
+countWithScalars
+schema::cnt1:l|cnt2:l
+SELECT count(DISTINCT CASE WHEN (languages - 1) > 3 THEN (languages + 3) * 1.2 ELSE (languages - 1) * 2.7 END) AS "cnt1",
+count(CASE WHEN (languages - 2) > 2 THEN (languages + 5) * 1.2 ELSE ((languages / 0.87) - 11) * 2.7 END) AS "cnt2"
+FROM test_emp;
+
+   cnt1   |  cnt2
+----------+-------
+5         | 90
+;
+
+aggregateFunctionsWithScalarsAndGroupBy
+schema::max:d|min:d|avg:d|sum:d|mad:d|gender:s
+SELECT MAX(CASE WHEN (salary - 10) > 70000 THEN (salary + 12345) * 1.2 ELSE (salary - 12345) * 2.7 END) AS "max",
+MIN(CASE WHEN (salary - 20) > 50000 THEN (salary * 1.2) - 1234 ELSE (salary - 20) * 0.93 END) AS "min",
+AVG(cos(salary * 1.2) + 100 * (salary / 5)) AS "avg",
+SUM(-salary / 0.765 + sin((salary + 12345) / 12)) AS "sum",
+MAD(abs(salary / -0.813) / 2 + (12345 * (salary % 10))) AS "mad",
+gender
+FROM test_emp GROUP BY gender ORDER BY gender;
+
+       max        |      min      |       avg        |        sum        |       mad       |    gender
+------------------+---------------+------------------+-------------------+-----------------+---------------
+132335.1          |23532.72       |975179.5463883684 |-637388.2516376646 |33398.4963099631 |null
+155409.30000000002|24139.08       |1009778.6217005679|-2178038.0602625553|24031.90651906518|F
+151745.40000000002|24110.25       |937180.7539433916 |-3491578.2056075027|32956.9126691267 |M
+;
+
+countWithScalarsAndGroupBy
+schema::cnt1:l|cnt2:l|gender:s
+SELECT count(DISTINCT CASE WHEN (languages - 1) > 3 THEN (languages + 3) * 1.2 ELSE (languages - 1) * 2.7 END) AS "cnt1",
+count(CASE WHEN (languages - 2) > 2 THEN (languages + 5) * 1.2 ELSE ((languages / 0.87) - 11) * 2.7 END) AS "cnt2",
+gender
+FROM test_emp GROUP BY gender ORDER BY gender;
+
+     cnt1      |     cnt2      |    gender
+---------------+---------------+---------------
+4              |10             |null
+5              |30             |F
+5              |50             |M
+;
+
+aggregatesWithScalarsAndGroupByOrderByAgg
+schema::max:d|gender:s
+SELECT MAX(CASE WHEN (salary - 10) > 70000 THEN (salary + 12345) * 1.2 ELSE (salary - 12345) * 2.7 END) AS "max",
+gender
+FROM test_emp GROUP BY gender ORDER BY max DESC;
+
+       max        |    gender
+------------------+---------------
+155409.30000000002|F
+151745.40000000002|M
+132335.1          |null
+;
+
+aggregatesWithScalarsAndGroupByOrderByAggWithoutProjection
+schema::gender:s
+SELECT gender FROM test_emp GROUP BY gender ORDER BY MAX(salary % 100) DESC;
+
+    gender
+---------------
+M
+null
+F
+;
+
+topHitsWithScalars
+schema::first:s|last:s|gender:s
+SELECT FIRST(concat('aa_', substring(first_name, 3, 10)), birth_date) AS first,
+LAST(concat('bb_', substring(last_name, 4, 8)), birth_date) AS last,
+gender
+FROM test_emp GROUP BY gender ORDER By gender;
+
+     first     |     last      |    gender
+---------------+---------------+---------------
+aa_llian       |bb_kki         |null
+aa_mant        |bb_zuma        |F
+aa_mzi         |bb_ton         |M
+;
+
+aggregateFunctionsWithScalarsAndGroupByAndHaving
+schema::max:d|min:d|gender:s
+SELECT MAX(CASE WHEN (salary - 10) > 70000 THEN (salary + 12345) * 1.2 ELSE (salary - 12345) * 2.7 END) AS "max",
+MIN(CASE WHEN (salary - 20) > 50000 THEN (salary * 1.2) - 1234 ELSE (salary - 20) * 0.93 END) AS "min",
+gender FROM test_emp
+GROUP BY gender HAVING max > 152000 or min > 24000 ORDER BY gender;
+
+       max        |      min      |    gender
+------------------+---------------+---------------
+155409.30000000002|24139.08       |F
+151745.40000000002|24110.25       |M
+;
+
+aggregateFunctionsWithScalarsAndGroupByAndHaving_ComplexExpressions
+schema::max:d|min:d|gender:s
+SELECT ABS((MAX(CASE WHEN (salary - 10) > 70000 THEN (salary + 12345) * 1.2 ELSE (salary - 12345) * 2.7 END) + 123) / -100) AS "max",
+cos(MIN(CASE WHEN (salary - 20) > 50000 THEN (salary * 1.2) - 1234 ELSE (salary - 20) * 0.93 END) % 100) AS "min",
+gender
+FROM test_emp
+GROUP BY gender HAVING (max / 10) + 10 > 165  OR ABS(min * -100) > 60 ORDER BY gender;
+
+       max        |        min        |    gender
+------------------+-------------------+---------------
+1555.323          |0.1887687166044111 |F
+1518.6840000000002|-0.6783938504738453|M
+;
+
+aggregateFunctionsWithScalarsAndGroupByAndHaving_CombinedFields
+schema::min:d|max:d|gender:s
+SELECT MIN(ABS(salary * (languages / - 20.0))) AS "min",
+MAX(salary / ((languages / 3.0) + 1)) AS "max",
+gender
+FROM test_emp
+GROUP BY gender HAVING (min::long) / 120 > 12 OR ROUND(max) / 10 > 5200 ORDER BY gender;
+
+      min      |      max      |    gender
+---------------+---------------+---------------
+2436.75        |55287.75       |null
+1401.75        |52508.25       |M
+;
+
+aggregateFunctionsWithScalarsAndGroupByAndHavingConvertedToStats
+schema::max:d|min:d|gender:s
+SELECT MAX(CASE WHEN (salary - 10) > 70000 THEN (salary + 12345) * 1.2 ELSE (salary - 12345) * 2.7 END) AS "max",
+MIN(CASE WHEN (salary - 10) > 70000 THEN (salary + 12345) * 1.2 ELSE (salary - 12345) * 2.7 END) AS "min",
+gender FROM test_emp
+GROUP BY gender HAVING max > 155000 or min > 36000 ORDER BY gender;
+
+       max        |       min        |    gender
+------------------+------------------+---------------
+155409.30000000002|36803.700000000004|F
+151745.40000000002|36720.0           |M
+;
+
+percentileAggregateFunctionsWithScalars
+schema::percentile:d|percentile_rank:d|gender:s
+SELECT PERCENTILE(CASE WHEN (salary / 2) > 10000 THEN (salary + 12345) * 1.2 ELSE (salary - 12345) * 2.7 END, 80) AS "percentile",
+PERCENTILE_RANK(CASE WHEN (salary - 20) > 50000 THEN (salary * 1.2) - 1234 ELSE (salary - 20) * 0.93 END, 40000) AS "percentile_rank",
+gender FROM test_emp
+GROUP BY gender ORDER BY gender;
+
+   percentile    | percentile_rank  |    gender
+-----------------+------------------+---------------
+86857.79999999999|32.69659025378865 |null
+94042.92000000001|37.03569653103581 |F
+87348.36         |44.337514210592246|M
+;
+
+extendedStatsAggregateFunctionsWithScalars
+schema::stddev_pop:d|sum_of_squares:d|var_pop:d|gender:s
+SELECT STDDEV_POP(CASE WHEN (salary / 2) > 10000 THEN (salary + 12345) * 1.2 ELSE (salary - 12345) * 2.7 END) AS "stddev_pop",
+SUM_OF_SQUARES(CASE WHEN (salary - 20) > 50000 THEN (salary * 1.2) - 1234 ELSE (salary - 20) * 0.93 END) AS "sum_of_squares",
+VAR_POP(CASE WHEN (salary - 20) % 1000 > 200 THEN (salary * 1.2) - 1234 ELSE (salary - 20) * 0.93 END) AS "var_pop",
+gender FROM test_emp
+GROUP BY gender ORDER BY gender;
+
+    stddev_pop    |   sum_of_squares    |      var_pop       |    gender
+------------------+---------------------+--------------------+---------------
+16752.73244172422 |3.06310583829007E10  |3.460331137445282E8 |null
+17427.462400181845|1.148127725047658E11 |3.1723426960671306E8|F
+15702.798665784752|1.5882243113919238E11|2.529402043805585E8 |M
+;
+
+extendedStatsAggregateFunctionsWithScalarAndSameArg
+schema::stddev_pop:d|sum_of_squares:d|var_pop:d|gender:s
+SELECT STDDEV_POP(CASE WHEN (salary - 20) % 1000 > 200 THEN (salary * 1.2) - 1234 ELSE (salary - 20) * 0.93 END) AS "stddev_pop",
+SUM_OF_SQUARES(CASE WHEN (salary - 20) % 1000 > 200 THEN (salary * 1.2) - 1234 ELSE (salary - 20) * 0.93 END) AS "sum_of_squares",
+VAR_POP(CASE WHEN (salary - 20) % 1000 > 200 THEN (salary * 1.2) - 1234 ELSE (salary - 20) * 0.93 END) AS "var_pop",
+gender FROM test_emp
+GROUP BY gender ORDER BY gender;
+
+    stddev_pop    |   sum_of_squares    |      var_pop       |    gender
+------------------+---------------------+--------------------+---------------
+18601.965319409886|3.4461553130896095E10|3.460331137445282E8 |null
+17811.071545718776|1.2151168881502939E11|3.1723426960671306E8|F
+15904.093950318531|1.699198993070239E11 |2.529402043805585E8 |M
+;

+ 169 - 3
x-pack/plugin/sql/qa/src/main/resources/docs/docs.csv-spec

@@ -1133,15 +1133,27 @@ Georgi         |Facello        |10001
 ///////////////////////////////
 
 aggAvg
+schema::avg:d
 // tag::aggAvg
 SELECT AVG(salary) AS avg FROM emp;
 
-      avg:d      
+      avg
 ---------------
-48248.55          
+48248.55
 // end::aggAvg
 ;
 
+aggAvgScalars
+schema::avg:d
+// tag::aggAvgScalars
+SELECT AVG(salary / 12.0) AS avg FROM emp;
+
+      avg
+---------------
+4020.7125
+// end::aggAvgScalars
+;
+
 aggCountStar
 // tag::aggCountStar
 SELECT COUNT(*) AS count FROM emp;
@@ -1162,9 +1174,19 @@ SELECT COUNT(ALL last_name) AS count_all, COUNT(DISTINCT last_name) count_distin
 // end::aggCountAll
 ;
 
+aggCountAllScalars
+// tag::aggCountAllScalars
+SELECT COUNT(ALL CASE WHEN languages IS NULL THEN -1 ELSE languages END) AS count_all, COUNT(DISTINCT CASE WHEN languages IS NULL THEN -1 ELSE languages END) count_distinct FROM emp;
+
+   count_all   |  count_distinct
+---------------+---------------
+100            |6
+
+// end::aggCountAllScalars
+;
+
 aggCountDistinct
 // tag::aggCountDistinct
-
 SELECT COUNT(DISTINCT hire_date) unique_hires, COUNT(hire_date) AS hires FROM emp;
 
   unique_hires  |     hires
@@ -1173,6 +1195,16 @@ SELECT COUNT(DISTINCT hire_date) unique_hires, COUNT(hire_date) AS hires FROM em
 // end::aggCountDistinct
 ;
 
+aggCountDistinctScalars
+// tag::aggCountDistinctScalars
+SELECT COUNT(DISTINCT DATE_TRUNC('YEAR', hire_date)) unique_hires, COUNT(DATE_TRUNC('YEAR', hire_date)) AS hires FROM emp;
+
+ unique_hires  |     hires
+---------------+---------------
+14             |100
+// end::aggCountDistinctScalars
+;
+
 firstWithOneArg
 schema::FIRST(first_name):s
 // tag::firstWithOneArg
@@ -1239,6 +1271,19 @@ M             |   Remzi
 // end::firstValueWithTwoArgsAndGroupBy
 ;
 
+firstValueWithTwoArgsAndGroupByScalars
+schema::gender:s|first:s
+// tag::firstValueWithTwoArgsAndGroupByScalars
+SELECT gender, FIRST_VALUE(SUBSTRING(first_name, 2, 6), birth_date) AS "first" FROM emp GROUP BY gender ORDER BY gender;
+
+    gender     |     first
+---------------+---------------
+null           |illian
+F              |umant
+M              |emzi
+// end::firstValueWithTwoArgsAndGroupByScalars
+;
+
 lastWithOneArg
 schema::LAST(first_name):s
 // tag::lastWithOneArg
@@ -1307,6 +1352,18 @@ M          |   Hilari
 // end::lastValueWithTwoArgsAndGroupBy
 ;
 
+lastValueWithTwoArgsAndGroupByScalars
+schema::gender:s|last:s
+// tag::lastValueWithTwoArgsAndGroupByScalars
+SELECT gender, LAST_VALUE(SUBSTRING(first_name, 3, 8), birth_date) AS "last" FROM emp GROUP BY gender ORDER BY gender;
+
+    gender     |     last
+---------------+---------------
+null           |erhardt
+F              |ldiodio
+M              |lari
+// end::lastValueWithTwoArgsAndGroupByScalars
+;
 
 aggMax
 // tag::aggMax
@@ -1318,6 +1375,17 @@ SELECT MAX(salary) AS max FROM emp;
 // end::aggMax
 ;
 
+aggMaxScalars
+schema::max:d
+// tag::aggMaxScalars
+SELECT MAX(ABS(salary / -12.0)) AS max FROM emp;
+
+       max
+-----------------
+6249.916666666667
+// end::aggMaxScalars
+;
+
 aggMin
 // tag::aggMin
 SELECT MIN(salary) AS min FROM emp;
@@ -1328,6 +1396,17 @@ SELECT MIN(salary) AS min FROM emp;
 // end::aggMin
 ;
 
+aggMinScalars
+schema::min:d
+// tag::aggMinScalars
+SELECT MIN(ABS(salary / 12.0)) AS min FROM emp;
+
+       min
+------------------
+2110.3333333333335
+// end::aggMinScalars
+;
+
 aggSum
 // tag::aggSum
 SELECT SUM(salary) AS sum FROM emp;
@@ -1338,6 +1417,17 @@ SELECT SUM(salary) AS sum FROM emp;
 // end::aggSum
 ;
 
+aggSumScalars
+schema::sum:d
+// tag::aggSumScalars
+SELECT ROUND(SUM(salary / 12.0), 1) AS sum FROM emp;
+
+      sum
+---------------
+402071.3
+// end::aggSumScalars
+;
+
 aggKurtosis
 // tag::aggKurtosis
 SELECT MIN(salary) AS min, MAX(salary) AS max, KURTOSIS(salary) AS k FROM emp;
@@ -1358,6 +1448,17 @@ SELECT MIN(salary) AS min, MAX(salary) AS max, AVG(salary) AS avg, MAD(salary) A
 // end::aggMad
 ;
 
+aggMadScalars
+schema::min:d|max:d|avg:d|mad:d
+// tag::aggMadScalars
+SELECT MIN(salary / 12.0) AS min, MAX(salary / 12.0) AS max, AVG(salary/ 12.0) AS avg, MAD(salary / 12.0) AS mad FROM emp;
+
+       min        |       max       |      avg      |       mad
+------------------+-----------------+---------------+-----------------
+2110.3333333333335|6249.916666666667|4020.7125      |841.3750000000002
+// end::aggMadScalars
+;
+
 aggPercentile
 // tag::aggPercentile
 SELECT languages, PERCENTILE(salary, 95) AS "95th" FROM emp 
@@ -1374,6 +1475,23 @@ null           |74999.0
 // end::aggPercentile
 ;
 
+aggPercentileScalars
+schema::languages:i|95th:d
+// tag::aggPercentileScalars
+SELECT languages, PERCENTILE(salary / 12.0, 95) AS "95th" FROM emp
+       GROUP BY languages;
+
+   languages   |       95th
+---------------+------------------
+null           |6249.916666666667
+1              |6065.875
+2              |5993.725
+3              |6136.520833333332
+4              |6009.633333333332
+5              |5089.3083333333325
+// end::aggPercentileScalars
+;
+
 aggPercentileRank
 // tag::aggPercentileRank
 SELECT languages, PERCENTILE_RANK(salary, 65000) AS rank FROM emp GROUP BY languages;
@@ -1389,6 +1507,22 @@ null           |73.65766569962062
 // end::aggPercentileRank
 ;
 
+aggPercentileRankScalars
+schema::languages:i|rank:d
+// tag::aggPercentileRankScalars
+SELECT languages, PERCENTILE_RANK(salary/12, 5000) AS rank FROM emp GROUP BY languages;
+
+   languages   |       rank
+---------------+------------------
+null           |66.91240875912409
+1              |66.70766707667076
+2              |84.13266895048271
+3              |61.052992625621684
+4              |76.55646443990001
+5              |94.00696864111498
+// end::aggPercentileRankScalars
+;
+
 aggSkewness
 // tag::aggSkewness
 SELECT MIN(salary) AS min, MAX(salary) AS max, SKEWNESS(salary) AS s FROM emp;
@@ -1410,6 +1544,17 @@ SELECT MIN(salary) AS min, MAX(salary) AS max, STDDEV_POP(salary) AS stddev
 // end::aggStddevPop
 ;
 
+aggStddevPopScalars
+schema::min:d|max:d|stddev:d
+// tag::aggStddevPopScalars
+SELECT MIN(salary / 12.0) AS min, MAX(salary / 12.0) AS max, STDDEV_POP(salary / 12.0) AS stddev FROM emp;
+
+       min        |       max       |     stddev
+------------------+-----------------+-----------------
+2110.3333333333335|6249.916666666667|1147.093791898986
+// end::aggStddevPopScalars
+;
+
 
 aggSumOfSquares
 // tag::aggSumOfSquares
@@ -1422,6 +1567,16 @@ SELECT MIN(salary) AS min, MAX(salary) AS max, SUM_OF_SQUARES(salary) AS sumsq
 // end::aggSumOfSquares
 ;
 
+aggSumOfSquaresScalars
+schema::min:d|max:d|sumsq:d
+// tag::aggSumOfSquaresScalars
+SELECT MIN(salary / 24.0) AS min, MAX(salary / 24.0) AS max, SUM_OF_SQUARES(salary / 24.0) AS sumsq FROM emp;
+
+       min        |       max        |       sumsq
+------------------+------------------+-------------------
+1055.1666666666667|3124.9583333333335|4.370488293767361E8
+// end::aggSumOfSquaresScalars
+;
 
 aggVarPop
 // tag::aggVarPop
@@ -1433,6 +1588,17 @@ SELECT MIN(salary) AS min, MAX(salary) AS max, VAR_POP(salary) AS varpop FROM em
 // end::aggVarPop
 ;
 
+aggVarPopScalars
+schema::min:d|max:d|varpop:d
+// tag::aggVarPopScalars
+SELECT MIN(salary / 24.0) AS min, MAX(salary / 24.0) AS max, VAR_POP(salary / 24.0) AS varpop FROM emp;
+
+       min        |       max        |      varpop
+------------------+------------------+------------------
+1055.1666666666667|3124.9583333333335|328956.04185329855
+
+// end::aggVarPopScalars
+;
 
 ///////////////////////////////
 //

+ 18 - 0
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/analysis/analyzer/Verifier.java

@@ -43,8 +43,10 @@ import org.elasticsearch.xpack.ql.util.Holder;
 import org.elasticsearch.xpack.ql.util.StringUtils;
 import org.elasticsearch.xpack.sql.expression.Exists;
 import org.elasticsearch.xpack.sql.expression.function.Score;
+import org.elasticsearch.xpack.sql.expression.function.aggregate.Kurtosis;
 import org.elasticsearch.xpack.sql.expression.function.aggregate.Max;
 import org.elasticsearch.xpack.sql.expression.function.aggregate.Min;
+import org.elasticsearch.xpack.sql.expression.function.aggregate.Skewness;
 import org.elasticsearch.xpack.sql.expression.function.aggregate.TopHits;
 import org.elasticsearch.xpack.sql.plan.logical.Distinct;
 import org.elasticsearch.xpack.sql.plan.logical.LocalRelation;
@@ -217,6 +219,7 @@ public final class Verifier {
                 checkNestedUsedInGroupByOrHavingOrWhereOrOrderBy(p, localFailures, attributeRefs);
                 checkForGeoFunctionsOnDocValues(p, localFailures);
                 checkPivot(p, localFailures, attributeRefs);
+                checkMatrixStats(p, localFailures);
 
                 // everything checks out
                 // mark the plan as analyzed
@@ -847,4 +850,19 @@ public final class Verifier {
 
         }, Pivot.class);
     }
+
+    private static void checkMatrixStats(LogicalPlan p, Set<Failure> localFailures) {
+        // MatrixStats aggregate functions cannot operates on scalars
+        // https://github.com/elastic/elasticsearch/issues/55344
+        p.forEachExpressions(e -> e.forEachUp((Kurtosis s) -> {
+            if (s.field() instanceof Function) {
+                localFailures.add(fail(s.field(), "[{}()] cannot be used on top of operators or scalars", s.functionName()));
+            }
+        }, Kurtosis.class));
+        p.forEachExpressions(e -> e.forEachUp((Skewness s) -> {
+            if (s.field() instanceof Function) {
+                localFailures.add(fail(s.field(), "[{}()] cannot be used on top of operators or scalars", s.functionName()));
+            }
+        }, Skewness.class));
+    }
 }

+ 51 - 40
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/planner/QueryTranslator.java

@@ -11,7 +11,6 @@ import org.elasticsearch.search.sort.SortOrder;
 import org.elasticsearch.xpack.ql.expression.Expression;
 import org.elasticsearch.xpack.ql.expression.FieldAttribute;
 import org.elasticsearch.xpack.ql.expression.Foldables;
-import org.elasticsearch.xpack.ql.expression.Literal;
 import org.elasticsearch.xpack.ql.expression.NamedExpression;
 import org.elasticsearch.xpack.ql.expression.function.Function;
 import org.elasticsearch.xpack.ql.expression.function.aggregate.AggregateFunction;
@@ -60,6 +59,7 @@ import org.elasticsearch.xpack.sql.expression.function.scalar.geo.StDistance;
 import org.elasticsearch.xpack.sql.expression.literal.geo.GeoShape;
 import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.In;
 import org.elasticsearch.xpack.sql.querydsl.agg.AggFilter;
+import org.elasticsearch.xpack.sql.querydsl.agg.AggSource;
 import org.elasticsearch.xpack.sql.querydsl.agg.AndAggFilter;
 import org.elasticsearch.xpack.sql.querydsl.agg.AvgAgg;
 import org.elasticsearch.xpack.sql.querydsl.agg.CardinalityAgg;
@@ -92,7 +92,7 @@ final class QueryTranslator {
     public static final String DATE_FORMAT = "strict_date_time";
     public static final String TIME_FORMAT = "strict_hour_minute_second_millis";
 
-    private QueryTranslator(){}
+    private QueryTranslator() {}
 
     private static final List<SqlExpressionTranslator<?>> QUERY_TRANSLATORS = Arrays.asList(
             new BinaryComparisons(),
@@ -135,10 +135,6 @@ final class QueryTranslator {
             this(query, null);
         }
 
-        QueryTranslation(AggFilter aggFilter) {
-            this(null, aggFilter);
-        }
-
         QueryTranslation(Query query, AggFilter aggFilter) {
             this.query = query;
             this.aggFilter = aggFilter;
@@ -240,39 +236,36 @@ final class QueryTranslator {
         }
     }
 
-    static String dateFormat(Expression e) {
-        if (e instanceof DateTimeFunction) {
-            return ((DateTimeFunction) e).dateTimeFormat();
+    static String field(AggregateFunction af, Expression arg) {
+        if (arg.foldable()) {
+            return String.valueOf(arg.fold());
         }
-        return null;
-    }
-
-    static String field(AggregateFunction af) {
-        Expression arg = af.field();
         if (arg instanceof FieldAttribute) {
             FieldAttribute field = (FieldAttribute) arg;
             // COUNT(DISTINCT) uses cardinality aggregation which works on exact values (not changed by analyzers or normalizers)
-            if (af instanceof Count && ((Count) af).distinct()) {
+            if ((af instanceof Count && ((Count) af).distinct()) || af instanceof TopHits) {
                 // use the `keyword` version of the field, if there is one
                 return field.exactAttribute().name();
             }
             return field.name();
         }
-        if (arg instanceof Literal) {
-            return String.valueOf(((Literal) arg).value());
-        }
         throw new SqlIllegalArgumentException("Does not know how to convert argument {} for function {}", arg.nodeString(),
                                               af.nodeString());
     }
 
-    private static String topAggsField(AggregateFunction af, Expression e) {
+    private static boolean isFieldOrLiteral(Expression e) {
+        return e.foldable() || e instanceof FieldAttribute;
+    }
+
+    private static AggSource asFieldOrLiteralOrScript(AggregateFunction af) {
+        return asFieldOrLiteralOrScript(af, af.field());
+    }
+
+    private static AggSource asFieldOrLiteralOrScript(AggregateFunction af, Expression e) {
         if (e == null) {
             return null;
         }
-        if (e instanceof FieldAttribute) {
-            return ((FieldAttribute) e).exactAttribute().name();
-        }
-        throw new SqlIllegalArgumentException("Does not know how to convert argument {} for function {}", e.nodeString(), af.nodeString());
+        return isFieldOrLiteral(e) ? AggSource.of(field(af, e)) : AggSource.of(((ScalarFunction) e).asScript());
     }
 
     // TODO: see whether escaping is needed
@@ -524,9 +517,9 @@ final class QueryTranslator {
         @Override
         protected LeafAgg toAgg(String id, Count c) {
             if (c.distinct()) {
-                return new CardinalityAgg(id, field(c));
+                return new CardinalityAgg(id, asFieldOrLiteralOrScript(c));
             } else {
-                return new FilterExistsAgg(id, field(c));
+                return new FilterExistsAgg(id, asFieldOrLiteralOrScript(c));
             }
         }
     }
@@ -535,7 +528,7 @@ final class QueryTranslator {
 
         @Override
         protected LeafAgg toAgg(String id, Sum s) {
-            return new SumAgg(id, field(s));
+            return new SumAgg(id, asFieldOrLiteralOrScript(s));
         }
     }
 
@@ -543,7 +536,7 @@ final class QueryTranslator {
 
         @Override
         protected LeafAgg toAgg(String id, Avg a) {
-            return new AvgAgg(id, field(a));
+            return new AvgAgg(id, asFieldOrLiteralOrScript(a));
         }
     }
 
@@ -551,7 +544,7 @@ final class QueryTranslator {
 
         @Override
         protected LeafAgg toAgg(String id, Max m) {
-            return new MaxAgg(id, field(m));
+            return new MaxAgg(id, asFieldOrLiteralOrScript(m));
         }
     }
 
@@ -559,14 +552,14 @@ final class QueryTranslator {
 
         @Override
         protected LeafAgg toAgg(String id, Min m) {
-            return new MinAgg(id, field(m));
+            return new MinAgg(id, asFieldOrLiteralOrScript(m));
         }
     }
 
     static class MADs extends SingleValueAggTranslator<MedianAbsoluteDeviation> {
         @Override
         protected LeafAgg toAgg(String id, MedianAbsoluteDeviation m) {
-            return new MedianAbsoluteDeviationAgg(id, field(m));
+            return new MedianAbsoluteDeviationAgg(id, asFieldOrLiteralOrScript(m));
         }
     }
 
@@ -574,8 +567,14 @@ final class QueryTranslator {
 
         @Override
         protected LeafAgg toAgg(String id, First f) {
-            return new TopHitsAgg(id, topAggsField(f, f.field()), f.dataType(),
-                topAggsField(f, f.orderField()), f.orderField() == null ? null : f.orderField().dataType(), SortOrder.ASC);
+            return new TopHitsAgg(
+                id,
+                asFieldOrLiteralOrScript(f, f.field()),
+                f.dataType(),
+                asFieldOrLiteralOrScript(f, f.orderField()),
+                f.orderField() == null ? null : f.orderField().dataType(),
+                SortOrder.ASC
+            );
         }
     }
 
@@ -583,8 +582,14 @@ final class QueryTranslator {
 
         @Override
         protected LeafAgg toAgg(String id, Last l) {
-            return new TopHitsAgg(id, topAggsField(l, l.field()), l.dataType(),
-                topAggsField(l, l.orderField()), l.orderField() == null ? null : l.orderField().dataType(), SortOrder.DESC);
+            return new TopHitsAgg(
+                id,
+                asFieldOrLiteralOrScript(l, l.field()),
+                l.dataType(),
+                asFieldOrLiteralOrScript(l, l.orderField()),
+                l.orderField() == null ? null : l.orderField().dataType(),
+                SortOrder.DESC
+            );
         }
     }
 
@@ -592,7 +597,7 @@ final class QueryTranslator {
 
         @Override
         protected LeafAgg toAgg(String id, Stats s) {
-            return new StatsAgg(id, field(s));
+            return new StatsAgg(id, asFieldOrLiteralOrScript(s));
         }
     }
 
@@ -600,7 +605,7 @@ final class QueryTranslator {
 
         @Override
         protected LeafAgg toAgg(String id, ExtendedStats e) {
-            return new ExtendedStatsAgg(id, field(e));
+            return new ExtendedStatsAgg(id, asFieldOrLiteralOrScript(e));
         }
     }
 
@@ -608,7 +613,13 @@ final class QueryTranslator {
 
         @Override
         protected LeafAgg toAgg(String id, MatrixStats m) {
-            return new MatrixStatsAgg(id, singletonList(field(m)));
+            if (isFieldOrLiteral(m.field())) {
+                return new MatrixStatsAgg(id, singletonList(field(m, m.field())));
+            }
+            throw new SqlIllegalArgumentException(
+                "Cannot use scalar functions or operators: [{}] in aggregate functions [KURTOSIS] and [SKEWNESS]",
+                m.field().toString()
+            );
         }
     }
 
@@ -616,7 +627,7 @@ final class QueryTranslator {
 
         @Override
         protected LeafAgg toAgg(String id, Percentiles p) {
-            return new PercentilesAgg(id, field(p), foldAndConvertToDoubles(p.percents()));
+            return new PercentilesAgg(id, asFieldOrLiteralOrScript(p), foldAndConvertToDoubles(p.percents()));
         }
     }
 
@@ -624,7 +635,7 @@ final class QueryTranslator {
 
         @Override
         protected LeafAgg toAgg(String id, PercentileRanks p) {
-            return new PercentileRanksAgg(id, field(p), foldAndConvertToDoubles(p.values()));
+            return new PercentileRanksAgg(id, asFieldOrLiteralOrScript(p), foldAndConvertToDoubles(p.values()));
         }
     }
 
@@ -632,7 +643,7 @@ final class QueryTranslator {
 
         @Override
         protected LeafAgg toAgg(String id, Min m) {
-            return new MinAgg(id, field(m));
+            return new MinAgg(id, asFieldOrLiteralOrScript(m));
         }
     }
 

+ 10 - 9
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/Agg.java

@@ -16,24 +16,25 @@ import static java.lang.String.format;
 public abstract class Agg {
 
     private final String id;
-    private final String fieldName;
+    private final AggSource source;
 
-    Agg(String id, String fieldName) {
+    Agg(String id, AggSource source) {
+        Objects.requireNonNull(source, "AggSource must not be null");
         this.id = id;
-        this.fieldName = fieldName;
+        this.source = source;
     }
 
     public String id() {
         return id;
     }
 
-    protected String fieldName() {
-        return fieldName;
+    public AggSource source() {
+        return source;
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(id, fieldName);
+        return Objects.hash(id, source);
     }
 
     @Override
@@ -48,11 +49,11 @@ public abstract class Agg {
 
         Agg other = (Agg) obj;
         return Objects.equals(id, other.id)
-                && Objects.equals(fieldName, other.fieldName);
+            && Objects.equals(source, other.source);
     }
 
     @Override
     public String toString() {
-        return format(Locale.ROOT, "%s(%s)", getClass().getSimpleName(), fieldName);
+        return format(Locale.ROOT, "%s(%s)", getClass().getSimpleName(), source.toString());
     }
-}
+}

+ 70 - 0
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/AggSource.java

@@ -0,0 +1,70 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License;
+ * you may not use this file except in compliance with the Elastic License.
+ */
+package org.elasticsearch.xpack.sql.querydsl.agg;
+
+import org.elasticsearch.search.aggregations.support.ValuesSourceAggregationBuilder;
+import org.elasticsearch.xpack.ql.expression.gen.script.ScriptTemplate;
+
+import java.util.Objects;
+
+public class AggSource {
+
+    private final String fieldName;
+    private final ScriptTemplate script;
+
+    private AggSource(String fieldName, ScriptTemplate script) {
+        this.fieldName = fieldName;
+        this.script = script;
+    }
+
+    public static AggSource of(String fieldName) {
+        return new AggSource(fieldName, null);
+    }
+
+    public static AggSource of(ScriptTemplate script) {
+        return new AggSource(null, script);
+    }
+
+    @SuppressWarnings("rawtypes")
+    ValuesSourceAggregationBuilder with(ValuesSourceAggregationBuilder aggBuilder) {
+        if (fieldName != null) {
+            return aggBuilder.field(fieldName);
+        }
+        else {
+            return aggBuilder.script(script.toPainless());
+        }
+    }
+
+    public String fieldName() {
+        return fieldName;
+    }
+
+    public ScriptTemplate script() {
+        return script;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(fieldName, script);
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) {
+            return true;
+        }
+        if (o == null || getClass() != o.getClass()) {
+            return false;
+        }
+        AggSource aggSource = (AggSource) o;
+        return Objects.equals(fieldName, aggSource.fieldName) && Objects.equals(script, aggSource.script);
+    }
+
+    @Override
+    public String toString() {
+        return fieldName != null ? fieldName : script.toString();
+    }
+}

+ 2 - 3
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/Aggs.java

@@ -9,7 +9,6 @@ import org.elasticsearch.search.aggregations.AggregationBuilder;
 import org.elasticsearch.search.aggregations.bucket.composite.CompositeAggregationBuilder;
 import org.elasticsearch.search.aggregations.bucket.composite.CompositeValuesSourceBuilder;
 import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregationBuilder;
-import org.elasticsearch.xpack.ql.expression.gen.script.ScriptTemplate;
 import org.elasticsearch.xpack.ql.querydsl.container.Sort.Direction;
 import org.elasticsearch.xpack.ql.util.StringUtils;
 import org.elasticsearch.xpack.sql.SqlIllegalArgumentException;
@@ -40,7 +39,7 @@ public class Aggs {
 
     public static final String ROOT_GROUP_NAME = "groupby";
 
-    public static final GroupByKey IMPLICIT_GROUP_KEY = new GroupByKey(ROOT_GROUP_NAME, StringUtils.EMPTY, null, null) {
+    public static final GroupByKey IMPLICIT_GROUP_KEY = new GroupByKey(ROOT_GROUP_NAME, AggSource.of(StringUtils.EMPTY), null) {
 
         @Override
         public CompositeValuesSourceBuilder<?> createSourceBuilder() {
@@ -48,7 +47,7 @@ public class Aggs {
         }
 
         @Override
-        protected GroupByKey copy(String id, String fieldName, ScriptTemplate script, Direction direction) {
+        protected GroupByKey copy(String id, AggSource source, Direction direction) {
             return this;
         }
     };

+ 3 - 3
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/AvgAgg.java

@@ -11,12 +11,12 @@ import static org.elasticsearch.search.aggregations.AggregationBuilders.avg;
 
 public class AvgAgg extends LeafAgg {
 
-    public AvgAgg(String id, String fieldName) {
-        super(id, fieldName);
+    public AvgAgg(String id, AggSource source) {
+        super(id, source);
     }
 
     @Override
     AggregationBuilder toBuilder() {
-        return avg(id()).field(fieldName());
+        return addAggSource(avg(id()));
     }
 }

+ 3 - 3
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/CardinalityAgg.java

@@ -11,12 +11,12 @@ import static org.elasticsearch.search.aggregations.AggregationBuilders.cardinal
 
 public class CardinalityAgg extends LeafAgg {
 
-    public CardinalityAgg(String id, String fieldName) {
-        super(id, fieldName);
+    public CardinalityAgg(String id, AggSource source) {
+        super(id, source);
     }
 
     @Override
     AggregationBuilder toBuilder() {
-        return cardinality(id()).field(fieldName());
+        return addAggSource(cardinality(id()));
     }
 }

+ 3 - 3
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/ExtendedStatsAgg.java

@@ -11,12 +11,12 @@ import static org.elasticsearch.search.aggregations.AggregationBuilders.extended
 
 public class ExtendedStatsAgg extends LeafAgg {
 
-    public ExtendedStatsAgg(String id, String fieldName) {
-        super(id, fieldName);
+    public ExtendedStatsAgg(String id, AggSource source) {
+        super(id, source);
     }
 
     @Override
     AggregationBuilder toBuilder() {
-        return extendedStats(id()).field(fieldName());
+        return addAggSource(extendedStats(id()));
     }
 }

+ 20 - 3
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/FilterExistsAgg.java

@@ -7,20 +7,37 @@ package org.elasticsearch.xpack.sql.querydsl.agg;
 
 import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.search.aggregations.AggregationBuilder;
+import org.elasticsearch.xpack.ql.expression.gen.script.ScriptTemplate;
+import org.elasticsearch.xpack.ql.type.DataTypes;
 
+import java.util.Locale;
+
+import static java.lang.String.format;
 import static org.elasticsearch.search.aggregations.AggregationBuilders.filter;
+import static org.elasticsearch.xpack.ql.expression.gen.script.Scripts.formatTemplate;
 
 /**
  * Aggregation builder for a "filter" aggregation encapsulating an "exists" query.
  */
 public class FilterExistsAgg extends LeafAgg {
 
-    public FilterExistsAgg(String id, String fieldName) {
-        super(id, fieldName);
+    public FilterExistsAgg(String id, AggSource source) {
+        super(id, source);
     }
 
     @Override
     AggregationBuilder toBuilder() {
-        return filter(id(), QueryBuilders.existsQuery(fieldName()));
+        if (source().fieldName() != null) {
+            return filter(id(), QueryBuilders.existsQuery(source().fieldName()));
+        } else {
+            return filter(id(), QueryBuilders.scriptQuery(wrapWithIsNotNull(source().script()).toPainless()));
+        }
+    }
+
+    private static ScriptTemplate wrapWithIsNotNull(ScriptTemplate script) {
+        return new ScriptTemplate(formatTemplate(
+                format(Locale.ROOT, "{ql}.isNotNull(%s)", script.template())),
+                script.params(),
+                DataTypes.BOOLEAN);
     }
 }

+ 9 - 9
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/GroupByDateHistogram.java

@@ -25,24 +25,24 @@ public class GroupByDateHistogram extends GroupByKey {
     private final ZoneId zoneId;
 
     public GroupByDateHistogram(String id, String fieldName, long fixedInterval, ZoneId zoneId) {
-        this(id, fieldName, null, null, fixedInterval, null, zoneId);
+        this(id, AggSource.of(fieldName), null, fixedInterval, null, zoneId);
     }
 
     public GroupByDateHistogram(String id, ScriptTemplate script, long fixedInterval, ZoneId zoneId) {
-        this(id, null, script, null, fixedInterval, null, zoneId);
+        this(id, AggSource.of(script), null, fixedInterval, null, zoneId);
     }
     
     public GroupByDateHistogram(String id, String fieldName, String calendarInterval, ZoneId zoneId) {
-        this(id, fieldName, null, null, -1L, calendarInterval, zoneId);
+        this(id, AggSource.of(fieldName), null, -1L, calendarInterval, zoneId);
     }
     
     public GroupByDateHistogram(String id, ScriptTemplate script, String calendarInterval, ZoneId zoneId) {
-        this(id, null, script, null, -1L, calendarInterval, zoneId);
+        this(id, AggSource.of(script), null, -1L, calendarInterval, zoneId);
     }
 
-    private GroupByDateHistogram(String id, String fieldName, ScriptTemplate script, Direction direction, long fixedInterval,
-            String calendarInterval, ZoneId zoneId) {
-        super(id, fieldName, script, direction);
+    private GroupByDateHistogram(String id, AggSource source, Direction direction, long fixedInterval,
+                                 String calendarInterval, ZoneId zoneId) {
+        super(id, source, direction);
         if (fixedInterval <= 0 && (calendarInterval == null || calendarInterval.isBlank())) {
             throw new SqlIllegalArgumentException("Either fixed interval or calendar interval needs to be specified");
         }
@@ -64,8 +64,8 @@ public class GroupByDateHistogram extends GroupByKey {
     }
 
     @Override
-    protected GroupByKey copy(String id, String fieldName, ScriptTemplate script, Direction direction) {
-        return new GroupByDateHistogram(id, fieldName, script, direction, fixedInterval, calendarInterval, zoneId);
+    protected GroupByKey copy(String id, AggSource source, Direction direction) {
+        return new GroupByDateHistogram(id, source(), direction, fixedInterval, calendarInterval, zoneId);
     }
 
     @Override

+ 23 - 17
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/GroupByKey.java

@@ -25,18 +25,20 @@ import static org.elasticsearch.xpack.sql.type.SqlDataTypes.TIME;
 public abstract class GroupByKey extends Agg {
 
     protected final Direction direction;
-    private final ScriptTemplate script;
 
-    protected GroupByKey(String id, String fieldName, ScriptTemplate script, Direction direction) {
-        super(id, fieldName);
+    protected GroupByKey(String id, AggSource source, Direction direction) {
+        super(id, source);
         // ASC is the default order of CompositeValueSource
         this.direction = direction == null ? Direction.ASC : direction;
-        this.script = script;
+    }
+
+    public ScriptTemplate script() {
+        return source().script();
     }
 
     public final CompositeValuesSourceBuilder<?> asValueSource() {
         CompositeValuesSourceBuilder<?> builder = createSourceBuilder();
-        
+        ScriptTemplate script = source().script();
         if (script != null) {
             builder.script(script.toPainless());
             if (script.outputType().isInteger()) {
@@ -59,7 +61,7 @@ public abstract class GroupByKey extends Agg {
         }
         // field based
         else {
-            builder.field(fieldName());
+            builder.field(source().fieldName());
         }
         return builder.order(direction.asOrder())
                .missingBucket(true);
@@ -67,25 +69,29 @@ public abstract class GroupByKey extends Agg {
 
     protected abstract CompositeValuesSourceBuilder<?> createSourceBuilder();
 
-    protected abstract GroupByKey copy(String id, String fieldName, ScriptTemplate script, Direction direction);
+    protected abstract GroupByKey copy(String id, AggSource source, Direction direction);
 
     public GroupByKey with(Direction direction) {
-        return this.direction == direction ? this : copy(id(), fieldName(), script, direction);
-    }
-
-    public ScriptTemplate script() {
-        return script;
+        return this.direction == direction ? this : copy(id(), source(), direction);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(id(), fieldName(), script, direction);
+        return Objects.hash(super.hashCode(), direction);
     }
 
     @Override
-    public boolean equals(Object obj) {
-        return super.equals(obj)
-                && Objects.equals(script, ((GroupByKey) obj).script)
-                && Objects.equals(direction, ((GroupByKey) obj).direction);
+    public boolean equals(Object o) {
+        if (this == o) {
+            return true;
+        }
+        if (o == null || getClass() != o.getClass()) {
+            return false;
+        }
+        if (super.equals(o) == false) {
+            return false;
+        }
+        GroupByKey that = (GroupByKey) o;
+        return direction == that.direction;
     }
 }

+ 6 - 6
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/GroupByNumericHistogram.java

@@ -20,15 +20,15 @@ public class GroupByNumericHistogram extends GroupByKey {
     private final double interval;
 
     public GroupByNumericHistogram(String id, String fieldName, double interval) {
-        this(id, fieldName, null, null, interval);
+        this(id, AggSource.of(fieldName), null, interval);
     }
 
     public GroupByNumericHistogram(String id, ScriptTemplate script, double interval) {
-        this(id, null, script, null, interval);
+        this(id, AggSource.of(script), null, interval);
     }
 
-    private GroupByNumericHistogram(String id, String fieldName, ScriptTemplate script, Direction direction, double interval) {
-        super(id, fieldName, script, direction);
+    private GroupByNumericHistogram(String id, AggSource aggSource, Direction direction, double interval) {
+        super(id, aggSource, direction);
         this.interval = interval;
     }
 
@@ -39,8 +39,8 @@ public class GroupByNumericHistogram extends GroupByKey {
     }
 
     @Override
-    protected GroupByKey copy(String id, String fieldName, ScriptTemplate script, Direction direction) {
-        return new GroupByNumericHistogram(id, fieldName, script, direction, interval);
+    protected GroupByKey copy(String id, AggSource source, Direction direction) {
+        return new GroupByNumericHistogram(id, source(), direction, interval);
     }
 
     @Override

+ 6 - 6
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/GroupByValue.java

@@ -16,15 +16,15 @@ import org.elasticsearch.xpack.ql.querydsl.container.Sort.Direction;
 public class GroupByValue extends GroupByKey {
 
     public GroupByValue(String id, String fieldName) {
-        this(id, fieldName, null, null);
+        this(id, AggSource.of(fieldName), null);
     }
 
     public GroupByValue(String id, ScriptTemplate script) {
-        this(id, null, script, null);
+        this(id, AggSource.of(script), null);
     }
 
-    private GroupByValue(String id, String fieldName, ScriptTemplate script, Direction direction) {
-        super(id, fieldName, script, direction);
+    private GroupByValue(String id, AggSource source, Direction direction) {
+        super(id, source, direction);
     }
 
     @Override
@@ -33,7 +33,7 @@ public class GroupByValue extends GroupByKey {
     }
 
     @Override
-    protected GroupByKey copy(String id, String fieldName, ScriptTemplate script, Direction direction) {
-        return new GroupByValue(id, fieldName, script, direction);
+    protected GroupByKey copy(String id, AggSource source, Direction direction) {
+        return new GroupByValue(id, source(), direction);
     }
 }

+ 8 - 2
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/LeafAgg.java

@@ -6,12 +6,18 @@
 package org.elasticsearch.xpack.sql.querydsl.agg;
 
 import org.elasticsearch.search.aggregations.AggregationBuilder;
+import org.elasticsearch.search.aggregations.support.ValuesSourceAggregationBuilder;
 
 public abstract class LeafAgg extends Agg {
 
-    LeafAgg(String id, String fieldName) {
-        super(id, fieldName);
+    LeafAgg(String id, AggSource source) {
+        super(id, source);
     }
 
     abstract AggregationBuilder toBuilder();
+
+    @SuppressWarnings("rawtypes")
+    protected ValuesSourceAggregationBuilder addAggSource(ValuesSourceAggregationBuilder builder) {
+        return source().with(builder);
+    }
 }

+ 1 - 1
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/MatrixStatsAgg.java

@@ -16,7 +16,7 @@ public class MatrixStatsAgg extends LeafAgg {
     private final List<String> fields;
 
     public MatrixStatsAgg(String id, List<String> fields) {
-        super(id, "<multi-field>");
+        super(id, AggSource.of("<multi-field>"));
         this.fields = fields;
     }
 

+ 3 - 3
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/MaxAgg.java

@@ -11,12 +11,12 @@ import static org.elasticsearch.search.aggregations.AggregationBuilders.max;
 
 public class MaxAgg extends LeafAgg {
 
-    public MaxAgg(String id, String fieldName) {
-        super(id, fieldName);
+    public MaxAgg(String id, AggSource source) {
+        super(id, source);
     }
 
     @Override
     AggregationBuilder toBuilder() {
-        return max(id()).field(fieldName());
+        return addAggSource(max(id()));
     }
 }

+ 3 - 3
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/MedianAbsoluteDeviationAgg.java

@@ -12,12 +12,12 @@ import static org.elasticsearch.search.aggregations.AggregationBuilders.medianAb
 
 public class MedianAbsoluteDeviationAgg extends LeafAgg {
 
-    public MedianAbsoluteDeviationAgg(String id, String fieldName) {
-        super(id, fieldName);
+    public MedianAbsoluteDeviationAgg(String id, AggSource source) {
+        super(id, source);
     }
 
     @Override
     AggregationBuilder toBuilder() {
-        return medianAbsoluteDeviation(id()).field(fieldName());
+        return addAggSource(medianAbsoluteDeviation(id()));
     }
 }

+ 3 - 3
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/MinAgg.java

@@ -11,12 +11,12 @@ import static org.elasticsearch.search.aggregations.AggregationBuilders.min;
 
 public class MinAgg extends LeafAgg {
 
-    public MinAgg(String id, String fieldName) {
-        super(id, fieldName);
+    public MinAgg(String id, AggSource source) {
+        super(id, source);
     }
 
     @Override
     AggregationBuilder toBuilder() {
-        return min(id()).field(fieldName());
+        return addAggSource(min(id()));
     }
 }

+ 3 - 8
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/PercentileRanksAgg.java

@@ -15,18 +15,13 @@ public class PercentileRanksAgg extends LeafAgg {
 
     private final List<Double> values;
 
-    public PercentileRanksAgg(String id, String fieldName, List<Double> values) {
-        super(id, fieldName);
+    public PercentileRanksAgg(String id, AggSource source, List<Double> values) {
+        super(id, source);
         this.values = values;
     }
 
-    public List<Double> percents() {
-        return values;
-    }
-
     @Override
     AggregationBuilder toBuilder() {
-        return percentileRanks(id(), values.stream().mapToDouble(Double::doubleValue).toArray())
-                .field(fieldName());
+        return addAggSource(percentileRanks(id(), values.stream().mapToDouble(Double::doubleValue).toArray()));
     }
 }

+ 5 - 9
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/PercentilesAgg.java

@@ -6,6 +6,7 @@
 package org.elasticsearch.xpack.sql.querydsl.agg;
 
 import org.elasticsearch.search.aggregations.AggregationBuilder;
+import org.elasticsearch.search.aggregations.metrics.PercentilesAggregationBuilder;
 
 import java.util.List;
 
@@ -15,20 +16,15 @@ public class PercentilesAgg extends LeafAgg {
 
     private final List<Double> percents;
 
-    public PercentilesAgg(String id, String fieldName, List<Double> percents) {
-        super(id, fieldName);
+    public PercentilesAgg(String id, AggSource source, List<Double> percents) {
+        super(id, source);
         this.percents = percents;
     }
 
-    public List<Double> percents() {
-        return percents;
-    }
-
     @Override
     AggregationBuilder toBuilder() {
         // TODO: look at keyed
-        return percentiles(id())
-                .field(fieldName())
-                .percentiles(percents.stream().mapToDouble(Double::doubleValue).toArray());
+        PercentilesAggregationBuilder builder = (PercentilesAggregationBuilder) addAggSource(percentiles(id()));
+        return builder.percentiles(percents.stream().mapToDouble(Double::doubleValue).toArray());
     }
 }

+ 3 - 3
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/StatsAgg.java

@@ -11,12 +11,12 @@ import static org.elasticsearch.search.aggregations.AggregationBuilders.stats;
 
 public class StatsAgg extends LeafAgg {
 
-    public StatsAgg(String id, String fieldName) {
-        super(id, fieldName);
+    public StatsAgg(String id, AggSource source) {
+        super(id, source);
     }
 
     @Override
     AggregationBuilder toBuilder() {
-        return stats(id()).field(fieldName());
+        return addAggSource(stats(id()));
     }
 }

+ 5 - 4
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/SumAgg.java

@@ -11,11 +11,12 @@ import static org.elasticsearch.search.aggregations.AggregationBuilders.sum;
 
 public class SumAgg extends LeafAgg {
 
-    public SumAgg(String id, String fieldName) {
-        super(id, fieldName);
+    public SumAgg(String id, AggSource source) {
+        super(id, source);
     }
 
-    @Override AggregationBuilder toBuilder() {
-        return sum(id()).field(fieldName());
+    @Override
+    AggregationBuilder toBuilder() {
+        return addAggSource(sum(id()));
     }
 }

+ 62 - 27
x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/agg/TopHitsAgg.java

@@ -6,9 +6,12 @@
 package org.elasticsearch.xpack.sql.querydsl.agg;
 
 import org.elasticsearch.search.aggregations.AggregationBuilder;
+import org.elasticsearch.search.aggregations.metrics.TopHitsAggregationBuilder;
 import org.elasticsearch.search.sort.FieldSortBuilder;
+import org.elasticsearch.search.sort.ScriptSortBuilder;
 import org.elasticsearch.search.sort.SortBuilder;
 import org.elasticsearch.search.sort.SortOrder;
+import org.elasticsearch.xpack.ql.expression.gen.script.Scripts;
 import org.elasticsearch.xpack.ql.type.DataType;
 import org.elasticsearch.xpack.sql.type.SqlDataTypes;
 
@@ -21,18 +24,23 @@ import static org.elasticsearch.xpack.ql.querydsl.container.Sort.Missing.LAST;
 
 public class TopHitsAgg extends LeafAgg {
 
-    private final String sortField;
+    private final AggSource sortSource;
     private final SortOrder sortOrder;
     private final DataType fieldDataType;
     private final DataType sortFieldDataType;
 
-
-    public TopHitsAgg(String id, String fieldName, DataType fieldDataType, String sortField,
-                      DataType sortFieldDataType, SortOrder sortOrder) {
-        super(id, fieldName);
-        this.sortField = sortField;
-        this.sortOrder = sortOrder;
+    public TopHitsAgg(
+        String id,
+        AggSource source,
+        DataType fieldDataType,
+        AggSource sortSource,
+        DataType sortFieldDataType,
+        SortOrder sortOrder
+    ) {
+        super(id, source);
         this.fieldDataType = fieldDataType;
+        this.sortSource = sortSource;
+        this.sortOrder = sortOrder;
         this.sortFieldDataType = sortFieldDataType;
     }
 
@@ -40,20 +48,51 @@ public class TopHitsAgg extends LeafAgg {
     AggregationBuilder toBuilder() {
         // Sort missing values (NULLs) as last to get the first/last non-null value
         List<SortBuilder<?>> sortBuilderList = new ArrayList<>(2);
-        if (sortField != null) {
+        if (sortSource!= null) {
+            if (sortSource.fieldName() != null) {
+                sortBuilderList.add(
+                    new FieldSortBuilder(sortSource.fieldName()).order(sortOrder)
+                        .missing(LAST.position())
+                        .unmappedType(sortFieldDataType.esType())
+                );
+            } else if (sortSource.script() != null) {
+                sortBuilderList.add(
+                    new ScriptSortBuilder(
+                        Scripts.nullSafeSort(sortSource.script()).toPainless(),
+                        sortSource.script().outputType().isNumeric()
+                            ? ScriptSortBuilder.ScriptSortType.NUMBER
+                            : ScriptSortBuilder.ScriptSortType.STRING
+                    ).order(sortOrder)
+                );
+            }
+        }
+
+        if (source().fieldName() != null) {
+            sortBuilderList.add(
+                new FieldSortBuilder(source().fieldName()).order(sortOrder).missing(LAST.position()).unmappedType(fieldDataType.esType())
+            );
+        } else {
             sortBuilderList.add(
-                new FieldSortBuilder(sortField)
-                    .order(sortOrder)
-                    .missing(LAST.position())
-                    .unmappedType(sortFieldDataType.esType()));
+                new ScriptSortBuilder(
+                    Scripts.nullSafeSort(source().script()).toPainless(),
+                    source().script().outputType().isNumeric()
+                        ? ScriptSortBuilder.ScriptSortType.NUMBER
+                        : ScriptSortBuilder.ScriptSortType.STRING
+                ).order(sortOrder)
+            );
         }
-        sortBuilderList.add(
-                new FieldSortBuilder(fieldName())
-                    .order(sortOrder)
-                    .missing(LAST.position())
-                    .unmappedType(fieldDataType.esType()));
 
-        return topHits(id()).docValueField(fieldName(), SqlDataTypes.format(fieldDataType)).sorts(sortBuilderList).size(1);
+        TopHitsAggregationBuilder builder = topHits(id());
+        if (source().fieldName() != null) {
+            return builder.docValueField(source().fieldName(), SqlDataTypes.format(fieldDataType)).sorts(sortBuilderList).size(1);
+        } else {
+            return builder.scriptField(id(), source().script().toPainless()).sorts(sortBuilderList).size(1);
+        }
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(super.hashCode(), sortSource, sortOrder, fieldDataType, sortFieldDataType);
     }
 
     @Override
@@ -64,17 +103,13 @@ public class TopHitsAgg extends LeafAgg {
         if (o == null || getClass() != o.getClass()) {
             return false;
         }
-        if (!super.equals(o)) {
+        if (super.equals(o) == false) {
             return false;
         }
         TopHitsAgg that = (TopHitsAgg) o;
-        return Objects.equals(sortField, that.sortField)
-            && sortOrder == that.sortOrder
-            && fieldDataType == that.fieldDataType;
-    }
-
-    @Override
-    public int hashCode() {
-        return Objects.hash(super.hashCode(), sortField, sortOrder, fieldDataType);
+        return Objects.equals(sortSource, that.sortSource) &&
+                sortOrder==that.sortOrder &&
+                Objects.equals(fieldDataType, that.fieldDataType) &&
+                Objects.equals(sortFieldDataType, that.sortFieldDataType);
     }
 }

+ 11 - 4
x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/analysis/analyzer/VerifierErrorMessagesTests.java

@@ -36,12 +36,12 @@ import static org.elasticsearch.xpack.ql.type.DataTypes.KEYWORD;
 import static org.elasticsearch.xpack.ql.type.DataTypes.OBJECT;
 import static org.elasticsearch.xpack.sql.types.SqlTypesTests.loadMapping;
 
-
 public class VerifierErrorMessagesTests extends ESTestCase {
 
-    private SqlParser parser = new SqlParser();
-    private IndexResolution indexResolution = IndexResolution.valid(new EsIndex("test",
-            loadMapping("mapping-multi-field-with-nested.json")));
+    private final SqlParser parser = new SqlParser();
+    private final IndexResolution indexResolution = IndexResolution.valid(
+        new EsIndex("test", loadMapping("mapping-multi-field-with-nested.json"))
+    );
 
     private String error(String sql) {
         return error(indexResolution, sql);
@@ -1098,4 +1098,11 @@ public class VerifierErrorMessagesTests extends ESTestCase {
         assertEquals("1:81: Literal ['bla'] of type [keyword] does not match type [boolean] of PIVOT column [bool]",
                 error("SELECT * FROM (SELECT int, keyword, bool FROM test) " + "PIVOT(AVG(int) FOR bool IN ('bla', true))"));
     }
+
+    public void testErrorMessageForMatrixStatsWithScalars() {
+        assertEquals("1:17: [KURTOSIS()] cannot be used on top of operators or scalars",
+                error("SELECT KURTOSIS(ABS(int * 10.123)) FROM test"));
+        assertEquals("1:17: [SKEWNESS()] cannot be used on top of operators or scalars",
+                error("SELECT SKEWNESS(ABS(int * 10.123)) FROM test"));
+    }
 }

+ 2 - 1
x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/execution/search/SourceGeneratorTests.java

@@ -26,6 +26,7 @@ import org.elasticsearch.xpack.ql.querydsl.query.MatchQuery;
 import org.elasticsearch.xpack.ql.tree.Source;
 import org.elasticsearch.xpack.ql.type.KeywordEsField;
 import org.elasticsearch.xpack.sql.expression.function.Score;
+import org.elasticsearch.xpack.sql.querydsl.agg.AggSource;
 import org.elasticsearch.xpack.sql.querydsl.agg.AvgAgg;
 import org.elasticsearch.xpack.sql.querydsl.agg.GroupByValue;
 import org.elasticsearch.xpack.sql.querydsl.container.QueryContainer;
@@ -133,7 +134,7 @@ public class SourceGeneratorTests extends ESTestCase {
     public void testNoSortIfAgg() {
         QueryContainer container = new QueryContainer()
                 .addGroups(singletonList(new GroupByValue("group_id", "group_column")))
-                .addAgg("group_id", new AvgAgg("agg_id", "avg_column"));
+                .addAgg("group_id", new AvgAgg("agg_id", AggSource.of("avg_column")));
         SearchSourceBuilder sourceBuilder = SourceGenerator.sourceBuilder(container, null, randomIntBetween(1, 10));
         assertNull(sourceBuilder.sorts());
     }

+ 124 - 8
x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java

@@ -14,9 +14,12 @@ import org.elasticsearch.search.aggregations.metrics.CardinalityAggregationBuild
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.ql.QlIllegalArgumentException;
 import org.elasticsearch.xpack.ql.expression.Alias;
+import org.elasticsearch.xpack.ql.expression.Attribute;
 import org.elasticsearch.xpack.ql.expression.Expression;
 import org.elasticsearch.xpack.ql.expression.FieldAttribute;
 import org.elasticsearch.xpack.ql.expression.Literal;
+import org.elasticsearch.xpack.ql.expression.function.FunctionDefinition;
+import org.elasticsearch.xpack.ql.expression.function.aggregate.AggregateFunction;
 import org.elasticsearch.xpack.ql.expression.function.aggregate.Count;
 import org.elasticsearch.xpack.ql.expression.gen.script.ScriptTemplate;
 import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.GreaterThan;
@@ -42,6 +45,11 @@ import org.elasticsearch.xpack.sql.SqlTestUtils;
 import org.elasticsearch.xpack.sql.analysis.analyzer.Analyzer;
 import org.elasticsearch.xpack.sql.analysis.analyzer.Verifier;
 import org.elasticsearch.xpack.sql.expression.function.SqlFunctionRegistry;
+import org.elasticsearch.xpack.sql.expression.function.aggregate.ExtendedStatsEnclosed;
+import org.elasticsearch.xpack.sql.expression.function.aggregate.MatrixStatsEnclosed;
+import org.elasticsearch.xpack.sql.expression.function.aggregate.Percentile;
+import org.elasticsearch.xpack.sql.expression.function.aggregate.PercentileRank;
+import org.elasticsearch.xpack.sql.expression.function.aggregate.TopHits;
 import org.elasticsearch.xpack.sql.expression.function.grouping.Histogram;
 import org.elasticsearch.xpack.sql.expression.function.scalar.Cast;
 import org.elasticsearch.xpack.sql.expression.function.scalar.datetime.DateTimeProcessor.DateTimeExtractor;
@@ -58,7 +66,6 @@ import org.elasticsearch.xpack.sql.querydsl.agg.GroupByDateHistogram;
 import org.elasticsearch.xpack.sql.stats.Metrics;
 import org.elasticsearch.xpack.sql.types.SqlTypesTests;
 import org.elasticsearch.xpack.sql.util.DateUtils;
-import org.junit.AfterClass;
 import org.junit.BeforeClass;
 
 import java.time.ZonedDateTime;
@@ -87,6 +94,7 @@ import static org.hamcrest.Matchers.startsWith;
 
 public class QueryTranslatorTests extends ESTestCase {
 
+    private static SqlFunctionRegistry sqlFunctionRegistry;
     private static SqlParser parser;
     private static Analyzer analyzer;
     private static Optimizer optimizer;
@@ -95,21 +103,16 @@ public class QueryTranslatorTests extends ESTestCase {
     @BeforeClass
     public static void init() {
         parser = new SqlParser();
+        sqlFunctionRegistry = new SqlFunctionRegistry();
 
         Map<String, EsField> mapping = SqlTypesTests.loadMapping("mapping-multi-field-variation.json");
         EsIndex test = new EsIndex("test", mapping);
         IndexResolution getIndexResult = IndexResolution.valid(test);
-        analyzer = new Analyzer(SqlTestUtils.TEST_CFG, new SqlFunctionRegistry(), getIndexResult, new Verifier(new Metrics()));
+        analyzer = new Analyzer(SqlTestUtils.TEST_CFG, sqlFunctionRegistry, getIndexResult, new Verifier(new Metrics()));
         optimizer = new Optimizer();
         planner = new Planner();
     }
 
-    @AfterClass
-    public static void destroy() {
-        parser = null;
-        analyzer = null;
-    }
-
     private LogicalPlan plan(String sql) {
         return analyzer.analyze(parser.createStatement(sql), true);
     }
@@ -1816,4 +1819,117 @@ public class QueryTranslatorTests extends ESTestCase {
                 "\"script\":{\"source\":\"InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.gt(params.a0,params.v0))\","
                 + "\"lang\":\"painless\",\"params\":{\"v0\":0}}"));
     }
+
+    public void testScriptsInsideAggregateFunctions() {
+        for (FunctionDefinition fd : sqlFunctionRegistry.listFunctions()) {
+            if (AggregateFunction.class.isAssignableFrom(fd.clazz()) && (MatrixStatsEnclosed.class.isAssignableFrom(fd.clazz()) == false)) {
+                String aggFunction = fd.name() + "(ABS((int * 10) / 3) + 1";
+                if (fd.clazz() == Percentile.class || fd.clazz() == PercentileRank.class) {
+                    aggFunction += ", 50";
+                }
+                aggFunction += ")";
+                PhysicalPlan p = optimizeAndPlan("SELECT " + aggFunction + " FROM test");
+                assertEquals(EsQueryExec.class, p.getClass());
+                EsQueryExec eqe = (EsQueryExec) p;
+                assertEquals(1, eqe.output().size());
+                assertEquals(aggFunction, eqe.output().get(0).qualifiedName());
+                if (fd.clazz() == Count.class) {
+                    assertThat(
+                        eqe.queryContainer().aggs().asAggBuilder().toString().replaceAll("\\s+", ""),
+                        containsString(
+                            ":{\"script\":{\"source\":\"InternalQlScriptUtils.isNotNull(InternalSqlScriptUtils.add("
+                                + "InternalSqlScriptUtils.abs(InternalSqlScriptUtils.div(InternalSqlScriptUtils.mul("
+                                + "InternalQlScriptUtils.docValue(doc,params.v0),params.v1),params.v2)),params.v3))\","
+                                + "\"lang\":\"painless\",\"params\":{\"v0\":\"int\",\"v1\":10,\"v2\":3,\"v3\":1}}"
+                        )
+                    );
+                } else {
+                    assertThat(
+                        eqe.queryContainer().aggs().asAggBuilder().toString().replaceAll("\\s+", ""),
+                        containsString(
+                            ":{\"script\":{\"source\":\"InternalSqlScriptUtils.add(InternalSqlScriptUtils.abs("
+                                + "InternalSqlScriptUtils.div(InternalSqlScriptUtils.mul(InternalQlScriptUtils.docValue("
+                                + "doc,params.v0),params.v1),params.v2)),params.v3)\",\"lang\":\"painless\",\"params\":{"
+                                + "\"v0\":\"int\",\"v1\":10,\"v2\":3,\"v3\":1}}"
+                        )
+                    );
+                }
+            }
+        }
+    }
+
+    public void testScriptsInsideAggregateFunctions_WithHaving() {
+        for (FunctionDefinition fd : sqlFunctionRegistry.listFunctions()) {
+            if (AggregateFunction.class.isAssignableFrom(fd.clazz())
+                    && (MatrixStatsEnclosed.class.isAssignableFrom(fd.clazz()) == false)
+                    // First/Last don't support having: https://github.com/elastic/elasticsearch/issues/37938
+                    && (TopHits.class.isAssignableFrom(fd.clazz()) == false)) {
+                String aggFunction = fd.name() + "(ABS((int * 10) / 3) + 1";
+                if (fd.clazz() == Percentile.class || fd.clazz() == PercentileRank.class) {
+                    aggFunction += ", 50";
+                }
+                aggFunction += ")";
+                LogicalPlan p = plan("SELECT " + aggFunction + ", keyword FROM test " + "GROUP BY keyword HAVING " + aggFunction + " > 20");
+                assertTrue(p instanceof Filter);
+                assertTrue(((Filter) p).child() instanceof Aggregate);
+                List<Attribute> outputs = ((Filter) p).child().output();
+                assertEquals(2, outputs.size());
+                assertEquals(aggFunction, outputs.get(0).qualifiedName());
+                assertEquals("test.keyword", outputs.get(1).qualifiedName());
+
+                Expression condition = ((Filter) p).condition();
+                assertFalse(condition.foldable());
+                QueryTranslation translation = QueryTranslator.toQuery(condition, true);
+                assertNull(translation.query);
+                AggFilter aggFilter = translation.aggFilter;
+                assertEquals(
+                    "InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.gt(params.a0,params.v0))",
+                    aggFilter.scriptTemplate().toString()
+                );
+                assertEquals("[{a=" + aggFunction + "}, {v=20}]", aggFilter.scriptTemplate().params().toString());
+            }
+        }
+    }
+
+    public void testScriptsInsideAggregateFunctions_ConvertedToStats() {
+        String aggFunctionArgs1 = "MIN(ABS((int * 10) / 3) + 1)";
+        String aggFunctionArgs2 = "MAX(ABS((int * 10) / 3) + 1)";
+        PhysicalPlan p = optimizeAndPlan("SELECT " + aggFunctionArgs1 + ", " + aggFunctionArgs2 + " FROM test");
+        assertEquals(EsQueryExec.class, p.getClass());
+        EsQueryExec eqe = (EsQueryExec) p;
+        assertEquals(2, eqe.output().size());
+        assertEquals(aggFunctionArgs1, eqe.output().get(0).qualifiedName());
+        assertEquals(aggFunctionArgs2, eqe.output().get(1).qualifiedName());
+        assertThat(
+            eqe.queryContainer().aggs().asAggBuilder().toString().replaceAll("\\s+", ""),
+            containsString(
+                "{\"stats\":{\"script\":{\"source\":\"InternalSqlScriptUtils.add(InternalSqlScriptUtils.abs("
+                    + "InternalSqlScriptUtils.div(InternalSqlScriptUtils.mul(InternalQlScriptUtils.docValue("
+                    + "doc,params.v0),params.v1),params.v2)),params.v3)\",\"lang\":\"painless\",\"params\":{"
+                    + "\"v0\":\"int\",\"v1\":10,\"v2\":3,\"v3\":1}}"
+            )
+        );
+    }
+
+    public void testScriptsInsideAggregateFunctions_ExtendedStats() {
+        for (FunctionDefinition fd : sqlFunctionRegistry.listFunctions()) {
+            if (ExtendedStatsEnclosed.class.isAssignableFrom(fd.clazz())) {
+                String aggFunction = fd.name() + "(ABS((int * 10) / 3) + 1)";
+                PhysicalPlan p = optimizeAndPlan("SELECT " + aggFunction + " FROM test");
+                assertEquals(EsQueryExec.class, p.getClass());
+                EsQueryExec eqe = (EsQueryExec) p;
+                assertEquals(1, eqe.output().size());
+                assertEquals(aggFunction, eqe.output().get(0).qualifiedName());
+                assertThat(
+                    eqe.queryContainer().aggs().asAggBuilder().toString().replaceAll("\\s+", ""),
+                    containsString(
+                        "{\"extended_stats\":{\"script\":{\"source\":\"InternalSqlScriptUtils.add(InternalSqlScriptUtils.abs("
+                            + "InternalSqlScriptUtils.div(InternalSqlScriptUtils.mul(InternalQlScriptUtils.docValue("
+                            + "doc,params.v0),params.v1),params.v2)),params.v3)\",\"lang\":\"painless\",\"params\":{"
+                            + "\"v0\":\"int\",\"v1\":10,\"v2\":3,\"v3\":1}}"
+                    )
+                );
+            }
+        }
+    }
 }