1
0
Эх сурвалжийг харах

[8.x] [Inference API] Add node-local rate limiting for the inference API (#120400) (#121251)

* [Inference API] Add node-local rate limiting for the inference API (#120400)

* Add node-local rate limiting for the inference API

* Fix integration tests by using new LocalStateInferencePlugin instead of InferencePlugin and adjust formatting.

* Correct feature flag name

* Add more docs, reorganize methods and make some methods package private

* Clarify comment in BaseInferenceActionRequest

* Fix wrong merge

* Fix checkstyle

* Fix checkstyle in tests

* Check that the service we want to the read the rate limit config for actually exists

* [CI] Auto commit changes from spotless

* checkStyle apply

* Update docs/changelog/120400.yaml

* Move rate limit division logic to RequestExecutorService

* Spotless apply

* Remove debug sout

* Adding a few suggestions

* Adam feedback

* Fix compilation error

* [CI] Auto commit changes from spotless

* Add BWC test case to InferenceActionRequestTests

* Add BWC test case to UnifiedCompletionActionRequestTests

* Update x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculator.java

Co-authored-by: Adam Demjen <demjened@gmail.com>

* Update x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculator.java

Co-authored-by: Adam Demjen <demjened@gmail.com>

* Remove addressed TODO

* Spotless apply

* Only use new rate limit specific feature flag

* Use ThreadLocalRandom

* [CI] Auto commit changes from spotless

* Use Randomness.get()

* [CI] Auto commit changes from spotless

* Fix import

* Use ConcurrentHashMap in InferenceServiceNodeLocalRateLimitCalculator

* Check for null value in getRateLimitAssignment and remove AtomicReference

* Remove newAssignments

* Up the default rate limit for completions

* Put deprecated feature flag back in

* Check feature flag in BaseTransportInferenceAction

* spotlessApply

* Export inference.common

* Do not export inference.common

* Provide noop rate limit calculator, if feature flag is disabled

* Add proper dependency injection

---------

Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>
Co-authored-by: Jonathan Buttner <jonathan.buttner@elastic.co>
Co-authored-by: Adam Demjen <demjened@gmail.com>

* Use .get(0) as getFirst() doesn't exist in 8.18 (probably JDK difference?)

---------

Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>
Co-authored-by: Jonathan Buttner <jonathan.buttner@elastic.co>
Co-authored-by: Adam Demjen <demjened@gmail.com>
Tim Grein 8 сар өмнө
parent
commit
f0a5e25fca
29 өөрчлөгдсөн 1015 нэмэгдсэн , 49 устгасан
  1. 5 0
      docs/changelog/120400.yaml
  2. 1 0
      server/src/main/java/org/elasticsearch/TransportVersions.java
  3. 31 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/BaseInferenceActionRequest.java
  4. 23 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java
  5. 20 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java
  6. 30 7
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java
  7. 154 24
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java
  8. 11 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java
  9. 11 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java
  10. 28 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceAPIClusterAwareRateLimitingFeature.java
  11. 197 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculator.java
  12. 18 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceServiceRateLimitCalculator.java
  13. 27 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/NoopNodeLocalRateLimitCalculator.java
  14. 19 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/RateLimitAssignment.java
  15. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/RateLimiter.java
  16. 5 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSender.java
  17. 2 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/RequestExecutor.java
  18. 4 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java
  19. 53 3
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java
  20. 2 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManager.java
  21. 2 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/Sender.java
  22. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java
  23. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java
  24. 19 3
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java
  25. 128 2
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java
  26. 11 2
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java
  27. 205 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculatorTests.java
  28. 5 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java
  29. 1 1
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettingsTests.java

+ 5 - 0
docs/changelog/120400.yaml

@@ -0,0 +1,5 @@
+pr: 120400
+summary: "[Inference API] Add node-local rate limiting for the inference API"
+area: Machine Learning
+type: feature
+issues: []

+ 1 - 0
server/src/main/java/org/elasticsearch/TransportVersions.java

@@ -176,6 +176,7 @@ public class TransportVersions {
     public static final TransportVersion RESOURCE_DEPRECATION_CHECKS = def(8_836_00_0);
     public static final TransportVersion LINEAR_RETRIEVER_SUPPORT = def(8_837_00_0);
     public static final TransportVersion TIMEOUT_GET_PARAM_FOR_RESOLVE_CLUSTER = def(8_838_00_0);
+    public static final TransportVersion INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING = def(8_839_00_0);
 
     /*
      * STOP! READ THIS FIRST! No, really,

+ 31 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/BaseInferenceActionRequest.java

@@ -7,20 +7,35 @@
 
 package org.elasticsearch.xpack.core.inference.action;
 
+import org.elasticsearch.TransportVersions;
 import org.elasticsearch.action.ActionRequest;
 import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.inference.TaskType;
 
 import java.io.IOException;
 
+/**
+ * Base class for inference action requests. Tracks request routing state to prevent potential routing loops
+ * and supports both streaming and non-streaming inference operations.
+ */
 public abstract class BaseInferenceActionRequest extends ActionRequest {
 
+    private boolean hasBeenRerouted;
+
     public BaseInferenceActionRequest() {
         super();
     }
 
     public BaseInferenceActionRequest(StreamInput in) throws IOException {
         super(in);
+        if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING)) {
+            this.hasBeenRerouted = in.readBoolean();
+        } else {
+            // For backwards compatibility, we treat all inference requests coming from ES nodes having
+            // a version pre-node-local-rate-limiting as already rerouted to maintain pre-node-local-rate-limiting behavior.
+            this.hasBeenRerouted = true;
+        }
     }
 
     public abstract boolean isStreaming();
@@ -28,4 +43,20 @@ public abstract class BaseInferenceActionRequest extends ActionRequest {
     public abstract TaskType getTaskType();
 
     public abstract String getInferenceEntityId();
+
+    public void setHasBeenRerouted(boolean hasBeenRerouted) {
+        this.hasBeenRerouted = hasBeenRerouted;
+    }
+
+    public boolean hasBeenRerouted() {
+        return hasBeenRerouted;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        super.writeTo(out);
+        if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING)) {
+            out.writeBoolean(hasBeenRerouted);
+        }
+    }
 }

+ 23 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java

@@ -386,6 +386,29 @@ public class InferenceActionRequestTests extends AbstractBWCWireSerializationTes
         assertThat(deserializedInstance.getInputType(), is(InputType.UNSPECIFIED));
     }
 
+    public void testWriteTo_WhenVersionIsBeforeAdaptiveRateLimiting_ShouldSetHasBeenReroutedToTrue() throws IOException {
+        var instance = new InferenceAction.Request(
+            TaskType.TEXT_EMBEDDING,
+            "model",
+            null,
+            List.of("input"),
+            Map.of(),
+            InputType.UNSPECIFIED,
+            InferenceAction.Request.DEFAULT_TIMEOUT,
+            false
+        );
+
+        InferenceAction.Request deserializedInstance = copyWriteable(
+            instance,
+            getNamedWriteableRegistry(),
+            instanceReader(),
+            TransportVersions.V_8_13_0
+        );
+
+        // Verify that hasBeenRerouted is true after deserializing a request coming from an older transport version
+        assertTrue(deserializedInstance.hasBeenRerouted());
+    }
+
     public void testGetInputTypeToWrite_ReturnsIngest_WhenInputTypeIsUnspecified_VersionBeforeUnspecifiedIntroduced() {
         assertThat(getInputTypeToWrite(InputType.UNSPECIFIED, TransportVersions.V_8_12_1), is(InputType.INGEST));
     }

+ 20 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java

@@ -8,6 +8,7 @@
 package org.elasticsearch.xpack.core.inference.action;
 
 import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.core.TimeValue;
@@ -65,6 +66,25 @@ public class UnifiedCompletionActionRequestTests extends AbstractBWCWireSerializ
         assertNull(request.validate());
     }
 
+    public void testWriteTo_WhenVersionIsBeforeAdaptiveRateLimiting_ShouldSetHasBeenReroutedToTrue() throws IOException {
+        var instance = new UnifiedCompletionAction.Request(
+            "model",
+            TaskType.ANY,
+            UnifiedCompletionRequest.of(List.of(UnifiedCompletionRequestTests.randomMessage())),
+            TimeValue.timeValueSeconds(10)
+        );
+
+        UnifiedCompletionAction.Request deserializedInstance = copyWriteable(
+            instance,
+            getNamedWriteableRegistry(),
+            instanceReader(),
+            TransportVersions.ELASTIC_INFERENCE_SERVICE_UNIFIED_CHAT_COMPLETIONS_INTEGRATION
+        );
+
+        // Verify that hasBeenRerouted is true after deserializing a request coming from an older transport version
+        assertTrue(deserializedInstance.hasBeenRerouted());
+    }
+
     @Override
     protected UnifiedCompletionAction.Request mutateInstanceForVersion(UnifiedCompletionAction.Request instance, TransportVersion version) {
         return instance;

+ 30 - 7
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

@@ -72,6 +72,9 @@ import org.elasticsearch.xpack.inference.action.TransportPutInferenceModelAction
 import org.elasticsearch.xpack.inference.action.TransportUnifiedCompletionInferenceAction;
 import org.elasticsearch.xpack.inference.action.TransportUpdateInferenceModelAction;
 import org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter;
+import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator;
+import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
+import org.elasticsearch.xpack.inference.common.NoopNodeLocalRateLimitCalculator;
 import org.elasticsearch.xpack.inference.common.Truncator;
 import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockRequestSender;
 import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
@@ -133,6 +136,7 @@ import java.util.function.Predicate;
 import java.util.function.Supplier;
 
 import static java.util.Collections.singletonList;
+import static org.elasticsearch.xpack.inference.common.InferenceAPIClusterAwareRateLimitingFeature.INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG;
 
 public class InferencePlugin extends Plugin
     implements
@@ -229,6 +233,7 @@ public class InferencePlugin extends Plugin
 
     @Override
     public Collection<?> createComponents(PluginServices services) {
+        var components = new ArrayList<>();
         var throttlerManager = new ThrottlerManager(settings, services.threadPool(), services.clusterService());
         var truncator = new Truncator(settings, services.clusterService());
         serviceComponents.set(new ServiceComponents(services.threadPool(), throttlerManager, settings, truncator));
@@ -297,20 +302,38 @@ public class InferencePlugin extends Plugin
 
         // This must be done after the HttpRequestSenderFactory is created so that the services can get the
         // reference correctly
-        var registry = new InferenceServiceRegistry(inferenceServices, factoryContext);
-        registry.init(services.client());
-        for (var service : registry.getServices().values()) {
+        var serviceRegistry = new InferenceServiceRegistry(inferenceServices, factoryContext);
+        serviceRegistry.init(services.client());
+        for (var service : serviceRegistry.getServices().values()) {
             service.defaultConfigIds().forEach(modelRegistry::addDefaultIds);
         }
-        inferenceServiceRegistry.set(registry);
+        inferenceServiceRegistry.set(serviceRegistry);
 
-        var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), registry, modelRegistry);
+        var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), serviceRegistry, modelRegistry);
         shardBulkInferenceActionFilter.set(actionFilter);
 
         var meterRegistry = services.telemetryProvider().getMeterRegistry();
-        var stats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry));
+        var inferenceStats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry));
+
+        components.add(serviceRegistry);
+        components.add(modelRegistry);
+        components.add(httpClientManager);
+        components.add(inferenceStats);
+
+        // Only add InferenceServiceNodeLocalRateLimitCalculator (which is a ClusterStateListener) for cluster aware rate limiting,
+        // if the rate limiting feature flags are enabled, otherwise provide noop implementation
+        InferenceServiceRateLimitCalculator calculator;
+        if (INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG.isEnabled()) {
+            calculator = new InferenceServiceNodeLocalRateLimitCalculator(services.clusterService(), serviceRegistry);
+        } else {
+            calculator = new NoopNodeLocalRateLimitCalculator();
+        }
+
+        // Add binding for interface -> implementation
+        components.add(new PluginComponentBinding<>(InferenceServiceRateLimitCalculator.class, calculator));
+        components.add(calculator);
 
-        return List.of(modelRegistry, registry, httpClientManager, stats);
+        return components;
     }
 
     @Override

+ 154 - 24
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java

@@ -13,6 +13,10 @@ import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.ActionFilters;
 import org.elasticsearch.action.support.HandledTransportAction;
+import org.elasticsearch.client.internal.node.NodeClient;
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.common.Randomness;
+import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
 import org.elasticsearch.common.xcontent.ChunkedToXContent;
@@ -27,24 +31,42 @@ import org.elasticsearch.license.LicenseUtils;
 import org.elasticsearch.license.XPackLicenseState;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.tasks.Task;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.transport.TransportException;
+import org.elasticsearch.transport.TransportResponseHandler;
 import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.xpack.core.XPackField;
 import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest;
 import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.inference.InferencePlugin;
 import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
 import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
+import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator;
+import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
 import org.elasticsearch.xpack.inference.registry.ModelRegistry;
 import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
 import org.elasticsearch.xpack.inference.telemetry.InferenceTimer;
 
+import java.io.IOException;
+import java.util.Random;
+import java.util.concurrent.Executor;
 import java.util.function.Supplier;
 import java.util.stream.Collectors;
 
 import static org.elasticsearch.core.Strings.format;
 import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
+import static org.elasticsearch.xpack.inference.common.InferenceAPIClusterAwareRateLimitingFeature.INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG;
 import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes;
 import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes;
 
