Browse Source

Add dense vector inference mock service for testing (#105655)

Carlos Delgado 1 year ago
parent
commit
9c72157bb7

+ 30 - 9
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

@@ -50,11 +50,11 @@ public class InferenceBaseRestTest extends ESRestTestCase {
         return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build();
         return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build();
     }
     }
 
 
-    static String mockServiceModelConfig() {
-        return mockServiceModelConfig(null);
+    static String mockSparseServiceModelConfig() {
+        return mockSparseServiceModelConfig(null);
     }
     }
 
 
-    static String mockServiceModelConfig(@Nullable TaskType taskTypeInBody) {
+    static String mockSparseServiceModelConfig(@Nullable TaskType taskTypeInBody) {
         var taskType = taskTypeInBody == null ? "" : "\"task_type\": \"" + taskTypeInBody + "\",";
         var taskType = taskTypeInBody == null ? "" : "\"task_type\": \"" + taskTypeInBody + "\",";
         return Strings.format("""
         return Strings.format("""
             {
             {
@@ -72,7 +72,7 @@ public class InferenceBaseRestTest extends ESRestTestCase {
             """, taskType);
             """, taskType);
     }
     }
 
 
-    static String mockServiceModelConfig(@Nullable TaskType taskTypeInBody, boolean shouldReturnHiddenField) {
+    static String mockSparseServiceModelConfig(@Nullable TaskType taskTypeInBody, boolean shouldReturnHiddenField) {
         var taskType = taskTypeInBody == null ? "" : "\"task_type\": \"" + taskTypeInBody + "\",";
         var taskType = taskTypeInBody == null ? "" : "\"task_type\": \"" + taskTypeInBody + "\",";
         return Strings.format("""
         return Strings.format("""
             {
             {
@@ -91,6 +91,22 @@ public class InferenceBaseRestTest extends ESRestTestCase {
             """, taskType, shouldReturnHiddenField);
             """, taskType, shouldReturnHiddenField);
     }
     }
 
 
+    static String mockDenseServiceModelConfig() {
+        return """
+            {
+              "task_type": "text_embedding",
+              "service": "text_embedding_test_service",
+              "service_settings": {
+                "model": "my_dense_vector_model",
+                "api_key": "abc64",
+                "dimensions": 246
+              },
+              "task_settings": {
+              }
+            }
+            """;
+    }
+
     protected void deleteModel(String modelId) throws IOException {
     protected void deleteModel(String modelId) throws IOException {
         var request = new Request("DELETE", "_inference/" + modelId);
         var request = new Request("DELETE", "_inference/" + modelId);
         var response = client().performRequest(request);
         var response = client().performRequest(request);
@@ -200,11 +216,16 @@ public class InferenceBaseRestTest extends ESRestTestCase {
 
 
     @SuppressWarnings("unchecked")
     @SuppressWarnings("unchecked")
     protected void assertNonEmptyInferenceResults(Map<String, Object> resultMap, int expectedNumberOfResults, TaskType taskType) {
     protected void assertNonEmptyInferenceResults(Map<String, Object> resultMap, int expectedNumberOfResults, TaskType taskType) {
-        if (taskType == TaskType.SPARSE_EMBEDDING) {
-            var results = (List<Map<String, Object>>) resultMap.get(TaskType.SPARSE_EMBEDDING.toString());
-            assertThat(results, hasSize(expectedNumberOfResults));
-        } else {
-            fail("test with task type [" + taskType + "] are not supported yet");
+        switch (taskType) {
+            case SPARSE_EMBEDDING -> {
+                var results = (List<Map<String, Object>>) resultMap.get(TaskType.SPARSE_EMBEDDING.toString());
+                assertThat(results, hasSize(expectedNumberOfResults));
+            }
+            case TEXT_EMBEDDING -> {
+                var results = (List<Map<String, Object>>) resultMap.get(TaskType.TEXT_EMBEDDING.toString());
+                assertThat(results, hasSize(expectedNumberOfResults));
+            }
+            default -> fail("test with task type [" + taskType + "] are not supported yet");
         }
         }
     }
     }
 
 

+ 6 - 6
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

@@ -25,10 +25,10 @@ public class InferenceCrudIT extends InferenceBaseRestTest {
     @SuppressWarnings("unchecked")
     @SuppressWarnings("unchecked")
     public void testGet() throws IOException {
     public void testGet() throws IOException {
         for (int i = 0; i < 5; i++) {
         for (int i = 0; i < 5; i++) {
-            putModel("se_model_" + i, mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
+            putModel("se_model_" + i, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
         }
         }
         for (int i = 0; i < 4; i++) {
         for (int i = 0; i < 4; i++) {
-            putModel("te_model_" + i, mockServiceModelConfig(), TaskType.TEXT_EMBEDDING);
+            putModel("te_model_" + i, mockSparseServiceModelConfig(), TaskType.TEXT_EMBEDDING);
         }
         }
 
 
         var getAllModels = (List<Map<String, Object>>) getAllModels().get("models");
         var getAllModels = (List<Map<String, Object>>) getAllModels().get("models");
@@ -59,7 +59,7 @@ public class InferenceCrudIT extends InferenceBaseRestTest {
     }
     }
 
 
     public void testGetModelWithWrongTaskType() throws IOException {
     public void testGetModelWithWrongTaskType() throws IOException {
-        putModel("sparse_embedding_model", mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
+        putModel("sparse_embedding_model", mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
         var e = expectThrows(ResponseException.class, () -> getModels("sparse_embedding_model", TaskType.TEXT_EMBEDDING));
         var e = expectThrows(ResponseException.class, () -> getModels("sparse_embedding_model", TaskType.TEXT_EMBEDDING));
         assertThat(
         assertThat(
             e.getMessage(),
             e.getMessage(),
@@ -68,7 +68,7 @@ public class InferenceCrudIT extends InferenceBaseRestTest {
     }
     }
 
 
     public void testDeleteModelWithWrongTaskType() throws IOException {
     public void testDeleteModelWithWrongTaskType() throws IOException {
-        putModel("sparse_embedding_model", mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
+        putModel("sparse_embedding_model", mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
         var e = expectThrows(ResponseException.class, () -> deleteModel("sparse_embedding_model", TaskType.TEXT_EMBEDDING));
         var e = expectThrows(ResponseException.class, () -> deleteModel("sparse_embedding_model", TaskType.TEXT_EMBEDDING));
         assertThat(
         assertThat(
             e.getMessage(),
             e.getMessage(),
@@ -79,7 +79,7 @@ public class InferenceCrudIT extends InferenceBaseRestTest {
     @SuppressWarnings("unchecked")
     @SuppressWarnings("unchecked")
     public void testGetModelWithAnyTaskType() throws IOException {
     public void testGetModelWithAnyTaskType() throws IOException {
         String inferenceEntityId = "sparse_embedding_model";
         String inferenceEntityId = "sparse_embedding_model";
-        putModel(inferenceEntityId, mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
+        putModel(inferenceEntityId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
         var singleModel = (List<Map<String, Object>>) getModels(inferenceEntityId, TaskType.ANY).get("models");
         var singleModel = (List<Map<String, Object>>) getModels(inferenceEntityId, TaskType.ANY).get("models");
         assertEquals(inferenceEntityId, singleModel.get(0).get("model_id"));
         assertEquals(inferenceEntityId, singleModel.get(0).get("model_id"));
         assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get(0).get("task_type"));
         assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get(0).get("task_type"));
@@ -88,7 +88,7 @@ public class InferenceCrudIT extends InferenceBaseRestTest {
     @SuppressWarnings("unchecked")
     @SuppressWarnings("unchecked")
     public void testApisWithoutTaskType() throws IOException {
     public void testApisWithoutTaskType() throws IOException {
         String modelId = "no_task_type_in_url";
         String modelId = "no_task_type_in_url";
-        putModel(modelId, mockServiceModelConfig(TaskType.SPARSE_EMBEDDING));
+        putModel(modelId, mockSparseServiceModelConfig(TaskType.SPARSE_EMBEDDING));
         var singleModel = (List<Map<String, Object>>) getModel(modelId).get("models");
         var singleModel = (List<Map<String, Object>>) getModel(modelId).get("models");
         assertEquals(modelId, singleModel.get(0).get("model_id"));
         assertEquals(modelId, singleModel.get(0).get("model_id"));
         assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get(0).get("task_type"));
         assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get(0).get("task_type"));

+ 65 - 0
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockDenseInferenceServiceIT.java

@@ -0,0 +1,65 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference;
+
+import org.elasticsearch.inference.TaskType;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
+public class MockDenseInferenceServiceIT extends InferenceBaseRestTest {
+
+    @SuppressWarnings("unchecked")
+    public void testMockService() throws IOException {
+        String inferenceEntityId = "test-mock";
+        var putModel = putModel(inferenceEntityId, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING);
+        var getModels = getModels(inferenceEntityId, TaskType.TEXT_EMBEDDING);
+        var model = ((List<Map<String, Object>>) getModels.get("models")).get(0);
+
+        for (var modelMap : List.of(putModel, model)) {
+            assertEquals(inferenceEntityId, modelMap.get("model_id"));
+            assertEquals(TaskType.TEXT_EMBEDDING, TaskType.fromString((String) modelMap.get("task_type")));
+            assertEquals("text_embedding_test_service", modelMap.get("service"));
+        }
+
+        // The response is randomly generated, the input can be anything
+        var inference = inferOnMockService(inferenceEntityId, List.of(randomAlphaOfLength(10)));
+        assertNonEmptyInferenceResults(inference, 1, TaskType.TEXT_EMBEDDING);
+    }
+
+    public void testMockServiceWithMultipleInputs() throws IOException {
+        String inferenceEntityId = "test-mock-with-multi-inputs";
+        putModel(inferenceEntityId, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING);
+
+        // The response is randomly generated, the input can be anything
+        var inference = inferOnMockService(
+            inferenceEntityId,
+            TaskType.TEXT_EMBEDDING,
+            List.of(randomAlphaOfLength(5), randomAlphaOfLength(10), randomAlphaOfLength(15))
+        );
+
+        assertNonEmptyInferenceResults(inference, 3, TaskType.TEXT_EMBEDDING);
+    }
+
+    @SuppressWarnings("unchecked")
+    public void testMockService_DoesNotReturnSecretsInGetResponse() throws IOException {
+        String inferenceEntityId = "test-mock";
+        var putModel = putModel(inferenceEntityId, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING);
+        var getModels = getModels(inferenceEntityId, TaskType.TEXT_EMBEDDING);
+        var model = ((List<Map<String, Object>>) getModels.get("models")).get(0);
+
+        var serviceSettings = (Map<String, Object>) model.get("service_settings");
+        assertNull(serviceSettings.get("api_key"));
+        assertNotNull(serviceSettings.get("model"));
+
+        var putServiceSettings = (Map<String, Object>) putModel.get("service_settings");
+        assertNull(putServiceSettings.get("api_key"));
+        assertNotNull(putServiceSettings.get("model"));
+    }
+}

+ 6 - 6
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockInferenceServiceIT.java → x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockSparseInferenceServiceIT.java

@@ -15,12 +15,12 @@ import java.util.Map;
 
 
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.is;
 
 
-public class MockInferenceServiceIT extends InferenceBaseRestTest {
+public class MockSparseInferenceServiceIT extends InferenceBaseRestTest {
 
 
     @SuppressWarnings("unchecked")
     @SuppressWarnings("unchecked")
     public void testMockService() throws IOException {
     public void testMockService() throws IOException {
         String inferenceEntityId = "test-mock";
         String inferenceEntityId = "test-mock";
-        var putModel = putModel(inferenceEntityId, mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
+        var putModel = putModel(inferenceEntityId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
         var getModels = getModels(inferenceEntityId, TaskType.SPARSE_EMBEDDING);
         var getModels = getModels(inferenceEntityId, TaskType.SPARSE_EMBEDDING);
         var model = ((List<Map<String, Object>>) getModels.get("models")).get(0);
         var model = ((List<Map<String, Object>>) getModels.get("models")).get(0);
 
 
@@ -37,7 +37,7 @@ public class MockInferenceServiceIT extends InferenceBaseRestTest {
 
 
     public void testMockServiceWithMultipleInputs() throws IOException {
     public void testMockServiceWithMultipleInputs() throws IOException {
         String inferenceEntityId = "test-mock-with-multi-inputs";
         String inferenceEntityId = "test-mock-with-multi-inputs";
-        putModel(inferenceEntityId, mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
+        putModel(inferenceEntityId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
 
 
         // The response is randomly generated, the input can be anything
         // The response is randomly generated, the input can be anything
         var inference = inferOnMockService(
         var inference = inferOnMockService(
@@ -52,7 +52,7 @@ public class MockInferenceServiceIT extends InferenceBaseRestTest {
     @SuppressWarnings("unchecked")
     @SuppressWarnings("unchecked")
     public void testMockService_DoesNotReturnSecretsInGetResponse() throws IOException {
     public void testMockService_DoesNotReturnSecretsInGetResponse() throws IOException {
         String inferenceEntityId = "test-mock";
         String inferenceEntityId = "test-mock";
-        var putModel = putModel(inferenceEntityId, mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
+        var putModel = putModel(inferenceEntityId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
         var getModels = getModels(inferenceEntityId, TaskType.SPARSE_EMBEDDING);
         var getModels = getModels(inferenceEntityId, TaskType.SPARSE_EMBEDDING);
         var model = ((List<Map<String, Object>>) getModels.get("models")).get(0);
         var model = ((List<Map<String, Object>>) getModels.get("models")).get(0);
 
 
@@ -68,7 +68,7 @@ public class MockInferenceServiceIT extends InferenceBaseRestTest {
     @SuppressWarnings("unchecked")
     @SuppressWarnings("unchecked")
     public void testMockService_DoesNotReturnHiddenField_InModelResponses() throws IOException {
     public void testMockService_DoesNotReturnHiddenField_InModelResponses() throws IOException {
         String inferenceEntityId = "test-mock";
         String inferenceEntityId = "test-mock";
-        var putModel = putModel(inferenceEntityId, mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
+        var putModel = putModel(inferenceEntityId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
         var getModels = getModels(inferenceEntityId, TaskType.SPARSE_EMBEDDING);
         var getModels = getModels(inferenceEntityId, TaskType.SPARSE_EMBEDDING);
         var model = ((List<Map<String, Object>>) getModels.get("models")).get(0);
         var model = ((List<Map<String, Object>>) getModels.get("models")).get(0);
 
 
@@ -87,7 +87,7 @@ public class MockInferenceServiceIT extends InferenceBaseRestTest {
     @SuppressWarnings("unchecked")
     @SuppressWarnings("unchecked")
     public void testMockService_DoesReturnHiddenField_InModelResponses() throws IOException {
     public void testMockService_DoesReturnHiddenField_InModelResponses() throws IOException {
         String inferenceEntityId = "test-mock";
         String inferenceEntityId = "test-mock";
-        var putModel = putModel(inferenceEntityId, mockServiceModelConfig(null, true), TaskType.SPARSE_EMBEDDING);
+        var putModel = putModel(inferenceEntityId, mockSparseServiceModelConfig(null, true), TaskType.SPARSE_EMBEDDING);
         var getModels = getModels(inferenceEntityId, TaskType.SPARSE_EMBEDDING);
         var getModels = getModels(inferenceEntityId, TaskType.SPARSE_EMBEDDING);
         var model = ((List<Map<String, Object>>) getModels.get("models")).get(0);
         var model = ((List<Map<String, Object>>) getModels.get("models")).get(0);
 
 

+ 206 - 0
x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java

@@ -0,0 +1,206 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.mock;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.inference.InferenceService;
+import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ModelSecrets;
+import org.elasticsearch.inference.SecretSettings;
+import org.elasticsearch.inference.ServiceSettings;
+import org.elasticsearch.inference.TaskSettings;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.Map;
+
+public abstract class AbstractTestInferenceService implements InferenceService {
+
+    @Override
+    public TransportVersion getMinimalSupportedVersion() {
+        return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests
+    }
+
+    @SuppressWarnings("unchecked")
+    protected static Map<String, Object> getTaskSettingsMap(Map<String, Object> settings) {
+        Map<String, Object> taskSettingsMap;
+        // task settings are optional
+        if (settings.containsKey(ModelConfigurations.TASK_SETTINGS)) {
+            taskSettingsMap = (Map<String, Object>) settings.remove(ModelConfigurations.TASK_SETTINGS);
+        } else {
+            taskSettingsMap = Map.of();
+        }
+
+        return taskSettingsMap;
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public TestServiceModel parsePersistedConfigWithSecrets(
+        String modelId,
+        TaskType taskType,
+        Map<String, Object> config,
+        Map<String, Object> secrets
+    ) {
+        var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);
+        var secretSettingsMap = (Map<String, Object>) secrets.remove(ModelSecrets.SECRET_SETTINGS);
+
+        var serviceSettings = getServiceSettingsFromMap(serviceSettingsMap);
+        var secretSettings = TestSecretSettings.fromMap(secretSettingsMap);
+
+        var taskSettingsMap = getTaskSettingsMap(config);
+        var taskSettings = TestTaskSettings.fromMap(taskSettingsMap);
+
+        return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Model parsePersistedConfig(String modelId, TaskType taskType, Map<String, Object> config) {
+        var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);
+
+        var serviceSettings = getServiceSettingsFromMap(serviceSettingsMap);
+
+        var taskSettingsMap = getTaskSettingsMap(config);
+        var taskSettings = TestTaskSettings.fromMap(taskSettingsMap);
+
+        return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, null);
+    }
+
+    protected abstract ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceSettingsMap);
+
+    @Override
+    public void start(Model model, ActionListener<Boolean> listener) {
+        listener.onResponse(true);
+    }
+
+    @Override
+    public void close() throws IOException {}
+
+    public static class TestServiceModel extends Model {
+
+        public TestServiceModel(
+            String modelId,
+            TaskType taskType,
+            String service,
+            ServiceSettings serviceSettings,
+            TestTaskSettings taskSettings,
+            TestSecretSettings secretSettings
+        ) {
+            super(new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secretSettings));
+        }
+
+        @Override
+        public TestDenseInferenceServiceExtension.TestServiceSettings getServiceSettings() {
+            return (TestDenseInferenceServiceExtension.TestServiceSettings) super.getServiceSettings();
+        }
+
+        @Override
+        public TestTaskSettings getTaskSettings() {
+            return (TestTaskSettings) super.getTaskSettings();
+        }
+
+        @Override
+        public TestSecretSettings getSecretSettings() {
+            return (TestSecretSettings) super.getSecretSettings();
+        }
+    }
+
+    public record TestTaskSettings(Integer temperature) implements TaskSettings {
+
+        static final String NAME = "test_task_settings";
+
+        public static TestTaskSettings fromMap(Map<String, Object> map) {
+            Integer temperature = (Integer) map.remove("temperature");
+            return new TestTaskSettings(temperature);
+        }
+
+        public TestTaskSettings(StreamInput in) throws IOException {
+            this(in.readOptionalVInt());
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeOptionalVInt(temperature);
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            if (temperature != null) {
+                builder.field("temperature", temperature);
+            }
+            builder.endObject();
+            return builder;
+        }
+
+        @Override
+        public String getWriteableName() {
+            return NAME;
+        }
+
+        @Override
+        public TransportVersion getMinimalSupportedVersion() {
+            return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests
+        }
+    }
+
+    public record TestSecretSettings(String apiKey) implements SecretSettings {
+
+        static final String NAME = "test_secret_settings";
+
+        public static TestSecretSettings fromMap(Map<String, Object> map) {
+            ValidationException validationException = new ValidationException();
+
+            String apiKey = (String) map.remove("api_key");
+
+            if (apiKey == null) {
+                validationException.addValidationError("missing api_key");
+            }
+
+            if (validationException.validationErrors().isEmpty() == false) {
+                throw validationException;
+            }
+
+            return new TestSecretSettings(apiKey);
+        }
+
+        public TestSecretSettings(StreamInput in) throws IOException {
+            this(in.readString());
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeString(apiKey);
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            builder.field("api_key", apiKey);
+            builder.endObject();
+            return builder;
+        }
+
+        @Override
+        public String getWriteableName() {
+            return NAME;
+        }
+
+        @Override
+        public TransportVersion getMinimalSupportedVersion() {
+            return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests
+        }
+    }
+}

+ 224 - 0
x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java

@@ -0,0 +1,224 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.mock;
+
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.inference.ChunkedInferenceServiceResults;
+import org.elasticsearch.inference.ChunkingOptions;
+import org.elasticsearch.inference.InferenceServiceExtension;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ServiceSettings;
+import org.elasticsearch.inference.SimilarityMeasure;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.xcontent.ToXContentObject;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+public class TestDenseInferenceServiceExtension implements InferenceServiceExtension {
+    @Override
+    public List<Factory> getInferenceServiceFactories() {
+        return List.of(TestInferenceService::new);
+    }
+
+    public static class TestInferenceService extends AbstractTestInferenceService {
+        private static final String NAME = "text_embedding_test_service";
+
+        public TestInferenceService(InferenceServiceFactoryContext context) {}
+
+        @Override
+        public String name() {
+            return NAME;
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void parseRequestConfig(
+            String modelId,
+            TaskType taskType,
+            Map<String, Object> config,
+            Set<String> platformArchitectures,
+            ActionListener<Model> parsedModelListener
+        ) {
+            var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);
+            var serviceSettings = TestServiceSettings.fromMap(serviceSettingsMap);
+            var secretSettings = TestSecretSettings.fromMap(serviceSettingsMap);
+
+            var taskSettingsMap = getTaskSettingsMap(config);
+            var taskSettings = TestTaskSettings.fromMap(taskSettingsMap);
+
+            parsedModelListener.onResponse(new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings));
+        }
+
+        @Override
+        public void infer(
+            Model model,
+            List<String> input,
+            Map<String, Object> taskSettings,
+            InputType inputType,
+            ActionListener<InferenceServiceResults> listener
+        ) {
+            switch (model.getConfigurations().getTaskType()) {
+                case ANY, TEXT_EMBEDDING -> listener.onResponse(
+                    makeResults(input, ((TestServiceModel) model).getServiceSettings().dimensions())
+                );
+                default -> listener.onFailure(
+                    new ElasticsearchStatusException(
+                        TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()),
+                        RestStatus.BAD_REQUEST
+                    )
+                );
+            }
+        }
+
+        @Override
+        public void chunkedInfer(
+            Model model,
+            List<String> input,
+            Map<String, Object> taskSettings,
+            InputType inputType,
+            ChunkingOptions chunkingOptions,
+            ActionListener<List<ChunkedInferenceServiceResults>> listener
+        ) {
+            switch (model.getConfigurations().getTaskType()) {
+                case ANY, TEXT_EMBEDDING -> listener.onResponse(
+                    makeChunkedResults(input, ((TestServiceModel) model).getServiceSettings().dimensions())
+                );
+                default -> listener.onFailure(
+                    new ElasticsearchStatusException(
+                        TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()),
+                        RestStatus.BAD_REQUEST
+                    )
+                );
+            }
+        }
+
+        private TextEmbeddingResults makeResults(List<String> input, int dimensions) {
+            List<TextEmbeddingResults.Embedding> embeddings = new ArrayList<>();
+            for (int i = 0; i < input.size(); i++) {
+                List<Float> values = new ArrayList<>();
+                for (int j = 0; j < dimensions; j++) {
+                    values.add((float) j);
+                }
+                embeddings.add(new TextEmbeddingResults.Embedding(values));
+            }
+            return new TextEmbeddingResults(embeddings);
+        }
+
+        private List<ChunkedInferenceServiceResults> makeChunkedResults(List<String> input, int dimensions) {
+            var results = new ArrayList<ChunkedInferenceServiceResults>();
+            for (int i = 0; i < input.size(); i++) {
+                double[] values = new double[dimensions];
+                for (int j = 0; j < 5; j++) {
+                    values[j] = j;
+                }
+                results.add(
+                    new org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults(
+                        List.of(new ChunkedTextEmbeddingResults.EmbeddingChunk(input.get(i), values))
+                    )
+                );
+            }
+            return results;
+        }
+
+        protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceSettingsMap) {
+            return TestServiceSettings.fromMap(serviceSettingsMap);
+        }
+    }
+
+    public record TestServiceSettings(String model, Integer dimensions, SimilarityMeasure similarity) implements ServiceSettings {
+
+        static final String NAME = "test_text_embedding_service_settings";
+
+        public static TestServiceSettings fromMap(Map<String, Object> map) {
+            ValidationException validationException = new ValidationException();
+
+            String model = (String) map.remove("model");
+            if (model == null) {
+                validationException.addValidationError("missing model");
+            }
+
+            Integer dimensions = (Integer) map.remove("dimensions");
+            if (dimensions == null) {
+                validationException.addValidationError("missing dimensions");
+            }
+
+            SimilarityMeasure similarity = null;
+            String similarityStr = (String) map.remove("similarity");
+            if (similarityStr != null) {
+                similarity = SimilarityMeasure.valueOf(similarityStr);
+            }
+
+            return new TestServiceSettings(model, dimensions, similarity);
+        }
+
+        public TestServiceSettings(StreamInput in) throws IOException {
+            this(in.readString(), in.readOptionalInt(), in.readOptionalEnum(SimilarityMeasure.class));
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            builder.field("model", model);
+            builder.field("dimensions", dimensions);
+            if (similarity != null) {
+                builder.field("similarity", similarity);
+            }
+            builder.endObject();
+            return builder;
+        }
+
+        @Override
+        public String getWriteableName() {
+            return NAME;
+        }
+
+        @Override
+        public TransportVersion getMinimalSupportedVersion() {
+            return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeString(model);
+            out.writeInt(dimensions);
+            out.writeOptionalEnum(similarity);
+        }
+
+        @Override
+        public ToXContentObject getFilteredXContentObject() {
+            return (builder, params) -> {
+                builder.startObject();
+                builder.field("model", model);
+                builder.field("dimensions", dimensions);
+                if (similarity != null) {
+                    builder.field("similarity", similarity);
+                }
+                builder.endObject();
+                return builder;
+            };
+        }
+
+    }
+
+}

+ 14 - 9
x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServicePlugin.java

@@ -20,20 +20,25 @@ public class TestInferenceServicePlugin extends Plugin {
     @Override
     @Override
     public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
     public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
         return List.of(
         return List.of(
-            new NamedWriteableRegistry.Entry(
-                ServiceSettings.class,
-                TestInferenceServiceExtension.TestServiceSettings.NAME,
-                TestInferenceServiceExtension.TestServiceSettings::new
-            ),
             new NamedWriteableRegistry.Entry(
             new NamedWriteableRegistry.Entry(
                 TaskSettings.class,
                 TaskSettings.class,
-                TestInferenceServiceExtension.TestTaskSettings.NAME,
-                TestInferenceServiceExtension.TestTaskSettings::new
+                AbstractTestInferenceService.TestTaskSettings.NAME,
+                AbstractTestInferenceService.TestTaskSettings::new
             ),
             ),
             new NamedWriteableRegistry.Entry(
             new NamedWriteableRegistry.Entry(
                 SecretSettings.class,
                 SecretSettings.class,
-                TestInferenceServiceExtension.TestSecretSettings.NAME,
-                TestInferenceServiceExtension.TestSecretSettings::new
+                AbstractTestInferenceService.TestSecretSettings.NAME,
+                AbstractTestInferenceService.TestSecretSettings::new
+            ),
+            new NamedWriteableRegistry.Entry(
+                ServiceSettings.class,
+                TestDenseInferenceServiceExtension.TestServiceSettings.NAME,
+                TestDenseInferenceServiceExtension.TestServiceSettings::new
+            ),
+            new NamedWriteableRegistry.Entry(
+                ServiceSettings.class,
+                TestSparseInferenceServiceExtension.TestServiceSettings.NAME,
+                TestSparseInferenceServiceExtension.TestServiceSettings::new
             )
             )
         );
         );
     }
     }

+ 5 - 179
x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServiceExtension.java → x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java

@@ -15,16 +15,12 @@ import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.inference.ChunkedInferenceServiceResults;
 import org.elasticsearch.inference.ChunkedInferenceServiceResults;
 import org.elasticsearch.inference.ChunkingOptions;
 import org.elasticsearch.inference.ChunkingOptions;
-import org.elasticsearch.inference.InferenceService;
 import org.elasticsearch.inference.InferenceServiceExtension;
 import org.elasticsearch.inference.InferenceServiceExtension;
 import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.inference.InputType;
 import org.elasticsearch.inference.InputType;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.ModelConfigurations;
 import org.elasticsearch.inference.ModelConfigurations;
-import org.elasticsearch.inference.ModelSecrets;
-import org.elasticsearch.inference.SecretSettings;
 import org.elasticsearch.inference.ServiceSettings;
 import org.elasticsearch.inference.ServiceSettings;
-import org.elasticsearch.inference.TaskSettings;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.xcontent.ToXContentObject;
 import org.elasticsearch.xcontent.ToXContentObject;
@@ -40,13 +36,13 @@ import java.util.List;
 import java.util.Map;
 import java.util.Map;
 import java.util.Set;
 import java.util.Set;
 
 
-public class TestInferenceServiceExtension implements InferenceServiceExtension {
+public class TestSparseInferenceServiceExtension implements InferenceServiceExtension {
     @Override
     @Override
     public List<Factory> getInferenceServiceFactories() {
     public List<Factory> getInferenceServiceFactories() {
         return List.of(TestInferenceService::new);
         return List.of(TestInferenceService::new);
     }
     }
 
 
-    public static class TestInferenceService implements InferenceService {
+    public static class TestInferenceService extends AbstractTestInferenceService {
         private static final String NAME = "test_service";
         private static final String NAME = "test_service";
 
 
         public TestInferenceService(InferenceServiceExtension.InferenceServiceFactoryContext context) {}
         public TestInferenceService(InferenceServiceExtension.InferenceServiceFactoryContext context) {}
@@ -56,31 +52,13 @@ public class TestInferenceServiceExtension implements InferenceServiceExtension
             return NAME;
             return NAME;
         }
         }
 
 
-        @Override
-        public TransportVersion getMinimalSupportedVersion() {
-            return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests
-        }
-
-        @SuppressWarnings("unchecked")
-        private static Map<String, Object> getTaskSettingsMap(Map<String, Object> settings) {
-            Map<String, Object> taskSettingsMap;
-            // task settings are optional
-            if (settings.containsKey(ModelConfigurations.TASK_SETTINGS)) {
-                taskSettingsMap = (Map<String, Object>) settings.remove(ModelConfigurations.TASK_SETTINGS);
-            } else {
-                taskSettingsMap = Map.of();
-            }
-
-            return taskSettingsMap;
-        }
-
         @Override
         @Override
         @SuppressWarnings("unchecked")
         @SuppressWarnings("unchecked")
         public void parseRequestConfig(
         public void parseRequestConfig(
             String modelId,
             String modelId,
             TaskType taskType,
             TaskType taskType,
             Map<String, Object> config,
             Map<String, Object> config,
-            Set<String> platfromArchitectures,
+            Set<String> platformArchitectures,
             ActionListener<Model> parsedModelListener
             ActionListener<Model> parsedModelListener
         ) {
         ) {
             var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);
             var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);
@@ -93,39 +71,6 @@ public class TestInferenceServiceExtension implements InferenceServiceExtension
             parsedModelListener.onResponse(new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings));
             parsedModelListener.onResponse(new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings));
         }
         }
 
 
-        @Override
-        @SuppressWarnings("unchecked")
-        public TestServiceModel parsePersistedConfigWithSecrets(
-            String modelId,
-            TaskType taskType,
-            Map<String, Object> config,
-            Map<String, Object> secrets
-        ) {
-            var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);
-            var secretSettingsMap = (Map<String, Object>) secrets.remove(ModelSecrets.SECRET_SETTINGS);
-
-            var serviceSettings = TestServiceSettings.fromMap(serviceSettingsMap);
-            var secretSettings = TestSecretSettings.fromMap(secretSettingsMap);
-
-            var taskSettingsMap = getTaskSettingsMap(config);
-            var taskSettings = TestTaskSettings.fromMap(taskSettingsMap);
-
-            return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings);
-        }
-
-        @Override
-        @SuppressWarnings("unchecked")
-        public Model parsePersistedConfig(String modelId, TaskType taskType, Map<String, Object> config) {
-            var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);
-
-            var serviceSettings = TestServiceSettings.fromMap(serviceSettingsMap);
-
-            var taskSettingsMap = getTaskSettingsMap(config);
-            var taskSettings = TestTaskSettings.fromMap(taskSettingsMap);
-
-            return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, null);
-        }
-
         @Override
         @Override
         public void infer(
         public void infer(
             Model model,
             Model model,
@@ -189,42 +134,10 @@ public class TestInferenceServiceExtension implements InferenceServiceExtension
             return List.of(new ChunkedSparseEmbeddingResults(chunks));
             return List.of(new ChunkedSparseEmbeddingResults(chunks));
         }
         }
 
 
-        @Override
-        public void start(Model model, ActionListener<Boolean> listener) {
-            listener.onResponse(true);
-        }
-
-        @Override
-        public void close() throws IOException {}
-    }
-
-    public static class TestServiceModel extends Model {
-
-        public TestServiceModel(
-            String modelId,
-            TaskType taskType,
-            String service,
-            TestServiceSettings serviceSettings,
-            TestTaskSettings taskSettings,
-            TestSecretSettings secretSettings
-        ) {
-            super(new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secretSettings));
-        }
-
-        @Override
-        public TestServiceSettings getServiceSettings() {
-            return (TestServiceSettings) super.getServiceSettings();
-        }
-
-        @Override
-        public TestTaskSettings getTaskSettings() {
-            return (TestTaskSettings) super.getTaskSettings();
+        protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceSettingsMap) {
+            return TestServiceSettings.fromMap(serviceSettingsMap);
         }
         }
 
 
-        @Override
-        public TestSecretSettings getSecretSettings() {
-            return (TestSecretSettings) super.getSecretSettings();
-        }
     }
     }
 
 
     public record TestServiceSettings(String model, String hiddenField, boolean shouldReturnHiddenField) implements ServiceSettings {
     public record TestServiceSettings(String model, String hiddenField, boolean shouldReturnHiddenField) implements ServiceSettings {
@@ -300,91 +213,4 @@ public class TestInferenceServiceExtension implements InferenceServiceExtension
             };
             };
         }
         }
     }
     }
-
-    public record TestTaskSettings(Integer temperature) implements TaskSettings {
-
-        static final String NAME = "test_task_settings";
-
-        public static TestTaskSettings fromMap(Map<String, Object> map) {
-            Integer temperature = (Integer) map.remove("temperature");
-            return new TestTaskSettings(temperature);
-        }
-
-        public TestTaskSettings(StreamInput in) throws IOException {
-            this(in.readOptionalVInt());
-        }
-
-        @Override
-        public void writeTo(StreamOutput out) throws IOException {
-            out.writeOptionalVInt(temperature);
-        }
-
-        @Override
-        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-            builder.startObject();
-            if (temperature != null) {
-                builder.field("temperature", temperature);
-            }
-            builder.endObject();
-            return builder;
-        }
-
-        @Override
-        public String getWriteableName() {
-            return NAME;
-        }
-
-        @Override
-        public TransportVersion getMinimalSupportedVersion() {
-            return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests
-        }
-    }
-
-    public record TestSecretSettings(String apiKey) implements SecretSettings {
-
-        static final String NAME = "test_secret_settings";
-
-        public static TestSecretSettings fromMap(Map<String, Object> map) {
-            ValidationException validationException = new ValidationException();
-
-            String apiKey = (String) map.remove("api_key");
-
-            if (apiKey == null) {
-                validationException.addValidationError("missing api_key");
-            }
-
-            if (validationException.validationErrors().isEmpty() == false) {
-                throw validationException;
-            }
-
-            return new TestSecretSettings(apiKey);
-        }
-
-        public TestSecretSettings(StreamInput in) throws IOException {
-            this(in.readString());
-        }
-
-        @Override
-        public void writeTo(StreamOutput out) throws IOException {
-            out.writeString(apiKey);
-        }
-
-        @Override
-        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-            builder.startObject();
-            builder.field("api_key", apiKey);
-            builder.endObject();
-            return builder;
-        }
-
-        @Override
-        public String getWriteableName() {
-            return NAME;
-        }
-
-        @Override
-        public TransportVersion getMinimalSupportedVersion() {
-            return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests
-        }
-    }
 }
 }

+ 2 - 1
x-pack/plugin/inference/qa/test-service-plugin/src/main/resources/META-INF/services/org.elasticsearch.inference.InferenceServiceExtension

@@ -1 +1,2 @@
-org.elasticsearch.xpack.inference.mock.TestInferenceServiceExtension
+org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension
+org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension