Browse Source

Adding endpoint creation validation to ElasticInferenceService (#117642)

* Adding endpoint creation validation to ElasticInferenceService

* Fix unit tests

* Update docs/changelog/117642.yaml

---------

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
Dan Rubinstein 8 months ago
parent
commit
bea8df3c8e

+ 5 - 0
docs/changelog/117642.yaml

@@ -0,0 +1,5 @@
+pr: 117642
+summary: Adding endpoint creation validation to `ElasticInferenceService`
+area: Machine Learning
+type: enhancement
+issues: []

+ 3 - 17
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

@@ -54,6 +54,7 @@ import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticI
 import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
 import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
 import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
+import org.elasticsearch.xpack.inference.services.validation.ModelValidatorBuilder;
 import org.elasticsearch.xpack.inference.telemetry.TraceContext;
 
 import java.util.ArrayList;
@@ -557,11 +558,8 @@ public class ElasticInferenceService extends SenderService {
 
     @Override
     public void checkModelConfig(Model model, ActionListener<Model> listener) {
-        if (model instanceof ElasticInferenceServiceSparseEmbeddingsModel embeddingsModel) {
-            listener.onResponse(updateModelWithEmbeddingDetails(embeddingsModel));
-        } else {
-            listener.onResponse(model);
-        }
+        // TODO: Remove this function once all services have been updated to use the new model validators
+        ModelValidatorBuilder.buildModelValidator(model.getTaskType()).validate(this, model, listener);
     }
 
     private static List<ChunkedInference> translateToChunkedResults(InferenceInputs inputs, InferenceServiceResults inferenceResults) {
@@ -576,18 +574,6 @@ public class ElasticInferenceService extends SenderService {
         }
     }
 
-    private ElasticInferenceServiceSparseEmbeddingsModel updateModelWithEmbeddingDetails(
-        ElasticInferenceServiceSparseEmbeddingsModel model
-    ) {
-        ElasticInferenceServiceSparseEmbeddingsServiceSettings serviceSettings = new ElasticInferenceServiceSparseEmbeddingsServiceSettings(
-            model.getServiceSettings().modelId(),
-            model.getServiceSettings().maxInputTokens(),
-            model.getServiceSettings().rateLimitSettings()
-        );
-
-        return new ElasticInferenceServiceSparseEmbeddingsModel(model, serviceSettings);
-    }
-
     private TraceContext getCurrentTraceInfo() {
         var threadPool = getServiceComponents().threadPool();
 

+ 14 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java

@@ -317,7 +317,21 @@ public class ElasticInferenceServiceTests extends ESTestCase {
 
     public void testCheckModelConfig_ReturnsNewModelReference() throws IOException {
         var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
         try (var service = createService(senderFactory, getUrl(webServer))) {
+            String responseJson = """
+                {
+                    "data": [
+                        {
+                            "hello": 2.1259406,
+                            "greet": 1.7073475
+                        }
+                    ]
+                }
+                """;
+
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
             var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(getUrl(webServer), "my-model-id");
             PlainActionFuture<Model> listener = new PlainActionFuture<>();
             service.checkModelConfig(model, listener);