浏览代码

Optimize memory usage in ShardBulkInferenceActionFilter (#124313)

This refactor improves memory efficiency by processing inference requests in batches, capped by a max input length.

Changes include:
- A new dynamic operator setting to control the maximum batch size in bytes.
- Dropping input data from inference responses when the legacy semantic text format isn’t used, saving memory.
- Clearing inference results dynamically after each bulk item to free up memory sooner.

This is a step toward enabling circuit breakers to better handle memory usage when dealing with large inputs.
Jim Ferenczi 7 月之前
父节点
当前提交
361b51d436

+ 5 - 0
docs/changelog/124313.yaml

@@ -0,0 +1,5 @@
+pr: 124313
+summary: Optimize memory usage in `ShardBulkInferenceActionFilter`
+area: Search
+type: enhancement
+issues: []

+ 8 - 1
x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterIT.java

@@ -20,6 +20,7 @@ import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.action.update.UpdateRequestBuilder;
 import org.elasticsearch.cluster.metadata.IndexMetadata;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.index.IndexSettings;
 import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
 import org.elasticsearch.index.mapper.SourceFieldMapper;
@@ -44,6 +45,7 @@ import java.util.Locale;
 import java.util.Map;
 import java.util.Set;
 
+import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
 import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
@@ -85,7 +87,12 @@ public class ShardBulkInferenceActionFilterIT extends ESIntegTestCase {
 
     @Override
     protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
-        return Settings.builder().put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial").build();
+        long batchSizeInBytes = randomLongBetween(0, ByteSizeValue.ofKb(1).getBytes());
+        return Settings.builder()
+            .put(otherSettings)
+            .put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial")
+            .put(INDICES_INFERENCE_BATCH_SIZE.getKey(), ByteSizeValue.ofBytes(batchSizeInBytes))
+            .build();
     }
 
     @Override

+ 2 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

@@ -142,6 +142,7 @@ import java.util.function.Predicate;
 import java.util.function.Supplier;
 
 import static java.util.Collections.singletonList;
+import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
 import static org.elasticsearch.xpack.inference.common.InferenceAPIClusterAwareRateLimitingFeature.INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG;
 
 public class InferencePlugin extends Plugin
@@ -442,6 +443,7 @@ public class InferencePlugin extends Plugin
         settings.addAll(Truncator.getSettingsDefinitions());
         settings.addAll(RequestExecutorServiceSettings.getSettingsDefinitions());
         settings.add(SKIP_VALIDATE_AND_START);
+        settings.add(INDICES_INFERENCE_BATCH_SIZE);
         settings.addAll(ElasticInferenceServiceSettings.getSettingsDefinitions());
 
         return settings;

+ 257 - 200
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java

@@ -25,7 +25,11 @@ import org.elasticsearch.action.update.UpdateRequest;
 import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
 import org.elasticsearch.cluster.metadata.ProjectMetadata;
 import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.settings.Setting;
+import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.common.util.concurrent.AtomicArray;
+import org.elasticsearch.common.xcontent.XContentHelper;
 import org.elasticsearch.common.xcontent.support.XContentMapValues;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.Releasable;
@@ -43,6 +47,10 @@ import org.elasticsearch.license.XPackLicenseState;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.xcontent.XContent;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xcontent.XContentParserConfiguration;
+import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xpack.core.XPackField;
 import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
 import org.elasticsearch.xpack.inference.InferenceException;
@@ -63,6 +71,8 @@ import java.util.Map;
 import java.util.stream.Collectors;
 
 import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
+import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks;
+import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunksLegacy;
 
 /**
  * A {@link MappedActionFilter} that intercepts {@link BulkShardRequest} to apply inference on fields specified
@@ -72,10 +82,23 @@ import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FE
  * This transformation happens on the bulk coordinator node, and the {@link SemanticTextFieldMapper} parses the
  * results during indexing on the shard.
  *
- * TODO: batchSize should be configurable via a cluster setting
  */
 public class ShardBulkInferenceActionFilter implements MappedActionFilter {
-    protected static final int DEFAULT_BATCH_SIZE = 512;
+    private static final ByteSizeValue DEFAULT_BATCH_SIZE = ByteSizeValue.ofMb(1);
+
+    /**
+     * Defines the cumulative size limit of input data before triggering a batch inference call.
+     * This setting controls how much data can be accumulated before an inference request is sent in batch.
+     */
+    public static Setting<ByteSizeValue> INDICES_INFERENCE_BATCH_SIZE = Setting.byteSizeSetting(
+        "indices.inference.batch_size",
+        DEFAULT_BATCH_SIZE,
+        ByteSizeValue.ONE,
+        ByteSizeValue.ofMb(100),
+        Setting.Property.NodeScope,
+        Setting.Property.OperatorDynamic
+    );
+
     private static final Object EXPLICIT_NULL = new Object();
     private static final ChunkedInference EMPTY_CHUNKED_INFERENCE = new EmptyChunkedInference();
 
@@ -83,29 +106,24 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
     private final InferenceServiceRegistry inferenceServiceRegistry;
     private final ModelRegistry modelRegistry;
     private final XPackLicenseState licenseState;
-    private final int batchSize;
+    private volatile long batchSizeInBytes;
 
     public ShardBulkInferenceActionFilter(
         ClusterService clusterService,
         InferenceServiceRegistry inferenceServiceRegistry,
         ModelRegistry modelRegistry,
         XPackLicenseState licenseState
-    ) {
-        this(clusterService, inferenceServiceRegistry, modelRegistry, licenseState, DEFAULT_BATCH_SIZE);
-    }
-
-    public ShardBulkInferenceActionFilter(
-        ClusterService clusterService,
-        InferenceServiceRegistry inferenceServiceRegistry,
-        ModelRegistry modelRegistry,
-        XPackLicenseState licenseState,
-        int batchSize
     ) {
         this.clusterService = clusterService;
         this.inferenceServiceRegistry = inferenceServiceRegistry;
         this.modelRegistry = modelRegistry;
         this.licenseState = licenseState;
-        this.batchSize = batchSize;
+        this.batchSizeInBytes = INDICES_INFERENCE_BATCH_SIZE.get(clusterService.getSettings()).getBytes();
+        clusterService.getClusterSettings().addSettingsUpdateConsumer(INDICES_INFERENCE_BATCH_SIZE, this::setBatchSize);
+    }
+
+    private void setBatchSize(ByteSizeValue newBatchSize) {
+        batchSizeInBytes = newBatchSize.getBytes();
     }
 
     @Override
@@ -148,14 +166,21 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
 
     /**
      * A field inference request on a single input.
-     * @param index The index of the request in the original bulk request.
+     * @param bulkItemIndex The index of the item in the original bulk request.
      * @param field The target field.
      * @param sourceField The source field.
      * @param input The input to run inference on.
      * @param inputOrder The original order of the input.
      * @param offsetAdjustment The adjustment to apply to the chunk text offsets.
      */
-    private record FieldInferenceRequest(int index, String field, String sourceField, String input, int inputOrder, int offsetAdjustment) {}
+    private record FieldInferenceRequest(
+        int bulkItemIndex,
+        String field,
+        String sourceField,
+        String input,
+        int inputOrder,
+        int offsetAdjustment
+    ) {}
 
     /**
      * The field inference response.
@@ -218,29 +243,54 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
 
         @Override
         public void run() {
-            Map<String, List<FieldInferenceRequest>> inferenceRequests = createFieldInferenceRequests(bulkShardRequest);
+            executeNext(0);
+        }
+
+        private void executeNext(int itemOffset) {
+            if (itemOffset >= bulkShardRequest.items().length) {
+                onCompletion.run();
+                return;
+            }
+
+            var items = bulkShardRequest.items();
+            Map<String, List<FieldInferenceRequest>> fieldRequestsMap = new HashMap<>();
+            long totalInputLength = 0;
+            int itemIndex = itemOffset;
+            while (itemIndex < items.length && totalInputLength < batchSizeInBytes) {
+                var item = items[itemIndex];
+                totalInputLength += addFieldInferenceRequests(item, itemIndex, fieldRequestsMap);
+                itemIndex += 1;
+            }
+            int nextItemOffset = itemIndex;
             Runnable onInferenceCompletion = () -> {
                 try {
-                    for (var inferenceResponse : inferenceResults.asList()) {
-                        var request = bulkShardRequest.items()[inferenceResponse.id];
+                    for (int i = itemOffset; i < nextItemOffset; i++) {
+                        var result = inferenceResults.get(i);
+                        if (result == null) {
+                            continue;
+                        }
+                        var item = items[i];
                         try {
-                            applyInferenceResponses(request, inferenceResponse);
+                            applyInferenceResponses(item, result);
                         } catch (Exception exc) {
-                            request.abort(bulkShardRequest.index(), exc);
+                            item.abort(bulkShardRequest.index(), exc);
                         }
+                        // we don't need to keep the inference results around
+                        inferenceResults.set(i, null);
                     }
                 } finally {
-                    onCompletion.run();
+                    executeNext(nextItemOffset);
                 }
             };
+
             try (var releaseOnFinish = new RefCountingRunnable(onInferenceCompletion)) {
-                for (var entry : inferenceRequests.entrySet()) {
-                    executeShardBulkInferenceAsync(entry.getKey(), null, entry.getValue(), releaseOnFinish.acquire());
+                for (var entry : fieldRequestsMap.entrySet()) {
+                    executeChunkedInferenceAsync(entry.getKey(), null, entry.getValue(), releaseOnFinish.acquire());
                 }
             }
         }
 
-        private void executeShardBulkInferenceAsync(
+        private void executeChunkedInferenceAsync(
             final String inferenceId,
             @Nullable InferenceProvider inferenceProvider,
             final List<FieldInferenceRequest> requests,
@@ -262,11 +312,11 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
                                         unparsedModel.secrets()
                                     )
                             );
-                            executeShardBulkInferenceAsync(inferenceId, provider, requests, onFinish);
+                            executeChunkedInferenceAsync(inferenceId, provider, requests, onFinish);
                         } else {
                             try (onFinish) {
                                 for (FieldInferenceRequest request : requests) {
-                                    inferenceResults.get(request.index).failures.add(
+                                    inferenceResults.get(request.bulkItemIndex).failures.add(
                                         new ResourceNotFoundException(
                                             "Inference service [{}] not found for field [{}]",
                                             unparsedModel.service(),
@@ -297,7 +347,7 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
                                         request.field
                                     );
                                 }
-                                inferenceResults.get(request.index).failures.add(failure);
+                                inferenceResults.get(request.bulkItemIndex).failures.add(failure);
                             }
                         }
                     }
@@ -305,18 +355,15 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
                 modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener);
                 return;
             }
-            int currentBatchSize = Math.min(requests.size(), batchSize);
-            final List<FieldInferenceRequest> currentBatch = requests.subList(0, currentBatchSize);
-            final List<FieldInferenceRequest> nextBatch = requests.subList(currentBatchSize, requests.size());
-            final List<String> inputs = currentBatch.stream().map(FieldInferenceRequest::input).collect(Collectors.toList());
+            final List<String> inputs = requests.stream().map(FieldInferenceRequest::input).collect(Collectors.toList());
             ActionListener<List<ChunkedInference>> completionListener = new ActionListener<>() {
                 @Override
                 public void onResponse(List<ChunkedInference> results) {
-                    try {
+                    try (onFinish) {
                         var requestsIterator = requests.iterator();
                         for (ChunkedInference result : results) {
                             var request = requestsIterator.next();
-                            var acc = inferenceResults.get(request.index);
+                            var acc = inferenceResults.get(request.bulkItemIndex);
                             if (result instanceof ChunkedInferenceError error) {
                                 acc.addFailure(
                                     new InferenceException(
@@ -331,7 +378,7 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
                                     new FieldInferenceResponse(
                                         request.field(),
                                         request.sourceField(),
-                                        request.input(),
+                                        useLegacyFormat ? request.input() : null,
                                         request.inputOrder(),
                                         request.offsetAdjustment(),
                                         inferenceProvider.model,
@@ -340,17 +387,15 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
                                 );
                             }
                         }
-                    } finally {
-                        onFinish();
                     }
                 }
 
                 @Override
                 public void onFailure(Exception exc) {
-                    try {
+                    try (onFinish) {
                         for (FieldInferenceRequest request : requests) {
                             addInferenceResponseFailure(
-                                request.index,
+                                request.bulkItemIndex,
                                 new InferenceException(
                                     "Exception when running inference id [{}] on field [{}]",
                                     exc,
@@ -359,16 +404,6 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
                                 )
                             );
                         }
-                    } finally {
-                        onFinish();
-                    }
-                }
-
-                private void onFinish() {
-                    if (nextBatch.isEmpty()) {
-                        onFinish.close();
-                    } else {
-                        executeShardBulkInferenceAsync(inferenceId, inferenceProvider, nextBatch, onFinish);
                     }
                 }
             };
@@ -376,6 +411,132 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
                 .chunkedInfer(inferenceProvider.model(), null, inputs, Map.of(), InputType.INGEST, TimeValue.MAX_VALUE, completionListener);
         }
 
+        /**
+         * Adds all inference requests associated with their respective inference IDs to the given {@code requestsMap}
+         * for the specified {@code item}.
+         *
+         * @param item       The bulk request item to process.
+         * @param itemIndex  The position of the item within the original bulk request.
+         * @param requestsMap A map storing inference requests, where each key is an inference ID,
+         *                    and the value is a list of associated {@link FieldInferenceRequest} objects.
+         * @return The total content length of all newly added requests, or {@code 0} if no requests were added.
+         */
+        private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<String, List<FieldInferenceRequest>> requestsMap) {
+            boolean isUpdateRequest = false;
+            final IndexRequest indexRequest;
+            if (item.request() instanceof IndexRequest ir) {
+                indexRequest = ir;
+            } else if (item.request() instanceof UpdateRequest updateRequest) {
+                isUpdateRequest = true;
+                if (updateRequest.script() != null) {
+                    addInferenceResponseFailure(
+                        itemIndex,
+                        new ElasticsearchStatusException(
+                            "Cannot apply update with a script on indices that contain [{}] field(s)",
+                            RestStatus.BAD_REQUEST,
+                            SemanticTextFieldMapper.CONTENT_TYPE
+                        )
+                    );
+                    return 0;
+                }
+                indexRequest = updateRequest.doc();
+            } else {
+                // ignore delete request
+                return 0;
+            }
+
+            final Map<String, Object> docMap = indexRequest.sourceAsMap();
+            long inputLength = 0;
+            for (var entry : fieldInferenceMap.values()) {
+                String field = entry.getName();
+                String inferenceId = entry.getInferenceId();
+
+                if (useLegacyFormat) {
+                    var originalFieldValue = XContentMapValues.extractValue(field, docMap);
+                    if (originalFieldValue instanceof Map || (originalFieldValue == null && entry.getSourceFields().length == 1)) {
+                        // Inference has already been computed, or there is no inference required.
+                        continue;
+                    }
+                } else {
+                    var inferenceMetadataFieldsValue = XContentMapValues.extractValue(
+                        InferenceMetadataFieldsMapper.NAME + "." + field,
+                        docMap,
+                        EXPLICIT_NULL
+                    );
+                    if (inferenceMetadataFieldsValue != null) {
+                        // Inference has already been computed
+                        continue;
+                    }
+                }
+
+                int order = 0;
+                for (var sourceField : entry.getSourceFields()) {
+                    var valueObj = XContentMapValues.extractValue(sourceField, docMap, EXPLICIT_NULL);
+                    if (useLegacyFormat == false && isUpdateRequest && valueObj == EXPLICIT_NULL) {
+                        /**
+                         * It's an update request, and the source field is explicitly set to null,
+                         * so we need to propagate this information to the inference fields metadata
+                         * to overwrite any inference previously computed on the field.
+                         * This ensures that the field is treated as intentionally cleared,
+                         * preventing any unintended carryover of prior inference results.
+                         */
+                        var slot = ensureResponseAccumulatorSlot(itemIndex);
+                        slot.addOrUpdateResponse(
+                            new FieldInferenceResponse(field, sourceField, null, order++, 0, null, EMPTY_CHUNKED_INFERENCE)
+                        );
+                        continue;
+                    }
+                    if (valueObj == null || valueObj == EXPLICIT_NULL) {
+                        if (isUpdateRequest && useLegacyFormat) {
+                            addInferenceResponseFailure(
+                                itemIndex,
+                                new ElasticsearchStatusException(
+                                    "Field [{}] must be specified on an update request to calculate inference for field [{}]",
+                                    RestStatus.BAD_REQUEST,
+                                    sourceField,
+                                    field
+                                )
+                            );
+                            break;
+                        }
+                        continue;
+                    }
+                    var slot = ensureResponseAccumulatorSlot(itemIndex);
+                    final List<String> values;
+                    try {
+                        values = SemanticTextUtils.nodeStringValues(field, valueObj);
+                    } catch (Exception exc) {
+                        addInferenceResponseFailure(itemIndex, exc);
+                        break;
+                    }
+
+                    if (INFERENCE_API_FEATURE.check(licenseState) == false) {
+                        addInferenceResponseFailure(itemIndex, LicenseUtils.newComplianceException(XPackField.INFERENCE));
+                        break;
+                    }
+
+                    List<FieldInferenceRequest> requests = requestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>());
+                    int offsetAdjustment = 0;
+                    for (String v : values) {
+                        inputLength += v.length();
+                        if (v.isBlank()) {
+                            slot.addOrUpdateResponse(
+                                new FieldInferenceResponse(field, sourceField, v, order++, 0, null, EMPTY_CHUNKED_INFERENCE)
+                            );
+                        } else {
+                            requests.add(new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment));
+                        }
+
+                        // When using the inference metadata fields format, all the input values are concatenated so that the
+                        // chunk text offsets are expressed in the context of a single string. Calculate the offset adjustment
+                        // to apply to account for this.
+                        offsetAdjustment += v.length() + 1; // Add one for separator char length
+                    }
+                }
+            }
+            return inputLength;
+        }
+
         private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) {
             FieldInferenceResponseAccumulator acc = inferenceResults.get(id);
             if (acc == null) {
@@ -404,7 +565,6 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
             }
 
             final IndexRequest indexRequest = getIndexRequestOrNull(item.request());
-            var newDocMap = indexRequest.sourceAsMap();
             Map<String, Object> inferenceFieldsMap = new HashMap<>();
             for (var entry : response.responses.entrySet()) {
                 var fieldName = entry.getKey();
@@ -426,28 +586,22 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
                     }
 
                     var lst = chunkMap.computeIfAbsent(resp.sourceField, k -> new ArrayList<>());
-                    lst.addAll(
-                        SemanticTextField.toSemanticTextFieldChunks(
-                            resp.input,
-                            resp.offsetAdjustment,
-                            resp.chunkedResults,
-                            indexRequest.getContentType(),
-                            useLegacyFormat
-                        )
-                    );
+                    var chunks = useLegacyFormat
+                        ? toSemanticTextFieldChunksLegacy(resp.input, resp.chunkedResults, indexRequest.getContentType())
+                        : toSemanticTextFieldChunks(resp.offsetAdjustment, resp.chunkedResults, indexRequest.getContentType());
+                    lst.addAll(chunks);
                 }
 
-                List<String> inputs = responses.stream()
-                    .filter(r -> r.sourceField().equals(fieldName))
-                    .map(r -> r.input)
-                    .collect(Collectors.toList());
+                List<String> inputs = useLegacyFormat
+                    ? responses.stream().filter(r -> r.sourceField().equals(fieldName)).map(r -> r.input).collect(Collectors.toList())
+                    : null;
 
                 // The model can be null if we are only processing update requests that clear inference results. This is ok because we will
                 // merge in the field's existing model settings on the data node.
                 var result = new SemanticTextField(
                     useLegacyFormat,
                     fieldName,
-                    useLegacyFormat ? inputs : null,
+                    inputs,
                     new SemanticTextField.InferenceResult(
                         inferenceFieldMetadata.getInferenceId(),
                         model != null ? new MinimalServiceSettings(model) : null,
@@ -455,149 +609,52 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
                     ),
                     indexRequest.getContentType()
                 );
-
-                if (useLegacyFormat) {
-                    SemanticTextUtils.insertValue(fieldName, newDocMap, result);
-                } else {
-                    inferenceFieldsMap.put(fieldName, result);
-                }
-            }
-            if (useLegacyFormat == false) {
-                newDocMap.put(InferenceMetadataFieldsMapper.NAME, inferenceFieldsMap);
+                inferenceFieldsMap.put(fieldName, result);
             }
