Browse Source

[Inference API] Fix tests in TransportInferenceActionTests (#121302)

Tim Grein 8 months ago
parent
commit
eee6973389

+ 0 - 6
muted-tests.yml

@@ -359,12 +359,6 @@ tests:
 - class: org.elasticsearch.xpack.security.CoreWithSecurityClientYamlTestSuiteIT
   method: test {yaml=indices.get_alias/10_basic/Get aliases via /*/_alias/}
   issue: https://github.com/elastic/elasticsearch/issues/121290
-- class: org.elasticsearch.xpack.inference.action.TransportInferenceActionTests
-  method: testRerouting_HandlesTransportException_FromOtherNode
-  issue: https://github.com/elastic/elasticsearch/issues/121292
-- class: org.elasticsearch.xpack.inference.action.TransportInferenceActionTests
-  method: testRerouting_ToOtherNode
-  issue: https://github.com/elastic/elasticsearch/issues/121293
 - class: org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculatorTests
   issue: https://github.com/elastic/elasticsearch/issues/121294
 - class: org.elasticsearch.env.NodeEnvironmentTests

+ 5 - 5
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java

@@ -28,7 +28,7 @@ import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest;
 import org.elasticsearch.xpack.core.inference.action.InferenceAction;
 import org.elasticsearch.xpack.inference.InferencePlugin;
 import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
-import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator;
+import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
 import org.elasticsearch.xpack.inference.registry.ModelRegistry;
 import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
 import org.junit.Before;
@@ -64,7 +64,7 @@ public abstract class BaseTransportInferenceActionTestCase<Request extends BaseI
     protected static final String inferenceId = "inferenceEntityId";
     protected InferenceServiceRegistry serviceRegistry;
     protected InferenceStats inferenceStats;
-    protected InferenceServiceNodeLocalRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator;
+    protected InferenceServiceRateLimitCalculator inferenceServiceRateLimitCalculator;
     protected TransportService transportService;
     protected NodeClient nodeClient;
 
@@ -79,7 +79,7 @@ public abstract class BaseTransportInferenceActionTestCase<Request extends BaseI
         ThreadPool threadPool = mock();
         nodeClient = mock();
         transportService = mock();
-        inferenceServiceNodeLocalRateLimitCalculator = mock();
+        inferenceServiceRateLimitCalculator = mock();
         licenseState = mock();
         modelRegistry = mock();
         serviceRegistry = mock();
@@ -94,7 +94,7 @@ public abstract class BaseTransportInferenceActionTestCase<Request extends BaseI
             serviceRegistry,
             inferenceStats,
             streamingTaskManager,
-            inferenceServiceNodeLocalRateLimitCalculator,
+            inferenceServiceRateLimitCalculator,
             nodeClient,
             threadPool
         );
@@ -110,7 +110,7 @@ public abstract class BaseTransportInferenceActionTestCase<Request extends BaseI
         InferenceServiceRegistry serviceRegistry,
         InferenceStats inferenceStats,
         StreamingTaskManager streamingTaskManager,
-        InferenceServiceNodeLocalRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator,
+        InferenceServiceRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator,
         NodeClient nodeClient,
         ThreadPool threadPool
     );

+ 17 - 13
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java

@@ -19,7 +19,7 @@ import org.elasticsearch.transport.TransportResponseHandler;
 import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.xpack.core.inference.action.InferenceAction;
 import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
-import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator;
+import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
 import org.elasticsearch.xpack.inference.common.RateLimitAssignment;
 import org.elasticsearch.xpack.inference.registry.ModelRegistry;
 import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
@@ -50,7 +50,7 @@ public class TransportInferenceActionTests extends BaseTransportInferenceActionT
         InferenceServiceRegistry serviceRegistry,
         InferenceStats inferenceStats,
         StreamingTaskManager streamingTaskManager,
-        InferenceServiceNodeLocalRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator,
+        InferenceServiceRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator,
         NodeClient nodeClient,
         ThreadPool threadPool
     ) {
@@ -77,7 +77,7 @@ public class TransportInferenceActionTests extends BaseTransportInferenceActionT
         TaskType unsupportedTaskType = TaskType.COMPLETION;
         mockService(listener -> listener.onResponse(mock()));
 
-        when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, unsupportedTaskType)).thenReturn(false);
+        when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, unsupportedTaskType)).thenReturn(false);
 
         var listener = doExecute(unsupportedTaskType);
 
@@ -89,8 +89,8 @@ public class TransportInferenceActionTests extends BaseTransportInferenceActionT
     public void testNoRerouting_WhenNoGroupingCalculatedYet() {
         mockService(listener -> listener.onResponse(mock()));
 
-        when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
-        when(inferenceServiceNodeLocalRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(null);
+        when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
+        when(inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(null);
 
         var listener = doExecute(taskType);
 
@@ -102,8 +102,8 @@ public class TransportInferenceActionTests extends BaseTransportInferenceActionT
     public void testNoRerouting_WhenEmptyNodeList() {
         mockService(listener -> listener.onResponse(mock()));
 
-        when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
-        when(inferenceServiceNodeLocalRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(
+        when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
+        when(inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(
             new RateLimitAssignment(List.of())
         );
 
@@ -120,10 +120,10 @@ public class TransportInferenceActionTests extends BaseTransportInferenceActionT
 
         // The local node is different to the "other-node" responsible for serviceId
         when(nodeClient.getLocalNodeId()).thenReturn("local-node");
-        when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
+        when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
         // Requests for serviceId are always routed to "other-node"
         var assignment = new RateLimitAssignment(List.of(otherNode));
-        when(inferenceServiceNodeLocalRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment);
+        when(inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment);
 
         mockService(listener -> listener.onResponse(mock()));
         var listener = doExecute(taskType);
@@ -141,9 +141,9 @@ public class TransportInferenceActionTests extends BaseTransportInferenceActionT
 
         // The local node is the only one responsible for serviceId
         when(nodeClient.getLocalNodeId()).thenReturn(localNodeId);
-        when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
+        when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
         var assignment = new RateLimitAssignment(List.of(localNode));
-        when(inferenceServiceNodeLocalRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment);
+        when(inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment);
 
         mockService(listener -> listener.onResponse(mock()));
         var listener = doExecute(taskType);
@@ -158,9 +158,9 @@ public class TransportInferenceActionTests extends BaseTransportInferenceActionT
         when(otherNode.getId()).thenReturn("other-node");
 
         when(nodeClient.getLocalNodeId()).thenReturn("local-node");
-        when(inferenceServiceNodeLocalRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
+        when(inferenceServiceRateLimitCalculator.isTaskTypeReroutingSupported(serviceId, taskType)).thenReturn(true);
         var assignment = new RateLimitAssignment(List.of(otherNode));
-        when(inferenceServiceNodeLocalRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment);
+        when(inferenceServiceRateLimitCalculator.getRateLimitAssignment(serviceId, taskType)).thenReturn(assignment);
 
         mockService(listener -> listener.onResponse(mock()));
 
@@ -173,6 +173,10 @@ public class TransportInferenceActionTests extends BaseTransportInferenceActionT
 
         var listener = doExecute(taskType);
 
+        // Verify request was rerouted
+        verify(transportService).sendRequest(same(otherNode), eq(InferenceAction.NAME), any(), any());
+        // Verify local execution didn't happen
+        verify(listener, never()).onResponse(any());
         // Verify exception was propagated from "other-node" to "local-node"
         verify(listener).onFailure(same(expectedException));
     }

+ 3 - 3
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java

@@ -18,7 +18,7 @@ import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction;
 import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
-import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator;
+import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
 import org.elasticsearch.xpack.inference.registry.ModelRegistry;
 import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
 
@@ -49,7 +49,7 @@ public class TransportUnifiedCompletionActionTests extends BaseTransportInferenc
         InferenceServiceRegistry serviceRegistry,
         InferenceStats inferenceStats,
         StreamingTaskManager streamingTaskManager,
-        InferenceServiceNodeLocalRateLimitCalculator inferenceServiceNodeLocalRateLimitCalculator,
+        InferenceServiceRateLimitCalculator inferenceServiceRateLimitCalculator,
         NodeClient nodeClient,
         ThreadPool threadPool
     ) {
@@ -61,7 +61,7 @@ public class TransportUnifiedCompletionActionTests extends BaseTransportInferenc
             serviceRegistry,
             inferenceStats,
             streamingTaskManager,
-            inferenceServiceNodeLocalRateLimitCalculator,
+            inferenceServiceRateLimitCalculator,
             nodeClient,
             threadPool
         );