Prechádzať zdrojové kódy

[Transform] Adding null check to fix potential NPE (#96785)

AjeetNathawat 2 rokov pred
rodič
commit
c9e37ef7c5

+ 6 - 0
docs/changelog/96785.yaml

@@ -0,0 +1,6 @@
+pr: 96785
+summary: Adding null check to fix potential NPE
+area: Transform
+type: enhancement
+issues:
+  - 96781

+ 6 - 1
x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/transforms/common/AbstractCompositeAggFunction.java

@@ -33,6 +33,7 @@ import org.elasticsearch.xpack.core.transform.transforms.TransformProgress;
 import org.elasticsearch.xpack.transform.transforms.Function;
 import org.elasticsearch.xpack.transform.transforms.pivot.AggregationResultUtils;
 
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
 import java.util.stream.Collectors;
@@ -88,9 +89,13 @@ public abstract class AbstractCompositeAggFunction implements Function {
                         return;
                     }
                     final CompositeAggregation agg = aggregations.get(COMPOSITE_AGGREGATION_NAME);
+                    if (agg == null || agg.getBuckets().isEmpty()) {
+                        listener.onResponse(Collections.emptyList());
+                        return;
+                    }
+
                     TransformIndexerStats stats = new TransformIndexerStats();
                     TransformProgress progress = new TransformProgress();
-
                     List<Map<String, Object>> docs = extractResults(agg, fieldTypeMap, stats, progress).map(
                         this::documentTransformationFunction
                     ).collect(Collectors.toList());

+ 94 - 0
x-pack/plugin/transform/src/test/java/org/elasticsearch/xpack/transform/transforms/pivot/PivotTests.java

@@ -56,6 +56,7 @@ import org.junit.Before;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -70,6 +71,7 @@ import static org.hamcrest.CoreMatchers.instanceOf;
 import static org.hamcrest.CoreMatchers.is;
 import static org.hamcrest.Matchers.contains;
 import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.empty;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.notNullValue;
 import static org.hamcrest.Matchers.nullValue;
@@ -265,6 +267,60 @@ public class PivotTests extends ESTestCase {
         assertThat(pivot.processSearchResponse(searchResponseFromAggs(aggs), null, null, null, null, null), is(nullValue()));
     }
 
+    public void testPreviewForEmptyAggregation() throws Exception {
+        Function pivot = new Pivot(
+            PivotConfigTests.randomPivotConfig(),
+            SettingsConfigTests.randomSettingsConfig(),
+            Version.CURRENT,
+            Collections.emptySet()
+        );
+
+        CountDownLatch latch = new CountDownLatch(1);
+        final AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
+        final AtomicReference<List<Map<String, Object>>> responseHolder = new AtomicReference<>();
+
+        Client emptyAggregationClient = new MyMockClientWithEmptyAggregation("empty aggregation test for preview");
+        pivot.preview(emptyAggregationClient, null, new HashMap<>(), new SourceConfig("test"), null, 1, ActionListener.wrap(r -> {
+            responseHolder.set(r);
+            latch.countDown();
+        }, e -> {
+            exceptionHolder.set(e);
+            latch.countDown();
+        }));
+        assertTrue(latch.await(100, TimeUnit.MILLISECONDS));
+        emptyAggregationClient.close();
+
+        assertThat(exceptionHolder.get(), is(nullValue()));
+        assertThat(responseHolder.get(), is(empty()));
+    }
+
+    public void testPreviewForCompositeAggregation() throws Exception {
+        Function pivot = new Pivot(
+            PivotConfigTests.randomPivotConfig(),
+            SettingsConfigTests.randomSettingsConfig(),
+            Version.CURRENT,
+            Collections.emptySet()
+        );
+
+        CountDownLatch latch = new CountDownLatch(1);
+        final AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
+        final AtomicReference<List<Map<String, Object>>> responseHolder = new AtomicReference<>();
+
+        Client compositeAggregationClient = new MyMockClientWithCompositeAggregation("composite aggregation test for preview");
+        pivot.preview(compositeAggregationClient, null, new HashMap<>(), new SourceConfig("test"), null, 1, ActionListener.wrap(r -> {
+            responseHolder.set(r);
+            latch.countDown();
+        }, e -> {
+            exceptionHolder.set(e);
+            latch.countDown();
+        }));
+        assertTrue(latch.await(100, TimeUnit.MILLISECONDS));
+        compositeAggregationClient.close();
+
+        assertThat(exceptionHolder.get(), is(nullValue()));
+        assertThat(responseHolder.get(), is(empty()));
+    }
+
     private static SearchResponse searchResponseFromAggs(Aggregations aggs) {
         SearchResponseSections sections = new SearchResponseSections(null, aggs, null, false, null, null, 1);
         SearchResponse searchResponse = new SearchResponse(sections, null, 10, 5, 0, 0, new ShardSearchFailure[0], null);
@@ -326,6 +382,44 @@ public class PivotTests extends ESTestCase {
         }
     }
 
+    private class MyMockClientWithEmptyAggregation extends NoOpClient {
+        MyMockClientWithEmptyAggregation(String testName) {
+            super(testName);
+        }
+
+        @SuppressWarnings("unchecked")
+        @Override
+        protected <Request extends ActionRequest, Response extends ActionResponse> void doExecute(
+            ActionType<Response> action,
+            Request request,
+            ActionListener<Response> listener
+        ) {
+            SearchResponse response = mock(SearchResponse.class);
+            when(response.getAggregations()).thenReturn(new Aggregations(List.of()));
+            listener.onResponse((Response) response);
+        }
+    }
+
+    private class MyMockClientWithCompositeAggregation extends NoOpClient {
+        MyMockClientWithCompositeAggregation(String testName) {
+            super(testName);
+        }
+
+        @SuppressWarnings("unchecked")
+        @Override
+        protected <Request extends ActionRequest, Response extends ActionResponse> void doExecute(
+            ActionType<Response> action,
+            Request request,
+            ActionListener<Response> listener
+        ) {
+            SearchResponse response = mock(SearchResponse.class);
+            CompositeAggregation compositeAggregation = mock(CompositeAggregation.class);
+            when(response.getAggregations()).thenReturn(new Aggregations(List.of(compositeAggregation)));
+            when(compositeAggregation.getBuckets()).thenReturn(new ArrayList<>());
+            listener.onResponse((Response) response);
+        }
+    }
+
     private PivotConfig getValidPivotConfig() throws IOException {
         return new PivotConfig(GroupConfigTests.randomGroupConfig(), getValidAggregationConfig(), null);
     }