Browse Source

Enable force inference endpoint deleting for invalid models and after stopping model deployment fails (#129090)

* Enable force inference endpoint deleting for invalid models and after stopping model deployment fails

* Update docs/changelog/129090.yaml

---------

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
Dan Rubinstein 2 months ago
parent
commit
9c6cf90456

+ 6 - 0
docs/changelog/129090.yaml

@@ -0,0 +1,6 @@
+pr: 129090
+summary: Enable force inference endpoint deleting for invalid models and after stopping
+  model deployment fails
+area: Machine Learning
+type: enhancement
+issues: []

+ 32 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java

@@ -23,6 +23,7 @@ import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
 import org.elasticsearch.inference.InferenceServiceRegistry;
+import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.UnparsedModel;
 import org.elasticsearch.injection.guice.Inject;
 import org.elasticsearch.rest.RestStatus;
@@ -128,10 +129,38 @@ public class TransportDeleteInferenceEndpointAction extends TransportMasterNodeA
             }
 
             var service = serviceRegistry.getService(unparsedModel.service());
+            Model model;
             if (service.isPresent()) {
-                var model = service.get()
-                    .parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
-                service.get().stop(model, listener);
+                try {
+                    model = service.get()
+                        .parsePersistedConfig(unparsedModel.inferenceEntityId(), unparsedModel.taskType(), unparsedModel.settings());
+                } catch (Exception e) {
+                    if (request.isForceDelete()) {
+                        listener.onResponse(true);
+                        return;
+                    } else {
+                        listener.onFailure(
+                            new ElasticsearchStatusException(
+                                Strings.format(
+                                    "Failed to parse model configuration for inference endpoint [%s]",
+                                    request.getInferenceEndpointId()
+                                ),
+                                RestStatus.INTERNAL_SERVER_ERROR,
+                                e
+                            )
+                        );
+                        return;
+                    }
+                }
+                service.get().stop(model, listener.delegateResponse((l, e) -> {
+                    if (request.isForceDelete()) {
+                        l.onResponse(true);
+                    } else {
+                        l.onFailure(e);
+                    }
+                }));
+            } else if (request.isForceDelete()) {
+                listener.onResponse(true);
             } else {
                 listener.onFailure(
                     new ElasticsearchStatusException(

+ 217 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointActionTests.java

@@ -17,8 +17,10 @@ import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.inference.InferenceService;
 import org.elasticsearch.inference.InferenceServiceRegistry;
+import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.inference.UnparsedModel;
+import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.threadpool.ThreadPool;
@@ -32,11 +34,17 @@ import java.util.Map;
 import java.util.Optional;
 
 import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
+import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.is;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
 import static org.mockito.Mockito.when;
 
 public class TransportDeleteInferenceEndpointActionTests extends ESTestCase {
@@ -130,4 +138,213 @@ public class TransportDeleteInferenceEndpointActionTests extends ESTestCase {
 
         assertTrue(response.isAcknowledged());
     }
+
+    public void testFailsToDeleteUnparsableEndpoint_WhenForceIsFalse() {
+        var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
+        var serviceName = randomAlphanumericOfLength(10);
+        var taskType = randomFrom(TaskType.values());
+        var mockService = mock(InferenceService.class);
+        mockUnparsableModel(inferenceEndpointId, serviceName, taskType, mockService);
+        when(mockModelRegistry.containsDefaultConfigId(inferenceEndpointId)).thenReturn(false);
+
+        var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
+        action.masterOperation(
+            mock(Task.class),
+            new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, false, false),
+            ClusterState.EMPTY_STATE,
+            listener
+        );
+
+        var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
+        assertThat(exception.getMessage(), containsString("Failed to parse model configuration for inference endpoint"));
+
+        verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
+        verify(mockInferenceServiceRegistry).getService(eq(serviceName));
+        verify(mockModelRegistry).containsDefaultConfigId(eq(inferenceEndpointId));
+        verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
+        verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService);
+    }
+
+    public void testDeletesUnparsableEndpoint_WhenForceIsTrue() {
+        var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
+        var serviceName = randomAlphanumericOfLength(10);
+        var taskType = randomFrom(TaskType.values());
+        var mockService = mock(InferenceService.class);
+        mockUnparsableModel(inferenceEndpointId, serviceName, taskType, mockService);
+        doAnswer(invocationOnMock -> {
+            ActionListener<Boolean> listener = invocationOnMock.getArgument(1);
+            listener.onResponse(true);
+            return Void.TYPE;
+        }).when(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());
+
+        var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
+
+        action.masterOperation(
+            mock(Task.class),
+            new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, true, false),
+            ClusterState.EMPTY_STATE,
+            listener
+        );
+
+        var response = listener.actionGet(TIMEOUT);
+        assertTrue(response.isAcknowledged());
+
+        verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
+        verify(mockInferenceServiceRegistry).getService(eq(serviceName));
+        verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
+        verify(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());
+        verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService);
+    }
+
+    private void mockUnparsableModel(String inferenceEndpointId, String serviceName, TaskType taskType, InferenceService mockService) {
+        doAnswer(invocationOnMock -> {
+            ActionListener<UnparsedModel> listener = invocationOnMock.getArgument(1);
+            listener.onResponse(new UnparsedModel(inferenceEndpointId, taskType, serviceName, Map.of(), Map.of()));
+            return Void.TYPE;
+        }).when(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
+        doThrow(new ElasticsearchStatusException(randomAlphanumericOfLength(10), RestStatus.INTERNAL_SERVER_ERROR)).when(mockService)
+            .parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
+        when(mockInferenceServiceRegistry.getService(serviceName)).thenReturn(Optional.of(mockService));
+    }
+
+    public void testDeletesEndpointWithNoService_WhenForceIsTrue() {
+        var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
+        var serviceName = randomAlphanumericOfLength(10);
+        var taskType = randomFrom(TaskType.values());
+        mockNoService(inferenceEndpointId, serviceName, taskType);
+        doAnswer(invocationOnMock -> {
+            ActionListener<Boolean> listener = invocationOnMock.getArgument(1);
+            listener.onResponse(true);
+            return Void.TYPE;
+        }).when(mockModelRegistry).deleteModel(anyString(), any());
+
+        var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
+
+        action.masterOperation(
+            mock(Task.class),
+            new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, true, false),
+            ClusterState.EMPTY_STATE,
+            listener
+        );
+
+        var response = listener.actionGet(TIMEOUT);
+        assertTrue(response.isAcknowledged());
+        verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
+        verify(mockInferenceServiceRegistry).getService(eq(serviceName));
+        verify(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());
+        verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry);
+    }
+
+    public void testFailsToDeleteEndpointWithNoService_WhenForceIsFalse() {
+        var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
+        var serviceName = randomAlphanumericOfLength(10);
+        var taskType = randomFrom(TaskType.values());
+        mockNoService(inferenceEndpointId, serviceName, taskType);
+        when(mockModelRegistry.containsDefaultConfigId(inferenceEndpointId)).thenReturn(false);
+
+        var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
+
+        action.masterOperation(
+            mock(Task.class),
+            new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, false, false),
+            ClusterState.EMPTY_STATE,
+            listener
+        );
+
+        var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
+        assertThat(exception.getMessage(), containsString("No service found for this inference endpoint"));
+        verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
+        verify(mockInferenceServiceRegistry).getService(eq(serviceName));
+        verify(mockModelRegistry).containsDefaultConfigId(eq(inferenceEndpointId));
+        verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry);
+    }
+
+    private void mockNoService(String inferenceEndpointId, String serviceName, TaskType taskType) {
+        doAnswer(invocationOnMock -> {
+            ActionListener<UnparsedModel> listener = invocationOnMock.getArgument(1);
+            listener.onResponse(new UnparsedModel(inferenceEndpointId, taskType, serviceName, Map.of(), Map.of()));
+            return Void.TYPE;
+        }).when(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
+        when(mockInferenceServiceRegistry.getService(serviceName)).thenReturn(Optional.empty());
+    }
+
+    public void testFailsToDeleteEndpointIfModelDeploymentStopFails_WhenForceIsFalse() {
+        var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
+        var serviceName = randomAlphanumericOfLength(10);
+        var taskType = randomFrom(TaskType.values());
+        var mockService = mock(InferenceService.class);
+        var mockModel = mock(Model.class);
+        mockStopDeploymentFails(inferenceEndpointId, serviceName, taskType, mockService, mockModel);
+        when(mockModelRegistry.containsDefaultConfigId(inferenceEndpointId)).thenReturn(false);
+
+        var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
+        action.masterOperation(
+            mock(Task.class),
+            new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, false, false),
+            ClusterState.EMPTY_STATE,
+            listener
+        );
+
+        var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
+        assertThat(exception.getMessage(), containsString("Failed to stop model deployment"));
+        verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
+        verify(mockInferenceServiceRegistry).getService(eq(serviceName));
+        verify(mockModelRegistry).containsDefaultConfigId(eq(inferenceEndpointId));
+        verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
+        verify(mockService).stop(eq(mockModel), any());
+        verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService, mockModel);
+    }
+
+    public void testDeletesEndpointIfModelDeploymentStopFails_WhenForceIsTrue() {
+        var inferenceEndpointId = randomAlphaOfLengthBetween(5, 10);
+        var serviceName = randomAlphanumericOfLength(10);
+        var taskType = randomFrom(TaskType.values());
+        var mockService = mock(InferenceService.class);
+        var mockModel = mock(Model.class);
+        mockStopDeploymentFails(inferenceEndpointId, serviceName, taskType, mockService, mockModel);
+        doAnswer(invocationOnMock -> {
+            ActionListener<Boolean> listener = invocationOnMock.getArgument(1);
+            listener.onResponse(true);
+            return Void.TYPE;
+        }).when(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());
+
+        var listener = new PlainActionFuture<DeleteInferenceEndpointAction.Response>();
+        action.masterOperation(
+            mock(Task.class),
+            new DeleteInferenceEndpointAction.Request(inferenceEndpointId, taskType, true, false),
+            ClusterState.EMPTY_STATE,
+            listener
+        );
+
+        var response = listener.actionGet(TIMEOUT);
+        assertTrue(response.isAcknowledged());
+        verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
+        verify(mockInferenceServiceRegistry).getService(eq(serviceName));
+        verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
+        verify(mockService).stop(eq(mockModel), any());
+        verify(mockModelRegistry).deleteModel(eq(inferenceEndpointId), any());
+        verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService, mockModel);
+    }
+
+    private void mockStopDeploymentFails(
+        String inferenceEndpointId,
+        String serviceName,
+        TaskType taskType,
+        InferenceService mockService,
+        Model mockModel
+    ) {
+        doAnswer(invocationOnMock -> {
+            ActionListener<UnparsedModel> listener = invocationOnMock.getArgument(1);
+            listener.onResponse(new UnparsedModel(inferenceEndpointId, taskType, serviceName, Map.of(), Map.of()));
+            return Void.TYPE;
+        }).when(mockModelRegistry).getModel(eq(inferenceEndpointId), any());
+        when(mockInferenceServiceRegistry.getService(serviceName)).thenReturn(Optional.of(mockService));
+        doReturn(mockModel).when(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any());
+        doAnswer(invocationOnMock -> {
+            ActionListener<Boolean> listener = invocationOnMock.getArgument(1);
+            listener.onFailure(new ElasticsearchStatusException("Failed to stop model deployment", RestStatus.INTERNAL_SERVER_ERROR));
+            return Void.TYPE;
+        }).when(mockService).stop(eq(mockModel), any());
+    }
+
 }