Browse Source

Adding inference endpoint validation for AzureAiStudioService (#113713) (#116347)

* Adding inference endpoint validation for AzureAiStudioService

* Run spotlessApple

* Update docs/changelog/113713.yaml

* Remove isInClusterService from InferenceService

* Run spotless apply

---------

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
Dan Rubinstein 11 months ago
parent
commit
0ac7f65096

+ 5 - 0
docs/changelog/113713.yaml

@@ -0,0 +1,5 @@
+pr: 113713
+summary: Adding inference endpoint validation for `AzureAiStudioService`
+area: Machine Learning
+type: enhancement
+issues: []

+ 9 - 0
server/src/main/java/org/elasticsearch/inference/InferenceService.java

@@ -178,6 +178,15 @@ public interface InferenceService extends Closeable {
         return model;
     }
 
+    /**
+     * Update a chat completion model's max tokens if required. The default behaviour is to just return the model.
+     * @param model The original model without updated embedding details
+     * @return The model with updated chat completion details
+     */
+    default Model updateModelWithChatCompletionDetails(Model model) {
+        return model;
+    }
+
     /**
      * Defines the version required across all clusters to use this service
      * @return {@link TransportVersion} specifying the version

+ 9 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

@@ -209,6 +209,15 @@ public final class ServiceUtils {
         );
     }
 
+    public static ElasticsearchStatusException invalidModelTypeForUpdateModelWithChatCompletionDetails(
+        Class<? extends Model> invalidModelType
+    ) {
+        throw new ElasticsearchStatusException(
+            Strings.format("Can't update chat completion details for model with unexpected type %s", invalidModelType),
+            RestStatus.BAD_REQUEST
+        );
+    }
+
     public static String missingSettingErrorMsg(String settingName, String scope) {
         return Strings.format("[%s] does not contain the required setting [%s]", scope, settingName);
     }

+ 39 - 48
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java

@@ -49,6 +49,7 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.Azure
 import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettings;
 import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
 import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
+import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
 
 import java.util.EnumSet;
 import java.util.HashMap;
@@ -315,62 +316,52 @@ public class AzureAiStudioService extends SenderService {
 
     @Override
     public void checkModelConfig(Model model, ActionListener<Model> listener) {
+        // TODO: Remove this function once all services have been updated to use the new model validators
+        ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
+    }
+
+    @Override
+    public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
         if (model instanceof AzureAiStudioEmbeddingsModel embeddingsModel) {
-            ServiceUtils.getEmbeddingSize(
-                model,
-                this,
-                listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateEmbeddingModelConfig(embeddingsModel, size)))
+            var serviceSettings = embeddingsModel.getServiceSettings();
+            var similarityFromModel = serviceSettings.similarity();
+            var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;
+
+            var updatedServiceSettings = new AzureAiStudioEmbeddingsServiceSettings(
+                serviceSettings.target(),
+                serviceSettings.provider(),
+                serviceSettings.endpointType(),
+                embeddingSize,
+                serviceSettings.dimensionsSetByUser(),
+                serviceSettings.maxInputTokens(),
+                similarityToUse,
+                serviceSettings.rateLimitSettings()
             );
-        } else if (model instanceof AzureAiStudioChatCompletionModel chatCompletionModel) {
-            listener.onResponse(updateChatCompletionModelConfig(chatCompletionModel));
+
+            return new AzureAiStudioEmbeddingsModel(embeddingsModel, updatedServiceSettings);
         } else {
-            listener.onResponse(model);
+            throw ServiceUtils.invalidModelTypeForUpdateModelWithEmbeddingDetails(model.getClass());
         }
     }
 
-    private AzureAiStudioEmbeddingsModel updateEmbeddingModelConfig(AzureAiStudioEmbeddingsModel embeddingsModel, int embeddingsSize) {
-        if (embeddingsModel.getServiceSettings().dimensionsSetByUser()
-            && embeddingsModel.getServiceSettings().dimensions() != null
-            && embeddingsModel.getServiceSettings().dimensions() != embeddingsSize) {
-            throw new ElasticsearchStatusException(
-                Strings.format(
-                    "The retrieved embeddings size [%s] does not match the size specified in the settings [%s]. "
-                        + "Please recreate the [%s] configuration with the correct dimensions",
-                    embeddingsSize,
-                    embeddingsModel.getServiceSettings().dimensions(),
-                    embeddingsModel.getConfigurations().getInferenceEntityId()
-                ),
-                RestStatus.BAD_REQUEST
+    @Override
+    public Model updateModelWithChatCompletionDetails(Model model) {
+        if (model instanceof AzureAiStudioChatCompletionModel chatCompletionModel) {
+            var taskSettings = chatCompletionModel.getTaskSettings();
+            var modelMaxNewTokens = taskSettings.maxNewTokens();
+            var maxNewTokensToUse = modelMaxNewTokens == null ? DEFAULT_MAX_NEW_TOKENS : modelMaxNewTokens;
+
+            var updatedTaskSettings = new AzureAiStudioChatCompletionTaskSettings(
+                taskSettings.temperature(),
+                taskSettings.topP(),
+                taskSettings.doSample(),
+                maxNewTokensToUse
             );
-        }
-
-        var similarityFromModel = embeddingsModel.getServiceSettings().similarity();
-        var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;
-
-        AzureAiStudioEmbeddingsServiceSettings serviceSettings = new AzureAiStudioEmbeddingsServiceSettings(
-            embeddingsModel.getServiceSettings().target(),
-            embeddingsModel.getServiceSettings().provider(),
-            embeddingsModel.getServiceSettings().endpointType(),
-            embeddingsSize,
-            embeddingsModel.getServiceSettings().dimensionsSetByUser(),
-            embeddingsModel.getServiceSettings().maxInputTokens(),
-            similarityToUse,
-            embeddingsModel.getServiceSettings().rateLimitSettings()
-        );
-
-        return new AzureAiStudioEmbeddingsModel(embeddingsModel, serviceSettings);
-    }
 
-    private AzureAiStudioChatCompletionModel updateChatCompletionModelConfig(AzureAiStudioChatCompletionModel chatCompletionModel) {
-        var modelMaxNewTokens = chatCompletionModel.getTaskSettings().maxNewTokens();
-        var maxNewTokensToUse = modelMaxNewTokens == null ? DEFAULT_MAX_NEW_TOKENS : modelMaxNewTokens;
-        var updatedTaskSettings = new AzureAiStudioChatCompletionTaskSettings(
-            chatCompletionModel.getTaskSettings().temperature(),
-            chatCompletionModel.getTaskSettings().topP(),
-            chatCompletionModel.getTaskSettings().doSample(),
-            maxNewTokensToUse
-        );
-        return new AzureAiStudioChatCompletionModel(chatCompletionModel, updatedTaskSettings);
+            return new AzureAiStudioChatCompletionModel(chatCompletionModel, updatedTaskSettings);
+        } else {
+            throw ServiceUtils.invalidModelTypeForUpdateModelWithChatCompletionDetails(model.getClass());
+        }
     }
 
     private static void checkProviderAndEndpointTypeForTask(

+ 32 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidator.java

@@ -0,0 +1,32 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.validation;
+
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.inference.InferenceService;
+import org.elasticsearch.inference.Model;
+
+public class ChatCompletionModelValidator implements ModelValidator {
+
+    private final ServiceIntegrationValidator serviceIntegrationValidator;
+
+    public ChatCompletionModelValidator(ServiceIntegrationValidator serviceIntegrationValidator) {
+        this.serviceIntegrationValidator = serviceIntegrationValidator;
+    }
+
+    @Override
+    public void validate(InferenceService service, Model model, ActionListener<Model> listener) {
+        serviceIntegrationValidator.validate(service, model, listener.delegateFailureAndWrap((delegate, r) -> {
+            delegate.onResponse(postValidate(service, model));
+        }));
+    }
+
+    private Model postValidate(InferenceService service, Model model) {
+        return service.updateModelWithChatCompletionDetails(model);
+    }
+}

+ 4 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java

@@ -20,7 +20,10 @@ public class ModelValidatorBuilder {
             case TEXT_EMBEDDING -> {
                 return new TextEmbeddingModelValidator(new SimpleServiceIntegrationValidator());
             }
-            case SPARSE_EMBEDDING, RERANK, COMPLETION, ANY -> {
+            case COMPLETION -> {
+                return new ChatCompletionModelValidator(new SimpleServiceIntegrationValidator());
+            }
+            case SPARSE_EMBEDDING, RERANK, ANY -> {
                 return new SimpleModelValidator(new SimpleServiceIntegrationValidator());
             }
             default -> throw new IllegalArgumentException(Strings.format("Can't validate inference model of for task type %s ", taskType));

+ 107 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java

@@ -53,6 +53,7 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.Azure
 import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModelTests;
 import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettingsTests;
 import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsTaskSettingsTests;
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
 import org.hamcrest.CoreMatchers;
 import org.hamcrest.MatcherAssert;
 import org.hamcrest.Matchers;
@@ -973,6 +974,112 @@ public class AzureAiStudioServiceTests extends ESTestCase {
         }
     }
 
+    public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+        try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
+            var model = AzureAiStudioChatCompletionModelTests.createModel(
+                randomAlphaOfLength(10),
+                randomAlphaOfLength(10),
+                randomFrom(AzureAiStudioProvider.values()),
+                randomFrom(AzureAiStudioEndpointType.values()),
+                randomAlphaOfLength(10)
+            );
+            assertThrows(
+                ElasticsearchStatusException.class,
+                () -> { service.updateModelWithEmbeddingDetails(model, randomNonNegativeInt()); }
+            );
+        }
+    }
+
+    public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException {
+        testUpdateModelWithEmbeddingDetails_Successful(null);
+    }
+
+    public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel() throws IOException {
+        testUpdateModelWithEmbeddingDetails_Successful(randomFrom(SimilarityMeasure.values()));
+    }
+
+    private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+        try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
+            var embeddingSize = randomNonNegativeInt();
+            var model = AzureAiStudioEmbeddingsModelTests.createModel(
+                randomAlphaOfLength(10),
+                randomAlphaOfLength(10),
+                randomFrom(AzureAiStudioProvider.values()),
+                randomFrom(AzureAiStudioEndpointType.values()),
+                randomAlphaOfLength(10),
+                randomNonNegativeInt(),
+                randomBoolean(),
+                randomNonNegativeInt(),
+                similarityMeasure,
+                randomAlphaOfLength(10),
+                RateLimitSettingsTests.createRandom()
+            );
+
+            Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
+
+            SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null ? SimilarityMeasure.DOT_PRODUCT : similarityMeasure;
+            assertEquals(expectedSimilarityMeasure, updatedModel.getServiceSettings().similarity());
+            assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue());
+        }
+    }
+
+    public void testUpdateModelWithChatCompletionDetails_InvalidModelProvided() throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+        try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
+            var model = AzureAiStudioEmbeddingsModelTests.createModel(
+                randomAlphaOfLength(10),
+                randomAlphaOfLength(10),
+                randomFrom(AzureAiStudioProvider.values()),
+                randomFrom(AzureAiStudioEndpointType.values()),
+                randomAlphaOfLength(10),
+                randomNonNegativeInt(),
+                randomBoolean(),
+                randomNonNegativeInt(),
+                randomFrom(SimilarityMeasure.values()),
+                randomAlphaOfLength(10),
+                RateLimitSettingsTests.createRandom()
+            );
+            assertThrows(ElasticsearchStatusException.class, () -> { service.updateModelWithChatCompletionDetails(model); });
+        }
+    }
+
+    public void testUpdateModelWithChatCompletionDetails_NullSimilarityInOriginalModel() throws IOException {
+        testUpdateModelWithChatCompletionDetails_Successful(null);
+    }
+
+    public void testUpdateModelWithChatCompletionDetails_NonNullSimilarityInOriginalModel() throws IOException {
+        testUpdateModelWithChatCompletionDetails_Successful(randomNonNegativeInt());
+    }
+
+    private void testUpdateModelWithChatCompletionDetails_Successful(Integer maxNewTokens) throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+        try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
+            var model = AzureAiStudioChatCompletionModelTests.createModel(
+                randomAlphaOfLength(10),
+                randomAlphaOfLength(10),
+                randomFrom(AzureAiStudioProvider.values()),
+                randomFrom(AzureAiStudioEndpointType.values()),
+                randomAlphaOfLength(10),
+                randomDouble(),
+                randomDouble(),
+                randomBoolean(),
+                maxNewTokens,
+                RateLimitSettingsTests.createRandom()
+            );
+
+            Model updatedModel = service.updateModelWithChatCompletionDetails(model);
+            assertThat(updatedModel, instanceOf(AzureAiStudioChatCompletionModel.class));
+            AzureAiStudioChatCompletionTaskSettings updatedTaskSettings = (AzureAiStudioChatCompletionTaskSettings) updatedModel
+                .getTaskSettings();
+            Integer expectedMaxNewTokens = maxNewTokens == null
+                ? AzureAiStudioChatCompletionTaskSettings.DEFAULT_MAX_NEW_TOKENS
+                : maxNewTokens;
+            assertEquals(expectedMaxNewTokens, updatedTaskSettings.maxNewTokens());
+        }
+    }
+
     public void testInfer_ThrowsErrorWhenModelIsNotAzureAiStudioModel() throws IOException {
         var sender = mock(Sender.class);
 

+ 92 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidatorTests.java

@@ -0,0 +1,92 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.validation;
+
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.inference.InferenceService;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.inference.Model;
+import org.elasticsearch.test.ESTestCase;
+import org.junit.Before;
+import org.mockito.Mock;
+
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
+import static org.mockito.Mockito.when;
+import static org.mockito.MockitoAnnotations.openMocks;
+
+public class ChatCompletionModelValidatorTests extends ESTestCase {
+    @Mock
+    private ServiceIntegrationValidator mockServiceIntegrationValidator;
+    @Mock
+    private InferenceService mockInferenceService;
+    @Mock
+    private InferenceServiceResults mockInferenceServiceResults;
+    @Mock
+    private Model mockModel;
+    @Mock
+    private ActionListener<Model> mockActionListener;
+
+    private ChatCompletionModelValidator underTest;
+
+    @Before
+    public void setup() {
+        openMocks(this);
+
+        underTest = new ChatCompletionModelValidator(mockServiceIntegrationValidator);
+    }
+
+    public void testValidate_ServiceIntegrationValidatorThrowsException() {
+        doThrow(ElasticsearchStatusException.class).when(mockServiceIntegrationValidator)
+            .validate(eq(mockInferenceService), eq(mockModel), any());
+
+        assertThrows(
+            ElasticsearchStatusException.class,
+            () -> { underTest.validate(mockInferenceService, mockModel, mockActionListener); }
+        );
+
+        verify(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), any());
+        verify(mockActionListener).delegateFailureAndWrap(any());
+        verifyNoMoreInteractions(
+            mockServiceIntegrationValidator,
+            mockInferenceService,
+            mockInferenceServiceResults,
+            mockModel,
+            mockActionListener
+        );
+    }
+
+    public void testValidate_ChatCompletionDetailsUpdated() {
+        when(mockActionListener.delegateFailureAndWrap(any())).thenCallRealMethod();
+        when(mockInferenceService.updateModelWithChatCompletionDetails(mockModel)).thenReturn(mockModel);
+        doAnswer(ans -> {
+            ActionListener<InferenceServiceResults> responseListener = ans.getArgument(2);
+            responseListener.onResponse(mockInferenceServiceResults);
+            return null;
+        }).when(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), any());
+
+        underTest.validate(mockInferenceService, mockModel, mockActionListener);
+
+        verify(mockServiceIntegrationValidator).validate(eq(mockInferenceService), eq(mockModel), any());
+        verify(mockActionListener).delegateFailureAndWrap(any());
+        verify(mockActionListener).onResponse(mockModel);
+        verify(mockInferenceService).updateModelWithChatCompletionDetails(mockModel);
+        verifyNoMoreInteractions(
+            mockServiceIntegrationValidator,
+            mockInferenceService,
+            mockInferenceServiceResults,
+            mockModel,
+            mockActionListener
+        );
+    }
+}

+ 1 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java

@@ -34,7 +34,7 @@ public class ModelValidatorBuilderTests extends ESTestCase {
             TaskType.RERANK,
             SimpleModelValidator.class,
             TaskType.COMPLETION,
-            SimpleModelValidator.class,
+            ChatCompletionModelValidator.class,
             TaskType.ANY,
             SimpleModelValidator.class
         );