|  | @@ -19,7 +19,7 @@ import org.elasticsearch.transport.TransportResponseHandler;
 | 
	
		
			
				|  |  |  import org.elasticsearch.transport.TransportService;
 | 
	
		
			
				|  |  |  import org.elasticsearch.xpack.core.inference.action.InferenceAction;
 | 
	
		
			
				|  |  |  import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
 | 
	
		
			
				|  |  | -import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator;
 | 
	
		
			
				|  |  | +import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
 | 
	
		
			
				|  |  |  import org.elasticsearch.xpack.inference.common.RateLimitAssignment;
 | 
	
		
			
				|  |  |  import org.elasticsearch.xpack.inference.registry.ModelRegistry;
 | 
	
		
			
				|  |  |  import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
 | 
	
	
		
			
				|  | @@ -50,7 +50,7 @@ public class TransportInferenceActionTests extends BaseTransportInferenceActionT
 | 
	
		
			
				|  |  |          InferenceServiceRegistry serviceRegistry,
 | 
	
		
			
				|  |  |          InferenceStats inferenceStats,
 | 
	
		
			
				|  |  |          StreamingTaskManager streamingTaskManager,
 | 
	
		
			
				|  |  | -        InferenceServiceNodeLocalRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator,
 | 
	
		
			
				|  |  | +        InferenceServiceRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator,
 | 
	
		
			
				|  |  |          NodeClient nodeClient,
 | 
	
		
			
				|  |  |          ThreadPool threadPool
 | 
	
		
			
				|  |  |      ) {
 | 
	
	
		
			
				|  | @@ -77,7 +77,7 @@ public class TransportInferenceActionTests extends BaseTransportInferenceActionT
 | 
	
		
			
				|  |  |          TaskType unsupportedTaskType = TaskType.COMPLETION;
 | 
	
		
			
				|  |  |          mockService(listener -> listener.onResponse(mock()));
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, unsupportedTaskType)).thenReturn(false);
 | 
	
		
			
				|  |  | +        when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, unsupportedTaskType)).thenReturn(false);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          var listener = doExecute(unsupportedTaskType);
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -89,8 +89,8 @@ public class TransportInferenceActionTests extends BaseTransportInferenceActionT
 | 
	
		
			
				|  |  |      public void testNoRerouting_WhenNoGroupingCalculatedYet() {
 | 
	
		
			
				|  |  |          mockService(listener -> listener.onResponse(mock()));
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
 | 
	
		
			
				|  |  | -        when(inferenceServiceNodeLocalRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(null);
 | 
	
		
			
				|  |  | +        when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
 | 
	
		
			
				|  |  | +        when(inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(null);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          var listener = doExecute(taskType);
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -102,8 +102,8 @@ public class TransportInferenceActionTests extends BaseTransportInferenceActionT
 | 
	
		
			
				|  |  |      public void testNoRerouting_WhenEmptyNodeList() {
 | 
	
		
			
				|  |  |          mockService(listener -> listener.onResponse(mock()));
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -        when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
 | 
	
		
			
				|  |  | -        when(inferenceServiceNodeLocalRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(
 | 
	
		
			
				|  |  | +        when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
 | 
	
		
			
				|  |  | +        when(inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(
 | 
	
		
			
				|  |  |              new RateLimitAssignment(List.of())
 | 
	
		
			
				|  |  |          );
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -120,10 +120,10 @@ public class TransportInferenceActionTests extends BaseTransportInferenceActionT
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          // The local node is different to the "other-node" responsible for serviceId
 | 
	
		
			
				|  |  |          when(nodeClient.getLocalNodeId()).thenReturn("local-node");
 | 
	
		
			
				|  |  | -        when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
 | 
	
		
			
				|  |  | +        when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
 | 
	
		
			
				|  |  |          // Requests for serviceId are always routed to "other-node"
 | 
	
		
			
				|  |  |          var assignment = new RateLimitAssignment(List.of(otherNode));
 | 
	
		
			
				|  |  | -        when(inferenceServiceNodeLocalRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment);
 | 
	
		
			
				|  |  | +        when(inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          mockService(listener -> listener.onResponse(mock()));
 | 
	
		
			
				|  |  |          var listener = doExecute(taskType);
 | 
	
	
		
			
				|  | @@ -141,9 +141,9 @@ public class TransportInferenceActionTests extends BaseTransportInferenceActionT
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          // The local node is the only one responsible for serviceId
 | 
	
		
			
				|  |  |          when(nodeClient.getLocalNodeId()).thenReturn(localNodeId);
 | 
	
		
			
				|  |  | -        when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
 | 
	
		
			
				|  |  | +        when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
 | 
	
		
			
				|  |  |          var assignment = new RateLimitAssignment(List.of(localNode));
 | 
	
		
			
				|  |  | -        when(inferenceServiceNodeLocalRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment);
 | 
	
		
			
				|  |  | +        when(inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          mockService(listener -> listener.onResponse(mock()));
 | 
	
		
			
				|  |  |          var listener = doExecute(taskType);
 | 
	
	
		
			
				|  | @@ -158,9 +158,9 @@ public class TransportInferenceActionTests extends BaseTransportInferenceActionT
 | 
	
		
			
				|  |  |          when(otherNode.getId()).thenReturn("other-node");
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          when(nodeClient.getLocalNodeId()).thenReturn("local-node");
 | 
	
		
			
				|  |  | -        when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
 | 
	
		
			
				|  |  | +        when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
 | 
	
		
			
				|  |  |          var assignment = new RateLimitAssignment(List.of(otherNode));
 | 
	
		
			
				|  |  | -        when(inferenceServiceNodeLocalRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment);
 | 
	
		
			
				|  |  | +        when(inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          mockService(listener -> listener.onResponse(mock()));
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -173,6 +173,10 @@ public class TransportInferenceActionTests extends BaseTransportInferenceActionT
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          var listener = doExecute(taskType);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +        // Verify request was rerouted
 | 
	
		
			
				|  |  | +        verify(transportService).sendRequest(same(otherNode), eq(InferenceAction.NAME), any(), any());
 | 
	
		
			
				|  |  | +        // Verify local execution didn't happen
 | 
	
		
			
				|  |  | +        verify(listener, never()).onResponse(any());
 | 
	
		
			
				|  |  |          // Verify exception was propagated from "other-node" to "local-node"
 | 
	
		
			
				|  |  |          verify(listener).onFailure(same(expectedException));
 | 
	
		
			
				|  |  |      }
 |