Browse Source

[ML] Adding missing onFailure call for Inference API start model request (#126930) (#126943)

* Adding missing onFailure call

* Update docs/changelog/126930.yaml
Jonathan Buttner 6 months ago
parent
commit
1bcb2e2b11

+ 5 - 0
docs/changelog/126930.yaml

@@ -0,0 +1,5 @@
+pr: 126930
+summary: Adding missing `onFailure` call for Inference API start model request
+area: Machine Learning
+type: bug
+issues: []

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java

@@ -106,7 +106,7 @@ public abstract class BaseElasticsearchInternalService implements InferenceServi
                 })
                 .<Boolean>andThen((l2, modelDidPut) -> {
                     var startRequest = esModel.getStartTrainedModelDeploymentActionRequest(timeout);
-                    var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, finalListener);
+                    var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, l2);
                     client.execute(StartTrainedModelDeploymentAction.INSTANCE, startRequest, responseListener);
                 })
                 .addListener(finalListener);

+ 2 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalModel.java

@@ -105,6 +105,8 @@ public abstract class ElasticsearchInternalModel extends Model {
                         && statusException.getRootCause() instanceof ResourceAlreadyExistsException) {
                         // Deployment is already started
                         listener.onResponse(Boolean.TRUE);
+                    } else {
+                        listener.onFailure(e);
                     }
                     return;
                 }

+ 47 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java

@@ -37,6 +37,7 @@ import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.ModelConfigurations;
 import org.elasticsearch.inference.SimilarityMeasure;
 import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xcontent.ParseField;
@@ -49,13 +50,16 @@ import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
 import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
 import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
 import org.elasticsearch.xpack.core.ml.MachineLearningField;
+import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
 import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
 import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
 import org.elasticsearch.xpack.core.ml.action.InferModelAction;
 import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
+import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
+import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
 import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
 import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
@@ -1832,6 +1836,49 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
         }
     }
 
+    public void testStart_OnFailure_WhenTimeoutOccurs() throws IOException {
+        var model = new ElserInternalModel(
+            "inference_id",
+            TaskType.SPARSE_EMBEDDING,
+            "elasticsearch",
+            new ElserInternalServiceSettings(
+                new ElasticsearchInternalServiceSettings(1, 1, "id", new AdaptiveAllocationsSettings(false, 0, 0), null)
+            ),
+            new ElserMlNodeTaskSettings(),
+            null
+        );
+
+        var client = mock(Client.class);
+        when(client.threadPool()).thenReturn(threadPool);
+
+        doAnswer(invocationOnMock -> {
+            ActionListener<GetTrainedModelsAction.Response> listener = invocationOnMock.getArgument(2);
+            var builder = GetTrainedModelsAction.Response.builder();
+            builder.setModels(List.of(mock(TrainedModelConfig.class)));
+            builder.setTotalCount(1);
+
+            listener.onResponse(builder.build());
+            return Void.TYPE;
+        }).when(client).execute(eq(GetTrainedModelsAction.INSTANCE), any(), any());
+
+        doAnswer(invocationOnMock -> {
+            ActionListener<CreateTrainedModelAssignmentAction.Response> listener = invocationOnMock.getArgument(2);
+            listener.onFailure(new ElasticsearchStatusException("failed", RestStatus.GATEWAY_TIMEOUT));
+            return Void.TYPE;
+        }).when(client).execute(eq(StartTrainedModelDeploymentAction.INSTANCE), any(), any());
+
+        try (var service = createService(client)) {
+            var actionListener = new PlainActionFuture<Boolean>();
+            service.start(model, TimeValue.timeValueSeconds(30), actionListener);
+            var exception = expectThrows(
+                ElasticsearchStatusException.class,
+                () -> actionListener.actionGet(TimeValue.timeValueSeconds(30))
+            );
+
+            assertThat(exception.getMessage(), is("failed"));
+        }
+    }
+
     private ElasticsearchInternalService createService(Client client) {
         var cs = mock(ClusterService.class);
         var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES));