|
@@ -335,7 +335,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|
|
// item 3
|
|
|
assertNull(bulkShardRequest.items()[3].getPrimaryResponse());
|
|
|
actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[3].request());
|
|
|
- assertInferenceResults(useLegacyFormat, actualRequest, "obj.field1", EXPLICIT_NULL, 0);
|
|
|
+ assertInferenceResults(useLegacyFormat, actualRequest, "obj.field1", EXPLICIT_NULL, null);
|
|
|
|
|
|
// item 4
|
|
|
assertNull(bulkShardRequest.items()[4].getPrimaryResponse());
|
|
@@ -368,6 +368,59 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|
|
awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
|
|
|
}
|
|
|
|
|
|
+ @SuppressWarnings({ "unchecked", "rawtypes" })
|
|
|
+ public void testHandleEmptyInput() throws Exception {
|
|
|
+ StaticModel model = StaticModel.createRandomInstance();
|
|
|
+ ShardBulkInferenceActionFilter filter = createFilter(
|
|
|
+ threadPool,
|
|
|
+ Map.of(model.getInferenceEntityId(), model),
|
|
|
+ randomIntBetween(1, 10),
|
|
|
+ useLegacyFormat,
|
|
|
+ true
|
|
|
+ );
|
|
|
+
|
|
|
+ CountDownLatch chainExecuted = new CountDownLatch(1);
|
|
|
+ ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
|
|
|
+ try {
|
|
|
+ BulkShardRequest bulkShardRequest = (BulkShardRequest) request;
|
|
|
+ assertNull(bulkShardRequest.getInferenceFieldMap());
|
|
|
+ assertThat(bulkShardRequest.items().length, equalTo(3));
|
|
|
+
|
|
|
+ // Create with Empty string
|
|
|
+ assertNull(bulkShardRequest.items()[0].getPrimaryResponse());
|
|
|
+ IndexRequest actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[0].request());
|
|
|
+ assertInferenceResults(useLegacyFormat, actualRequest, "semantic_text_field", "", 0);
|
|
|
+
|
|
|
+ // Create with whitespace only
|
|
|
+ assertNull(bulkShardRequest.items()[1].getPrimaryResponse());
|
|
|
+ actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[1].request());
|
|
|
+ assertInferenceResults(useLegacyFormat, actualRequest, "semantic_text_field", " ", 0);
|
|
|
+
|
|
|
+ // Update with multiple Whitespaces
|
|
|
+ assertNull(bulkShardRequest.items()[2].getPrimaryResponse());
|
|
|
+ actualRequest = getIndexRequestOrNull(bulkShardRequest.items()[2].request());
|
|
|
+ assertInferenceResults(useLegacyFormat, actualRequest, "semantic_text_field", " ", 0);
|
|
|
+ } finally {
|
|
|
+ chainExecuted.countDown();
|
|
|
+ }
|
|
|
+ };
|
|
|
+ ActionListener actionListener = mock(ActionListener.class);
|
|
|
+ Task task = mock(Task.class);
|
|
|
+ Map<String, InferenceFieldMetadata> inferenceFieldMap = Map.of(
|
|
|
+ "semantic_text_field",
|
|
|
+ new InferenceFieldMetadata("semantic_text_field", model.getInferenceEntityId(), new String[] { "semantic_text_field" })
|
|
|
+ );
|
|
|
+
|
|
|
+ BulkItemRequest[] items = new BulkItemRequest[3];
|
|
|
+ items[0] = new BulkItemRequest(0, new IndexRequest("index").source(Map.of("semantic_text_field", "")));
|
|
|
+ items[1] = new BulkItemRequest(1, new IndexRequest("index").source(Map.of("semantic_text_field", " ")));
|
|
|
+ items[2] = new BulkItemRequest(2, new UpdateRequest().doc(new IndexRequest("index").source(Map.of("semantic_text_field", " "))));
|
|
|
+ BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items);
|
|
|
+ request.setInferenceFieldMap(inferenceFieldMap);
|
|
|
+ filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain);
|
|
|
+ awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
|
|
|
+ }
|
|
|
+
|
|
|
@SuppressWarnings({ "unchecked", "rawtypes" })
|
|
|
public void testManyRandomDocs() throws Exception {
|
|
|
Map<String, StaticModel> inferenceModelMap = new HashMap<>();
|
|
@@ -591,7 +644,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|
|
IndexRequest request,
|
|
|
String fieldName,
|
|
|
Object expectedOriginalValue,
|
|
|
- int expectedChunkCount
|
|
|
+ Integer expectedChunkCount
|
|
|
) {
|
|
|
final Map<String, Object> requestMap = request.sourceAsMap();
|
|
|
if (useLegacyFormat) {
|
|
@@ -601,13 +654,11 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|
|
);
|
|
|
|
|
|
List<Object> chunks = (List<Object>) XContentMapValues.extractValue(getChunksFieldName(fieldName), requestMap);
|
|
|
- if (expectedChunkCount > 0) {
|
|
|
+ if (expectedChunkCount == null) {
|
|
|
+ assertNull(chunks);
|
|
|
+ } else {
|
|
|
assertNotNull(chunks);
|
|
|
assertThat(chunks.size(), equalTo(expectedChunkCount));
|
|
|
- } else {
|
|
|
- // If the expected chunk count is 0, we expect that no inference has been performed. In this case, the source should not be
|
|
|
- // transformed, and thus the semantic text field structure should not be created.
|
|
|
- assertNull(chunks);
|
|
|
}
|
|
|
} else {
|
|
|
assertThat(XContentMapValues.extractValue(fieldName, requestMap, EXPLICIT_NULL), equalTo(expectedOriginalValue));
|
|
@@ -627,8 +678,11 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|
|
inferenceMetadataFields,
|
|
|
EXPLICIT_NULL
|
|
|
);
|
|
|
+
|
|
|
+ // When using the new format, the chunks field should always exist
|
|
|
+ int expectedSize = expectedChunkCount == null ? 0 : expectedChunkCount;
|
|
|
assertNotNull(chunks);
|
|
|
- assertThat(chunks.size(), equalTo(expectedChunkCount));
|
|
|
+ assertThat(chunks.size(), equalTo(expectedSize));
|
|
|
}
|
|
|
}
|
|
|
|