|
@@ -0,0 +1,225 @@
|
|
|
+// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
|
|
+// or more contributor license agreements. Licensed under the Elastic License
|
|
|
+// 2.0; you may not use this file except in compliance with the Elastic License
|
|
|
+// 2.0.
|
|
|
+package org.elasticsearch.compute.aggregation;
|
|
|
+
|
|
|
+import java.lang.Integer;
|
|
|
+import java.lang.Override;
|
|
|
+import java.lang.String;
|
|
|
+import java.lang.StringBuilder;
|
|
|
+import java.util.List;
|
|
|
+import org.elasticsearch.compute.data.Block;
|
|
|
+import org.elasticsearch.compute.data.DoubleBlock;
|
|
|
+import org.elasticsearch.compute.data.DoubleVector;
|
|
|
+import org.elasticsearch.compute.data.ElementType;
|
|
|
+import org.elasticsearch.compute.data.FloatBlock;
|
|
|
+import org.elasticsearch.compute.data.FloatVector;
|
|
|
+import org.elasticsearch.compute.data.IntBlock;
|
|
|
+import org.elasticsearch.compute.data.IntVector;
|
|
|
+import org.elasticsearch.compute.data.LongBlock;
|
|
|
+import org.elasticsearch.compute.data.LongVector;
|
|
|
+import org.elasticsearch.compute.data.Page;
|
|
|
+import org.elasticsearch.compute.operator.DriverContext;
|
|
|
+
|
|
|
+/**
|
|
|
+ * {@link GroupingAggregatorFunction} implementation for {@link StdDevFloatAggregator}.
|
|
|
+ * This class is generated. Do not edit it.
|
|
|
+ */
|
|
|
+public final class StdDevFloatGroupingAggregatorFunction implements GroupingAggregatorFunction {
|
|
|
+ private static final List<IntermediateStateDesc> INTERMEDIATE_STATE_DESC = List.of(
|
|
|
+ new IntermediateStateDesc("mean", ElementType.DOUBLE),
|
|
|
+ new IntermediateStateDesc("m2", ElementType.DOUBLE),
|
|
|
+ new IntermediateStateDesc("count", ElementType.LONG) );
|
|
|
+
|
|
|
+ private final StdDevStates.GroupingState state;
|
|
|
+
|
|
|
+ private final List<Integer> channels;
|
|
|
+
|
|
|
+ private final DriverContext driverContext;
|
|
|
+
|
|
|
+ public StdDevFloatGroupingAggregatorFunction(List<Integer> channels,
|
|
|
+ StdDevStates.GroupingState state, DriverContext driverContext) {
|
|
|
+ this.channels = channels;
|
|
|
+ this.state = state;
|
|
|
+ this.driverContext = driverContext;
|
|
|
+ }
|
|
|
+
|
|
|
+ public static StdDevFloatGroupingAggregatorFunction create(List<Integer> channels,
|
|
|
+ DriverContext driverContext) {
|
|
|
+ return new StdDevFloatGroupingAggregatorFunction(channels, StdDevFloatAggregator.initGrouping(driverContext.bigArrays()), driverContext);
|
|
|
+ }
|
|
|
+
|
|
|
+ public static List<IntermediateStateDesc> intermediateStateDesc() {
|
|
|
+ return INTERMEDIATE_STATE_DESC;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public int intermediateBlockCount() {
|
|
|
+ return INTERMEDIATE_STATE_DESC.size();
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds,
|
|
|
+ Page page) {
|
|
|
+ FloatBlock valuesBlock = page.getBlock(channels.get(0));
|
|
|
+ FloatVector valuesVector = valuesBlock.asVector();
|
|
|
+ if (valuesVector == null) {
|
|
|
+ if (valuesBlock.mayHaveNulls()) {
|
|
|
+ state.enableGroupIdTracking(seenGroupIds);
|
|
|
+ }
|
|
|
+ return new GroupingAggregatorFunction.AddInput() {
|
|
|
+ @Override
|
|
|
+ public void add(int positionOffset, IntBlock groupIds) {
|
|
|
+ addRawInput(positionOffset, groupIds, valuesBlock);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void add(int positionOffset, IntVector groupIds) {
|
|
|
+ addRawInput(positionOffset, groupIds, valuesBlock);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void close() {
|
|
|
+ }
|
|
|
+ };
|
|
|
+ }
|
|
|
+ return new GroupingAggregatorFunction.AddInput() {
|
|
|
+ @Override
|
|
|
+ public void add(int positionOffset, IntBlock groupIds) {
|
|
|
+ addRawInput(positionOffset, groupIds, valuesVector);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void add(int positionOffset, IntVector groupIds) {
|
|
|
+ addRawInput(positionOffset, groupIds, valuesVector);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void close() {
|
|
|
+ }
|
|
|
+ };
|
|
|
+ }
|
|
|
+
|
|
|
+ private void addRawInput(int positionOffset, IntVector groups, FloatBlock values) {
|
|
|
+ for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
|
|
|
+ int groupId = groups.getInt(groupPosition);
|
|
|
+ if (values.isNull(groupPosition + positionOffset)) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
|
|
|
+ int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
|
|
|
+ for (int v = valuesStart; v < valuesEnd; v++) {
|
|
|
+ StdDevFloatAggregator.combine(state, groupId, values.getFloat(v));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private void addRawInput(int positionOffset, IntVector groups, FloatVector values) {
|
|
|
+ for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
|
|
|
+ int groupId = groups.getInt(groupPosition);
|
|
|
+ StdDevFloatAggregator.combine(state, groupId, values.getFloat(groupPosition + positionOffset));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private void addRawInput(int positionOffset, IntBlock groups, FloatBlock values) {
|
|
|
+ for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
|
|
|
+ if (groups.isNull(groupPosition)) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ int groupStart = groups.getFirstValueIndex(groupPosition);
|
|
|
+ int groupEnd = groupStart + groups.getValueCount(groupPosition);
|
|
|
+ for (int g = groupStart; g < groupEnd; g++) {
|
|
|
+ int groupId = groups.getInt(g);
|
|
|
+ if (values.isNull(groupPosition + positionOffset)) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset);
|
|
|
+ int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset);
|
|
|
+ for (int v = valuesStart; v < valuesEnd; v++) {
|
|
|
+ StdDevFloatAggregator.combine(state, groupId, values.getFloat(v));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private void addRawInput(int positionOffset, IntBlock groups, FloatVector values) {
|
|
|
+ for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
|
|
|
+ if (groups.isNull(groupPosition)) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ int groupStart = groups.getFirstValueIndex(groupPosition);
|
|
|
+ int groupEnd = groupStart + groups.getValueCount(groupPosition);
|
|
|
+ for (int g = groupStart; g < groupEnd; g++) {
|
|
|
+ int groupId = groups.getInt(g);
|
|
|
+ StdDevFloatAggregator.combine(state, groupId, values.getFloat(groupPosition + positionOffset));
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) {
|
|
|
+ state.enableGroupIdTracking(seenGroupIds);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void addIntermediateInput(int positionOffset, IntVector groups, Page page) {
|
|
|
+ state.enableGroupIdTracking(new SeenGroupIds.Empty());
|
|
|
+ assert channels.size() == intermediateBlockCount();
|
|
|
+ Block meanUncast = page.getBlock(channels.get(0));
|
|
|
+ if (meanUncast.areAllValuesNull()) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ DoubleVector mean = ((DoubleBlock) meanUncast).asVector();
|
|
|
+ Block m2Uncast = page.getBlock(channels.get(1));
|
|
|
+ if (m2Uncast.areAllValuesNull()) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ DoubleVector m2 = ((DoubleBlock) m2Uncast).asVector();
|
|
|
+ Block countUncast = page.getBlock(channels.get(2));
|
|
|
+ if (countUncast.areAllValuesNull()) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ LongVector count = ((LongBlock) countUncast).asVector();
|
|
|
+ assert mean.getPositionCount() == m2.getPositionCount() && mean.getPositionCount() == count.getPositionCount();
|
|
|
+ for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) {
|
|
|
+ int groupId = groups.getInt(groupPosition);
|
|
|
+ StdDevFloatAggregator.combineIntermediate(state, groupId, mean.getDouble(groupPosition + positionOffset), m2.getDouble(groupPosition + positionOffset), count.getLong(groupPosition + positionOffset));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) {
|
|
|
+ if (input.getClass() != getClass()) {
|
|
|
+ throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass());
|
|
|
+ }
|
|
|
+ StdDevStates.GroupingState inState = ((StdDevFloatGroupingAggregatorFunction) input).state;
|
|
|
+ state.enableGroupIdTracking(new SeenGroupIds.Empty());
|
|
|
+ StdDevFloatAggregator.combineStates(state, groupId, inState, position);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) {
|
|
|
+ state.toIntermediate(blocks, offset, selected, driverContext);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void evaluateFinal(Block[] blocks, int offset, IntVector selected,
|
|
|
+ DriverContext driverContext) {
|
|
|
+ blocks[offset] = StdDevFloatAggregator.evaluateFinal(state, selected, driverContext);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public String toString() {
|
|
|
+ StringBuilder sb = new StringBuilder();
|
|
|
+ sb.append(getClass().getSimpleName()).append("[");
|
|
|
+ sb.append("channels=").append(channels);
|
|
|
+ sb.append("]");
|
|
|
+ return sb.toString();
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void close() {
|
|
|
+ state.close();
|
|
|
+ }
|
|
|
+}
|