|
@@ -23,7 +23,6 @@ import java.util.List;
|
|
|
import java.util.Locale;
|
|
|
import java.util.function.Consumer;
|
|
|
import java.util.stream.Collectors;
|
|
|
-import java.util.stream.Stream;
|
|
|
|
|
|
import javax.lang.model.element.ExecutableElement;
|
|
|
import javax.lang.model.element.Modifier;
|
|
@@ -51,6 +50,7 @@ import static org.elasticsearch.compute.gen.Types.LIST_INTEGER;
|
|
|
import static org.elasticsearch.compute.gen.Types.LONG_BLOCK;
|
|
|
import static org.elasticsearch.compute.gen.Types.LONG_VECTOR;
|
|
|
import static org.elasticsearch.compute.gen.Types.PAGE;
|
|
|
+import static org.elasticsearch.compute.gen.Types.SEEN_GROUP_IDS;
|
|
|
import static org.elasticsearch.compute.gen.Types.blockType;
|
|
|
import static org.elasticsearch.compute.gen.Types.vectorType;
|
|
|
|
|
@@ -93,10 +93,10 @@ public class GroupingAggregatorImplementer {
|
|
|
this.combineIntermediate = findMethod(declarationType, "combineIntermediate");
|
|
|
this.evaluateFinal = findMethod(declarationType, "evaluateFinal");
|
|
|
this.valuesIsBytesRef = BYTES_REF.equals(TypeName.get(combine.getParameters().get(combine.getParameters().size() - 1).asType()));
|
|
|
- List<Parameter> createParameters = init.getParameters().stream().map(Parameter::from).toList();
|
|
|
- this.createParameters = createParameters.stream().anyMatch(p -> p.type().equals(BIG_ARRAYS))
|
|
|
- ? createParameters
|
|
|
- : Stream.concat(Stream.of(new Parameter(BIG_ARRAYS, "bigArrays")), createParameters.stream()).toList();
|
|
|
+ this.createParameters = init.getParameters().stream().map(Parameter::from).collect(Collectors.toList());
|
|
|
+ if (false == createParameters.stream().anyMatch(p -> p.type().equals(BIG_ARRAYS))) {
|
|
|
+ createParameters.add(0, new Parameter(BIG_ARRAYS, "bigArrays"));
|
|
|
+ }
|
|
|
|
|
|
this.implementation = ClassName.get(
|
|
|
elements.getPackageOf(declarationType).toString(),
|
|
@@ -161,10 +161,8 @@ public class GroupingAggregatorImplementer {
|
|
|
builder.addMethod(prepareProcessPage());
|
|
|
builder.addMethod(addRawInputLoop(LONG_VECTOR, valueBlockType(init, combine)));
|
|
|
builder.addMethod(addRawInputLoop(LONG_VECTOR, valueVectorType(init, combine)));
|
|
|
- builder.addMethod(addRawInputLoop(LONG_VECTOR, BLOCK));
|
|
|
builder.addMethod(addRawInputLoop(LONG_BLOCK, valueBlockType(init, combine)));
|
|
|
builder.addMethod(addRawInputLoop(LONG_BLOCK, valueVectorType(init, combine)));
|
|
|
- builder.addMethod(addRawInputLoop(LONG_BLOCK, BLOCK));
|
|
|
builder.addMethod(addIntermediateInput());
|
|
|
builder.addMethod(addIntermediateRowInput());
|
|
|
builder.addMethod(evaluateIntermediate());
|
|
@@ -250,21 +248,24 @@ public class GroupingAggregatorImplementer {
|
|
|
private MethodSpec prepareProcessPage() {
|
|
|
MethodSpec.Builder builder = MethodSpec.methodBuilder("prepareProcessPage");
|
|
|
builder.addAnnotation(Override.class).addModifiers(Modifier.PUBLIC).returns(GROUPING_AGGREGATOR_FUNCTION_ADD_INPUT);
|
|
|
- builder.addParameter(PAGE, "page");
|
|
|
+ builder.addParameter(SEEN_GROUP_IDS, "seenGroupIds").addParameter(PAGE, "page");
|
|
|
|
|
|
builder.addStatement("$T uncastValuesBlock = page.getBlock(channels.get(0))", BLOCK);
|
|
|
+
|
|
|
builder.beginControlFlow("if (uncastValuesBlock.areAllValuesNull())");
|
|
|
{
|
|
|
- builder.addStatement(
|
|
|
- "return $L",
|
|
|
- addInput(b -> b.addStatement("addRawInputAllNulls(positionOffset, groupIds, uncastValuesBlock)"))
|
|
|
- );
|
|
|
+ builder.addStatement("state.enableGroupIdTracking(seenGroupIds)");
|
|
|
+ builder.addStatement("return $L", addInput(b -> {}));
|
|
|
}
|
|
|
builder.endControlFlow();
|
|
|
+
|
|
|
builder.addStatement("$T valuesBlock = ($T) uncastValuesBlock", valueBlockType(init, combine), valueBlockType(init, combine));
|
|
|
builder.addStatement("$T valuesVector = valuesBlock.asVector()", valueVectorType(init, combine));
|
|
|
builder.beginControlFlow("if (valuesVector == null)");
|
|
|
{
|
|
|
+ builder.beginControlFlow("if (valuesBlock.mayHaveNulls())");
|
|
|
+ builder.addStatement("state.enableGroupIdTracking(seenGroupIds)");
|
|
|
+ builder.endControlFlow();
|
|
|
builder.addStatement("return $L", addInput(b -> b.addStatement("addRawInput(positionOffset, groupIds, valuesBlock)")));
|
|
|
}
|
|
|
builder.endControlFlow();
|
|
@@ -299,18 +300,8 @@ public class GroupingAggregatorImplementer {
|
|
|
*/
|
|
|
private MethodSpec addRawInputLoop(TypeName groupsType, TypeName valuesType) {
|
|
|
boolean groupsIsBlock = groupsType.toString().endsWith("Block");
|
|
|
- enum ValueType {
|
|
|
- VECTOR,
|
|
|
- TYPED_BLOCK,
|
|
|
- NULL_ONLY_BLOCK
|
|
|
- }
|
|
|
- ValueType valueType = valuesType.equals(BLOCK) ? ValueType.NULL_ONLY_BLOCK
|
|
|
- : valuesType.toString().endsWith("Block") ? ValueType.TYPED_BLOCK
|
|
|
- : ValueType.VECTOR;
|
|
|
+ boolean valuesIsBlock = valuesType.toString().endsWith("Block");
|
|
|
String methodName = "addRawInput";
|
|
|
- if (valueType == ValueType.NULL_ONLY_BLOCK) {
|
|
|
- methodName += "AllNulls";
|
|
|
- }
|
|
|
MethodSpec.Builder builder = MethodSpec.methodBuilder(methodName);
|
|
|
builder.addModifiers(Modifier.PRIVATE);
|
|
|
builder.addParameter(TypeName.INT, "positionOffset").addParameter(groupsType, "groups").addParameter(valuesType, "values");
|
|
@@ -333,23 +324,17 @@ public class GroupingAggregatorImplementer {
|
|
|
builder.addStatement("int groupId = Math.toIntExact(groups.getLong(groupPosition))");
|
|
|
}
|
|
|
|
|
|
- switch (valueType) {
|
|
|
- case VECTOR -> combineRawInput(builder, "values", "groupPosition + positionOffset");
|
|
|
- case TYPED_BLOCK -> {
|
|
|
- builder.beginControlFlow("if (values.isNull(groupPosition + positionOffset))");
|
|
|
- builder.addStatement("state.putNull(groupId)");
|
|
|
- builder.addStatement("continue");
|
|
|
- builder.endControlFlow();
|
|
|
- builder.addStatement("int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset)");
|
|
|
- builder.addStatement("int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset)");
|
|
|
- builder.beginControlFlow("for (int v = valuesStart; v < valuesEnd; v++)");
|
|
|
- combineRawInput(builder, "values", "v");
|
|
|
- builder.endControlFlow();
|
|
|
- }
|
|
|
- case NULL_ONLY_BLOCK -> {
|
|
|
- builder.addStatement("assert values.isNull(groupPosition + positionOffset)");
|
|
|
- builder.addStatement("state.putNull(groupPosition + positionOffset)");
|
|
|
- }
|
|
|
+ if (valuesIsBlock) {
|
|
|
+ builder.beginControlFlow("if (values.isNull(groupPosition + positionOffset))");
|
|
|
+ builder.addStatement("continue");
|
|
|
+ builder.endControlFlow();
|
|
|
+ builder.addStatement("int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset)");
|
|
|
+ builder.addStatement("int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset)");
|
|
|
+ builder.beginControlFlow("for (int v = valuesStart; v < valuesEnd; v++)");
|
|
|
+ combineRawInput(builder, "values", "v");
|
|
|
+ builder.endControlFlow();
|
|
|
+ } else {
|
|
|
+ combineRawInput(builder, "values", "groupPosition + positionOffset");
|
|
|
}
|
|
|
|
|
|
if (groupsIsBlock) {
|
|
@@ -391,7 +376,7 @@ public class GroupingAggregatorImplementer {
|
|
|
String offsetVariable
|
|
|
) {
|
|
|
builder.addStatement(
|
|
|
- "state.set($T.combine(state.getOrDefault(groupId), $L.$L($L)), groupId)",
|
|
|
+ "state.set(groupId, $T.combine(state.getOrDefault(groupId), $L.$L($L)))",
|
|
|
declarationType,
|
|
|
blockVariable,
|
|
|
secondParameterGetter,
|
|
@@ -426,6 +411,7 @@ public class GroupingAggregatorImplementer {
|
|
|
builder.addParameter(LONG_VECTOR, "groups");
|
|
|
builder.addParameter(PAGE, "page");
|
|
|
|
|
|
+ builder.addStatement("state.enableGroupIdTracking(new $T.Empty())", SEEN_GROUP_IDS);
|
|
|
builder.addStatement("assert channels.size() == intermediateBlockCount()");
|
|
|
int count = 0;
|
|
|
for (var interState : intermediateState) {
|
|
@@ -461,13 +447,11 @@ public class GroupingAggregatorImplementer {
|
|
|
var name = intermediateState.get(0).name();
|
|
|
var m = vectorAccessorName(intermediateState.get(0).elementType());
|
|
|
builder.addStatement(
|
|
|
- "state.set($T.combine(state.getOrDefault(groupId), $L.$L(groupPosition + positionOffset)), groupId)",
|
|
|
+ "state.set(groupId, $T.combine($L.$L(groupPosition + positionOffset), state.getOrDefault(groupId)))",
|
|
|
declarationType,
|
|
|
name,
|
|
|
m
|
|
|
);
|
|
|
- builder.nextControlFlow("else");
|
|
|
- builder.addStatement("state.putNull(groupId)");
|
|
|
builder.endControlFlow();
|
|
|
}
|
|
|
} else {
|
|
@@ -493,9 +477,7 @@ public class GroupingAggregatorImplementer {
|
|
|
private void combineStates(MethodSpec.Builder builder) {
|
|
|
if (combineStates == null) {
|
|
|
builder.beginControlFlow("if (inState.hasValue(position))");
|
|
|
- builder.addStatement("state.set($T.combine(state.getOrDefault(groupId), inState.get(position)), groupId)", declarationType);
|
|
|
- builder.nextControlFlow("else");
|
|
|
- builder.addStatement("state.putNull(groupId)");
|
|
|
+ builder.addStatement("state.set(groupId, $T.combine(state.getOrDefault(groupId), inState.get(position)))", declarationType);
|
|
|
builder.endControlFlow();
|
|
|
return;
|
|
|
}
|
|
@@ -512,6 +494,7 @@ public class GroupingAggregatorImplementer {
|
|
|
}
|
|
|
builder.endControlFlow();
|
|
|
builder.addStatement("$T inState = (($T) input).state", stateType, implementation);
|
|
|
+ builder.addStatement("state.enableGroupIdTracking(new $T.Empty())", SEEN_GROUP_IDS);
|
|
|
combineStates(builder);
|
|
|
return builder.build();
|
|
|
}
|