1
0
Эх сурвалжийг харах

[8.19] Semantic Text Chunking Indexing Pressure (#125517) (#127463)

* Semantic Text Chunking Indexing Pressure (#125517)

We have observed many OOMs due to the memory required to inject chunked inference results for semantic_text fields. This PR uses coordinating indexing pressure to account for this memory usage. When indexing pressure memory usage exceeds the threshold set by indexing_pressure.memory.limit, chunked inference result injection will be suspended to prevent OOMs.

(cherry picked from commit 85713f78e00682e6abe3329440682ba2abef0d3f)

# Conflicts:
#	server/src/main/java/org/elasticsearch/node/NodeConstruction.java
#	server/src/main/java/org/elasticsearch/node/PluginServiceInstances.java
#	x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java
#	x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java

* [CI] Auto commit changes from spotless

---------

Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>
Mike Pellegrini 5 сар өмнө
parent
commit
30d5a9dbfc

+ 5 - 0
docs/changelog/125517.yaml

@@ -0,0 +1,5 @@
+pr: 125517
+summary: Semantic Text Chunking Indexing Pressure
+area: Machine Learning
+type: enhancement
+issues: []

+ 6 - 2
server/src/main/java/org/elasticsearch/index/IndexingPressure.java

@@ -146,11 +146,15 @@ public class IndexingPressure {
     }
 
     public Coordinating markCoordinatingOperationStarted(int operations, long bytes, boolean forceExecution) {
-        Coordinating coordinating = new Coordinating(forceExecution);
+        Coordinating coordinating = createCoordinatingOperation(forceExecution);
         coordinating.increment(operations, bytes);
         return coordinating;
     }
 
+    public Coordinating createCoordinatingOperation(boolean forceExecution) {
+        return new Coordinating(forceExecution);
+    }
+
     public class Incremental implements Releasable {
 
         private final AtomicBoolean closed = new AtomicBoolean();
@@ -243,7 +247,7 @@ public class IndexingPressure {
             this.forceExecution = forceExecution;
         }
 
-        private void increment(int operations, long bytes) {
+        public void increment(int operations, long bytes) {
             assert closed.get() == false;
             long combinedBytes = currentCombinedCoordinatingAndPrimaryBytes.addAndGet(bytes);
             long replicaWriteBytes = currentReplicaBytes.get();

+ 4 - 2
server/src/main/java/org/elasticsearch/node/NodeConstruction.java

@@ -936,6 +936,8 @@ class NodeConstruction {
             metadataCreateIndexService
         );
 
+        final IndexingPressure indexingLimits = new IndexingPressure(settings);
+
         PluginServiceInstances pluginServices = new PluginServiceInstances(
             client,
             clusterService,
@@ -957,7 +959,8 @@ class NodeConstruction {
             dataStreamGlobalRetentionSettings,
             documentParsingProvider,
             taskManager,
-            slowLogFieldProvider
+            slowLogFieldProvider,
+            indexingLimits
         );
 
         Collection<?> pluginComponents = pluginsService.flatMap(plugin -> {
@@ -990,7 +993,6 @@ class NodeConstruction {
             .map(TerminationHandlerProvider::handler);
         terminationHandler = getSinglePlugin(terminationHandlers, TerminationHandler.class).orElse(null);
 
-        final IndexingPressure indexingLimits = new IndexingPressure(settings);
         final IncrementalBulkService incrementalBulkService = new IncrementalBulkService(client, indexingLimits);
 
         ActionModule actionModule = new ActionModule(

+ 3 - 1
server/src/main/java/org/elasticsearch/node/PluginServiceInstances.java

@@ -19,6 +19,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.env.Environment;
 import org.elasticsearch.env.NodeEnvironment;
 import org.elasticsearch.features.FeatureService;
+import org.elasticsearch.index.IndexingPressure;
 import org.elasticsearch.index.SlowLogFieldProvider;
 import org.elasticsearch.indices.IndicesService;
 import org.elasticsearch.indices.SystemIndices;
@@ -53,5 +54,6 @@ public record PluginServiceInstances(
     DataStreamGlobalRetentionSettings dataStreamGlobalRetentionSettings,
     DocumentParsingProvider documentParsingProvider,
     TaskManager taskManager,
-    SlowLogFieldProvider slowLogFieldProvider
+    SlowLogFieldProvider slowLogFieldProvider,
+    IndexingPressure indexingPressure
 ) implements Plugin.PluginServices {}

+ 6 - 0
server/src/main/java/org/elasticsearch/plugins/Plugin.java

@@ -27,6 +27,7 @@ import org.elasticsearch.env.NodeEnvironment;
 import org.elasticsearch.features.FeatureService;
 import org.elasticsearch.index.IndexModule;
 import org.elasticsearch.index.IndexSettingProvider;
+import org.elasticsearch.index.IndexingPressure;
 import org.elasticsearch.index.SlowLogFieldProvider;
 import org.elasticsearch.indices.IndicesService;
 import org.elasticsearch.indices.SystemIndices;
@@ -179,6 +180,11 @@ public abstract class Plugin implements Closeable {
          * Provider for additional SlowLog fields
          */
         SlowLogFieldProvider slowLogFieldProvider();
+
+        /**
+         * Provider for indexing pressure
+         */
+        IndexingPressure indexingPressure();
     }
 
     /**

+ 2 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

@@ -328,7 +328,8 @@ public class InferencePlugin extends Plugin
             services.clusterService(),
             serviceRegistry,
             modelRegistry.get(),
-            getLicenseState()
+            getLicenseState(),
+            services.indexingPressure()
         );
         shardBulkInferenceActionFilter.set(actionFilter);
 

+ 100 - 11
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java

@@ -28,11 +28,13 @@ import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.common.util.concurrent.AtomicArray;
+import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
 import org.elasticsearch.common.xcontent.XContentHelper;
 import org.elasticsearch.common.xcontent.support.XContentMapValues;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.Releasable;
 import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.index.IndexingPressure;
 import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
 import org.elasticsearch.inference.ChunkInferenceInput;
 import org.elasticsearch.inference.ChunkedInference;
@@ -108,18 +110,21 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
     private final InferenceServiceRegistry inferenceServiceRegistry;
     private final ModelRegistry modelRegistry;
     private final XPackLicenseState licenseState;
+    private final IndexingPressure indexingPressure;
     private volatile long batchSizeInBytes;
 
     public ShardBulkInferenceActionFilter(
         ClusterService clusterService,
         InferenceServiceRegistry inferenceServiceRegistry,
         ModelRegistry modelRegistry,
-        XPackLicenseState licenseState
+        XPackLicenseState licenseState,
+        IndexingPressure indexingPressure
     ) {
         this.clusterService = clusterService;
         this.inferenceServiceRegistry = inferenceServiceRegistry;
         this.modelRegistry = modelRegistry;
         this.licenseState = licenseState;
+        this.indexingPressure = indexingPressure;
         this.batchSizeInBytes = INDICES_INFERENCE_BATCH_SIZE.get(clusterService.getSettings()).getBytes();
         clusterService.getClusterSettings().addSettingsUpdateConsumer(INDICES_INFERENCE_BATCH_SIZE, this::setBatchSize);
     }
@@ -145,8 +150,15 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
             BulkShardRequest bulkShardRequest = (BulkShardRequest) request;
             var fieldInferenceMetadata = bulkShardRequest.consumeInferenceFieldMap();
             if (fieldInferenceMetadata != null && fieldInferenceMetadata.isEmpty() == false) {
-                Runnable onInferenceCompletion = () -> chain.proceed(task, action, request, listener);
-                processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion);
+                // Maintain coordinating indexing pressure from inference until the indexing operations are complete
+                IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.createCoordinatingOperation(false);
+                Runnable onInferenceCompletion = () -> chain.proceed(
+                    task,
+                    action,
+                    request,
+                    ActionListener.releaseAfter(listener, coordinatingIndexingPressure)
+                );
+                processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion, coordinatingIndexingPressure);
                 return;
             }
         }
@@ -156,11 +168,13 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
     private void processBulkShardRequest(
         Map<String, InferenceFieldMetadata> fieldInferenceMap,
         BulkShardRequest bulkShardRequest,
-        Runnable onCompletion
+        Runnable onCompletion,
+        IndexingPressure.Coordinating coordinatingIndexingPressure
     ) {
         var index = clusterService.state().getMetadata().index(bulkShardRequest.index());
         boolean useLegacyFormat = InferenceMetadataFieldsMapper.isEnabled(index.getSettings()) == false;
-        new AsyncBulkShardInferenceAction(useLegacyFormat, fieldInferenceMap, bulkShardRequest, onCompletion).run();
+        new AsyncBulkShardInferenceAction(useLegacyFormat, fieldInferenceMap, bulkShardRequest, onCompletion, coordinatingIndexingPressure)
+            .run();
     }
 
     private record InferenceProvider(InferenceService service, Model model) {}
@@ -230,18 +244,21 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
         private final BulkShardRequest bulkShardRequest;
         private final Runnable onCompletion;
         private final AtomicArray<FieldInferenceResponseAccumulator> inferenceResults;
+        private final IndexingPressure.Coordinating coordinatingIndexingPressure;
 
         private AsyncBulkShardInferenceAction(
             boolean useLegacyFormat,
             Map<String, InferenceFieldMetadata> fieldInferenceMap,
             BulkShardRequest bulkShardRequest,
-            Runnable onCompletion
+            Runnable onCompletion,
+            IndexingPressure.Coordinating coordinatingIndexingPressure
         ) {
             this.useLegacyFormat = useLegacyFormat;
             this.fieldInferenceMap = fieldInferenceMap;
             this.bulkShardRequest = bulkShardRequest;
             this.inferenceResults = new AtomicArray<>(bulkShardRequest.items().length);
             this.onCompletion = onCompletion;
+            this.coordinatingIndexingPressure = coordinatingIndexingPressure;
         }
 
         @Override
@@ -429,9 +446,9 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
          */
         private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<String, List<FieldInferenceRequest>> requestsMap) {
             boolean isUpdateRequest = false;
-            final IndexRequest indexRequest;
+            final IndexRequestWithIndexingPressure indexRequest;
             if (item.request() instanceof IndexRequest ir) {
-                indexRequest = ir;
+                indexRequest = new IndexRequestWithIndexingPressure(ir);
             } else if (item.request() instanceof UpdateRequest updateRequest) {
                 isUpdateRequest = true;
                 if (updateRequest.script() != null) {
@@ -445,13 +462,13 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
                     );
                     return 0;
                 }
-                indexRequest = updateRequest.doc();
+                indexRequest = new IndexRequestWithIndexingPressure(updateRequest.doc());
             } else {
                 // ignore delete request
                 return 0;
             }
 
-            final Map<String, Object> docMap = indexRequest.sourceAsMap();
+            final Map<String, Object> docMap = indexRequest.getIndexRequest().sourceAsMap();
             long inputLength = 0;
             for (var entry : fieldInferenceMap.values()) {
                 String field = entry.getName();
@@ -487,6 +504,10 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
                          * This ensures that the field is treated as intentionally cleared,
                          * preventing any unintended carryover of prior inference results.
                          */
+                        if (incrementIndexingPressure(indexRequest, itemIndex) == false) {
+                            return inputLength;
+                        }
+
                         var slot = ensureResponseAccumulatorSlot(itemIndex);
                         slot.addOrUpdateResponse(
                             new FieldInferenceResponse(field, sourceField, null, order++, 0, null, EMPTY_CHUNKED_INFERENCE)
@@ -508,6 +529,7 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
                         }
                         continue;
                     }
+
                     var slot = ensureResponseAccumulatorSlot(itemIndex);
                     final List<String> values;
                     try {
@@ -525,7 +547,10 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
                     List<FieldInferenceRequest> requests = requestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>());
                     int offsetAdjustment = 0;
                     for (String v : values) {
-                        inputLength += v.length();
+                        if (incrementIndexingPressure(indexRequest, itemIndex) == false) {
+                            return inputLength;
+                        }
+
                         if (v.isBlank()) {
                             slot.addOrUpdateResponse(
                                 new FieldInferenceResponse(field, sourceField, v, order++, 0, null, EMPTY_CHUNKED_INFERENCE)
@@ -534,6 +559,7 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
                             requests.add(
                                 new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment, chunkingSettings)
                             );
+                            inputLength += v.length();
                         }
 
                         // When using the inference metadata fields format, all the input values are concatenated so that the
@@ -543,9 +569,54 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
                     }
                 }
             }
+
             return inputLength;
         }
 
+        private static class IndexRequestWithIndexingPressure {
+            private final IndexRequest indexRequest;
+            private boolean indexingPressureIncremented;
+
+            private IndexRequestWithIndexingPressure(IndexRequest indexRequest) {
+                this.indexRequest = indexRequest;
+                this.indexingPressureIncremented = false;
+            }
+
+            private IndexRequest getIndexRequest() {
+                return indexRequest;
+            }
+
+            private boolean isIndexingPressureIncremented() {
+                return indexingPressureIncremented;
+            }
+
+            private void setIndexingPressureIncremented() {
+                this.indexingPressureIncremented = true;
+            }
+        }
+
+        private boolean incrementIndexingPressure(IndexRequestWithIndexingPressure indexRequest, int itemIndex) {
+            boolean success = true;
+            if (indexRequest.isIndexingPressureIncremented() == false) {
+                try {
+                    // Track operation count as one operation per document source update
+                    coordinatingIndexingPressure.increment(1, indexRequest.getIndexRequest().source().ramBytesUsed());
+                    indexRequest.setIndexingPressureIncremented();
+                } catch (EsRejectedExecutionException e) {
+                    addInferenceResponseFailure(
+                        itemIndex,
+                        new InferenceException(
+                            "Insufficient memory available to update source on document [" + indexRequest.getIndexRequest().id() + "]",
+                            e
+                        )
+                    );
+                    success = false;
+                }
+            }
+
+            return success;
+        }
+
         private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) {
             FieldInferenceResponseAccumulator acc = inferenceResults.get(id);
             if (acc == null) {
@@ -622,6 +693,7 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
                 inferenceFieldsMap.put(fieldName, result);
             }
 
+            BytesReference originalSource = indexRequest.source();
             if (useLegacyFormat) {
                 var newDocMap = indexRequest.sourceAsMap();
                 for (var entry : inferenceFieldsMap.entrySet()) {
@@ -634,6 +706,23 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
                     indexRequest.source(builder);
                 }
             }
+            long modifiedSourceSize = indexRequest.source().ramBytesUsed();
+
+            // Add the indexing pressure from the source modifications.
+            // Don't increment operation count because we count one source update as one operation, and we already accounted for those
+            // in addFieldInferenceRequests.
+            try {
+                coordinatingIndexingPressure.increment(0, modifiedSourceSize - originalSource.ramBytesUsed());
+            } catch (EsRejectedExecutionException e) {
+                indexRequest.source(originalSource, indexRequest.getContentType());
+                item.abort(
+                    item.index(),
+                    new InferenceException(
+                        "Insufficient memory available to insert inference results into document [" + indexRequest.id() + "]",
+                        e
+                    )
+                );
+            }
         }
     }
 

+ 480 - 10
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java

@@ -15,6 +15,7 @@ import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.bulk.BulkItemRequest;
 import org.elasticsearch.action.bulk.BulkItemResponse;
 import org.elasticsearch.action.bulk.BulkShardRequest;
+import org.elasticsearch.action.bulk.BulkShardResponse;
 import org.elasticsearch.action.bulk.TransportShardBulkAction;
 import org.elasticsearch.action.index.IndexRequest;
 import org.elasticsearch.action.support.ActionFilterChain;
@@ -26,19 +27,24 @@ import org.elasticsearch.cluster.metadata.IndexMetadata;
 import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
 import org.elasticsearch.cluster.metadata.Metadata;
 import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.CheckedBiFunction;
 import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.unit.ByteSizeValue;
+import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
 import org.elasticsearch.common.xcontent.XContentHelper;
 import org.elasticsearch.common.xcontent.support.XContentMapValues;
 import org.elasticsearch.index.IndexVersion;
+import org.elasticsearch.index.IndexingPressure;
 import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
 import org.elasticsearch.index.shard.ShardId;
 import org.elasticsearch.inference.ChunkInferenceInput;
 import org.elasticsearch.inference.ChunkedInference;
 import org.elasticsearch.inference.InferenceService;
 import org.elasticsearch.inference.InferenceServiceRegistry;
+import org.elasticsearch.inference.MinimalServiceSettings;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.inference.UnparsedModel;
@@ -48,6 +54,8 @@ import org.elasticsearch.tasks.Task;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.threadpool.TestThreadPool;
 import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
 import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xcontent.json.JsonXContent;
 import org.elasticsearch.xpack.core.XPackField;
@@ -72,13 +80,16 @@ import java.util.Set;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 
+import static org.elasticsearch.index.IndexingPressure.MAX_COORDINATING_BYTES;
 import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
 import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch;
+import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
 import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
 import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.getIndexRequestOrNull;
 import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getChunksFieldName;
 import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getOriginalTextFieldName;
-import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomChunkedInferenceEmbeddingSparse;
+import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapperTests.addSemanticTextInferenceResults;
+import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomChunkedInferenceEmbedding;
 import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText;
 import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput;
 import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.semanticTextFieldFromChunkedInferenceResults;
@@ -87,13 +98,23 @@ import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.notNullValue;
+import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.anyLong;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.ArgumentMatchers.longThat;
 import static org.mockito.Mockito.any;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
 public class ShardBulkInferenceActionFilterTests extends ESTestCase {
     private static final Object EXPLICIT_NULL = new Object();
+    private static final IndexingPressure NOOP_INDEXING_PRESSURE = new NoopIndexingPressure();
 
     private final boolean useLegacyFormat;
     private ThreadPool threadPool;
@@ -119,7 +140,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
 
     @SuppressWarnings({ "unchecked", "rawtypes" })
     public void testFilterNoop() throws Exception {
-        ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), useLegacyFormat, true);
+        ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), NOOP_INDEXING_PRESSURE, useLegacyFormat, true);
         CountDownLatch chainExecuted = new CountDownLatch(1);
         ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
             try {
@@ -145,7 +166,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
     @SuppressWarnings({ "unchecked", "rawtypes" })
     public void testLicenseInvalidForInference() throws InterruptedException {
         StaticModel model = StaticModel.createRandomInstance();
-        ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), useLegacyFormat, false);
+        ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), NOOP_INDEXING_PRESSURE, useLegacyFormat, false);
         CountDownLatch chainExecuted = new CountDownLatch(1);
         ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
             try {
@@ -186,6 +207,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
         ShardBulkInferenceActionFilter filter = createFilter(
             threadPool,
             Map.of(model.getInferenceEntityId(), model),
+            NOOP_INDEXING_PRESSURE,
             useLegacyFormat,
             true
         );
@@ -227,16 +249,17 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
 
     @SuppressWarnings({ "unchecked", "rawtypes" })
     public void testItemFailures() throws Exception {
-        StaticModel model = StaticModel.createRandomInstance();
+        StaticModel model = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
 
         ShardBulkInferenceActionFilter filter = createFilter(
             threadPool,
             Map.of(model.getInferenceEntityId(), model),
+            NOOP_INDEXING_PRESSURE,
             useLegacyFormat,
             true
         );
         model.putResult("I am a failure", new ChunkedInferenceError(new IllegalArgumentException("boom")));
-        model.putResult("I am a success", randomChunkedInferenceEmbeddingSparse(List.of("I am a success")));
+        model.putResult("I am a success", randomChunkedInferenceEmbedding(model, List.of("I am a success")));
         CountDownLatch chainExecuted = new CountDownLatch(1);
         ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
             try {
@@ -295,13 +318,14 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
 
     @SuppressWarnings({ "unchecked", "rawtypes" })
     public void testExplicitNull() throws Exception {
-        StaticModel model = StaticModel.createRandomInstance();
+        StaticModel model = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
         model.putResult("I am a failure", new ChunkedInferenceError(new IllegalArgumentException("boom")));
-        model.putResult("I am a success", randomChunkedInferenceEmbeddingSparse(List.of("I am a success")));
+        model.putResult("I am a success", randomChunkedInferenceEmbedding(model, List.of("I am a success")));
 
         ShardBulkInferenceActionFilter filter = createFilter(
             threadPool,
             Map.of(model.getInferenceEntityId(), model),
+            NOOP_INDEXING_PRESSURE,
             useLegacyFormat,
             true
         );
@@ -372,6 +396,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
         ShardBulkInferenceActionFilter filter = createFilter(
             threadPool,
             Map.of(model.getInferenceEntityId(), model),
+            NOOP_INDEXING_PRESSURE,
             useLegacyFormat,
             true
         );
@@ -444,7 +469,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
             modifiedRequests[id] = res[1];
         }
 
-        ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap, useLegacyFormat, true);
+        ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap, NOOP_INDEXING_PRESSURE, useLegacyFormat, true);
         CountDownLatch chainExecuted = new CountDownLatch(1);
         ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
             try {
@@ -474,10 +499,397 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
         awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
     }
 
+    @SuppressWarnings({ "unchecked", "rawtypes" })
+    public void testIndexingPressure() throws Exception {
+        final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure(Settings.EMPTY);
+        final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
+        final StaticModel denseModel = StaticModel.createRandomInstance(TaskType.TEXT_EMBEDDING);
+        final ShardBulkInferenceActionFilter filter = createFilter(
+            threadPool,
+            Map.of(sparseModel.getInferenceEntityId(), sparseModel, denseModel.getInferenceEntityId(), denseModel),
+            indexingPressure,
+            useLegacyFormat,
+            true
+        );
+
+        XContentBuilder doc0Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "a test value");
+        XContentBuilder doc1Source = IndexRequest.getXContentBuilder(XContentType.JSON, "dense_field", "another test value");
+        XContentBuilder doc2Source = IndexRequest.getXContentBuilder(
+            XContentType.JSON,
+            "sparse_field",
+            "a test value",
+            "dense_field",
+            "another test value"
+        );
+        XContentBuilder doc3Source = IndexRequest.getXContentBuilder(
+            XContentType.JSON,
+            "dense_field",
+            List.of("value one", "  ", "value two")
+        );
+        XContentBuilder doc4Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "   ");
+        XContentBuilder doc5Source = XContentFactory.contentBuilder(XContentType.JSON);
+        {
+            doc5Source.startObject();
+            if (useLegacyFormat == false) {
+                doc5Source.field("sparse_field", "a test value");
+            }
+            addSemanticTextInferenceResults(
+                useLegacyFormat,
+                doc5Source,
+                List.of(randomSemanticText(useLegacyFormat, "sparse_field", sparseModel, null, List.of("a test value"), XContentType.JSON))
+            );
+            doc5Source.endObject();
+        }
+        XContentBuilder doc0UpdateSource = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "an updated value");
+        XContentBuilder doc1UpdateSource = IndexRequest.getXContentBuilder(XContentType.JSON, "dense_field", null);
+
+        CountDownLatch chainExecuted = new CountDownLatch(1);
+        ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
+            try {
+                BulkShardRequest bulkShardRequest = (BulkShardRequest) request;
+                assertNull(bulkShardRequest.getInferenceFieldMap());
+                assertThat(bulkShardRequest.items().length, equalTo(10));
+
+                for (BulkItemRequest item : bulkShardRequest.items()) {
+                    assertNull(item.getPrimaryResponse());
+                }
+
+                IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
+                assertThat(coordinatingIndexingPressure, notNullValue());
+                verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc0Source));
+                verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1Source));
+                verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc2Source));
+                verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc3Source));
+                verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc4Source));
+                verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc0UpdateSource));
+                if (useLegacyFormat == false) {
+                    verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1UpdateSource));
+                }
+
+                verify(coordinatingIndexingPressure, times(useLegacyFormat ? 6 : 7)).increment(eq(0), longThat(l -> l > 0));
+
+                // Verify that the only times that increment is called are the times verified above
+                verify(coordinatingIndexingPressure, times(useLegacyFormat ? 12 : 14)).increment(anyInt(), anyLong());
+
+                // Verify that the coordinating indexing pressure is maintained through downstream action filters
+                verify(coordinatingIndexingPressure, never()).close();
+
+                // Call the listener once the request is successfully processed, like is done in the production code path
+                listener.onResponse(null);
+            } finally {
+                chainExecuted.countDown();
+            }
+        };
+        ActionListener actionListener = mock(ActionListener.class);
+        Task task = mock(Task.class);
+
+        Map<String, InferenceFieldMetadata> inferenceFieldMap = Map.of(
+            "sparse_field",
+            new InferenceFieldMetadata("sparse_field", sparseModel.getInferenceEntityId(), new String[] { "sparse_field" }, null),
+            "dense_field",
+            new InferenceFieldMetadata("dense_field", denseModel.getInferenceEntityId(), new String[] { "dense_field" }, null)
+        );
+
+        BulkItemRequest[] items = new BulkItemRequest[10];
+        items[0] = new BulkItemRequest(0, new IndexRequest("index").id("doc_0").source(doc0Source));
+        items[1] = new BulkItemRequest(1, new IndexRequest("index").id("doc_1").source(doc1Source));
+        items[2] = new BulkItemRequest(2, new IndexRequest("index").id("doc_2").source(doc2Source));
+        items[3] = new BulkItemRequest(3, new IndexRequest("index").id("doc_3").source(doc3Source));
+        items[4] = new BulkItemRequest(4, new IndexRequest("index").id("doc_4").source(doc4Source));
+        items[5] = new BulkItemRequest(5, new IndexRequest("index").id("doc_5").source(doc5Source));
+        items[6] = new BulkItemRequest(6, new IndexRequest("index").id("doc_6").source("non_inference_field", "yet another test value"));
+        items[7] = new BulkItemRequest(7, new UpdateRequest().doc(new IndexRequest("index").id("doc_0").source(doc0UpdateSource)));
+        items[8] = new BulkItemRequest(8, new UpdateRequest().doc(new IndexRequest("index").id("doc_1").source(doc1UpdateSource)));
+        items[9] = new BulkItemRequest(
+            9,
+            new UpdateRequest().doc(new IndexRequest("index").id("doc_3").source("non_inference_field", "yet another updated value"))
+        );
+
+        BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items);
+        request.setInferenceFieldMap(inferenceFieldMap);
+        filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain);
+        awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
+
+        IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
+        assertThat(coordinatingIndexingPressure, notNullValue());
+        verify(coordinatingIndexingPressure).close();
+    }
+
+    @SuppressWarnings("unchecked")
+    public void testIndexingPressureTripsOnInferenceRequestGeneration() throws Exception {
+        final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure(
+            Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), "1b").build()
+        );
+        final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
+        final ShardBulkInferenceActionFilter filter = createFilter(
+            threadPool,
+            Map.of(sparseModel.getInferenceEntityId(), sparseModel),
+            indexingPressure,
+            useLegacyFormat,
+            true
+        );
+
+        XContentBuilder doc1Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "bar");
+
+        CountDownLatch chainExecuted = new CountDownLatch(1);
+        ActionFilterChain<BulkShardRequest, BulkShardResponse> actionFilterChain = (task, action, request, listener) -> {
+            try {
+                assertNull(request.getInferenceFieldMap());
+                assertThat(request.items().length, equalTo(3));
+
+                assertNull(request.items()[0].getPrimaryResponse());
+                assertNull(request.items()[2].getPrimaryResponse());
+
+                BulkItemRequest doc1Request = request.items()[1];
+                BulkItemResponse doc1Response = doc1Request.getPrimaryResponse();
+                assertNotNull(doc1Response);
+                assertTrue(doc1Response.isFailed());
+                BulkItemResponse.Failure doc1Failure = doc1Response.getFailure();
+                assertThat(
+                    doc1Failure.getCause().getMessage(),
+                    containsString("Insufficient memory available to update source on document [doc_1]")
+                );
+                assertThat(doc1Failure.getCause().getCause(), instanceOf(EsRejectedExecutionException.class));
+                assertThat(doc1Failure.getStatus(), is(RestStatus.TOO_MANY_REQUESTS));
+
+                IndexRequest doc1IndexRequest = getIndexRequestOrNull(doc1Request.request());
+                assertThat(doc1IndexRequest, notNullValue());
+                assertThat(doc1IndexRequest.source(), equalTo(BytesReference.bytes(doc1Source)));
+
+                IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
+                assertThat(coordinatingIndexingPressure, notNullValue());
+                verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1Source));
+                verify(coordinatingIndexingPressure, times(1)).increment(anyInt(), anyLong());
+
+                // Verify that the coordinating indexing pressure is maintained through downstream action filters
+                verify(coordinatingIndexingPressure, never()).close();
+
+                // Call the listener once the request is successfully processed, like is done in the production code path
+                listener.onResponse(null);
+            } finally {
+                chainExecuted.countDown();
+            }
+        };
+        ActionListener<BulkShardResponse> actionListener = (ActionListener<BulkShardResponse>) mock(ActionListener.class);
+        Task task = mock(Task.class);
+
+        Map<String, InferenceFieldMetadata> inferenceFieldMap = Map.of(
+            "sparse_field",
+            new InferenceFieldMetadata("sparse_field", sparseModel.getInferenceEntityId(), new String[] { "sparse_field" }, null)
+        );
+
+        BulkItemRequest[] items = new BulkItemRequest[3];
+        items[0] = new BulkItemRequest(0, new IndexRequest("index").id("doc_0").source("non_inference_field", "foo"));
+        items[1] = new BulkItemRequest(1, new IndexRequest("index").id("doc_1").source(doc1Source));
+        items[2] = new BulkItemRequest(2, new IndexRequest("index").id("doc_2").source("non_inference_field", "baz"));
+
+        BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items);
+        request.setInferenceFieldMap(inferenceFieldMap);
+        filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain);
+        awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
+
+        IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
+        assertThat(coordinatingIndexingPressure, notNullValue());
+        verify(coordinatingIndexingPressure).close();
+    }
+
+    @SuppressWarnings("unchecked")
+    public void testIndexingPressureTripsOnInferenceResponseHandling() throws Exception {
+        final XContentBuilder doc1Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "bar");
+        final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure(
+            Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), (bytesUsed(doc1Source) + 1) + "b").build()
+        );
+
+        final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
+        sparseModel.putResult("bar", randomChunkedInferenceEmbedding(sparseModel, List.of("bar")));
+
+        final ShardBulkInferenceActionFilter filter = createFilter(
+            threadPool,
+            Map.of(sparseModel.getInferenceEntityId(), sparseModel),
+            indexingPressure,
+            useLegacyFormat,
+            true
+        );
+
+        CountDownLatch chainExecuted = new CountDownLatch(1);
+        ActionFilterChain<BulkShardRequest, BulkShardResponse> actionFilterChain = (task, action, request, listener) -> {
+            try {
+                assertNull(request.getInferenceFieldMap());
+                assertThat(request.items().length, equalTo(3));
+
+                assertNull(request.items()[0].getPrimaryResponse());
+                assertNull(request.items()[2].getPrimaryResponse());
+
+                BulkItemRequest doc1Request = request.items()[1];
+                BulkItemResponse doc1Response = doc1Request.getPrimaryResponse();
+                assertNotNull(doc1Response);
+                assertTrue(doc1Response.isFailed());
+                BulkItemResponse.Failure doc1Failure = doc1Response.getFailure();
+                assertThat(
+                    doc1Failure.getCause().getMessage(),
+                    containsString("Insufficient memory available to insert inference results into document [doc_1]")
+                );
+                assertThat(doc1Failure.getCause().getCause(), instanceOf(EsRejectedExecutionException.class));
+                assertThat(doc1Failure.getStatus(), is(RestStatus.TOO_MANY_REQUESTS));
+
+                IndexRequest doc1IndexRequest = getIndexRequestOrNull(doc1Request.request());
+                assertThat(doc1IndexRequest, notNullValue());
+                assertThat(doc1IndexRequest.source(), equalTo(BytesReference.bytes(doc1Source)));
+
+                IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
+                assertThat(coordinatingIndexingPressure, notNullValue());
+                verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1Source));
+                verify(coordinatingIndexingPressure).increment(eq(0), longThat(l -> l > 0));
+                verify(coordinatingIndexingPressure, times(2)).increment(anyInt(), anyLong());
+
+                // Verify that the coordinating indexing pressure is maintained through downstream action filters
+                verify(coordinatingIndexingPressure, never()).close();
+
+                // Call the listener once the request is successfully processed, like is done in the production code path
+                listener.onResponse(null);
+            } finally {
+                chainExecuted.countDown();
+            }
+        };
+        ActionListener<BulkShardResponse> actionListener = (ActionListener<BulkShardResponse>) mock(ActionListener.class);
+        Task task = mock(Task.class);
+
+        Map<String, InferenceFieldMetadata> inferenceFieldMap = Map.of(
+            "sparse_field",
+            new InferenceFieldMetadata("sparse_field", sparseModel.getInferenceEntityId(), new String[] { "sparse_field" }, null)
+        );
+
+        BulkItemRequest[] items = new BulkItemRequest[3];
+        items[0] = new BulkItemRequest(0, new IndexRequest("index").id("doc_0").source("non_inference_field", "foo"));
+        items[1] = new BulkItemRequest(1, new IndexRequest("index").id("doc_1").source(doc1Source));
+        items[2] = new BulkItemRequest(2, new IndexRequest("index").id("doc_2").source("non_inference_field", "baz"));
+
+        BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items);
+        request.setInferenceFieldMap(inferenceFieldMap);
+        filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain);
+        awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
+
+        IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
+        assertThat(coordinatingIndexingPressure, notNullValue());
+        verify(coordinatingIndexingPressure).close();
+    }
+
+    @SuppressWarnings("unchecked")
+    public void testIndexingPressurePartialFailure() throws Exception {
+        // Use different length strings so that doc 1 and doc 2 sources are different sizes
+        final XContentBuilder doc1Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "bar");
+        final XContentBuilder doc2Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "bazzz");
+
+        final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
+        final ChunkedInferenceEmbedding barEmbedding = randomChunkedInferenceEmbedding(sparseModel, List.of("bar"));
+        final ChunkedInferenceEmbedding bazzzEmbedding = randomChunkedInferenceEmbedding(sparseModel, List.of("bazzz"));
+        sparseModel.putResult("bar", barEmbedding);
+        sparseModel.putResult("bazzz", bazzzEmbedding);
+
+        CheckedBiFunction<List<String>, ChunkedInference, Long, IOException> estimateInferenceResultsBytes = (inputs, inference) -> {
+            SemanticTextField semanticTextField = semanticTextFieldFromChunkedInferenceResults(
+                useLegacyFormat,
+                "sparse_field",
+                sparseModel,
+                null,
+                inputs,
+                inference,
+                XContentType.JSON
+            );
+            XContentBuilder builder = XContentFactory.jsonBuilder();
+            semanticTextField.toXContent(builder, EMPTY_PARAMS);
+            return bytesUsed(builder);
+        };
+
+        final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure(
+            Settings.builder()
+                .put(
+                    MAX_COORDINATING_BYTES.getKey(),
+                    (bytesUsed(doc1Source) + bytesUsed(doc2Source) + estimateInferenceResultsBytes.apply(List.of("bar"), barEmbedding)
+                        + (estimateInferenceResultsBytes.apply(List.of("bazzz"), bazzzEmbedding) / 2)) + "b"
+                )
+                .build()
+        );
+
+        final ShardBulkInferenceActionFilter filter = createFilter(
+            threadPool,
+            Map.of(sparseModel.getInferenceEntityId(), sparseModel),
+            indexingPressure,
+            useLegacyFormat,
+            true
+        );
+
+        CountDownLatch chainExecuted = new CountDownLatch(1);
+        ActionFilterChain<BulkShardRequest, BulkShardResponse> actionFilterChain = (task, action, request, listener) -> {
+            try {
+                assertNull(request.getInferenceFieldMap());
+                assertThat(request.items().length, equalTo(4));
+
+                assertNull(request.items()[0].getPrimaryResponse());
+                assertNull(request.items()[1].getPrimaryResponse());
+                assertNull(request.items()[3].getPrimaryResponse());
+
+                BulkItemRequest doc2Request = request.items()[2];
+                BulkItemResponse doc2Response = doc2Request.getPrimaryResponse();
+                assertNotNull(doc2Response);
+                assertTrue(doc2Response.isFailed());
+                BulkItemResponse.Failure doc2Failure = doc2Response.getFailure();
+                assertThat(
+                    doc2Failure.getCause().getMessage(),
+                    containsString("Insufficient memory available to insert inference results into document [doc_2]")
+                );
+                assertThat(doc2Failure.getCause().getCause(), instanceOf(EsRejectedExecutionException.class));
+                assertThat(doc2Failure.getStatus(), is(RestStatus.TOO_MANY_REQUESTS));
+
+                IndexRequest doc2IndexRequest = getIndexRequestOrNull(doc2Request.request());
+                assertThat(doc2IndexRequest, notNullValue());
+                assertThat(doc2IndexRequest.source(), equalTo(BytesReference.bytes(doc2Source)));
+
+                IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
+                assertThat(coordinatingIndexingPressure, notNullValue());
+                verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1Source));
+                verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc2Source));
+                verify(coordinatingIndexingPressure, times(2)).increment(eq(0), longThat(l -> l > 0));
+                verify(coordinatingIndexingPressure, times(4)).increment(anyInt(), anyLong());
+
+                // Verify that the coordinating indexing pressure is maintained through downstream action filters
+                verify(coordinatingIndexingPressure, never()).close();
+
+                // Call the listener once the request is successfully processed, like is done in the production code path
+                listener.onResponse(null);
+            } finally {
+                chainExecuted.countDown();
+            }
+        };
+        ActionListener<BulkShardResponse> actionListener = (ActionListener<BulkShardResponse>) mock(ActionListener.class);
+        Task task = mock(Task.class);
+
+        Map<String, InferenceFieldMetadata> inferenceFieldMap = Map.of(
+            "sparse_field",
+            new InferenceFieldMetadata("sparse_field", sparseModel.getInferenceEntityId(), new String[] { "sparse_field" }, null)
+        );
+
+        BulkItemRequest[] items = new BulkItemRequest[4];
+        items[0] = new BulkItemRequest(0, new IndexRequest("index").id("doc_0").source("non_inference_field", "foo"));
+        items[1] = new BulkItemRequest(1, new IndexRequest("index").id("doc_1").source(doc1Source));
+        items[2] = new BulkItemRequest(2, new IndexRequest("index").id("doc_2").source(doc2Source));
+        items[3] = new BulkItemRequest(3, new IndexRequest("index").id("doc_3").source("non_inference_field", "baz"));
+
+        BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items);
+        request.setInferenceFieldMap(inferenceFieldMap);
+        filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain);
+        awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
+
+        IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
+        assertThat(coordinatingIndexingPressure, notNullValue());
+        verify(coordinatingIndexingPressure).close();
+    }
+
     @SuppressWarnings("unchecked")
     private static ShardBulkInferenceActionFilter createFilter(
         ThreadPool threadPool,
         Map<String, StaticModel> modelMap,
+        IndexingPressure indexingPressure,
         boolean useLegacyFormat,
         boolean isLicenseValidForInference
     ) {
@@ -503,6 +915,17 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
         };
         doAnswer(unparsedModelAnswer).when(modelRegistry).getModelWithSecrets(any(), any());
 