+/**
+ * Base class for transport actions that handle inference requests.
+ * Works in conjunction with {@link InferenceServiceNodeLocalRateLimitCalculator} to
+ * route requests to specific nodes, iff they support "node-local" rate limiting, which is described in detail
+ * in {@link InferenceServiceNodeLocalRateLimitCalculator}.
+ *
+ * @param <Request> The specific type of inference request being handled
+ */
 public abstract class BaseTransportInferenceAction<Request extends BaseInferenceActionRequest> extends HandledTransportAction<
     Request,
     InferenceAction.Response> {
@@ -57,6 +79,11 @@ public abstract class BaseTransportInferenceAction<Request extends BaseInference
     private final InferenceServiceRegistry serviceRegistry;
     private final InferenceStats inferenceStats;
     private final StreamingTaskManager streamingTaskManager;
+    private final InferenceServiceRateLimitCalculator inferenceServiceRateLimitCalculator;
+    private final NodeClient nodeClient;
+    private final ThreadPool threadPool;
+    private final TransportService transportService;
+    private final Random random;
 
     public BaseTransportInferenceAction(
         String inferenceActionName,
@@ -67,7 +94,10 @@ public abstract class BaseTransportInferenceAction<Request extends BaseInference
         InferenceServiceRegistry serviceRegistry,
         InferenceStats inferenceStats,
         StreamingTaskManager streamingTaskManager,
-        Writeable.Reader<Request> requestReader
+        Writeable.Reader<Request> requestReader,
+        InferenceServiceRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator,
+        NodeClient nodeClient,
+        ThreadPool threadPool
     ) {
         super(inferenceActionName, transportService, actionFilters, requestReader, EsExecutors.DIRECT_EXECUTOR_SERVICE);
         this.licenseState = licenseState;
@@ -75,8 +105,24 @@ public abstract class BaseTransportInferenceAction<Request extends BaseInference
         this.serviceRegistry = serviceRegistry;
         this.inferenceStats = inferenceStats;
         this.streamingTaskManager = streamingTaskManager;
+        this.inferenceServiceRateLimitCalculator = inferenceServiceNodeLocalRateLimitCalculator;
+        this.nodeClient = nodeClient;
+        this.threadPool = threadPool;
+        this.transportService = transportService;
+        this.random = Randomness.get();
     }
 
+    protected abstract boolean isInvalidTaskTypeForInferenceEndpoint(Request request, UnparsedModel unparsedModel);
+
+    protected abstract ElasticsearchStatusException createInvalidTaskTypeException(Request request, UnparsedModel unparsedModel);
+
+    protected abstract void doInference(
+        Model model,
+        Request request,
+        InferenceService service,
+        ActionListener<InferenceServiceResults> listener
+    );
+
     @Override
     protected void doExecute(Task task, Request request, ActionListener<InferenceAction.Response> listener) {
         if (INFERENCE_API_FEATURE.check(licenseState) == false) {
@@ -87,31 +133,32 @@ public abstract class BaseTransportInferenceAction<Request extends BaseInference
         var timer = InferenceTimer.start();
 
         var getModelListener = ActionListener.wrap((UnparsedModel unparsedModel) -> {
-            var service = serviceRegistry.getService(unparsedModel.service());
+            var serviceName = unparsedModel.service();
+
             try {
-                validationHelper(service::isEmpty, () -> unknownServiceException(unparsedModel.service(), request.getInferenceEntityId()));
-                validationHelper(
-                    () -> request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false,
-                    () -> requestModelTaskTypeMismatchException(request.getTaskType(), unparsedModel.taskType())
-                );
-                validationHelper(
-                    () -> isInvalidTaskTypeForInferenceEndpoint(request, unparsedModel),
-                    () -> createInvalidTaskTypeException(request, unparsedModel)
-                );
+                validateRequest(request, unparsedModel);
             } catch (Exception e) {
                 recordMetrics(unparsedModel, timer, e);
                 listener.onFailure(e);
                 return;
             }
 
-            var model = service.get()
-                .parsePersistedConfigWithSecrets(
+            var service = serviceRegistry.getService(serviceName).get();
+            var routingDecision = determineRouting(serviceName, request, unparsedModel);
+
+            if (routingDecision.currentNodeShouldHandleRequest()) {
+                var model = service.parsePersistedConfigWithSecrets(
                     unparsedModel.inferenceEntityId(),
                     unparsedModel.taskType(),
                     unparsedModel.settings(),
                     unparsedModel.secrets()
                 );
-            inferOnServiceWithMetrics(model, request, service.get(), timer, listener);
+                inferOnServiceWithMetrics(model, request, service, timer, listener);
+            } else {
+                // Reroute request
+                request.setHasBeenRerouted(true);
+                rerouteRequest(request, listener, routingDecision.targetNode);
+            }
         }, e -> {
             try {
                 inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(e));
@@ -124,15 +171,95 @@ public abstract class BaseTransportInferenceAction<Request extends BaseInference
         modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener);
     }
 
+    private void validateRequest(Request request, UnparsedModel unparsedModel) {
+        var serviceName = unparsedModel.service();
+        var requestTaskType = request.getTaskType();
+        var service = serviceRegistry.getService(serviceName);
+
+        validationHelper(service::isEmpty, () -> unknownServiceException(serviceName, request.getInferenceEntityId()));
+        validationHelper(
+            () -> request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false,
+            () -> requestModelTaskTypeMismatchException(requestTaskType, unparsedModel.taskType())
+        );
+        validationHelper(
+            () -> isInvalidTaskTypeForInferenceEndpoint(request, unparsedModel),
+            () -> createInvalidTaskTypeException(request, unparsedModel)
+        );
+    }
+
+    private NodeRoutingDecision determineRouting(String serviceName, Request request, UnparsedModel unparsedModel) {
+        if (INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG.isEnabled() == false) {
+            return NodeRoutingDecision.handleLocally();
+        }
+
+        var modelTaskType = unparsedModel.taskType();
+
+        // Rerouting not supported or request was already rerouted
+        if (inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceName, modelTaskType) == false
+            || request.hasBeenRerouted()) {
+            return NodeRoutingDecision.handleLocally();
+        }
+
+        var rateLimitAssignment = inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceName, modelTaskType);
+
+        // No assignment yet
+        if (rateLimitAssignment == null) {
+            return NodeRoutingDecision.handleLocally();
+        }
+
+        var responsibleNodes = rateLimitAssignment.responsibleNodes();
+
+        // Empty assignment
+        if (responsibleNodes == null || responsibleNodes.isEmpty()) {
+            return NodeRoutingDecision.handleLocally();
+        }
+
+        var nodeToHandleRequest = responsibleNodes.get(random.nextInt(responsibleNodes.size()));
+        String localNodeId = nodeClient.getLocalNodeId();
+
+        // The drawn node is the current node
+        if (nodeToHandleRequest.getId().equals(localNodeId)) {
+            return NodeRoutingDecision.handleLocally();
+        }
+
+        // Reroute request
+        return NodeRoutingDecision.routeTo(nodeToHandleRequest);
+    }
+
     private static void validationHelper(Supplier<Boolean> validationFailure, Supplier<ElasticsearchStatusException> exceptionCreator) {
         if (validationFailure.get()) {
             throw exceptionCreator.get();
         }
     }
 
-    protected abstract boolean isInvalidTaskTypeForInferenceEndpoint(Request request, UnparsedModel unparsedModel);
-
-    protected abstract ElasticsearchStatusException createInvalidTaskTypeException(Request request, UnparsedModel unparsedModel);
+    private void rerouteRequest(Request request, ActionListener<InferenceAction.Response> listener, DiscoveryNode nodeToHandleRequest) {
+        transportService.sendRequest(
+            nodeToHandleRequest,
+            InferenceAction.NAME,
+            request,
+            new TransportResponseHandler<InferenceAction.Response>() {
+                @Override
+                public Executor executor() {
+                    return threadPool.executor(InferencePlugin.UTILITY_THREAD_POOL_NAME);
+                }
+
+                @Override
+                public void handleResponse(InferenceAction.Response response) {
+                    listener.onResponse(response);
+                }
+
+                @Override
+                public void handleException(TransportException exp) {
+                    listener.onFailure(exp);
+                }
+
+                @Override
+                public InferenceAction.Response read(StreamInput in) throws IOException {
+                    return new InferenceAction.Response(in);
+                }
+            }
+        );
+    }
 
     private void recordMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) {
         try {
@@ -185,13 +312,6 @@ public abstract class BaseTransportInferenceAction<Request extends BaseInference
         }
     }
 
-    protected abstract void doInference(
-        Model model,
-        Request request,
-        InferenceService service,
-        ActionListener<InferenceServiceResults> listener
-    );
-
     private ElasticsearchStatusException unsupportedStreamingTaskException(Request request, InferenceService service) {
         var supportedTasks = service.supportedStreamingTasks();
         if (supportedTasks.isEmpty()) {
@@ -259,4 +379,14 @@ public abstract class BaseTransportInferenceAction<Request extends BaseInference
             super.onComplete();
         }
     }
+
+    private record NodeRoutingDecision(boolean currentNodeShouldHandleRequest, DiscoveryNode targetNode) {
+        static NodeRoutingDecision handleLocally() {
+            return new NodeRoutingDecision(true, null);
+        }
+
+        static NodeRoutingDecision routeTo(DiscoveryNode node) {
+            return new NodeRoutingDecision(false, node);
+        }
+    }
 }

+ 11 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java

@@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference.action;
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.ActionFilters;
+import org.elasticsearch.client.internal.node.NodeClient;
 import org.elasticsearch.inference.InferenceService;
 import org.elasticsearch.inference.InferenceServiceRegistry;
 import org.elasticsearch.inference.InferenceServiceResults;
@@ -17,9 +18,11 @@ import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.UnparsedModel;
 import org.elasticsearch.injection.guice.Inject;
 import org.elasticsearch.license.XPackLicenseState;
+import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.xpack.core.inference.action.InferenceAction;
 import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
+import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
 import org.elasticsearch.xpack.inference.registry.ModelRegistry;
 import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
 
@@ -33,7 +36,10 @@ public class TransportInferenceAction extends BaseTransportInferenceAction<Infer
         ModelRegistry modelRegistry,
         InferenceServiceRegistry serviceRegistry,
         InferenceStats inferenceStats,
-        StreamingTaskManager streamingTaskManager
+        StreamingTaskManager streamingTaskManager,
+        InferenceServiceRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator,
+        NodeClient nodeClient,
+        ThreadPool threadPool
     ) {
         super(
             InferenceAction.NAME,
@@ -44,7 +50,10 @@ public class TransportInferenceAction extends BaseTransportInferenceAction<Infer
             serviceRegistry,
             inferenceStats,
             streamingTaskManager,
-            InferenceAction.Request::new
+            InferenceAction.Request::new,
+            inferenceServiceNodeLocalRateLimitCalculator,
+            nodeClient,
+            threadPool
         );
     }
 

