Browse Source

Rest Tests for the mock inference service (#102469)

Adds a qa project to the inference plugin for REST tests
and the mock service plugin TestInferenceServicePlugin
David Kyle 1 year ago
parent
commit
aa35eb5215

+ 5 - 1
x-pack/plugin/inference/build.gradle

@@ -8,12 +8,16 @@ apply plugin: 'elasticsearch.internal-es-plugin'
 apply plugin: 'elasticsearch.internal-cluster-test'
 apply plugin: 'elasticsearch.internal-cluster-test'
 
 
 esplugin {
 esplugin {
-  name 'inference'
+  name 'x-pack-inference'
   description 'Configuration and evaluation of inference models'
   description 'Configuration and evaluation of inference models'
   classname 'org.elasticsearch.xpack.inference.InferencePlugin'
   classname 'org.elasticsearch.xpack.inference.InferencePlugin'
   extendedPlugins = ['x-pack-core']
   extendedPlugins = ['x-pack-core']
 }
 }
 
 
+base {
+  archivesName = 'x-pack-inference'
+}
+
 dependencies {
 dependencies {
   implementation project(path: ':libs:elasticsearch-logging')
   implementation project(path: ':libs:elasticsearch-logging')
   compileOnly project(":server")
   compileOnly project(":server")

+ 0 - 0
x-pack/plugin/inference/qa/build.gradle


+ 11 - 0
x-pack/plugin/inference/qa/inference-service-tests/build.gradle

@@ -0,0 +1,11 @@
+apply plugin: 'elasticsearch.internal-java-rest-test'
+
+dependencies {
+  compileOnly project(':x-pack:plugin:core')
+  javaRestTestImplementation project(path: xpackModule('inference'))
+  clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin')
+}
+
+tasks.named("javaRestTest").configure {
+  usesDefaultDistribution()
+}

+ 174 - 0
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockInferenceServiceIT.java

@@ -0,0 +1,174 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference;
+
+import org.apache.http.util.EntityUtils;
+import org.elasticsearch.client.Request;
+import org.elasticsearch.client.Response;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.settings.SecureString;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.test.cluster.ElasticsearchCluster;
+import org.elasticsearch.test.cluster.local.distribution.DistributionType;
+import org.elasticsearch.test.rest.ESRestTestCase;
+import org.junit.ClassRule;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.anyOf;
+import static org.hamcrest.Matchers.empty;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.not;
+
+public class MockInferenceServiceIT extends ESRestTestCase {
+
+    @ClassRule
+    public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
+        .distribution(DistributionType.DEFAULT)
+        .setting("xpack.license.self_generated.type", "trial")
+        .setting("xpack.security.enabled", "true")
+        .plugin("org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin")
+        .user("x_pack_rest_user", "x-pack-test-password")
+        .build();
+
+    @Override
+    protected String getTestRestCluster() {
+        return cluster.getHttpAddresses();
+    }
+
+    @Override
+    protected Settings restClientSettings() {
+        String token = basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray()));
+        return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build();
+    }
+
+    static String mockServiceModelConfig() {
+        return Strings.format("""
+            {
+              "service": "test_service",
+              "service_settings": {
+                "model": "my_model",
+                "api_key": "abc64"
+              },
+              "task_settings": {
+                "temperature": 3
+              }
+            }
+            """);
+    }
+
+    public void testMockService() throws IOException {
+        String modelId = "test-mock";
+        var putModel = putModel(modelId, mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
+        var getModel = getModel(modelId, TaskType.SPARSE_EMBEDDING);
+        assertEquals(putModel, getModel);
+
+        for (var modelMap : List.of(putModel, getModel)) {
+            assertEquals(modelId, modelMap.get("model_id"));
+            assertEquals(TaskType.SPARSE_EMBEDDING, TaskType.fromString((String) modelMap.get("task_type")));
+            assertEquals("test_service", modelMap.get("service"));
+        }
+
+        // The response is randomly generated, the input can be anything
+        var inference = inferOnMockService(modelId, TaskType.SPARSE_EMBEDDING, List.of(randomAlphaOfLength(10)));
+        assertNonEmptyInferenceResults(inference, TaskType.SPARSE_EMBEDDING);
+    }
+
+    @SuppressWarnings("unchecked")
+    public void testMockServiceWithMultipleInputs() throws IOException {
+        String modelId = "test-mock-with-multi-inputs";
+        var putModel = putModel(modelId, mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
+
+        // The response is randomly generated, the input can be anything
+        var inference = inferOnMockService(
+            modelId,
+            TaskType.SPARSE_EMBEDDING,
+            List.of(randomAlphaOfLength(5), randomAlphaOfLength(10), randomAlphaOfLength(15))
+        );
+
+        var tokens = (List<Map<String, Object>>) inference.get("inference_results");
+        assertThat(tokens, hasSize(3));
+        assertNonEmptyInferenceResults(inference, TaskType.SPARSE_EMBEDDING);
+    }
+
+    @SuppressWarnings("unchecked")
+    public void testMockService_DoesNotReturnSecretsInGetResponse() throws IOException {
+        String modelId = "test-mock";
+        var putModel = putModel(modelId, mockServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
+        var getModel = getModel(modelId, TaskType.SPARSE_EMBEDDING);
+
+        var serviceSettings = (Map<String, Object>) getModel.get("service_settings");
+        assertNull(serviceSettings.get("api_key"));
+        assertNotNull(serviceSettings.get("model"));
+
+        var putServiceSettings = (Map<String, Object>) putModel.get("service_settings");
+        assertNull(putServiceSettings.get("api_key"));
+        assertNotNull(putServiceSettings.get("model"));
+    }
+
+    private Map<String, Object> putModel(String modelId, String modelConfig, TaskType taskType) throws IOException {
+        String endpoint = Strings.format("_inference/%s/%s", taskType, modelId);
+        var request = new Request("PUT", endpoint);
+        request.setJsonEntity(modelConfig);
+        var reponse = client().performRequest(request);
+        assertOkWithErrorMessage(reponse);
+        return entityAsMap(reponse);
+    }
+
+    public Map<String, Object> getModel(String modelId, TaskType taskType) throws IOException {
+        var endpoint = Strings.format("_inference/%s/%s", taskType, modelId);
+        var request = new Request("GET", endpoint);
+        var reponse = client().performRequest(request);
+        assertOkWithErrorMessage(reponse);
+        return entityAsMap(reponse);
+    }
+
+    private Map<String, Object> inferOnMockService(String modelId, TaskType taskType, List<String> input) throws IOException {
+        var endpoint = Strings.format("_inference/%s/%s", taskType, modelId);
+        var request = new Request("POST", endpoint);
+
+        var bodyBuilder = new StringBuilder("{\"input\": [");
+        for (var in : input) {
+            bodyBuilder.append('"').append(in).append('"').append(',');
+        }
+        // remove last comma
+        bodyBuilder.deleteCharAt(bodyBuilder.length() - 1);
+        bodyBuilder.append("]}");
+
+        System.out.println("body_request:" + bodyBuilder);
+        request.setJsonEntity(bodyBuilder.toString());
+        var reponse = client().performRequest(request);
+        assertOkWithErrorMessage(reponse);
+        return entityAsMap(reponse);
+    }
+
+    @SuppressWarnings("unchecked")
+    protected void assertNonEmptyInferenceResults(Map<String, Object> resultMap, TaskType taskType) {
+        if (taskType == TaskType.SPARSE_EMBEDDING) {
+            var tokens = (List<Map<String, Object>>) resultMap.get("inference_results");
+            tokens.forEach(result -> { assertThat(result.keySet(), not(empty())); });
+        } else {
+            fail("test with task type [" + taskType + "] are not supported yet");
+        }
+    }
+
+    protected static void assertOkWithErrorMessage(Response response) throws IOException {
+        int statusCode = response.getStatusLine().getStatusCode();
+        if (statusCode == 200 || statusCode == 201) {
+            return;
+        }
+
+        String responseStr = EntityUtils.toString(response.getEntity());
+        assertThat(responseStr, response.getStatusLine().getStatusCode(), anyOf(equalTo(200), equalTo(201)));
+    }
+}

+ 19 - 0
x-pack/plugin/inference/qa/test-service-plugin/build.gradle

@@ -0,0 +1,19 @@
+
+apply plugin: 'elasticsearch.base-internal-es-plugin'
+apply plugin: 'elasticsearch.internal-java-rest-test'
+
+esplugin {
+  name 'inference-service-test'
+  description 'A mock inference service'
+  classname 'org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin'
+}
+
+dependencies {
+  compileOnly project(':x-pack:plugin:core')
+  compileOnly project(':x-pack:plugin:inference')
+  compileOnly project(':x-pack:plugin:ml')
+}
+
+tasks.named("javaRestTest").configure {
+  usesDefaultDistribution()
+}

+ 60 - 28
x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/TestInferenceServicePlugin.java → x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServicePlugin.java

@@ -5,12 +5,11 @@
  * 2.0.
  * 2.0.
  */
  */
 
 
-package org.elasticsearch.xpack.inference.integration;
+package org.elasticsearch.xpack.inference.mock;
 
 
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.TransportVersion;
 import org.elasticsearch.TransportVersion;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamInput;
@@ -28,9 +27,6 @@ import org.elasticsearch.plugins.InferenceServicePlugin;
 import org.elasticsearch.plugins.Plugin;
 import org.elasticsearch.plugins.Plugin;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentBuilder;
-import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
-import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResultsTests;
-import org.elasticsearch.xpack.inference.services.MapParsingUtils;
 
 
 import java.io.IOException;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.ArrayList;
@@ -38,9 +34,6 @@ import java.util.List;
 import java.util.Map;
 import java.util.Map;
 import java.util.Set;
 import java.util.Set;
 
 
-import static org.elasticsearch.xpack.inference.services.MapParsingUtils.removeFromMapOrThrowIfNull;
-import static org.elasticsearch.xpack.inference.services.MapParsingUtils.throwIfNotEmptyMap;
-
 public class TestInferenceServicePlugin extends Plugin implements InferenceServicePlugin {
 public class TestInferenceServicePlugin extends Plugin implements InferenceServicePlugin {
 
 
     @Override
     @Override
@@ -100,11 +93,12 @@ public class TestInferenceServicePlugin extends Plugin implements InferenceServi
 
 
     public abstract static class TestInferenceServiceBase implements InferenceService {
     public abstract static class TestInferenceServiceBase implements InferenceService {
 
 
+        @SuppressWarnings("unchecked")
         private static Map<String, Object> getTaskSettingsMap(Map<String, Object> settings) {
         private static Map<String, Object> getTaskSettingsMap(Map<String, Object> settings) {
             Map<String, Object> taskSettingsMap;
             Map<String, Object> taskSettingsMap;
             // task settings are optional
             // task settings are optional
             if (settings.containsKey(ModelConfigurations.TASK_SETTINGS)) {
             if (settings.containsKey(ModelConfigurations.TASK_SETTINGS)) {
-                taskSettingsMap = removeFromMapOrThrowIfNull(settings, ModelConfigurations.TASK_SETTINGS);
+                taskSettingsMap = (Map<String, Object>) settings.remove(ModelConfigurations.TASK_SETTINGS);
             } else {
             } else {
                 taskSettingsMap = Map.of();
                 taskSettingsMap = Map.of();
             }
             }
@@ -117,35 +111,33 @@ public class TestInferenceServicePlugin extends Plugin implements InferenceServi
         }
         }
 
 
         @Override
         @Override
+        @SuppressWarnings("unchecked")
         public TestServiceModel parseRequestConfig(
         public TestServiceModel parseRequestConfig(
             String modelId,
             String modelId,
             TaskType taskType,
             TaskType taskType,
             Map<String, Object> config,
             Map<String, Object> config,
             Set<String> platfromArchitectures
             Set<String> platfromArchitectures
         ) {
         ) {
-            Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
+            var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);
             var serviceSettings = TestServiceSettings.fromMap(serviceSettingsMap);
             var serviceSettings = TestServiceSettings.fromMap(serviceSettingsMap);
             var secretSettings = TestSecretSettings.fromMap(serviceSettingsMap);
             var secretSettings = TestSecretSettings.fromMap(serviceSettingsMap);
 
 
             var taskSettingsMap = getTaskSettingsMap(config);
             var taskSettingsMap = getTaskSettingsMap(config);
             var taskSettings = TestTaskSettings.fromMap(taskSettingsMap);
             var taskSettings = TestTaskSettings.fromMap(taskSettingsMap);
 
 
-            throwIfNotEmptyMap(config, name());
-            throwIfNotEmptyMap(serviceSettingsMap, name());
-            throwIfNotEmptyMap(taskSettingsMap, name());
-
             return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings);
             return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings);
         }
         }
 
 
         @Override
         @Override
+        @SuppressWarnings("unchecked")
         public TestServiceModel parsePersistedConfig(
         public TestServiceModel parsePersistedConfig(
             String modelId,
             String modelId,
             TaskType taskType,
             TaskType taskType,
             Map<String, Object> config,
             Map<String, Object> config,
             Map<String, Object> secrets
             Map<String, Object> secrets
         ) {
         ) {
-            Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
-            Map<String, Object> secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);
+            var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);
+            var secretSettingsMap = (Map<String, Object>) secrets.remove(ModelSecrets.SECRET_SETTINGS);
 
 
             var serviceSettings = TestServiceSettings.fromMap(serviceSettingsMap);
             var serviceSettings = TestServiceSettings.fromMap(serviceSettingsMap);
             var secretSettings = TestSecretSettings.fromMap(secretSettingsMap);
             var secretSettings = TestSecretSettings.fromMap(secretSettingsMap);
@@ -165,11 +157,8 @@ public class TestInferenceServicePlugin extends Plugin implements InferenceServi
         ) {
         ) {
             switch (model.getConfigurations().getTaskType()) {
             switch (model.getConfigurations().getTaskType()) {
                 case SPARSE_EMBEDDING -> {
                 case SPARSE_EMBEDDING -> {
-                    var results = new ArrayList<TextExpansionResults>();
-                    input.forEach(i -> {
-                        int numTokensInResult = Strings.tokenizeToStringArray(i, " ").length;
-                        results.add(TextExpansionResultsTests.createRandomResults(numTokensInResult, numTokensInResult));
-                    });
+                    var results = new ArrayList<TestResults>();
+                    input.forEach(i -> { results.add(new TestResults("bar")); });
                     listener.onResponse(results);
                     listener.onResponse(results);
                 }
                 }
                 default -> listener.onFailure(
                 default -> listener.onFailure(
@@ -227,12 +216,10 @@ public class TestInferenceServicePlugin extends Plugin implements InferenceServi
         public static TestServiceSettings fromMap(Map<String, Object> map) {
         public static TestServiceSettings fromMap(Map<String, Object> map) {
             ValidationException validationException = new ValidationException();
             ValidationException validationException = new ValidationException();
 
 
-            String model = MapParsingUtils.removeAsType(map, "model", String.class);
+            String model = (String) map.remove("model");
 
 
             if (model == null) {
             if (model == null) {
-                validationException.addValidationError(
-                    MapParsingUtils.missingSettingErrorMsg("model", ModelConfigurations.SERVICE_SETTINGS)
-                );
+                validationException.addValidationError("missing model");
             }
             }
 
 
             if (validationException.validationErrors().isEmpty() == false) {
             if (validationException.validationErrors().isEmpty() == false) {
@@ -275,7 +262,7 @@ public class TestInferenceServicePlugin extends Plugin implements InferenceServi
         private static final String NAME = "test_task_settings";
         private static final String NAME = "test_task_settings";
 
 
         public static TestTaskSettings fromMap(Map<String, Object> map) {
         public static TestTaskSettings fromMap(Map<String, Object> map) {
-            Integer temperature = MapParsingUtils.removeAsType(map, "temperature", Integer.class);
+            Integer temperature = (Integer) map.remove("temperature");
             return new TestTaskSettings(temperature);
             return new TestTaskSettings(temperature);
         }
         }
 
 
@@ -316,10 +303,10 @@ public class TestInferenceServicePlugin extends Plugin implements InferenceServi
         public static TestSecretSettings fromMap(Map<String, Object> map) {
         public static TestSecretSettings fromMap(Map<String, Object> map) {
             ValidationException validationException = new ValidationException();
             ValidationException validationException = new ValidationException();
 
 
-            String apiKey = MapParsingUtils.removeAsType(map, "api_key", String.class);
+            String apiKey = (String) map.remove("api_key");
 
 
             if (apiKey == null) {
             if (apiKey == null) {
-                validationException.addValidationError(MapParsingUtils.missingSettingErrorMsg("api_key", ModelSecrets.SECRET_SETTINGS));
+                validationException.addValidationError("missing api_key");
             }
             }
 
 
             if (validationException.validationErrors().isEmpty() == false) {
             if (validationException.validationErrors().isEmpty() == false) {
@@ -356,4 +343,49 @@ public class TestInferenceServicePlugin extends Plugin implements InferenceServi
             return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests
             return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests
         }
         }
     }
     }
+
+    private static class TestResults implements InferenceResults {
+
+        private String result;
+
+        public TestResults(String result) {
+            this.result = result;
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.field("result", result);
+            return builder;
+        }
+
+        @Override
+        public String getWriteableName() {
+            return "test_result";
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeString(result);
+        }
+
+        @Override
+        public String getResultsField() {
+            return "result";
+        }
+
+        @Override
+        public Map<String, Object> asMap() {
+            return Map.of("result", result);
+        }
+
+        @Override
+        public Map<String, Object> asMap(String outputField) {
+            return Map.of(outputField, result);
+        }
+
+        @Override
+        public Object predictedValue() {
+            return result;
+        }
+    }
 }
 }

+ 0 - 202
x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/MockInferenceServiceIT.java

@@ -1,202 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.inference.integration;
-
-import org.elasticsearch.action.support.PlainActionFuture;
-import org.elasticsearch.client.internal.Client;
-import org.elasticsearch.common.Strings;
-import org.elasticsearch.common.bytes.BytesArray;
-import org.elasticsearch.core.TimeValue;
-import org.elasticsearch.inference.InferenceResults;
-import org.elasticsearch.inference.ModelConfigurations;
-import org.elasticsearch.inference.ModelSecrets;
-import org.elasticsearch.inference.TaskType;
-import org.elasticsearch.plugins.Plugin;
-import org.elasticsearch.test.ESIntegTestCase;
-import org.elasticsearch.test.SecuritySettingsSourceField;
-import org.elasticsearch.xcontent.XContentBuilder;
-import org.elasticsearch.xcontent.XContentFactory;
-import org.elasticsearch.xcontent.XContentType;
-import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
-import org.elasticsearch.xpack.inference.InferencePlugin;
-import org.elasticsearch.xpack.inference.action.GetInferenceModelAction;
-import org.elasticsearch.xpack.inference.action.InferenceAction;
-import org.elasticsearch.xpack.inference.action.PutInferenceModelAction;
-import org.elasticsearch.xpack.inference.registry.ModelRegistry;
-import org.junit.Before;
-
-import java.io.IOException;
-import java.nio.charset.StandardCharsets;
-import java.util.Collection;
-import java.util.List;
-import java.util.Map;
-import java.util.concurrent.TimeUnit;
-import java.util.function.Function;
-
-import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue;
-import static org.elasticsearch.xpack.inference.services.MapParsingUtils.removeFromMapOrThrowIfNull;
-import static org.hamcrest.CoreMatchers.is;
-import static org.hamcrest.Matchers.empty;
-import static org.hamcrest.Matchers.instanceOf;
-import static org.hamcrest.Matchers.not;
-
-public class MockInferenceServiceIT extends ESIntegTestCase {
-
-    private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
-
-    private ModelRegistry modelRegistry;
-
-    @Before
-    public void createComponents() {
-        modelRegistry = new ModelRegistry(client());
-    }
-
-    @Override
-    protected Collection<Class<? extends Plugin>> nodePlugins() {
-        return List.of(InferencePlugin.class, TestInferenceServicePlugin.class);
-    }
-
-    @Override
-    protected Function<Client, Client> getClientWrapper() {
-        final Map<String, String> headers = Map.of(
-            "Authorization",
-            basicAuthHeaderValue("x_pack_rest_user", SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING)
-        );
-        // we need to wrap node clients because we do not specify a user for nodes and all requests will use the system
-        // user. This is ok for internal n2n stuff but the test framework does other things like wiping indices, repositories, etc
-        // that the system user cannot do. so we wrap the node client with a user that can do these things since the client() calls
-        // return a node client
-        return client -> client.filterWithHeader(headers);
-    }
-
-    public void testMockService() {
-        String modelId = "test-mock";
-        ModelConfigurations putModel = putMockService(modelId, "test_service", TaskType.SPARSE_EMBEDDING);
-        ModelConfigurations readModel = getModel(modelId, TaskType.SPARSE_EMBEDDING);
-        assertModelsAreEqual(putModel, readModel);
-
-        // The response is randomly generated, the input can be anything
-        inferOnMockService(modelId, TaskType.SPARSE_EMBEDDING, List.of(randomAlphaOfLength(10)));
-    }
-
-    public void testMockServiceWithMultipleInputs() {
-        String modelId = "test-mock-with-multi-inputs";
-        ModelConfigurations putModel = putMockService(modelId, "test_service", TaskType.SPARSE_EMBEDDING);
-        ModelConfigurations readModel = getModel(modelId, TaskType.SPARSE_EMBEDDING);
-        assertModelsAreEqual(putModel, readModel);
-
-        // The response is randomly generated, the input can be anything
-        inferOnMockService(
-            modelId,
-            TaskType.SPARSE_EMBEDDING,
-            List.of(randomAlphaOfLength(5), randomAlphaOfLength(10), randomAlphaOfLength(15))
-        );
-    }
-
-    public void testMockService_DoesNotReturnSecretsInGetResponse() throws IOException {
-        String modelId = "test-mock";
-        putMockService(modelId, "test_service", TaskType.SPARSE_EMBEDDING);
-        ModelConfigurations readModel = getModel(modelId, TaskType.SPARSE_EMBEDDING);
-
-        assertThat(readModel.getServiceSettings(), instanceOf(TestInferenceServicePlugin.TestServiceSettings.class));
-
-        var serviceSettings = (TestInferenceServicePlugin.TestServiceSettings) readModel.getServiceSettings();
-        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON).prettyPrint();
-        serviceSettings.toXContent(builder, null);
-        String xContentResult = Strings.toString(builder);
-
-        assertThat(xContentResult, is("""
-            {
-              "model" : "my_model"
-            }"""));
-    }
-
-    public void testGetUnparsedModelMap_ForTestServiceModel_ReturnsSecretsPopulated() {
-        String modelId = "test-unparsed";
-        putMockService(modelId, "test_service", TaskType.SPARSE_EMBEDDING);
-
-        var listener = new PlainActionFuture<ModelRegistry.ModelConfigMap>();
-        modelRegistry.getUnparsedModelMap(modelId, listener);
-
-        var modelConfig = listener.actionGet(TIMEOUT);
-        var secretsMap = removeFromMapOrThrowIfNull(modelConfig.secrets(), ModelSecrets.SECRET_SETTINGS);
-        var secrets = TestInferenceServicePlugin.TestSecretSettings.fromMap(secretsMap);
-        assertThat(secrets.apiKey(), is("abc64"));
-    }
-
-    private ModelConfigurations putMockService(String modelId, String serviceName, TaskType taskType) {
-        String body = Strings.format("""
-            {
-              "service": "%s",
-              "service_settings": {
-                "model": "my_model",
-                "api_key": "abc64"
-              },
-              "task_settings": {
-                "temperature": 3
-              }
-            }
-            """, serviceName);
-        var request = new PutInferenceModelAction.Request(
-            taskType.toString(),
-            modelId,
-            new BytesArray(body.getBytes(StandardCharsets.UTF_8)),
-            XContentType.JSON
-        );
-
-        var response = client().execute(PutInferenceModelAction.INSTANCE, request).actionGet();
-        assertEquals(serviceName, response.getModel().getService());
-
-        assertThat(response.getModel().getServiceSettings(), instanceOf(TestInferenceServicePlugin.TestServiceSettings.class));
-        var serviceSettings = (TestInferenceServicePlugin.TestServiceSettings) response.getModel().getServiceSettings();
-        assertEquals("my_model", serviceSettings.model());
-
-        assertThat(response.getModel().getTaskSettings(), instanceOf(TestInferenceServicePlugin.TestTaskSettings.class));
-        var taskSettings = (TestInferenceServicePlugin.TestTaskSettings) response.getModel().getTaskSettings();
-        assertEquals(3, (int) taskSettings.temperature());
-
-        return response.getModel();
-    }
-
-    public ModelConfigurations getModel(String modelId, TaskType taskType) {
-        var response = client().execute(GetInferenceModelAction.INSTANCE, new GetInferenceModelAction.Request(modelId, taskType.toString()))
-            .actionGet();
-        return response.getModel();
-    }
-
-    private List<? extends InferenceResults> inferOnMockService(String modelId, TaskType taskType, List<String> input) {
-        var response = client().execute(InferenceAction.INSTANCE, new InferenceAction.Request(taskType, modelId, input, Map.of()))
-            .actionGet();
-        if (taskType == TaskType.SPARSE_EMBEDDING) {
-            response.getResults().forEach(result -> {
-                assertThat(result, instanceOf(TextExpansionResults.class));
-                var teResult = (TextExpansionResults) result;
-                assertThat(teResult.getWeightedTokens(), not(empty()));
-            });
-
-        } else {
-            fail("test with task type [" + taskType + "] are not supported yet");
-        }
-
-        return response.getResults();
-    }
-
-    private void assertModelsAreEqual(ModelConfigurations model1, ModelConfigurations model2) {
-        // The test can't rely on Model::equals as the specific subclass
-        // may be different. Model loses information about it's implemented
-        // subtype when it is streamed across the wire.
-        assertEquals(model1.getModelId(), model2.getModelId());
-        assertEquals(model1.getService(), model2.getService());
-        assertEquals(model1.getTaskType(), model2.getTaskType());
-
-        // TaskSettings and Service settings are named writables so
-        // the actual implementing class type is not lost when streamed \
-        assertEquals(model1.getServiceSettings(), model2.getServiceSettings());
-        assertEquals(model1.getTaskSettings(), model2.getTaskSettings());
-    }
-}

+ 3 - 2
x-pack/plugin/inference/src/main/java/module-info.java

@@ -5,7 +5,7 @@
  * 2.0.
  * 2.0.
  */
  */
 
 
-module org.elasticsearch.xpack.inference {
+module org.elasticsearch.inference {
     requires org.elasticsearch.base;
     requires org.elasticsearch.base;
     requires org.elasticsearch.server;
     requires org.elasticsearch.server;
     requires org.elasticsearch.xcontent;
     requires org.elasticsearch.xcontent;
@@ -18,8 +18,9 @@ module org.elasticsearch.xpack.inference {
     requires org.apache.httpcomponents.httpcore.nio;
     requires org.apache.httpcomponents.httpcore.nio;
     requires org.apache.lucene.core;
     requires org.apache.lucene.core;
 
 
-    exports org.elasticsearch.xpack.inference.rest;
     exports org.elasticsearch.xpack.inference.action;
     exports org.elasticsearch.xpack.inference.action;
     exports org.elasticsearch.xpack.inference.registry;
     exports org.elasticsearch.xpack.inference.registry;
+    exports org.elasticsearch.xpack.inference.rest;
+    exports org.elasticsearch.xpack.inference.services;
     exports org.elasticsearch.xpack.inference;
     exports org.elasticsearch.xpack.inference;
 }
 }