+        Answer<MinimalServiceSettings> minimalServiceSettingsAnswer = invocationOnMock -> {
+            String inferenceId = (String) invocationOnMock.getArguments()[0];
+            var model = modelMap.get(inferenceId);
+            if (model == null) {
+                throw new ResourceNotFoundException("model id [{}] not found", inferenceId);
+            }
+
+            return new MinimalServiceSettings(model);
+        };
+        doAnswer(minimalServiceSettingsAnswer).when(modelRegistry).getMinimalServiceSettings(any());
+
         InferenceService inferenceService = mock(InferenceService.class);
         Answer<?> chunkedInferAnswer = invocationOnMock -> {
             StaticModel model = (StaticModel) invocationOnMock.getArguments()[0];
@@ -544,7 +967,8 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
             createClusterService(useLegacyFormat),
             inferenceServiceRegistry,
             modelRegistry,
-            licenseState
+            licenseState,
+            indexingPressure
         );
     }
 
@@ -629,6 +1053,10 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
             new BulkItemRequest(requestId, new IndexRequest("index").source(expectedDocMap, requestContentType)) };
     }
 
+    private static long bytesUsed(XContentBuilder builder) {
+        return BytesReference.bytes(builder).ramBytesUsed();
+    }
+
     @SuppressWarnings({ "unchecked" })
     private static void assertInferenceResults(
         boolean useLegacyFormat,
@@ -693,7 +1121,11 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
         }
 
         public static StaticModel createRandomInstance() {
-            TestModel testModel = TestModel.createRandomInstance();
+            return createRandomInstance(randomFrom(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING));
+        }
+
+        public static StaticModel createRandomInstance(TaskType taskType) {
+            TestModel testModel = TestModel.createRandomInstance(taskType);
             return new StaticModel(
                 testModel.getInferenceEntityId(),
                 testModel.getTaskType(),
@@ -716,4 +1148,42 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
             return resultMap.containsKey(text);
         }
     }