+ 11 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java

@@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference.action;
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.ActionFilters;
+import org.elasticsearch.client.internal.node.NodeClient;
 import org.elasticsearch.inference.InferenceService;
 import org.elasticsearch.inference.InferenceServiceRegistry;
 import org.elasticsearch.inference.InferenceServiceResults;
@@ -19,9 +20,11 @@ import org.elasticsearch.inference.UnparsedModel;
 import org.elasticsearch.injection.guice.Inject;
 import org.elasticsearch.license.XPackLicenseState;
 import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
 import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
+import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
 import org.elasticsearch.xpack.inference.registry.ModelRegistry;
 import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
 
@@ -35,7 +38,10 @@ public class TransportUnifiedCompletionInferenceAction extends BaseTransportInfe
         ModelRegistry modelRegistry,
         InferenceServiceRegistry serviceRegistry,
         InferenceStats inferenceStats,
-        StreamingTaskManager streamingTaskManager
+        StreamingTaskManager streamingTaskManager,
+        InferenceServiceRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator,
+        NodeClient nodeClient,
+        ThreadPool threadPool
     ) {
         super(
             UnifiedCompletionAction.NAME,
@@ -46,7 +52,10 @@ public class TransportUnifiedCompletionInferenceAction extends BaseTransportInfe
             serviceRegistry,
             inferenceStats,
             streamingTaskManager,
-            UnifiedCompletionAction.Request::new
+            UnifiedCompletionAction.Request::new,
+            inferenceServiceNodeLocalRateLimitCalculator,
+            nodeClient,
+            threadPool
         );
     }
 

+ 28 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceAPIClusterAwareRateLimitingFeature.java

@@ -0,0 +1,28 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.common;
+
+import org.elasticsearch.common.util.FeatureFlag;
+import org.elasticsearch.xpack.inference.InferencePlugin;
+
+/**
+ * Cluster aware rate limiting feature flag. When the feature is complete and fully rolled out, this flag will be removed.
+ * Enable feature via JVM option: `-Des.inference_cluster_aware_rate_limiting_feature_flag_enabled=true`.
+ *
+ * This controls, whether {@link InferenceServiceNodeLocalRateLimitCalculator} gets instantiated and
+ * added as injectable {@link InferencePlugin} component.
+ */
+public class InferenceAPIClusterAwareRateLimitingFeature {
+
+    public static final FeatureFlag INFERENCE_API_CLUSTER_AWARE_RATE_LIMITING_FEATURE_FLAG = new FeatureFlag(
+        "inference_cluster_aware_rate_limiting"
+    );
+
+    private InferenceAPIClusterAwareRateLimitingFeature() {}
+
+}

+ 197 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculator.java

@@ -0,0 +1,197 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.common;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.cluster.ClusterChangedEvent;
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.inference.InferenceService;
+import org.elasticsearch.inference.InferenceServiceRegistry;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.injection.guice.Inject;
+import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest;
+import org.elasticsearch.xpack.inference.action.BaseTransportInferenceAction;
+import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
+import org.elasticsearch.xpack.inference.external.http.sender.Sender;
+import org.elasticsearch.xpack.inference.services.SenderService;
+import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.concurrent.ConcurrentHashMap;
+
+/**
+ * Note: {@link InferenceAPIClusterAwareRateLimitingFeature} needs to be enabled for this class to get
+ * instantiated inside {@link org.elasticsearch.xpack.inference.InferencePlugin} and be available via dependency injection.
+ *
+ * Calculates and manages node-local rate limits for inference services based on changes in the cluster topology.
+ * This calculator calculates a "node-local" rate-limit, which essentially divides the rate limit for a service/task type
+ * through the number of nodes, which got assigned to this service/task type pair. Without this calculator the rate limit stored
+ * in the inference endpoint configuration would get effectively multiplied by the number of nodes in a cluster (assuming a ~ uniform
+ * distribution of requests to the nodes in the cluster).
+ *
+ * The calculator works in conjunction with several other components:
+ * - {@link BaseTransportInferenceAction} - Uses the calculator to determine, whether to reroute a request or not
+ * - {@link BaseInferenceActionRequest} - Tracks, if the request (an instance of a subclass of {@link BaseInferenceActionRequest})
+ *   already got re-routed at least once
+ * - {@link HttpRequestSender} - Provides original rate limits that this calculator divides through the number of nodes
+ *   responsible for a service/task type
+ */
+public class InferenceServiceNodeLocalRateLimitCalculator implements InferenceServiceRateLimitCalculator {
+
+    public static final Integer DEFAULT_MAX_NODES_PER_GROUPING = 3;
+
+    /**
+     * Configuration mapping services to their task type rate limiting settings.
+     * Each service can have multiple configs defining:
+     * - Which task types support request re-routing and "node-local" rate limit calculation
+     * - How many nodes should handle requests for each task type, based on cluster size (dynamically calculated or statically provided)
+     **/
+    static final Map<String, Collection<NodeLocalRateLimitConfig>> SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS = Map.of(
+        ElasticInferenceService.NAME,
+        // TODO: should probably be a map/set
+        List.of(new NodeLocalRateLimitConfig(TaskType.SPARSE_EMBEDDING, (numNodesInCluster) -> DEFAULT_MAX_NODES_PER_GROUPING))
+    );
+
+    record NodeLocalRateLimitConfig(TaskType taskType, MaxNodesPerGroupingStrategy maxNodesPerGroupingStrategy) {}
+
+    @FunctionalInterface
+    private interface MaxNodesPerGroupingStrategy {
+
+        Integer calculate(Integer numberOfNodesInCluster);
+
+    }
+
+    private static final Logger logger = LogManager.getLogger(InferenceServiceNodeLocalRateLimitCalculator.class);
+
+    private final InferenceServiceRegistry serviceRegistry;
+
+    private final ConcurrentHashMap<String, Map<TaskType, RateLimitAssignment>> serviceAssignments;
+
+    @Inject
+    public InferenceServiceNodeLocalRateLimitCalculator(ClusterService clusterService, InferenceServiceRegistry serviceRegistry) {
+        clusterService.addListener(this);
+        this.serviceRegistry = serviceRegistry;
+        this.serviceAssignments = new ConcurrentHashMap<>();
+    }
+
+    @Override
+    public void clusterChanged(ClusterChangedEvent event) {
+        boolean clusterTopologyChanged = event.nodesChanged();
+
+        // TODO: feature flag per node? We should not reroute to nodes not having eis and/or the inference plugin enabled
+        // Every node should land on the same grouping by calculation, so no need to put anything into the cluster state
+        if (clusterTopologyChanged) {
+            updateAssignments(event);
+        }
+    }
+
+    public boolean isTaskTypeReroutingSupported(String serviceName, TaskType taskType) {
+        return SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.getOrDefault(serviceName, Collections.emptyList())
+            .stream()
+            .anyMatch(rateLimitConfig -> taskType.equals(rateLimitConfig.taskType));
+    }
+
+    public RateLimitAssignment getRateLimitAssignment(String service, TaskType taskType) {
+        var assignmentsPerTaskType = serviceAssignments.get(service);
+
+        if (assignmentsPerTaskType == null) {
+            return null;
+        }
+
+        return assignmentsPerTaskType.get(taskType);
+    }
+
+    /**
+     * Updates instances of {@link RateLimitAssignment} for each service and task type when the cluster topology changes.
+     * For each service and supported task type, calculates which nodes should handle requests
+     * and what their local rate limits should be per inference endpoint.
+     */
+    private void updateAssignments(ClusterChangedEvent event) {
+        var newClusterState = event.state();
+        var nodes = newClusterState.nodes().getAllNodes();
+
+        // Sort nodes by id (every node lands on the same result)
+        var sortedNodes = nodes.stream().sorted(Comparator.comparing(DiscoveryNode::getId)).toList();
+
+        // Sort inference services by name (every node lands on the same result)
+        var sortedServices = new ArrayList<>(serviceRegistry.getServices().values());
+        sortedServices.sort(Comparator.comparing(InferenceService::name));
+
+        for (String serviceName : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet()) {
+            Optional<InferenceService> service = serviceRegistry.getService(serviceName);
+
+            if (service.isPresent()) {
+                var inferenceService = service.get();
+
+                for (NodeLocalRateLimitConfig rateLimitConfig : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.get(serviceName)) {
+                    Map<TaskType, RateLimitAssignment> perTaskTypeAssignments = new HashMap<>();
+                    TaskType taskType = rateLimitConfig.taskType();
+
+                    // Calculate node assignments needed for re-routing
+                    var assignedNodes = calculateServiceAssignment(rateLimitConfig.maxNodesPerGroupingStrategy(), sortedNodes);
+
+                    // Update rate limits to be "node-local"
+                    var numAssignedNodes = assignedNodes.size();
+                    updateRateLimits(inferenceService, numAssignedNodes);
+
+                    perTaskTypeAssignments.put(taskType, new RateLimitAssignment(assignedNodes));
+                    serviceAssignments.put(serviceName, perTaskTypeAssignments);
+                }
+            } else {
+                logger.warn(
+                    "Service [{}] is configured for node-local rate limiting but was not found in the service registry",
+                    serviceName
+                );
+            }
+        }
+    }
+
+    private List<DiscoveryNode> calculateServiceAssignment(
+        MaxNodesPerGroupingStrategy maxNodesPerGroupingStrategy,
+        List<DiscoveryNode> sortedNodes
+    ) {
+        int numberOfNodes = sortedNodes.size();
+        int nodesPerGrouping = Math.min(numberOfNodes, maxNodesPerGroupingStrategy.calculate(numberOfNodes));
+
+        List<DiscoveryNode> assignedNodes = new ArrayList<>();
+
+        // TODO: here we can probably be smarter: if |num nodes in cluster| > |num nodes per task types|
+        // -> make sure a service provider is not assigned the same nodes for all task types; only relevant as soon as we support more task
+        // types
+        for (int j = 0; j < nodesPerGrouping; j++) {
+            var assignedNode = sortedNodes.get(j % numberOfNodes);
+            assignedNodes.add(assignedNode);
+        }
+
+        return assignedNodes;
+    }
+
+    private void updateRateLimits(InferenceService service, int responsibleNodes) {
+        if ((service instanceof SenderService) == false) {
+            return;
+        }
+
+        SenderService senderService = (SenderService) service;
+        Sender sender = senderService.getSender();
+        // TODO: this needs to take in service and task type as soon as multiple services/task types are supported
+        sender.updateRateLimitDivisor(responsibleNodes);
+    }
+
+    InferenceServiceRegistry serviceRegistry() {
+        return serviceRegistry;
+    }
+}

+ 18 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/InferenceServiceRateLimitCalculator.java

@@ -0,0 +1,18 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.common;
+
+import org.elasticsearch.cluster.ClusterStateListener;
+import org.elasticsearch.inference.TaskType;
+
+public interface InferenceServiceRateLimitCalculator extends ClusterStateListener {
+
+    boolean isTaskTypeReroutingSupported(String serviceName, TaskType taskType);
+
+    RateLimitAssignment getRateLimitAssignment(String service, TaskType taskType);
+}

+ 27 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/NoopNodeLocalRateLimitCalculator.java

@@ -0,0 +1,27 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.common;
+
+import org.elasticsearch.cluster.ClusterChangedEvent;
+import org.elasticsearch.inference.TaskType;
+
+public class NoopNodeLocalRateLimitCalculator implements InferenceServiceRateLimitCalculator {
+
+    @Override
+    public void clusterChanged(ClusterChangedEvent event) {
+        // Do nothing
+    }
+
+    public boolean isTaskTypeReroutingSupported(String serviceName, TaskType taskType) {
+        return false;
+    }
+
+    public RateLimitAssignment getRateLimitAssignment(String service, TaskType taskType) {
+        return null;
+    }
+}

+ 19 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/RateLimitAssignment.java

@@ -0,0 +1,19 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.common;
+
+import org.elasticsearch.cluster.node.DiscoveryNode;
+
+import java.util.List;
+
+/**
+ * Record for storing rate limit assignment information.
+ *
+ * @param responsibleNodes - nodes responsible for a certain service and task type
+ */
+public record RateLimitAssignment(List<DiscoveryNode> responsibleNodes) {}

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/RateLimiter.java

@@ -55,7 +55,7 @@ public class RateLimiter {
         setRate(accumulatedTokensLimit, tokensPerTimeUnit, unit);
     }
 
