Browse Source

Revert endpoint creation validation for ELSER and E5 (#126792) (#126801)

* Revert endpoint creation validation for ELSER and E5

* Update docs/changelog/126792.yaml

* Revert start model deployment being in TransportPutInferenceModelAction

---------

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
Dan Rubinstein 6 months ago
parent
commit
156d5da28e

+ 5 - 0
docs/changelog/126792.yaml

@@ -0,0 +1,5 @@
+pr: 126792
+summary: Revert endpoint creation validation for ELSER and E5
+area: Machine Learning
+type: bug
+issues: []

+ 25 - 16
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java

@@ -195,23 +195,19 @@ public class TransportPutInferenceModelAction extends TransportMasterNodeAction<
         ActionListener<Model> storeModelListener = listener.delegateFailureAndWrap(
             (delegate, verifiedModel) -> modelRegistry.storeModel(
                 verifiedModel,
-                ActionListener.wrap(
-                    r -> listener.onResponse(new PutInferenceModelAction.Response(verifiedModel.getConfigurations())),
-                    e -> {
-                        if (e.getCause() instanceof StrictDynamicMappingException
-                            && e.getCause().getMessage().contains("chunking_settings")) {
-                            delegate.onFailure(
-                                new ElasticsearchStatusException(
-                                    "One or more nodes in your cluster does not support chunking_settings. "
-                                        + "Please update all nodes in your cluster to the latest version to use chunking_settings.",
-                                    RestStatus.BAD_REQUEST
-                                )
-                            );
-                        } else {
-                            delegate.onFailure(e);
-                        }
+                ActionListener.wrap(r -> startInferenceEndpoint(service, timeout, verifiedModel, delegate), e -> {
+                    if (e.getCause() instanceof StrictDynamicMappingException && e.getCause().getMessage().contains("chunking_settings")) {
+                        delegate.onFailure(
+                            new ElasticsearchStatusException(
+                                "One or more nodes in your cluster does not support chunking_settings. "
+                                    + "Please update all nodes in your cluster to the latest version to use chunking_settings.",
+                                RestStatus.BAD_REQUEST
+                            )
+                        );
+                    } else {
+                        delegate.onFailure(e);
                     }
-                ),
+                }),
                 timeout
             )
         );
@@ -228,6 +224,19 @@ public class TransportPutInferenceModelAction extends TransportMasterNodeAction<
         service.parseRequestConfig(inferenceEntityId, taskType, config, parsedModelListener);
     }
 
+    private void startInferenceEndpoint(
+        InferenceService service,
+        TimeValue timeout,
+        Model model,
+        ActionListener<PutInferenceModelAction.Response> listener
+    ) {
+        if (skipValidationAndStart) {
+            listener.onResponse(new PutInferenceModelAction.Response(model.getConfigurations()));
+        } else {
+            service.start(model, timeout, listener.map(started -> new PutInferenceModelAction.Response(model.getConfigurations())));
+        }
+    }
+
     private Map<String, Object> requestToMap(PutInferenceModelAction.Request request) throws IOException {
         try (
             XContentParser parser = XContentHelper.createParser(

+ 70 - 30
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java

@@ -9,51 +9,91 @@ package org.elasticsearch.xpack.inference.services.validation;
 
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.common.Strings;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.inference.InferenceService;
+import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
+import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandEmbeddingModel;
 
 public class ElasticsearchInternalServiceModelValidator implements ModelValidator {
 
-    ModelValidator modelValidator;
+    private final ServiceIntegrationValidator serviceIntegrationValidator;
 
-    public ElasticsearchInternalServiceModelValidator(ModelValidator modelValidator) {
-        this.modelValidator = modelValidator;
+    public ElasticsearchInternalServiceModelValidator(ServiceIntegrationValidator serviceIntegrationValidator) {
+        this.serviceIntegrationValidator = serviceIntegrationValidator;
     }
 
     @Override
     public void validate(InferenceService service, Model model, TimeValue timeout, ActionListener<Model> listener) {
-        service.start(model, timeout, ActionListener.wrap((modelDeploymentStarted) -> {
-            if (modelDeploymentStarted) {
-                try {
-                    modelValidator.validate(service, model, timeout, listener.delegateResponse((l, exception) -> {
-                        stopModelDeployment(service, model, l, exception);
-                    }));
-                } catch (Exception e) {
-                    stopModelDeployment(service, model, listener, e);
-                }
-            } else {
-                listener.onFailure(
-                    new ElasticsearchStatusException("Could not deploy model for inference endpoint", RestStatus.INTERNAL_SERVER_ERROR)
+        if (model instanceof CustomElandEmbeddingModel elandModel && elandModel.getTaskType() == TaskType.TEXT_EMBEDDING) {
+            var temporaryModelWithModelId = new CustomElandEmbeddingModel(
+                elandModel.getServiceSettings().modelId(),
+                elandModel.getTaskType(),
+                elandModel.getConfigurations().getService(),
+                elandModel.getServiceSettings(),
+                elandModel.getConfigurations().getChunkingSettings()
+            );
+
+            serviceIntegrationValidator.validate(
+                service,
+                temporaryModelWithModelId,
+                timeout,
+                listener.delegateFailureAndWrap((delegate, r) -> {
+                    delegate.onResponse(postValidate(service, model, r));
+                })
+            );
+        } else {
+            listener.onResponse(model);
+        }
+    }
+
+    private Model postValidate(InferenceService service, Model model, InferenceServiceResults results) {
+        if (results instanceof TextEmbeddingResults<?> embeddingResults) {
+            var serviceSettings = model.getServiceSettings();
+            var dimensions = serviceSettings.dimensions();
+            int embeddingSize = getEmbeddingSize(embeddingResults);
+
+            if (Boolean.TRUE.equals(serviceSettings.dimensionsSetByUser())
+                && dimensions != null
+                && (dimensions.equals(embeddingSize) == false)) {
+                throw new ElasticsearchStatusException(
+                    Strings.format(
+                        "The retrieved embeddings size [%s] does not match the size specified in the settings [%s]. "
+                            + "Please recreate the [%s] configuration with the correct dimensions",
+                        embeddingResults.getFirstEmbeddingSize(),
+                        serviceSettings.dimensions(),
+                        model.getInferenceEntityId()
+                    ),
+                    RestStatus.BAD_REQUEST
                 );
             }
-        }, listener::onFailure));
+
+            return service.updateModelWithEmbeddingDetails(model, embeddingSize);
+        } else {
+            throw new ElasticsearchStatusException(
+                "Validation call did not return expected results type."
+                    + "Expected a result of type ["
+                    + TextEmbeddingFloatResults.NAME
+                    + "] got ["
+                    + (results == null ? "null" : results.getWriteableName())
+                    + "]",
+                RestStatus.BAD_REQUEST
+            );
+        }
     }
 
-    private void stopModelDeployment(InferenceService service, Model model, ActionListener<Model> listener, Exception e) {
-        service.stop(
-            model,
-            ActionListener.wrap(
-                (v) -> listener.onFailure(e),
-                (ex) -> listener.onFailure(
-                    new ElasticsearchStatusException(
-                        "Model validation failed and model deployment could not be stopped",
-                        RestStatus.INTERNAL_SERVER_ERROR,
-                        ex
-                    )
-                )
-            )
-        );
+    private int getEmbeddingSize(TextEmbeddingResults<?> embeddingResults) {
+        int embeddingSize;
+        try {
+            embeddingSize = embeddingResults.getFirstEmbeddingSize();
+        } catch (Exception e) {
+            throw new ElasticsearchStatusException("Could not determine embedding size", RestStatus.BAD_REQUEST, e);
+        }
+        return embeddingSize;
     }
 }

+ 2 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java

@@ -12,11 +12,10 @@ import org.elasticsearch.inference.TaskType;
 
 public class ModelValidatorBuilder {
     public static ModelValidator buildModelValidator(TaskType taskType, boolean isElasticsearchInternalService) {
-        var modelValidator = buildModelValidatorForTaskType(taskType);
         if (isElasticsearchInternalService) {
-            return new ElasticsearchInternalServiceModelValidator(modelValidator);
+            return new ElasticsearchInternalServiceModelValidator(new SimpleServiceIntegrationValidator());
         } else {
-            return modelValidator;
+            return buildModelValidatorForTaskType(taskType);
         }
     }
 

+ 191 - 105
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidatorTests.java

@@ -11,17 +11,26 @@ import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.inference.InferenceService;
+import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
+import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
+import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandEmbeddingModel;
+import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings;
 import org.junit.Before;
-import org.mockito.ArgumentCaptor;
 import org.mockito.Mock;
 
+import java.util.List;
+
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoMoreInteractions;
 import static org.mockito.Mockito.when;
@@ -34,11 +43,11 @@ public class ElasticsearchInternalServiceModelValidatorTests extends ESTestCase
         "Model validation failed and model deployment could not be stopped";
 
     @Mock
-    private ModelValidator mockModelValidator;
+    private SimpleServiceIntegrationValidator mockServiceIntegrationValidator;
     @Mock
     private InferenceService mockInferenceService;
     @Mock
-    private Model mockModel;
+    private CustomElandEmbeddingModel mockCustomElandEmbeddingModel;
     @Mock
     private ActionListener<Model> mockActionListener;
 
@@ -48,150 +57,227 @@ public class ElasticsearchInternalServiceModelValidatorTests extends ESTestCase
     public void setup() {
         openMocks(this);
 
-        underTest = new ElasticsearchInternalServiceModelValidator(mockModelValidator);
+        underTest = new ElasticsearchInternalServiceModelValidator(mockServiceIntegrationValidator);
 
         when(mockActionListener.delegateResponse(any())).thenCallRealMethod();
+        when(mockActionListener.delegateFailureAndWrap(any())).thenCallRealMethod();
     }
 
-    public void testValidate_ModelDeploymentThrowsException() {
-        doThrow(ElasticsearchStatusException.class).when(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any());
+    public void testValidate_NonElandModelSkipsValidation() {
+        var mockModel = mock(Model.class);
+        underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener);
 
-        assertThrows(
-            ElasticsearchStatusException.class,
-            () -> { underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener); }
+        verify(mockActionListener).onResponse(mockModel);
+        verifyNoMoreInteractions(
+            mockServiceIntegrationValidator,
+            mockInferenceService,
+            mockCustomElandEmbeddingModel,
+            mockActionListener,
+            mockModel
         );
+    }
 
-        verify(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any());
-        verifyNoMoreInteractions(mockModelValidator, mockInferenceService, mockModel, mockActionListener);
+    public void testValidate_ElandModelWithNonTextEmbeddingTaskTypeSkipsValidation() {
+        when(mockCustomElandEmbeddingModel.getTaskType()).thenReturn(randomFrom(List.of(TaskType.RERANK, TaskType.SPARSE_EMBEDDING)));
+
+        underTest.validate(mockInferenceService, mockCustomElandEmbeddingModel, TIMEOUT, mockActionListener);
+
+        verify(mockActionListener).onResponse(mockCustomElandEmbeddingModel);
+        verify(mockCustomElandEmbeddingModel).getTaskType();
+        verifyNoMoreInteractions(mockServiceIntegrationValidator, mockInferenceService, mockCustomElandEmbeddingModel, mockActionListener);
     }
 
-    public void testValidate_ModelDeploymentReturnsFalse() {
-        mockModelDeployment(false);
+    public void testValidate_ElandTextEmbeddingModelValidationThrowsException() {
+        CustomElandEmbeddingModel customElandEmbeddingModel = createCustomElandEmbeddingModel(false, null);
 
-        underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener);
+        doThrow(new ElasticsearchStatusException("Model Validator Exception", RestStatus.INTERNAL_SERVER_ERROR)).when(
+            mockServiceIntegrationValidator
+        ).validate(eq(mockInferenceService), any(), eq(TIMEOUT), any());
 
-        verify(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any());
-        verify(mockActionListener).onFailure(any(ElasticsearchStatusException.class));
-        verifyNoMoreInteractions(mockModelValidator, mockInferenceService, mockModel, mockActionListener);
+        assertThrows(ElasticsearchStatusException.class, () -> {
+            underTest.validate(mockInferenceService, customElandEmbeddingModel, TIMEOUT, mockActionListener);
+        });
     }
 
-    public void testValidate_ModelValidatorThrowsExceptionAndModelDeploymentIsStopped() {
-        mockModelDeployment(true);
-        doThrow(new ElasticsearchStatusException("Model Validator Exception", RestStatus.INTERNAL_SERVER_ERROR)).when(mockModelValidator)
-            .validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any());
-        mockModelStop(true);
+    public void testValidate_ElandTextEmbeddingModelValidationFails() {
+        CustomElandEmbeddingModel customElandEmbeddingModel = createCustomElandEmbeddingModel(false, null);
 
-        underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener);
+        doAnswer(ans -> {
+            ActionListener<InferenceServiceResults> responseListener = ans.getArgument(3);
+            responseListener.onFailure(new ElasticsearchStatusException("Model validation failed", RestStatus.INTERNAL_SERVER_ERROR));
+            return null;
+        }).when(mockServiceIntegrationValidator).validate(eq(mockInferenceService), any(), eq(TIMEOUT), any());
 
-        verify(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any());
-        verify(mockInferenceService).stop(eq(mockModel), any());
-        verify(mockModelValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any());
-        verify(mockActionListener).delegateResponse(any());
-        verifyMockActionListenerAfterStopModelDeployment(true);
-        verifyNoMoreInteractions(mockModelValidator, mockInferenceService, mockModel, mockActionListener);
+        underTest.validate(mockInferenceService, customElandEmbeddingModel, TIMEOUT, mockActionListener);
+
+        verify(mockServiceIntegrationValidator).validate(eq(mockInferenceService), any(), eq(TIMEOUT), any());
+        verify(mockActionListener).delegateFailureAndWrap(any());
+        verify(mockActionListener).onFailure(any(ElasticsearchStatusException.class));
+        verifyNoMoreInteractions(mockServiceIntegrationValidator, mockInferenceService, mockCustomElandEmbeddingModel, mockActionListener);
     }
 
-    public void testValidate_ModelValidatorThrowsExceptionAndModelDeploymentIsNotStopped() {
-        mockModelDeployment(true);
-        doThrow(new ElasticsearchStatusException("Model Validator Exception", RestStatus.INTERNAL_SERVER_ERROR)).when(mockModelValidator)
-            .validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any());
-        mockModelStop(false);
+    public void testValidate_ElandTextEmbeddingModelValidationSucceedsAndDimensionsSetByUserValid() {
+        var dimensions = randomIntBetween(1, 10);
+        var mockInferenceServiceResults = mock(TextEmbeddingResults.class);
+        var mockUpdatedModel = mock(CustomElandEmbeddingModel.class);
+        when(mockInferenceServiceResults.getFirstEmbeddingSize()).thenReturn(dimensions);
+        CustomElandEmbeddingModel customElandEmbeddingModel = createCustomElandEmbeddingModel(true, dimensions);
 
-        underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener);
+        doAnswer(ans -> {
+            ActionListener<InferenceServiceResults> responseListener = ans.getArgument(3);
+            responseListener.onResponse(mockInferenceServiceResults);
+            return null;
+        }).when(mockServiceIntegrationValidator).validate(eq(mockInferenceService), any(), eq(TIMEOUT), any());
+        when(mockInferenceService.updateModelWithEmbeddingDetails(eq(customElandEmbeddingModel), eq(dimensions))).thenReturn(
+            mockUpdatedModel
+        );
 
-        verify(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any());
-        verify(mockInferenceService).stop(eq(mockModel), any());
-        verify(mockModelValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any());
-        verify(mockActionListener).delegateResponse(any());
-        verifyMockActionListenerAfterStopModelDeployment(false);
-        verifyNoMoreInteractions(mockModelValidator, mockInferenceService, mockModel, mockActionListener);
+        underTest.validate(mockInferenceService, customElandEmbeddingModel, TIMEOUT, mockActionListener);
+
+        verify(mockServiceIntegrationValidator).validate(eq(mockInferenceService), any(), eq(TIMEOUT), any());
+        verify(mockInferenceService).updateModelWithEmbeddingDetails(eq(customElandEmbeddingModel), eq(dimensions));
+        verify(mockActionListener).delegateFailureAndWrap(any());
+        verify(mockActionListener).onResponse(mockUpdatedModel);
+        verify(mockInferenceServiceResults).getFirstEmbeddingSize();
+        verifyNoMoreInteractions(
+            mockServiceIntegrationValidator,
+            mockInferenceService,
+            mockCustomElandEmbeddingModel,
+            mockActionListener,
+            mockUpdatedModel,
+            mockInferenceServiceResults
+        );
     }
 
-    public void testValidate_ModelValidationFailsAndModelDeploymentIsStopped() {
-        mockModelDeployment(true);
+    public void testValidate_ElandTextEmbeddingModelValidationSucceedsAndDimensionsSetByUserInvalid() {
+        var dimensions = randomIntBetween(1, 10);
+        var mockInferenceServiceResults = mock(TextEmbeddingResults.class);
+        when(mockInferenceServiceResults.getFirstEmbeddingSize()).thenReturn(
+            randomValueOtherThan(dimensions, () -> randomIntBetween(1, 10))
+        );
+        CustomElandEmbeddingModel customElandEmbeddingModel = createCustomElandEmbeddingModel(true, dimensions);
+
         doAnswer(ans -> {
-            ActionListener<Model> responseListener = ans.getArgument(3);
-            responseListener.onFailure(new ElasticsearchStatusException("Model validation failed", RestStatus.INTERNAL_SERVER_ERROR));
+            ActionListener<InferenceServiceResults> responseListener = ans.getArgument(3);
+            responseListener.onResponse(mockInferenceServiceResults);
             return null;
-        }).when(mockModelValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any());
-        mockModelStop(true);
+        }).when(mockServiceIntegrationValidator).validate(eq(mockInferenceService), any(), eq(TIMEOUT), any());
 
-        underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener);
+        underTest.validate(mockInferenceService, customElandEmbeddingModel, TIMEOUT, mockActionListener);
 
-        verify(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any());
-        verify(mockInferenceService).stop(eq(mockModel), any());
-        verify(mockModelValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any());
-        verify(mockActionListener).delegateResponse(any());
-        verifyMockActionListenerAfterStopModelDeployment(true);
-        verifyNoMoreInteractions(mockModelValidator, mockInferenceService, mockModel, mockActionListener);
+        verify(mockServiceIntegrationValidator).validate(eq(mockInferenceService), any(), eq(TIMEOUT), any());
+        verify(mockActionListener).delegateFailureAndWrap(any());
+        verify(mockActionListener).onFailure(any(ElasticsearchStatusException.class));
+        verify(mockInferenceServiceResults, times(2)).getFirstEmbeddingSize();
+        verifyNoMoreInteractions(
+            mockServiceIntegrationValidator,
+            mockInferenceService,
+            mockCustomElandEmbeddingModel,
+            mockActionListener,
+            mockInferenceServiceResults
+        );
     }
 
-    public void testValidate_ModelValidationFailsAndModelDeploymentIsNotStopped() {
-        mockModelDeployment(true);
+    public void testValidate_ElandTextEmbeddingAndValidationReturnsInvalidResultsType() {
+        var dimensions = randomIntBetween(1, 10);
+        var mockInferenceServiceResults = mock(InferenceServiceResults.class);
+        when(mockInferenceServiceResults.getWriteableName()).thenReturn(randomAlphaOfLength(10));
+        CustomElandEmbeddingModel customElandEmbeddingModel = createCustomElandEmbeddingModel(true, dimensions);
+
         doAnswer(ans -> {
-            ActionListener<Model> responseListener = ans.getArgument(3);
-            responseListener.onFailure(new ElasticsearchStatusException("Model validation failed", RestStatus.INTERNAL_SERVER_ERROR));
+            ActionListener<InferenceServiceResults> responseListener = ans.getArgument(3);
+            responseListener.onResponse(mockInferenceServiceResults);
             return null;
-        }).when(mockModelValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any());
-        mockModelStop(false);
+        }).when(mockServiceIntegrationValidator).validate(eq(mockInferenceService), any(), eq(TIMEOUT), any());
 
-        underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener);
+        underTest.validate(mockInferenceService, customElandEmbeddingModel, TIMEOUT, mockActionListener);
 
-        verify(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any());
-        verify(mockInferenceService).stop(eq(mockModel), any());
-        verify(mockModelValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any());
-        verify(mockActionListener).delegateResponse(any());
-        verifyMockActionListenerAfterStopModelDeployment(false);
-        verifyNoMoreInteractions(mockModelValidator, mockInferenceService, mockModel, mockActionListener);
+        verify(mockServiceIntegrationValidator).validate(eq(mockInferenceService), any(), eq(TIMEOUT), any());
+        verify(mockActionListener).delegateFailureAndWrap(any());
+        verify(mockActionListener).onFailure(any(ElasticsearchStatusException.class));
+        verify(mockInferenceServiceResults).getWriteableName();
+        verifyNoMoreInteractions(
+            mockServiceIntegrationValidator,
+            mockInferenceService,
+            mockCustomElandEmbeddingModel,
+            mockActionListener,
+            mockInferenceServiceResults
+        );
     }
 
-    public void testValidate_ModelValidationSucceeds() {
-        mockModelDeployment(true);
-        mockModelStop(true);
+    public void testValidate_ElandTextEmbeddingModelDimensionsNotSetByUser() {
+        var dimensions = randomIntBetween(1, 10);
+        var mockInferenceServiceResults = mock(TextEmbeddingResults.class);
+        when(mockInferenceServiceResults.getFirstEmbeddingSize()).thenReturn(dimensions);
+        CustomElandEmbeddingModel customElandEmbeddingModel = createCustomElandEmbeddingModel(false, null);
 
-        underTest.validate(mockInferenceService, mockModel, TIMEOUT, mockActionListener);
+        var mockUpdatedModel = mock(CustomElandEmbeddingModel.class);
+        when(mockInferenceService.updateModelWithEmbeddingDetails(eq(customElandEmbeddingModel), eq(dimensions))).thenReturn(
+            mockUpdatedModel
+        );
 
-        verify(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any());
-        verify(mockModelValidator).validate(eq(mockInferenceService), eq(mockModel), eq(TIMEOUT), any());
-        verify(mockActionListener).delegateResponse(any());
-        verifyNoMoreInteractions(mockModelValidator, mockInferenceService, mockModel, mockActionListener);
+        doAnswer(ans -> {
+            ActionListener<InferenceServiceResults> responseListener = ans.getArgument(3);
+            responseListener.onResponse(mockInferenceServiceResults);
+            return null;
+        }).when(mockServiceIntegrationValidator).validate(eq(mockInferenceService), any(), eq(TIMEOUT), any());
+
+        underTest.validate(mockInferenceService, customElandEmbeddingModel, TIMEOUT, mockActionListener);
+
+        verify(mockServiceIntegrationValidator).validate(eq(mockInferenceService), any(), eq(TIMEOUT), any());
+        verify(mockActionListener).delegateFailureAndWrap(any());
+        verify(mockActionListener).onResponse(mockUpdatedModel);
+        verify(mockInferenceService).updateModelWithEmbeddingDetails(eq(customElandEmbeddingModel), eq(dimensions));
+        verify(mockInferenceServiceResults).getFirstEmbeddingSize();
+        verifyNoMoreInteractions(
+            mockServiceIntegrationValidator,
+            mockInferenceService,
+            mockCustomElandEmbeddingModel,
+            mockActionListener,
+            mockInferenceServiceResults
+        );
     }
 
-    private void mockModelDeployment(boolean modelDeploymentStarted) {
+    public void testValidate_ElandTextEmbeddingModelAndEmbeddingSizeRetrievalThrowsException() {
+        var mockInferenceServiceResults = mock(TextEmbeddingResults.class);
+        when(mockInferenceServiceResults.getFirstEmbeddingSize()).thenThrow(ElasticsearchStatusException.class);
+        CustomElandEmbeddingModel customElandEmbeddingModel = createCustomElandEmbeddingModel(false, null);
+
         doAnswer(ans -> {
-            ActionListener<Boolean> responseListener = ans.getArgument(2);
-            responseListener.onResponse(modelDeploymentStarted);
+            ActionListener<InferenceServiceResults> responseListener = ans.getArgument(3);
+            responseListener.onResponse(mockInferenceServiceResults);
             return null;
-        }).when(mockInferenceService).start(eq(mockModel), eq(TIMEOUT), any());
-    }
+        }).when(mockServiceIntegrationValidator).validate(eq(mockInferenceService), any(), eq(TIMEOUT), any());
 
-    private void mockModelStop(boolean modelDeploymentStopped) {
-        if (modelDeploymentStopped) {
-            doAnswer(ans -> {
-                ActionListener<Void> responseListener = ans.getArgument(1);
-                responseListener.onResponse(null);
-                return null;
-            }).when(mockInferenceService).stop(eq(mockModel), any());
-        } else {
-            doAnswer(ans -> {
-                ActionListener<Void> responseListener = ans.getArgument(1);
-                responseListener.onFailure(new ElasticsearchStatusException("Model stop failed", RestStatus.INTERNAL_SERVER_ERROR));
-                return null;
-            }).when(mockInferenceService).stop(eq(mockModel), any());
-        }
+        underTest.validate(mockInferenceService, customElandEmbeddingModel, TIMEOUT, mockActionListener);
+
+        verify(mockServiceIntegrationValidator).validate(eq(mockInferenceService), any(), eq(TIMEOUT), any());
+        verify(mockActionListener).delegateFailureAndWrap(any());
+        verify(mockActionListener).onFailure(any(ElasticsearchStatusException.class));
+        verify(mockInferenceServiceResults).getFirstEmbeddingSize();
+        verifyNoMoreInteractions(
+            mockServiceIntegrationValidator,
+            mockInferenceService,
+            mockCustomElandEmbeddingModel,
+            mockActionListener,
+            mockInferenceServiceResults
+        );
     }
 
-    private void verifyMockActionListenerAfterStopModelDeployment(boolean modelDeploymentStopped) {
-        verify(mockInferenceService).stop(eq(mockModel), any());
-        ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);
-        verify(mockActionListener).onFailure(exceptionCaptor.capture());
-        assertTrue(exceptionCaptor.getValue() instanceof ElasticsearchStatusException);
-        assertEquals(RestStatus.INTERNAL_SERVER_ERROR, ((ElasticsearchStatusException) exceptionCaptor.getValue()).status());
-
-        if (modelDeploymentStopped) {
-            assertFalse(exceptionCaptor.getValue().getMessage().contains(MODEL_VALIDATION_AND_STOP_FAILED_MESSAGE));
-        } else {
-            assertTrue(exceptionCaptor.getValue().getMessage().contains(MODEL_VALIDATION_AND_STOP_FAILED_MESSAGE));
+    private CustomElandEmbeddingModel createCustomElandEmbeddingModel(boolean areDimensionsSetByUser, Integer dimensions) {
+        var mockServiceSettings = mock(CustomElandInternalTextEmbeddingServiceSettings.class);
+        when(mockServiceSettings.modelId()).thenReturn(randomAlphaOfLength(10));
+        when(mockServiceSettings.dimensionsSetByUser()).thenReturn(areDimensionsSetByUser);
+        if (dimensions != null) {
+            when(mockServiceSettings.dimensions()).thenReturn(dimensions);
         }
+
+        return new CustomElandEmbeddingModel(
+            randomAlphaOfLength(10),
+            TaskType.TEXT_EMBEDDING,
+            randomAlphaOfLength(10),
+            mockServiceSettings,
+            ChunkingSettingsTests.createRandomChunkingSettings()
+        );
     }
 }