+
+    private static class InstrumentedIndexingPressure extends IndexingPressure {
+        private Coordinating coordinating = null;
+
+        private InstrumentedIndexingPressure(Settings settings) {
+            super(settings);
+        }
+
+        private Coordinating getCoordinating() {
+            return coordinating;
+        }
+
+        @Override
+        public Coordinating createCoordinatingOperation(boolean forceExecution) {
+            coordinating = spy(super.createCoordinatingOperation(forceExecution));
+            return coordinating;
+        }
+    }
+
+    private static class NoopIndexingPressure extends IndexingPressure {
+        private NoopIndexingPressure() {
+            super(Settings.EMPTY);
+        }
+
+        @Override
+        public Coordinating createCoordinatingOperation(boolean forceExecution) {
+            return new NoopCoordinating(forceExecution);
+        }
+
+        private class NoopCoordinating extends Coordinating {
+            private NoopCoordinating(boolean forceExecution) {
+                super(forceExecution);
+            }
+
+            @Override
+            public void increment(int operations, long bytes) {}
+        }
+    }
 }

+ 11 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java

@@ -185,6 +185,17 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTex
         assertThat(ex.getMessage(), containsString("required [element_type] field is missing"));
     }
 
+    public static ChunkedInferenceEmbedding randomChunkedInferenceEmbedding(Model model, List<String> inputs) {
+        return switch (model.getTaskType()) {
+            case SPARSE_EMBEDDING -> randomChunkedInferenceEmbeddingSparse(inputs);
+            case TEXT_EMBEDDING -> switch (model.getServiceSettings().elementType()) {
+                case FLOAT -> randomChunkedInferenceEmbeddingFloat(model, inputs);
+                case BIT, BYTE -> randomChunkedInferenceEmbeddingByte(model, inputs);
+            };
+            default -> throw new AssertionError("invalid task type: " + model.getTaskType().name());
+        };
+    }
+
     public static ChunkedInferenceEmbedding randomChunkedInferenceEmbeddingByte(Model model, List<String> inputs) {
         DenseVectorFieldMapper.ElementType elementType = model.getServiceSettings().elementType();
         int embeddingLength = DenseVectorFieldMapperTestUtils.getEmbeddingLength(elementType, model.getServiceSettings().dimensions());