-            indexRequest.source(newDocMap, indexRequest.getContentType());
-        }
 
-        /**
-         * Register a {@link FieldInferenceRequest} for every non-empty field referencing an inference ID in the index.
-         * If results are already populated for fields in the original index request, the inference request for this specific
-         * field is skipped, and the existing results remain unchanged.
-         * Validation of inference ID and model settings occurs in the {@link SemanticTextFieldMapper} during field indexing,
-         * where an error will be thrown if they mismatch or if the content is malformed.
-         * <p>
-         * TODO: We should validate the settings for pre-existing results here and apply the inference only if they differ?
-         */
-        private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(BulkShardRequest bulkShardRequest) {
-            Map<String, List<FieldInferenceRequest>> fieldRequestsMap = new LinkedHashMap<>();
-            for (int itemIndex = 0; itemIndex < bulkShardRequest.items().length; itemIndex++) {
-                var item = bulkShardRequest.items()[itemIndex];
-                if (item.getPrimaryResponse() != null) {
-                    // item was already aborted/processed by a filter in the chain upstream (e.g. security)
-                    continue;
+            if (useLegacyFormat) {
+                var newDocMap = indexRequest.sourceAsMap();
+                for (var entry : inferenceFieldsMap.entrySet()) {
+                    SemanticTextUtils.insertValue(entry.getKey(), newDocMap, entry.getValue());
                 }
-                boolean isUpdateRequest = false;
-                final IndexRequest indexRequest;
-                if (item.request() instanceof IndexRequest ir) {
-                    indexRequest = ir;
-                } else if (item.request() instanceof UpdateRequest updateRequest) {
-                    isUpdateRequest = true;
-                    if (updateRequest.script() != null) {
-                        addInferenceResponseFailure(
-                            itemIndex,
-                            new ElasticsearchStatusException(
-                                "Cannot apply update with a script on indices that contain [{}] field(s)",
-                                RestStatus.BAD_REQUEST,
-                                SemanticTextFieldMapper.CONTENT_TYPE
-                            )
-                        );
-                        continue;
-                    }
-                    indexRequest = updateRequest.doc();
-                } else {
-                    // ignore delete request
-                    continue;
+                indexRequest.source(newDocMap, indexRequest.getContentType());
+            } else {
+                try (XContentBuilder builder = XContentBuilder.builder(indexRequest.getContentType().xContent())) {
+                    appendSourceAndInferenceMetadata(builder, indexRequest.source(), indexRequest.getContentType(), inferenceFieldsMap);
+                    indexRequest.source(builder);
                 }
+            }
+        }
+    }
 
-                final Map<String, Object> docMap = indexRequest.sourceAsMap();
-                for (var entry : fieldInferenceMap.values()) {
-                    String field = entry.getName();
-                    String inferenceId = entry.getInferenceId();
-
-                    if (useLegacyFormat) {
-                        var originalFieldValue = XContentMapValues.extractValue(field, docMap);
-                        if (originalFieldValue instanceof Map || (originalFieldValue == null && entry.getSourceFields().length == 1)) {
-                            // Inference has already been computed, or there is no inference required.
-                            continue;
-                        }
-                    } else {
-                        var inferenceMetadataFieldsValue = XContentMapValues.extractValue(
-                            InferenceMetadataFieldsMapper.NAME + "." + field,
-                            docMap,
-                            EXPLICIT_NULL
-                        );
-                        if (inferenceMetadataFieldsValue != null) {
-                            // Inference has already been computed
-                            continue;
-                        }
-                    }
-
-                    int order = 0;
-                    for (var sourceField : entry.getSourceFields()) {
-                        var valueObj = XContentMapValues.extractValue(sourceField, docMap, EXPLICIT_NULL);
-                        if (useLegacyFormat == false && isUpdateRequest && valueObj == EXPLICIT_NULL) {
-                            /**
-                             * It's an update request, and the source field is explicitly set to null,
-                             * so we need to propagate this information to the inference fields metadata
-                             * to overwrite any inference previously computed on the field.
-                             * This ensures that the field is treated as intentionally cleared,
-                             * preventing any unintended carryover of prior inference results.
-                             */
-                            var slot = ensureResponseAccumulatorSlot(itemIndex);
-                            slot.addOrUpdateResponse(
-                                new FieldInferenceResponse(field, sourceField, null, order++, 0, null, EMPTY_CHUNKED_INFERENCE)
-                            );
-                            continue;
-                        }
-                        if (valueObj == null || valueObj == EXPLICIT_NULL) {
-                            if (isUpdateRequest && useLegacyFormat) {
-                                addInferenceResponseFailure(
-                                    itemIndex,
-                                    new ElasticsearchStatusException(
-                                        "Field [{}] must be specified on an update request to calculate inference for field [{}]",
-                                        RestStatus.BAD_REQUEST,
-                                        sourceField,
-                                        field
-                                    )
-                                );
-                                break;
-                            }
-                            continue;
-                        }
-                        var slot = ensureResponseAccumulatorSlot(itemIndex);
-                        final List<String> values;
-                        try {
-                            values = SemanticTextUtils.nodeStringValues(field, valueObj);
-                        } catch (Exception exc) {
-                            addInferenceResponseFailure(itemIndex, exc);
-                            break;
-                        }
-
-                        if (INFERENCE_API_FEATURE.check(licenseState) == false) {
-                            addInferenceResponseFailure(itemIndex, LicenseUtils.newComplianceException(XPackField.INFERENCE));
-                            break;
-                        }
-
-                        List<FieldInferenceRequest> fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>());
-                        int offsetAdjustment = 0;
-                        for (String v : values) {
-                            if (v.isBlank()) {
-                                slot.addOrUpdateResponse(
-                                    new FieldInferenceResponse(field, sourceField, v, order++, 0, null, EMPTY_CHUNKED_INFERENCE)
-                                );
-                            } else {
-                                fieldRequests.add(new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment));
-                            }
-
-                            // When using the inference metadata fields format, all the input values are concatenated so that the
-                            // chunk text offsets are expressed in the context of a single string. Calculate the offset adjustment
-                            // to apply to account for this.
-                            offsetAdjustment += v.length() + 1; // Add one for separator char length
-                        }
-                    }
-                }
+    /**
+     * Appends the original source and the new inference metadata field directly to the provided
+     * {@link XContentBuilder}, avoiding the need to materialize the original source as a {@link Map}.
+     */
+    private static void appendSourceAndInferenceMetadata(
+        XContentBuilder builder,
+        BytesReference source,
+        XContentType xContentType,
+        Map<String, Object> inferenceFieldsMap
+    ) throws IOException {
+        builder.startObject();
+
+        // append the original source
+        try (XContentParser parser = XContentHelper.createParserNotCompressed(XContentParserConfiguration.EMPTY, source, xContentType)) {
+            // skip start object
+            parser.nextToken();
+            while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
+                builder.copyCurrentStructure(parser);
             }
-            return fieldRequestsMap;
         }
+
+        // add the inference metadata field
+        builder.field(InferenceMetadataFieldsMapper.NAME);
+        try (XContentParser parser = XContentHelper.mapToXContentParser(XContentParserConfiguration.EMPTY, inferenceFieldsMap)) {
+            builder.copyCurrentStructure(parser);
+        }
+
+        builder.endObject();
     }
 
     static IndexRequest getIndexRequestOrNull(DocWriteRequest<?> docWriteRequest) {

+ 23 - 22
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java

@@ -267,37 +267,38 @@ public record SemanticTextField(
     /**
      * Converts the provided {@link ChunkedInference} into a list of {@link Chunk}.
      */
-    public static List<Chunk> toSemanticTextFieldChunks(
-        String input,
-        int offsetAdjustment,
-        ChunkedInference results,
-        XContentType contentType,
-        boolean useLegacyFormat
-    ) throws IOException {
+    public static List<Chunk> toSemanticTextFieldChunks(int offsetAdjustment, ChunkedInference results, XContentType contentType)
+        throws IOException {
         List<Chunk> chunks = new ArrayList<>();
         Iterator<ChunkedInference.Chunk> it = results.chunksAsByteReference(contentType.xContent());
         while (it.hasNext()) {
-            chunks.add(toSemanticTextFieldChunk(input, offsetAdjustment, it.next(), useLegacyFormat));
+            chunks.add(toSemanticTextFieldChunk(offsetAdjustment, it.next()));
         }
         return chunks;
     }
 
-    public static Chunk toSemanticTextFieldChunk(
-        String input,
-        int offsetAdjustment,
-        ChunkedInference.Chunk chunk,
-        boolean useLegacyFormat
-    ) {
+    /**
+     * Converts the provided {@link ChunkedInference} into a list of {@link Chunk}.
+     */
+    public static Chunk toSemanticTextFieldChunk(int offsetAdjustment, ChunkedInference.Chunk chunk) {
         String text = null;
-        int startOffset = -1;
-        int endOffset = -1;
-        if (useLegacyFormat) {
-            text = input.substring(chunk.textOffset().start(), chunk.textOffset().end());
-        } else {
-            startOffset = chunk.textOffset().start() + offsetAdjustment;
-            endOffset = chunk.textOffset().end() + offsetAdjustment;
+        int startOffset = chunk.textOffset().start() + offsetAdjustment;
+        int endOffset = chunk.textOffset().end() + offsetAdjustment;
+        return new Chunk(text, startOffset, endOffset, chunk.bytesReference());
+    }
+
+    public static List<Chunk> toSemanticTextFieldChunksLegacy(String input, ChunkedInference results, XContentType contentType)
+        throws IOException {
+        List<Chunk> chunks = new ArrayList<>();
+        Iterator<ChunkedInference.Chunk> it = results.chunksAsByteReference(contentType.xContent());
+        while (it.hasNext()) {
+            chunks.add(toSemanticTextFieldChunkLegacy(input, it.next()));
         }
+        return chunks;
+    }
 
-        return new Chunk(text, startOffset, endOffset, chunk.bytesReference());
+    public static Chunk toSemanticTextFieldChunkLegacy(String input, ChunkedInference.Chunk chunk) {
+        var text = input.substring(chunk.textOffset().start(), chunk.textOffset().end());
+        return new Chunk(text, -1, -1, chunk.bytesReference());
     }
 }

+ 16 - 21
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java

@@ -28,7 +28,9 @@ import org.elasticsearch.cluster.metadata.Metadata;
 import org.elasticsearch.cluster.metadata.ProjectMetadata;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.common.xcontent.XContentHelper;
 import org.elasticsearch.common.xcontent.support.XContentMapValues;
 import org.elasticsearch.index.IndexVersion;
@@ -66,12 +68,13 @@ import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
+import java.util.Set;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 
 import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
 import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch;
-import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.DEFAULT_BATCH_SIZE;
+import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
 import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.getIndexRequestOrNull;
 import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getChunksFieldName;
 import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getOriginalTextFieldName;
@@ -118,7 +121,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
 
     @SuppressWarnings({ "unchecked", "rawtypes" })
     public void testFilterNoop() throws Exception {
-        ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat, true);
+        ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), useLegacyFormat, true);
         CountDownLatch chainExecuted = new CountDownLatch(1);
         ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
             try {
@@ -144,7 +147,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
     @SuppressWarnings({ "unchecked", "rawtypes" })
     public void testLicenseInvalidForInference() throws InterruptedException {
         StaticModel model = StaticModel.createRandomInstance();
-        ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), DEFAULT_BATCH_SIZE, useLegacyFormat, false);
+        ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), useLegacyFormat, false);
         CountDownLatch chainExecuted = new CountDownLatch(1);
         ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
             try {
@@ -185,7 +188,6 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
         ShardBulkInferenceActionFilter filter = createFilter(
             threadPool,
             Map.of(model.getInferenceEntityId(), model),
-            randomIntBetween(1, 10),
             useLegacyFormat,
             true
         );
@@ -232,7 +234,6 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
         ShardBulkInferenceActionFilter filter = createFilter(
             threadPool,
             Map.of(model.getInferenceEntityId(), model),
-            randomIntBetween(1, 10),
             useLegacyFormat,
             true
         );
@@ -303,7 +304,6 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
         ShardBulkInferenceActionFilter filter = createFilter(
             threadPool,
             Map.of(model.getInferenceEntityId(), model),
-            randomIntBetween(1, 10),
             useLegacyFormat,
             true
         );
@@ -374,7 +374,6 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
         ShardBulkInferenceActionFilter filter = createFilter(
             threadPool,
             Map.of(model.getInferenceEntityId(), model),
-            randomIntBetween(1, 10),
             useLegacyFormat,
             true
         );
@@ -447,13 +446,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
             modifiedRequests[id] = res[1];
         }
 
-        ShardBulkInferenceActionFilter filter = createFilter(
-            threadPool,
-            inferenceModelMap,
-            randomIntBetween(10, 30),
-            useLegacyFormat,
-            true
-        );
+        ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap, useLegacyFormat, true);
         CountDownLatch chainExecuted = new CountDownLatch(1);
         ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
             try {
@@ -487,7 +480,6 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
     private static ShardBulkInferenceActionFilter createFilter(
         ThreadPool threadPool,
         Map<String, StaticModel> modelMap,
-        int batchSize,
         boolean useLegacyFormat,
         boolean isLicenseValidForInference
     ) {
@@ -554,18 +546,17 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
             createClusterService(useLegacyFormat),
             inferenceServiceRegistry,
             modelRegistry,
-            licenseState,
-            batchSize
+            licenseState
         );
     }
 
     private static ClusterService createClusterService(boolean useLegacyFormat) {
         IndexMetadata indexMetadata = mock(IndexMetadata.class);
-        var settings = Settings.builder()
+        var indexSettings = Settings.builder()
             .put(IndexMetadata.SETTING_INDEX_VERSION_CREATED.getKey(), IndexVersion.current())
             .put(InferenceMetadataFieldsMapper.USE_LEGACY_SEMANTIC_TEXT_FORMAT.getKey(), useLegacyFormat)
             .build();
-        when(indexMetadata.getSettings()).thenReturn(settings);
+        when(indexMetadata.getSettings()).thenReturn(indexSettings);
 
         ProjectMetadata project = spy(ProjectMetadata.builder(Metadata.DEFAULT_PROJECT_ID).build());
         when(project.index(anyString())).thenReturn(indexMetadata);
@@ -576,7 +567,10 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
         ClusterState clusterState = ClusterState.builder(new ClusterName("test")).metadata(metadata).build();
         ClusterService clusterService = mock(ClusterService.class);
         when(clusterService.state()).thenReturn(clusterState);
-
+        long batchSizeInBytes = randomLongBetween(0, ByteSizeValue.ofKb(1).getBytes());
+        Settings settings = Settings.builder().put(INDICES_INFERENCE_BATCH_SIZE.getKey(), ByteSizeValue.ofBytes(batchSizeInBytes)).build();
+        when(clusterService.getSettings()).thenReturn(settings);
+        when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, Set.of(INDICES_INFERENCE_BATCH_SIZE)));
         return clusterService;
     }
 
@@ -587,7 +581,8 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
     ) throws IOException {
         Map<String, Object> docMap = new LinkedHashMap<>();
         Map<String, Object> expectedDocMap = new LinkedHashMap<>();
-        XContentType requestContentType = randomFrom(XContentType.values());
+        // force JSON to avoid double/float conversions
+        XContentType requestContentType = XContentType.JSON;
 
         Map<String, Object> inferenceMetadataFields = new HashMap<>();
         for (var entry : fieldInferenceMap.values()) {

+ 2 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java

@@ -41,6 +41,7 @@ import java.util.function.Predicate;
 
 import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD;
 import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunk;
+import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunkLegacy;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
 
@@ -274,7 +275,7 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTex
         while (inputsIt.hasNext() && chunkIt.hasNext()) {
             String input = inputsIt.next();
             var chunk = chunkIt.next();
-            chunks.add(toSemanticTextFieldChunk(input, offsetAdjustment, chunk, useLegacyFormat));
+            chunks.add(useLegacyFormat ? toSemanticTextFieldChunkLegacy(input, chunk) : toSemanticTextFieldChunk(offsetAdjustment, chunk));
 
             // When using the inference metadata fields format, all the input values are concatenated so that the
             // chunk text offsets are expressed in the context of a single string. Calculate the offset adjustment