|  | @@ -1,202 +0,0 @@
 | 
	
		
			
				|  |  | -/*
 | 
	
		
			
				|  |  | - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 | 
	
		
			
				|  |  | - * or more contributor license agreements. Licensed under the Elastic License
 | 
	
		
			
				|  |  | - * 2.0; you may not use this file except in compliance with the Elastic License
 | 
	
		
			
				|  |  | - * 2.0.
 | 
	
		
			
				|  |  | - */
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -package org.elasticsearch.xpack.inference.integration;
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -import org.elasticsearch.action.support.PlainActionFuture;
 | 
	
		
			
				|  |  | -import org.elasticsearch.client.internal.Client;
 | 
	
		
			
				|  |  | -import org.elasticsearch.common.Strings;
 | 
	
		
			
				|  |  | -import org.elasticsearch.common.bytes.BytesArray;
 | 
	
		
			
				|  |  | -import org.elasticsearch.core.TimeValue;
 | 
	
		
			
				|  |  | -import org.elasticsearch.inference.InferenceResults;
 | 
	
		
			
				|  |  | -import org.elasticsearch.inference.ModelConfigurations;
 | 
	
		
			
				|  |  | -import org.elasticsearch.inference.ModelSecrets;
 | 
	
		
			
				|  |  | -import org.elasticsearch.inference.TaskType;
 | 
	
		
			
				|  |  | -import org.elasticsearch.plugins.Plugin;
 | 
	
		
			
				|  |  | -import org.elasticsearch.test.ESIntegTestCase;
 | 
	
		
			
				|  |  | -import org.elasticsearch.test.SecuritySettingsSourceField;
 | 
	
		
			
				|  |  | -import org.elasticsearch.xcontent.XContentBuilder;
 | 
	
		
			
				|  |  | -import org.elasticsearch.xcontent.XContentFactory;
 | 
	
		
			
				|  |  | -import org.elasticsearch.xcontent.XContentType;
 | 
	
		
			
				|  |  | -import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
 | 
	
		
			
				|  |  | -import org.elasticsearch.xpack.inference.InferencePlugin;
 | 
	
		
			
				|  |  | -import org.elasticsearch.xpack.inference.action.GetInferenceModelAction;
 | 
	
		
			
				|  |  | -import org.elasticsearch.xpack.inference.action.InferenceAction;
 | 
	
		
			
				|  |  | -import org.elasticsearch.xpack.inference.action.PutInferenceModelAction;
 | 
	
		
			
				|  |  | -import org.elasticsearch.xpack.inference.registry.ModelRegistry;
 | 
	
		
			
				|  |  | -import org.junit.Before;
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -import java.io.IOException;
 | 
	
		
			
				|  |  | -import java.nio.charset.StandardCharsets;
 | 
	
		
			
				|  |  | -import java.util.Collection;
 | 
	
		
			
				|  |  | -import java.util.List;
 | 
	
		
			
				|  |  | -import java.util.Map;
 | 
	
		
			
				|  |  | -import java.util.concurrent.TimeUnit;
 | 
	
		
			
				|  |  | -import java.util.function.Function;
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue;
 | 
	
		
			
				|  |  | -import static org.elasticsearch.xpack.inference.services.MapParsingUtils.removeFromMapOrThrowIfNull;
 | 
	
		
			
				|  |  | -import static org.hamcrest.CoreMatchers.is;
 | 
	
		
			
				|  |  | -import static org.hamcrest.Matchers.empty;
 | 
	
		
			
				|  |  | -import static org.hamcrest.Matchers.instanceOf;
 | 
	
		
			
				|  |  | -import static org.hamcrest.Matchers.not;
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -public class MockInferenceServiceIT extends ESIntegTestCase {
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    private ModelRegistry modelRegistry;
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    @Before
 | 
	
		
			
				|  |  | -    public void createComponents() {
 | 
	
		
			
				|  |  | -        modelRegistry = new ModelRegistry(client());
 | 
	
		
			
				|  |  | -    }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    @Override
 | 
	
		
			
				|  |  | -    protected Collection<Class<? extends Plugin>> nodePlugins() {
 | 
	
		
			
				|  |  | -        return List.of(InferencePlugin.class, TestInferenceServicePlugin.class);
 | 
	
		
			
				|  |  | -    }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    @Override
 | 
	
		
			
				|  |  | -    protected Function<Client, Client> getClientWrapper() {
 | 
	
		
			
				|  |  | -        final Map<String, String> headers = Map.of(
 | 
	
		
			
				|  |  | -            "Authorization",
 | 
	
		
			
				|  |  | -            basicAuthHeaderValue("x_pack_rest_user", SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING)
 | 
	
		
			
				|  |  | -        );
 | 
	
		
			
				|  |  | -        // we need to wrap node clients because we do not specify a user for nodes and all requests will use the system
 | 
	
		
			
				|  |  | -        // user. This is ok for internal n2n stuff but the test framework does other things like wiping indices, repositories, etc
 | 
	
		
			
				|  |  | -        // that the system user cannot do. so we wrap the node client with a user that can do these things since the client() calls
 | 
	
		
			
				|  |  | -        // return a node client
 | 
	
		
			
				|  |  | -        return client -> client.filterWithHeader(headers);
 | 
	
		
			
				|  |  | -    }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    public void testMockService() {
 | 
	
		
			
				|  |  | -        String modelId = "test-mock";
 | 
	
		
			
				|  |  | -        ModelConfigurations putModel = putMockService(modelId, "test_service", TaskType.SPARSE_EMBEDDING);
 | 
	
		
			
				|  |  | -        ModelConfigurations readModel = getModel(modelId, TaskType.SPARSE_EMBEDDING);
 | 
	
		
			
				|  |  | -        assertModelsAreEqual(putModel, readModel);
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        // The response is randomly generated, the input can be anything
 | 
	
		
			
				|  |  | -        inferOnMockService(modelId, TaskType.SPARSE_EMBEDDING, List.of(randomAlphaOfLength(10)));
 | 
	
		
			
				|  |  | -    }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    public void testMockServiceWithMultipleInputs() {
 | 
	
		
			
				|  |  | -        String modelId = "test-mock-with-multi-inputs";
 | 
	
		
			
				|  |  | -        ModelConfigurations putModel = putMockService(modelId, "test_service", TaskType.SPARSE_EMBEDDING);
 | 
	
		
			
				|  |  | -        ModelConfigurations readModel = getModel(modelId, TaskType.SPARSE_EMBEDDING);
 | 
	
		
			
				|  |  | -        assertModelsAreEqual(putModel, readModel);
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        // The response is randomly generated, the input can be anything
 | 
	
		
			
				|  |  | -        inferOnMockService(
 | 
	
		
			
				|  |  | -            modelId,
 | 
	
		
			
				|  |  | -            TaskType.SPARSE_EMBEDDING,
 | 
	
		
			
				|  |  | -            List.of(randomAlphaOfLength(5), randomAlphaOfLength(10), randomAlphaOfLength(15))
 | 
	
		
			
				|  |  | -        );
 | 
	
		
			
				|  |  | -    }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    public void testMockService_DoesNotReturnSecretsInGetResponse() throws IOException {
 | 
	
		
			
				|  |  | -        String modelId = "test-mock";
 | 
	
		
			
				|  |  | -        putMockService(modelId, "test_service", TaskType.SPARSE_EMBEDDING);
 | 
	
		
			
				|  |  | -        ModelConfigurations readModel = getModel(modelId, TaskType.SPARSE_EMBEDDING);
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        assertThat(readModel.getServiceSettings(), instanceOf(TestInferenceServicePlugin.TestServiceSettings.class));
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        var serviceSettings = (TestInferenceServicePlugin.TestServiceSettings) readModel.getServiceSettings();
 | 
	
		
			
				|  |  | -        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON).prettyPrint();
 | 
	
		
			
				|  |  | -        serviceSettings.toXContent(builder, null);
 | 
	
		
			
				|  |  | -        String xContentResult = Strings.toString(builder);
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        assertThat(xContentResult, is("""
 | 
	
		
			
				|  |  | -            {
 | 
	
		
			
				|  |  | -              "model" : "my_model"
 | 
	
		
			
				|  |  | -            }"""));
 | 
	
		
			
				|  |  | -    }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    public void testGetUnparsedModelMap_ForTestServiceModel_ReturnsSecretsPopulated() {
 | 
	
		
			
				|  |  | -        String modelId = "test-unparsed";
 | 
	
		
			
				|  |  | -        putMockService(modelId, "test_service", TaskType.SPARSE_EMBEDDING);
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        var listener = new PlainActionFuture<ModelRegistry.ModelConfigMap>();
 | 
	
		
			
				|  |  | -        modelRegistry.getUnparsedModelMap(modelId, listener);
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        var modelConfig = listener.actionGet(TIMEOUT);
 | 
	
		
			
				|  |  | -        var secretsMap = removeFromMapOrThrowIfNull(modelConfig.secrets(), ModelSecrets.SECRET_SETTINGS);
 | 
	
		
			
				|  |  | -        var secrets = TestInferenceServicePlugin.TestSecretSettings.fromMap(secretsMap);
 | 
	
		
			
				|  |  | -        assertThat(secrets.apiKey(), is("abc64"));
 | 
	
		
			
				|  |  | -    }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    private ModelConfigurations putMockService(String modelId, String serviceName, TaskType taskType) {
 | 
	
		
			
				|  |  | -        String body = Strings.format("""
 | 
	
		
			
				|  |  | -            {
 | 
	
		
			
				|  |  | -              "service": "%s",
 | 
	
		
			
				|  |  | -              "service_settings": {
 | 
	
		
			
				|  |  | -                "model": "my_model",
 | 
	
		
			
				|  |  | -                "api_key": "abc64"
 | 
	
		
			
				|  |  | -              },
 | 
	
		
			
				|  |  | -              "task_settings": {
 | 
	
		
			
				|  |  | -                "temperature": 3
 | 
	
		
			
				|  |  | -              }
 | 
	
		
			
				|  |  | -            }
 | 
	
		
			
				|  |  | -            """, serviceName);
 | 
	
		
			
				|  |  | -        var request = new PutInferenceModelAction.Request(
 | 
	
		
			
				|  |  | -            taskType.toString(),
 | 
	
		
			
				|  |  | -            modelId,
 | 
	
		
			
				|  |  | -            new BytesArray(body.getBytes(StandardCharsets.UTF_8)),
 | 
	
		
			
				|  |  | -            XContentType.JSON
 | 
	
		
			
				|  |  | -        );
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        var response = client().execute(PutInferenceModelAction.INSTANCE, request).actionGet();
 | 
	
		
			
				|  |  | -        assertEquals(serviceName, response.getModel().getService());
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        assertThat(response.getModel().getServiceSettings(), instanceOf(TestInferenceServicePlugin.TestServiceSettings.class));
 | 
	
		
			
				|  |  | -        var serviceSettings = (TestInferenceServicePlugin.TestServiceSettings) response.getModel().getServiceSettings();
 | 
	
		
			
				|  |  | -        assertEquals("my_model", serviceSettings.model());
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        assertThat(response.getModel().getTaskSettings(), instanceOf(TestInferenceServicePlugin.TestTaskSettings.class));
 | 
	
		
			
				|  |  | -        var taskSettings = (TestInferenceServicePlugin.TestTaskSettings) response.getModel().getTaskSettings();
 | 
	
		
			
				|  |  | -        assertEquals(3, (int) taskSettings.temperature());
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        return response.getModel();
 | 
	
		
			
				|  |  | -    }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    public ModelConfigurations getModel(String modelId, TaskType taskType) {
 | 
	
		
			
				|  |  | -        var response = client().execute(GetInferenceModelAction.INSTANCE, new GetInferenceModelAction.Request(modelId, taskType.toString()))
 | 
	
		
			
				|  |  | -            .actionGet();
 | 
	
		
			
				|  |  | -        return response.getModel();
 | 
	
		
			
				|  |  | -    }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    private List<? extends InferenceResults> inferOnMockService(String modelId, TaskType taskType, List<String> input) {
 | 
	
		
			
				|  |  | -        var response = client().execute(InferenceAction.INSTANCE, new InferenceAction.Request(taskType, modelId, input, Map.of()))
 | 
	
		
			
				|  |  | -            .actionGet();
 | 
	
		
			
				|  |  | -        if (taskType == TaskType.SPARSE_EMBEDDING) {
 | 
	
		
			
				|  |  | -            response.getResults().forEach(result -> {
 | 
	
		
			
				|  |  | -                assertThat(result, instanceOf(TextExpansionResults.class));
 | 
	
		
			
				|  |  | -                var teResult = (TextExpansionResults) result;
 | 
	
		
			
				|  |  | -                assertThat(teResult.getWeightedTokens(), not(empty()));
 | 
	
		
			
				|  |  | -            });
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        } else {
 | 
	
		
			
				|  |  | -            fail("test with task type [" + taskType + "] are not supported yet");
 | 
	
		
			
				|  |  | -        }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        return response.getResults();
 | 
	
		
			
				|  |  | -    }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -    private void assertModelsAreEqual(ModelConfigurations model1, ModelConfigurations model2) {
 | 
	
		
			
				|  |  | -        // The test can't rely on Model::equals as the specific subclass
 | 
	
		
			
				|  |  | -        // may be different. Model loses information about it's implemented
 | 
	
		
			
				|  |  | -        // subtype when it is streamed across the wire.
 | 
	
		
			
				|  |  | -        assertEquals(model1.getModelId(), model2.getModelId());
 | 
	
		
			
				|  |  | -        assertEquals(model1.getService(), model2.getService());
 | 
	
		
			
				|  |  | -        assertEquals(model1.getTaskType(), model2.getTaskType());
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        // TaskSettings and Service settings are named writables so
 | 
	
		
			
				|  |  | -        // the actual implementing class type is not lost when streamed \
 | 
	
		
			
				|  |  | -        assertEquals(model1.getServiceSettings(), model2.getServiceSettings());
 | 
	
		
			
				|  |  | -        assertEquals(model1.getTaskSettings(), model2.getTaskSettings());
 | 
	
		
			
				|  |  | -    }
 | 
	
		
			
				|  |  | -}
 |