-    public final synchronized void setRate(double newAccumulatedTokensLimit, double newTokensPerTimeUnit, TimeUnit newUnit) {
+    public synchronized void setRate(double newAccumulatedTokensLimit, double newTokensPerTimeUnit, TimeUnit newUnit) {
         Objects.requireNonNull(newUnit);
 
         if (newAccumulatedTokensLimit < 0) {

+ 5 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSender.java

@@ -88,6 +88,11 @@ public class AmazonBedrockRequestSender implements Sender {
         );
     }
 
+    @Override
+    public void updateRateLimitDivisor(int rateLimitDivisor) {
+        executorService.updateRateLimitDivisor(rateLimitDivisor);
+    }
+
     @Override
     public void start() {
         if (started.compareAndSet(false, true)) {

+ 2 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/RequestExecutor.java

@@ -21,6 +21,8 @@ public interface RequestExecutor {
 
     void shutdown();
 
+    void updateRateLimitDivisor(int newDivisor);
+
     boolean isShutdown();
 
     boolean isTerminated();

+ 4 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java

@@ -111,6 +111,10 @@ public class HttpRequestSender implements Sender {
         }
     }
 
+    public void updateRateLimitDivisor(int rateLimitDivisor) {
+        service.updateRateLimitDivisor(rateLimitDivisor);
+    }
+
     private void waitForStartToComplete() {
         try {
             if (startCompleted.await(START_COMPLETED_WAIT_TIME.getSeconds(), TimeUnit.SECONDS) == false) {

+ 53 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java

@@ -19,6 +19,7 @@ import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.threadpool.Scheduler;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xpack.inference.common.AdjustableCapacityBlockingQueue;
+import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator;
 import org.elasticsearch.xpack.inference.common.RateLimiter;
 import org.elasticsearch.xpack.inference.external.http.RequestExecutor;
 import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
@@ -36,6 +37,7 @@ import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Supplier;
 
@@ -92,12 +94,22 @@ class RequestExecutorService implements RequestExecutor {
         RateLimiter create(double accumulatedTokensLimit, double tokensPerTimeUnit, TimeUnit unit);
     }
 
+    // TODO: for later (after 8.18)
+    // TODO: pass in divisor to RateLimiterCreator
+    // TODO: another map for service/task-type-key -> set of RateLimitingEndpointHandler (used for updates; update divisor and then update
+    // all endpoint handlers)
+    // TODO: one map for service/task-type-key -> divisor (this gets also read when we create an inference endpoint)
+    // TODO: divisor value read/writes need to be synchronized in some way
+
     // default for testing
     static final RateLimiterCreator DEFAULT_RATE_LIMIT_CREATOR = RateLimiter::new;
     private static final Logger logger = LogManager.getLogger(RequestExecutorService.class);
     private static final TimeValue RATE_LIMIT_GROUP_CLEANUP_INTERVAL = TimeValue.timeValueDays(1);
 
     private final ConcurrentMap<Object, RateLimitingEndpointHandler> rateLimitGroupings = new ConcurrentHashMap<>();
+    // TODO: add one atomic integer (number of nodes); also explain the assumption and why this works
+    // TODO: document that this impacts chat completion (and increase the default rate limit)
+    private final AtomicInteger rateLimitDivisor = new AtomicInteger(1);
     private final ThreadPool threadPool;
     private final CountDownLatch startupLatch;
     private final CountDownLatch terminationLatch = new CountDownLatch(1);
@@ -174,6 +186,19 @@ class RequestExecutorService implements RequestExecutor {
         return rateLimitGroupings.values().stream().mapToInt(RateLimitingEndpointHandler::queueSize).sum();
     }
 
+    @Override
+    public void updateRateLimitDivisor(int numResponsibleNodes) {
+        // in the unlikely case where we get an invalid value, we'll just ignore it
+        if (numResponsibleNodes <= 0) {
+            return;
+        }
+
+        rateLimitDivisor.set(numResponsibleNodes);
+        for (var rateLimitingEndpointHandler : rateLimitGroupings.values()) {
+            rateLimitingEndpointHandler.updateTokensPerTimeUnit(rateLimitDivisor.get());
+        }
+    }
+
     /**
      * Begin servicing tasks.
      * <p>
@@ -299,9 +324,12 @@ class RequestExecutorService implements RequestExecutor {
                 clock,
                 requestManager.rateLimitSettings(),
                 this::isShutdown,
-                rateLimiterCreator
+                rateLimiterCreator,
+                rateLimitDivisor.get()
             );
 
+            // TODO: add or create/compute if absent set for new map (service/task-type-key -> rate limit endpoint handler)
+
             endpointHandler.init();
             return endpointHandler;
         });
@@ -314,7 +342,7 @@ class RequestExecutorService implements RequestExecutor {
      * This allows many requests to be serialized if they are being sent too fast. If the rate limit has not been met they will be sent
      * as soon as a thread is available.
      */
-    private static class RateLimitingEndpointHandler {
+    static class RateLimitingEndpointHandler {
 
         private static final TimeValue NO_TASKS_AVAILABLE = TimeValue.MAX_VALUE;
         private static final TimeValue EXECUTED_A_TASK = TimeValue.ZERO;
@@ -329,6 +357,8 @@ class RequestExecutorService implements RequestExecutor {
         private final Clock clock;
         private final RateLimiter rateLimiter;
         private final RequestExecutorServiceSettings requestExecutorServiceSettings;
+        private final RateLimitSettings rateLimitSettings;
+        private final Long originalRequestsPerTimeUnit;
 
         RateLimitingEndpointHandler(
             String id,
@@ -338,7 +368,8 @@ class RequestExecutorService implements RequestExecutor {
             Clock clock,
             RateLimitSettings rateLimitSettings,
             Supplier<Boolean> isShutdownMethod,
-            RateLimiterCreator rateLimiterCreator
+            RateLimiterCreator rateLimiterCreator,
+            Integer rateLimitDivisor
         ) {
             this.requestExecutorServiceSettings = Objects.requireNonNull(settings);
             this.id = Objects.requireNonNull(id);
@@ -346,6 +377,8 @@ class RequestExecutorService implements RequestExecutor {
             this.requestSender = Objects.requireNonNull(requestSender);
             this.clock = Objects.requireNonNull(clock);
             this.isShutdownMethod = Objects.requireNonNull(isShutdownMethod);
+            this.rateLimitSettings = Objects.requireNonNull(rateLimitSettings);
+            this.originalRequestsPerTimeUnit = rateLimitSettings.requestsPerTimeUnit();
 
             Objects.requireNonNull(rateLimitSettings);
             Objects.requireNonNull(rateLimiterCreator);
@@ -355,12 +388,29 @@ class RequestExecutorService implements RequestExecutor {
                 rateLimitSettings.timeUnit()
             );
 
+            this.updateTokensPerTimeUnit(rateLimitDivisor);
         }
 
         public void init() {
             requestExecutorServiceSettings.registerQueueCapacityCallback(id, this::onCapacityChange);
         }
 
+        /**
+         * This method is solely called by {@link InferenceServiceNodeLocalRateLimitCalculator} to update
+         * rate limits, so they're "node-local".
+         * The general idea is described in {@link InferenceServiceNodeLocalRateLimitCalculator} in more detail.
+         *
+         * @param divisor - divisor to divide the initial requests per time unit by
+         */
+        public synchronized void updateTokensPerTimeUnit(Integer divisor) {
+            double updatedTokensPerTimeUnit = (double) originalRequestsPerTimeUnit / divisor;
+            rateLimiter.setRate(ACCUMULATED_TOKENS_LIMIT, updatedTokensPerTimeUnit, rateLimitSettings.timeUnit());
+        }
+
+        public String id() {
+            return id;
+        }
+
         private void onCapacityChange(int capacity) {
             logger.debug(() -> Strings.format("Executor service grouping [%s] setting queue capacity to [%s]", id, capacity));
 

+ 2 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManager.java

@@ -30,4 +30,6 @@ public interface RequestManager extends RateLimitable {
     // executePreparedRequest() which will execute all prepared requests aka sends the batch
 
     String inferenceEntityId();
+
+    // TODO: add service() and taskType()
 }

+ 2 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/Sender.java

@@ -27,6 +27,8 @@ public interface Sender extends Closeable {
         ActionListener<InferenceServiceResults> listener
     );
 
+    void updateRateLimitDivisor(int rateLimitDivisor);
+
     void sendWithoutQueuing(
         Logger logger,
         Request request,

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java

@@ -47,7 +47,7 @@ public abstract class SenderService implements InferenceService {
         this.serviceComponents = Objects.requireNonNull(serviceComponents);
     }
 
-    protected Sender getSender() {
+    public Sender getSender() {
         return sender;
     }
 

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java

@@ -36,7 +36,7 @@ public class ElasticInferenceServiceCompletionServiceSettings extends FilteredXC
     public static final String NAME = "elastic_inference_service_completion_service_settings";
 
     // TODO what value do we put here?
-    private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(240L);
+    private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(720L);
 
     public static ElasticInferenceServiceCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
         ValidationException validationException = new ValidationException();

+ 19 - 3
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java

@@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference.action;
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.ActionFilters;
+import org.elasticsearch.client.internal.node.NodeClient;
 import org.elasticsearch.common.xcontent.ChunkedToXContent;
 import org.elasticsearch.inference.InferenceService;
 import org.elasticsearch.inference.InferenceServiceRegistry;
@@ -21,11 +22,13 @@ import org.elasticsearch.inference.UnparsedModel;
 import org.elasticsearch.license.MockLicenseState;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest;
 import org.elasticsearch.xpack.core.inference.action.InferenceAction;
 import org.elasticsearch.xpack.inference.InferencePlugin;
 import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
+import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator;
 import org.elasticsearch.xpack.inference.registry.ModelRegistry;
 import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
 import org.junit.Before;
@@ -61,6 +64,9 @@ public abstract class BaseTransportInferenceActionTestCase<Request extends BaseI
     protected static final String inferenceId = "inferenceEntityId";
     protected InferenceServiceRegistry serviceRegistry;
     protected InferenceStats inferenceStats;
+    protected InferenceServiceNodeLocalRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator;
+    protected TransportService transportService;
+    protected NodeClient nodeClient;
 
     public BaseTransportInferenceActionTestCase(TaskType taskType) {
         this.taskType = taskType;
@@ -69,13 +75,17 @@ public abstract class BaseTransportInferenceActionTestCase<Request extends BaseI
     @Before
     public void setUp() throws Exception {
         super.setUp();
-        TransportService transportService = mock();
         ActionFilters actionFilters = mock();
+        ThreadPool threadPool = mock();
+        nodeClient = mock();
+        transportService = mock();
+        inferenceServiceNodeLocalRateLimitCalculator = mock();
         licenseState = mock();
         modelRegistry = mock();
         serviceRegistry = mock();
         inferenceStats = new InferenceStats(mock(), mock());
         streamingTaskManager = mock();
+
         action = createAction(
             transportService,
             actionFilters,
@@ -83,7 +93,10 @@ public abstract class BaseTransportInferenceActionTestCase<Request extends BaseI
             modelRegistry,
             serviceRegistry,
             inferenceStats,
-            streamingTaskManager
+            streamingTaskManager,
+            inferenceServiceNodeLocalRateLimitCalculator,
+            nodeClient,
+            threadPool
         );
 
         mockValidLicenseState();
@@ -96,7 +109,10 @@ public abstract class BaseTransportInferenceActionTestCase<Request extends BaseI
         ModelRegistry modelRegistry,
         InferenceServiceRegistry serviceRegistry,
         InferenceStats inferenceStats,
-        StreamingTaskManager streamingTaskManager
+        StreamingTaskManager streamingTaskManager,
+        InferenceServiceNodeLocalRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator,
+        NodeClient nodeClient,
+        ThreadPool threadPool
     );
 
     protected abstract Request createRequest();

+ 128 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java

@@ -8,16 +8,32 @@
 package org.elasticsearch.xpack.inference.action;
 
 import org.elasticsearch.action.support.ActionFilters;
+import org.elasticsearch.client.internal.node.NodeClient;
+import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.inference.InferenceServiceRegistry;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.license.MockLicenseState;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.transport.TransportException;
+import org.elasticsearch.transport.TransportResponseHandler;
 import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.xpack.core.inference.action.InferenceAction;
 import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
+import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator;
+import org.elasticsearch.xpack.inference.common.RateLimitAssignment;
 import org.elasticsearch.xpack.inference.registry.ModelRegistry;
 import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
 
+import java.util.List;
+
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.ArgumentMatchers.same;
+import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 
 public class TransportInferenceActionTests extends BaseTransportInferenceActionTestCase<InferenceAction.Request> {
 
@@ -33,7 +49,10 @@ public class TransportInferenceActionTests extends BaseTransportInferenceActionT
         ModelRegistry modelRegistry,
         InferenceServiceRegistry serviceRegistry,
         InferenceStats inferenceStats,
-        StreamingTaskManager streamingTaskManager
+        StreamingTaskManager streamingTaskManager,
+        InferenceServiceNodeLocalRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator,
+        NodeClient nodeClient,
+        ThreadPool threadPool
     ) {
         return new TransportInferenceAction(
             transportService,
@@ -42,7 +61,10 @@ public class TransportInferenceActionTests extends BaseTransportInferenceActionT
             modelRegistry,
             serviceRegistry,
             inferenceStats,
-            streamingTaskManager
+            streamingTaskManager,
+            inferenceServiceNodeLocalRateLimitCalculator,
+            nodeClient,
+            threadPool
         );
     }
 
@@ -50,4 +72,108 @@ public class TransportInferenceActionTests extends BaseTransportInferenceActionT
     protected InferenceAction.Request createRequest() {
         return mock();
     }
+
+    public void testNoRerouting_WhenTaskTypeNotSupported() {
+        TaskType unsupportedTaskType = TaskType.COMPLETION;
+        mockService(listener -> listener.onResponse(mock()));
+
+        when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, unsupportedTaskType)).thenReturn(false);
+
+        var listener = doExecute(unsupportedTaskType);
+
+        verify(listener).onResponse(any());
+        // Verify request was handled locally (not rerouted using TransportService)
+        verify(transportService, never()).sendRequest(any(), any(), any(), any());
+    }
+
+    public void testNoRerouting_WhenNoGroupingCalculatedYet() {
+        mockService(listener -> listener.onResponse(mock()));
+
+        when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
+        when(inferenceServiceNodeLocalRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(null);
+
+        var listener = doExecute(taskType);
+
+        verify(listener).onResponse(any());
+        // Verify request was handled locally (not rerouted using TransportService)
+        verify(transportService, never()).sendRequest(any(), any(), any(), any());
+    }
+
+    public void testNoRerouting_WhenEmptyNodeList() {
+        mockService(listener -> listener.onResponse(mock()));
+
+        when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
+        when(inferenceServiceNodeLocalRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(
+            new RateLimitAssignment(List.of())
+        );
+
+        var listener = doExecute(taskType);
+
+        verify(listener).onResponse(any());
+        // Verify request was handled locally (not rerouted using TransportService)
+        verify(transportService, never()).sendRequest(any(), any(), any(), any());
+    }
+
+    public void testRerouting_ToOtherNode() {
+        DiscoveryNode otherNode = mock(DiscoveryNode.class);
+        when(otherNode.getId()).thenReturn("other-node");
+
+        // The local node is different to the "other-node" responsible for serviceId
+        when(nodeClient.getLocalNodeId()).thenReturn("local-node");
+        when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
+        // Requests for serviceId are always routed to "other-node"
+        var assignment = new RateLimitAssignment(List.of(otherNode));
+        when(inferenceServiceNodeLocalRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment);
+
+        mockService(listener -> listener.onResponse(mock()));
+        var listener = doExecute(taskType);
+
+        // Verify request was rerouted
+        verify(transportService).sendRequest(same(otherNode), eq(InferenceAction.NAME), any(), any());
+        // Verify local execution didn't happen
+        verify(listener, never()).onResponse(any());
+    }
+
+    public void testRerouting_ToLocalNode_WithoutGoingThroughTransportLayerAgain() {
+        DiscoveryNode localNode = mock(DiscoveryNode.class);
+        String localNodeId = "local-node";
+        when(localNode.getId()).thenReturn(localNodeId);
+
+        // The local node is the only one responsible for serviceId
+        when(nodeClient.getLocalNodeId()).thenReturn(localNodeId);
+        when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
+        var assignment = new RateLimitAssignment(List.of(localNode));
+        when(inferenceServiceNodeLocalRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment);
+
+        mockService(listener -> listener.onResponse(mock()));
+        var listener = doExecute(taskType);
+
+        verify(listener).onResponse(any());
+        // Verify request was handled locally (not rerouted using TransportService)
+        verify(transportService, never()).sendRequest(any(), any(), any(), any());
+    }
+
+    public void testRerouting_HandlesTransportException_FromOtherNode() {
+        DiscoveryNode otherNode = mock(DiscoveryNode.class);
+        when(otherNode.getId()).thenReturn("other-node");
+
+        when(nodeClient.getLocalNodeId()).thenReturn("local-node");
+        when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
+        var assignment = new RateLimitAssignment(List.of(otherNode));
+        when(inferenceServiceNodeLocalRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment);
+
+        mockService(listener -> listener.onResponse(mock()));
+
+        TransportException expectedException = new TransportException("Failed to route");
+        doAnswer(invocation -> {
+            TransportResponseHandler<?> handler = invocation.getArgument(3);
+            handler.handleException(expectedException);
+            return null;
+        }).when(transportService).sendRequest(any(), any(), any(), any());
+
+        var listener = doExecute(taskType);
+
+        // Verify exception was propagated from "other-node" to "local-node"
+        verify(listener).onFailure(same(expectedException));
+    }
 }

+ 11 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java

@@ -9,13 +9,16 @@ package org.elasticsearch.xpack.inference.action;
 
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.support.ActionFilters;
+import org.elasticsearch.client.internal.node.NodeClient;
 import org.elasticsearch.inference.InferenceServiceRegistry;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.license.MockLicenseState;
 import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
 import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
+import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator;
 import org.elasticsearch.xpack.inference.registry.ModelRegistry;
 import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
 
@@ -45,7 +48,10 @@ public class TransportUnifiedCompletionActionTests extends BaseTransportInferenc
         ModelRegistry modelRegistry,
         InferenceServiceRegistry serviceRegistry,
         InferenceStats inferenceStats,
-        StreamingTaskManager streamingTaskManager
+        StreamingTaskManager streamingTaskManager,
+        InferenceServiceNodeLocalRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator,
+        NodeClient nodeClient,
+        ThreadPool threadPool
     ) {
         return new TransportUnifiedCompletionInferenceAction(
             transportService,
@@ -54,7 +60,10 @@ public class TransportUnifiedCompletionActionTests extends BaseTransportInferenc
             modelRegistry,
             serviceRegistry,
             inferenceStats,
-            streamingTaskManager
+            streamingTaskManager,
+            inferenceServiceNodeLocalRateLimitCalculator,
+            nodeClient,
+            threadPool
         );
     }
 

+ 205 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/InferenceServiceNodeLocalRateLimitCalculatorTests.java

@@ -0,0 +1,205 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.common;
+
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.plugins.Plugin;
+import org.elasticsearch.test.ESIntegTestCase;
+import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
+import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
+import org.elasticsearch.xpack.inference.services.SenderService;
+import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Set;
+
+import static org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator.DEFAULT_MAX_NODES_PER_GROUPING;
+import static org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator.SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS;
+import static org.hamcrest.Matchers.equalTo;
+
+@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.SUITE, numDataNodes = 0)
+public class InferenceServiceNodeLocalRateLimitCalculatorTests extends ESIntegTestCase {
+
+    public void setUp() throws Exception {
+        super.setUp();
+    }
+
+    public void testInitialClusterGrouping_Correct() {
+        // Start with 2-5 nodes
+        var numNodes = randomIntBetween(2, 5);
+        var nodeNames = internalCluster().startNodes(numNodes);
+        ensureStableCluster(numNodes);
+
+        RateLimitAssignment firstAssignment = null;
+
+        for (String nodeName : nodeNames) {
+            var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeName);
+
+            // Check first node's assignments
+            if (firstAssignment == null) {
+                // Get assignment for a specific service (e.g., EIS)
+                firstAssignment = calculator.getRateLimitAssignment(ElasticInferenceService.NAME, TaskType.SPARSE_EMBEDDING);
+
+                assertNotNull(firstAssignment);
+                // Verify there are assignments for this service
+                assertFalse(firstAssignment.responsibleNodes().isEmpty());
+            } else {
+                // Verify other nodes see the same assignment
+                var currentAssignment = calculator.getRateLimitAssignment(ElasticInferenceService.NAME, TaskType.SPARSE_EMBEDDING);
+                assertEquals(firstAssignment, currentAssignment);
+            }
+        }
+    }
+
+    public void testNumberOfNodesPerGroup_Decreases_When_NodeLeavesCluster() throws IOException {
+        // Start with 3-5 nodes
+        var numNodes = randomIntBetween(3, 5);
+        var nodeNames = internalCluster().startNodes(numNodes);
+        ensureStableCluster(numNodes);
+
+        var nodeLeftInCluster = nodeNames.get(0);
+        var currentNumberOfNodes = numNodes;
+
+        // Stop all nodes except one
+        for (String nodeName : nodeNames) {
+            if (nodeName.equals(nodeLeftInCluster)) {
+                continue;
+            }
+            internalCluster().stopNode(nodeName);
+            currentNumberOfNodes--;
+            ensureStableCluster(currentNumberOfNodes);
+        }
+
+        var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeLeftInCluster);
+
+        Set<String> supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet();
+
+        // Check assignments for each supported service
+        for (var service : supportedServices) {
+            var assignment = calculator.getRateLimitAssignment(service, TaskType.SPARSE_EMBEDDING);
+
+            assertNotNull(assignment);
+            // Should have exactly one responsible node
+            assertEquals(1, assignment.responsibleNodes().size());
+            // That node should be our remaining node
+            assertEquals(nodeLeftInCluster, assignment.responsibleNodes().get(0).getName());
+        }
+    }
+
+    public void testGrouping_RespectsMaxNodesPerGroupingLimit() {
+        // Start with more nodes possible per grouping
+        var numNodes = DEFAULT_MAX_NODES_PER_GROUPING + randomIntBetween(1, 3);
+        var nodeNames = internalCluster().startNodes(numNodes);
+        ensureStableCluster(numNodes);
+
+        var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeNames.get(0));
+
+        Set<String> supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet();
+
+        for (var service : supportedServices) {
+            var assignment = calculator.getRateLimitAssignment(service, TaskType.SPARSE_EMBEDDING);
+
+            assertNotNull(assignment);
+            assertThat(DEFAULT_MAX_NODES_PER_GROUPING, equalTo(assignment.responsibleNodes().size()));
+        }
+    }
+
+    public void testInitialRateLimitsCalculation_Correct() throws IOException {
+        // Start with max nodes per grouping (=3)
+        int numNodes = DEFAULT_MAX_NODES_PER_GROUPING;
+        var nodeNames = internalCluster().startNodes(numNodes);
+        ensureStableCluster(numNodes);
+
+        var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeNames.get(0));
+
+        Set<String> supportedServices = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet();
+
+        for (var serviceName : supportedServices) {
+            try (var serviceRegistry = calculator.serviceRegistry()) {
+                var serviceOptional = serviceRegistry.getService(serviceName);
+                assertTrue(serviceOptional.isPresent());
+                var service = serviceOptional.get();
+
+                if ((service instanceof SenderService senderService)) {
+                    var sender = senderService.getSender();
+                    if (sender instanceof HttpRequestSender httpSender) {
+                        var assignment = calculator.getRateLimitAssignment(service.name(), TaskType.SPARSE_EMBEDDING);
+
+                        assertNotNull(assignment);
+                        assertThat(DEFAULT_MAX_NODES_PER_GROUPING, equalTo(assignment.responsibleNodes().size()));
+                    }
+                }
+            }
+
+        }
+    }
+
+    public void testRateLimits_Decrease_OnNodeJoin() {
+        // Start with 2 nodes
+        var initialNodes = 2;
+        var nodeNames = internalCluster().startNodes(initialNodes);
+        ensureStableCluster(initialNodes);
+
+        var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeNames.get(0));
+
+        for (var serviceName : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet()) {
+            var configs = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.get(serviceName);
+            for (var config : configs) {
+                // Get initial assignments and rate limits
+                var initialAssignment = calculator.getRateLimitAssignment(serviceName, config.taskType());
+                assertEquals(2, initialAssignment.responsibleNodes().size());
+
+                // Add a new node
+                internalCluster().startNode();
+                ensureStableCluster(initialNodes + 1);
+
+                // Get updated assignments
+                var updatedAssignment = calculator.getRateLimitAssignment(serviceName, config.taskType());
+
+                // Verify number of responsible nodes increased
+                assertEquals(3, updatedAssignment.responsibleNodes().size());
+            }
+        }
+    }
+
+    public void testRateLimits_Increase_OnNodeLeave() throws IOException {
+        // Start with max nodes per grouping (=3)
+        int numNodes = DEFAULT_MAX_NODES_PER_GROUPING;
+        var nodeNames = internalCluster().startNodes(numNodes);
+        ensureStableCluster(numNodes);
+
+        var calculator = internalCluster().getInstance(InferenceServiceNodeLocalRateLimitCalculator.class, nodeNames.get(0));
+
+        for (var serviceName : SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.keySet()) {
+            var configs = SERVICE_NODE_LOCAL_RATE_LIMIT_CONFIGS.get(serviceName);
+            for (var config : configs) {
+                // Get initial assignments and rate limits
+                var initialAssignment = calculator.getRateLimitAssignment(serviceName, config.taskType());
+                assertThat(DEFAULT_MAX_NODES_PER_GROUPING, equalTo(initialAssignment.responsibleNodes().size()));
+
+                // Remove a node
+                var nodeToRemove = nodeNames.get(numNodes - 1);
+                internalCluster().stopNode(nodeToRemove);
+                ensureStableCluster(numNodes - 1);
+
+                // Get updated assignments
+                var updatedAssignment = calculator.getRateLimitAssignment(serviceName, config.taskType());
+
+                // Verify number of responsible nodes decreased
+                assertThat(2, equalTo(updatedAssignment.responsibleNodes().size()));
+            }
+        }
+    }
+
+    @Override
+    protected Collection<Class<? extends Plugin>> nodePlugins() {
+        return Arrays.asList(LocalStateInferencePlugin.class);
+    }
+}

+ 5 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java

@@ -63,6 +63,11 @@ public class AmazonBedrockMockRequestSender implements Sender {
         // do nothing
     }
 
+    @Override
+    public void updateRateLimitDivisor(int rateLimitDivisor) {
+        // do nothing
+    }
+
     @Override
     public void send(
         RequestManager requestCreator,

+ 1 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettingsTests.java

@@ -53,7 +53,7 @@ public class ElasticInferenceServiceCompletionServiceSettingsTests extends Abstr
             ConfigurationParseContext.REQUEST
         );
 
-        assertThat(serviceSettings, is(new ElasticInferenceServiceCompletionServiceSettings(modelId, new RateLimitSettings(240L))));
+        assertThat(serviceSettings, is(new ElasticInferenceServiceCompletionServiceSettings(modelId, new RateLimitSettings(720L))));
     }
 
     public void testFromMap_MissingModelId_ThrowsException() {