|
|
@@ -19,7 +19,7 @@ import org.elasticsearch.transport.TransportResponseHandler;
|
|
|
import org.elasticsearch.transport.TransportService;
|
|
|
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
|
|
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
|
|
|
-import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator;
|
|
|
+import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
|
|
|
import org.elasticsearch.xpack.inference.common.RateLimitAssignment;
|
|
|
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
|
|
|
import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
|
|
|
@@ -50,7 +50,7 @@ public class TransportInferenceActionTests extends BaseTransportInferenceActionT
|
|
|
InferenceServiceRegistry serviceRegistry,
|
|
|
InferenceStats inferenceStats,
|
|
|
StreamingTaskManager streamingTaskManager,
|
|
|
- InferenceServiceNodeLocalRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator,
|
|
|
+ InferenceServiceRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator,
|
|
|
NodeClient nodeClient,
|
|
|
ThreadPool threadPool
|
|
|
) {
|
|
|
@@ -77,7 +77,7 @@ public class TransportInferenceActionTests extends BaseTransportInferenceActionT
|
|
|
TaskType unsupportedTaskType = TaskType.COMPLETION;
|
|
|
mockService(listener -> listener.onResponse(mock()));
|
|
|
|
|
|
- when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, unsupportedTaskType)).thenReturn(false);
|
|
|
+ when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, unsupportedTaskType)).thenReturn(false);
|
|
|
|
|
|
var listener = doExecute(unsupportedTaskType);
|
|
|
|
|
|
@@ -89,8 +89,8 @@ public class TransportInferenceActionTests extends BaseTransportInferenceActionT
|
|
|
public void testNoRerouting_WhenNoGroupingCalculatedYet() {
|
|
|
mockService(listener -> listener.onResponse(mock()));
|
|
|
|
|
|
- when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
|
|
|
- when(inferenceServiceNodeLocalRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(null);
|
|
|
+ when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
|
|
|
+ when(inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(null);
|
|
|
|
|
|
var listener = doExecute(taskType);
|
|
|
|
|
|
@@ -102,8 +102,8 @@ public class TransportInferenceActionTests extends BaseTransportInferenceActionT
|
|
|
public void testNoRerouting_WhenEmptyNodeList() {
|
|
|
mockService(listener -> listener.onResponse(mock()));
|
|
|
|
|
|
- when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
|
|
|
- when(inferenceServiceNodeLocalRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(
|
|
|
+ when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
|
|
|
+ when(inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(
|
|
|
new RateLimitAssignment(List.of())
|
|
|
);
|
|
|
|
|
|
@@ -120,10 +120,10 @@ public class TransportInferenceActionTests extends BaseTransportInferenceActionT
|
|
|
|
|
|
// The local node is different to the "other-node" responsible for serviceId
|
|
|
when(nodeClient.getLocalNodeId()).thenReturn("local-node");
|
|
|
- when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
|
|
|
+ when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
|
|
|
// Requests for serviceId are always routed to "other-node"
|
|
|
var assignment = new RateLimitAssignment(List.of(otherNode));
|
|
|
- when(inferenceServiceNodeLocalRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment);
|
|
|
+ when(inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment);
|
|
|
|
|
|
mockService(listener -> listener.onResponse(mock()));
|
|
|
var listener = doExecute(taskType);
|
|
|
@@ -141,9 +141,9 @@ public class TransportInferenceActionTests extends BaseTransportInferenceActionT
|
|
|
|
|
|
// The local node is the only one responsible for serviceId
|
|
|
when(nodeClient.getLocalNodeId()).thenReturn(localNodeId);
|
|
|
- when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
|
|
|
+ when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
|
|
|
var assignment = new RateLimitAssignment(List.of(localNode));
|
|
|
- when(inferenceServiceNodeLocalRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment);
|
|
|
+ when(inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment);
|
|
|
|
|
|
mockService(listener -> listener.onResponse(mock()));
|
|
|
var listener = doExecute(taskType);
|
|
|
@@ -158,9 +158,9 @@ public class TransportInferenceActionTests extends BaseTransportInferenceActionT
|
|
|
when(otherNode.getId()).thenReturn("other-node");
|
|
|
|
|
|
when(nodeClient.getLocalNodeId()).thenReturn("local-node");
|
|
|
- when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
|
|
|
+ when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
|
|
|
var assignment = new RateLimitAssignment(List.of(otherNode));
|
|
|
- when(inferenceServiceNodeLocalRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment);
|
|
|
+ when(inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment);
|
|
|
|
|
|
mockService(listener -> listener.onResponse(mock()));
|
|
|
|
|
|
@@ -173,6 +173,10 @@ public class TransportInferenceActionTests extends BaseTransportInferenceActionT
|
|
|
|
|
|
var listener = doExecute(taskType);
|
|
|
|
|
|
+ // Verify request was rerouted
|
|
|
+ verify(transportService).sendRequest(same(otherNode), eq(InferenceAction.NAME), any(), any());
|
|
|
+ // Verify local execution didn't happen
|
|
|
+ verify(listener, never()).onResponse(any());
|
|
|
// Verify exception was propagated from "other-node" to "local-node"
|
|
|
verify(listener).onFailure(same(expectedException));
|
|
|
}
|