|
@@ -5,12 +5,11 @@
|
|
|
* 2.0.
|
|
* 2.0.
|
|
|
*/
|
|
*/
|
|
|
|
|
|
|
|
-package org.elasticsearch.xpack.inference.integration;
|
|
|
|
|
|
|
+package org.elasticsearch.xpack.inference.mock;
|
|
|
|
|
|
|
|
import org.elasticsearch.ElasticsearchStatusException;
|
|
import org.elasticsearch.ElasticsearchStatusException;
|
|
|
import org.elasticsearch.TransportVersion;
|
|
import org.elasticsearch.TransportVersion;
|
|
|
import org.elasticsearch.action.ActionListener;
|
|
import org.elasticsearch.action.ActionListener;
|
|
|
-import org.elasticsearch.common.Strings;
|
|
|
|
|
import org.elasticsearch.common.ValidationException;
|
|
import org.elasticsearch.common.ValidationException;
|
|
|
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
|
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
|
|
|
import org.elasticsearch.common.io.stream.StreamInput;
|
|
import org.elasticsearch.common.io.stream.StreamInput;
|
|
@@ -28,9 +27,6 @@ import org.elasticsearch.plugins.InferenceServicePlugin;
|
|
|
import org.elasticsearch.plugins.Plugin;
|
|
import org.elasticsearch.plugins.Plugin;
|
|
|
import org.elasticsearch.rest.RestStatus;
|
|
import org.elasticsearch.rest.RestStatus;
|
|
|
import org.elasticsearch.xcontent.XContentBuilder;
|
|
import org.elasticsearch.xcontent.XContentBuilder;
|
|
|
-import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
|
|
|
|
|
-import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResultsTests;
|
|
|
|
|
-import org.elasticsearch.xpack.inference.services.MapParsingUtils;
|
|
|
|
|
|
|
|
|
|
import java.io.IOException;
|
|
import java.io.IOException;
|
|
|
import java.util.ArrayList;
|
|
import java.util.ArrayList;
|
|
@@ -38,9 +34,6 @@ import java.util.List;
|
|
|
import java.util.Map;
|
|
import java.util.Map;
|
|
|
import java.util.Set;
|
|
import java.util.Set;
|
|
|
|
|
|
|
|
-import static org.elasticsearch.xpack.inference.services.MapParsingUtils.removeFromMapOrThrowIfNull;
|
|
|
|
|
-import static org.elasticsearch.xpack.inference.services.MapParsingUtils.throwIfNotEmptyMap;
|
|
|
|
|
-
|
|
|
|
|
public class TestInferenceServicePlugin extends Plugin implements InferenceServicePlugin {
|
|
public class TestInferenceServicePlugin extends Plugin implements InferenceServicePlugin {
|
|
|
|
|
|
|
|
@Override
|
|
@Override
|
|
@@ -100,11 +93,12 @@ public class TestInferenceServicePlugin extends Plugin implements InferenceServi
|
|
|
|
|
|
|
|
public abstract static class TestInferenceServiceBase implements InferenceService {
|
|
public abstract static class TestInferenceServiceBase implements InferenceService {
|
|
|
|
|
|
|
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
private static Map<String, Object> getTaskSettingsMap(Map<String, Object> settings) {
|
|
private static Map<String, Object> getTaskSettingsMap(Map<String, Object> settings) {
|
|
|
Map<String, Object> taskSettingsMap;
|
|
Map<String, Object> taskSettingsMap;
|
|
|
// task settings are optional
|
|
// task settings are optional
|
|
|
if (settings.containsKey(ModelConfigurations.TASK_SETTINGS)) {
|
|
if (settings.containsKey(ModelConfigurations.TASK_SETTINGS)) {
|
|
|
- taskSettingsMap = removeFromMapOrThrowIfNull(settings, ModelConfigurations.TASK_SETTINGS);
|
|
|
|
|
|
|
+ taskSettingsMap = (Map<String, Object>) settings.remove(ModelConfigurations.TASK_SETTINGS);
|
|
|
} else {
|
|
} else {
|
|
|
taskSettingsMap = Map.of();
|
|
taskSettingsMap = Map.of();
|
|
|
}
|
|
}
|
|
@@ -117,35 +111,33 @@ public class TestInferenceServicePlugin extends Plugin implements InferenceServi
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
@Override
|
|
@Override
|
|
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
public TestServiceModel parseRequestConfig(
|
|
public TestServiceModel parseRequestConfig(
|
|
|
String modelId,
|
|
String modelId,
|
|
|
TaskType taskType,
|
|
TaskType taskType,
|
|
|
Map<String, Object> config,
|
|
Map<String, Object> config,
|
|
|
Set<String> platfromArchitectures
|
|
Set<String> platfromArchitectures
|
|
|
) {
|
|
) {
|
|
|
- Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
|
|
|
|
|
|
|
+ var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);
|
|
|
var serviceSettings = TestServiceSettings.fromMap(serviceSettingsMap);
|
|
var serviceSettings = TestServiceSettings.fromMap(serviceSettingsMap);
|
|
|
var secretSettings = TestSecretSettings.fromMap(serviceSettingsMap);
|
|
var secretSettings = TestSecretSettings.fromMap(serviceSettingsMap);
|
|
|
|
|
|
|
|
var taskSettingsMap = getTaskSettingsMap(config);
|
|
var taskSettingsMap = getTaskSettingsMap(config);
|
|
|
var taskSettings = TestTaskSettings.fromMap(taskSettingsMap);
|
|
var taskSettings = TestTaskSettings.fromMap(taskSettingsMap);
|
|
|
|
|
|
|
|
- throwIfNotEmptyMap(config, name());
|
|
|
|
|
- throwIfNotEmptyMap(serviceSettingsMap, name());
|
|
|
|
|
- throwIfNotEmptyMap(taskSettingsMap, name());
|
|
|
|
|
-
|
|
|
|
|
return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings);
|
|
return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
@Override
|
|
@Override
|
|
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
public TestServiceModel parsePersistedConfig(
|
|
public TestServiceModel parsePersistedConfig(
|
|
|
String modelId,
|
|
String modelId,
|
|
|
TaskType taskType,
|
|
TaskType taskType,
|
|
|
Map<String, Object> config,
|
|
Map<String, Object> config,
|
|
|
Map<String, Object> secrets
|
|
Map<String, Object> secrets
|
|
|
) {
|
|
) {
|
|
|
- Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
|
|
|
|
|
- Map<String, Object> secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);
|
|
|
|
|
|
|
+ var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);
|
|
|
|
|
+ var secretSettingsMap = (Map<String, Object>) secrets.remove(ModelSecrets.SECRET_SETTINGS);
|
|
|
|
|
|
|
|
var serviceSettings = TestServiceSettings.fromMap(serviceSettingsMap);
|
|
var serviceSettings = TestServiceSettings.fromMap(serviceSettingsMap);
|
|
|
var secretSettings = TestSecretSettings.fromMap(secretSettingsMap);
|
|
var secretSettings = TestSecretSettings.fromMap(secretSettingsMap);
|
|
@@ -165,11 +157,8 @@ public class TestInferenceServicePlugin extends Plugin implements InferenceServi
|
|
|
) {
|
|
) {
|
|
|
switch (model.getConfigurations().getTaskType()) {
|
|
switch (model.getConfigurations().getTaskType()) {
|
|
|
case SPARSE_EMBEDDING -> {
|
|
case SPARSE_EMBEDDING -> {
|
|
|
- var results = new ArrayList<TextExpansionResults>();
|
|
|
|
|
- input.forEach(i -> {
|
|
|
|
|
- int numTokensInResult = Strings.tokenizeToStringArray(i, " ").length;
|
|
|
|
|
- results.add(TextExpansionResultsTests.createRandomResults(numTokensInResult, numTokensInResult));
|
|
|
|
|
- });
|
|
|
|
|
|
|
+ var results = new ArrayList<TestResults>();
|
|
|
|
|
+ input.forEach(i -> { results.add(new TestResults("bar")); });
|
|
|
listener.onResponse(results);
|
|
listener.onResponse(results);
|
|
|
}
|
|
}
|
|
|
default -> listener.onFailure(
|
|
default -> listener.onFailure(
|
|
@@ -227,12 +216,10 @@ public class TestInferenceServicePlugin extends Plugin implements InferenceServi
|
|
|
public static TestServiceSettings fromMap(Map<String, Object> map) {
|
|
public static TestServiceSettings fromMap(Map<String, Object> map) {
|
|
|
ValidationException validationException = new ValidationException();
|
|
ValidationException validationException = new ValidationException();
|
|
|
|
|
|
|
|
- String model = MapParsingUtils.removeAsType(map, "model", String.class);
|
|
|
|
|
|
|
+ String model = (String) map.remove("model");
|
|
|
|
|
|
|
|
if (model == null) {
|
|
if (model == null) {
|
|
|
- validationException.addValidationError(
|
|
|
|
|
- MapParsingUtils.missingSettingErrorMsg("model", ModelConfigurations.SERVICE_SETTINGS)
|
|
|
|
|
- );
|
|
|
|
|
|
|
+ validationException.addValidationError("missing model");
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
if (validationException.validationErrors().isEmpty() == false) {
|
|
if (validationException.validationErrors().isEmpty() == false) {
|
|
@@ -275,7 +262,7 @@ public class TestInferenceServicePlugin extends Plugin implements InferenceServi
|
|
|
private static final String NAME = "test_task_settings";
|
|
private static final String NAME = "test_task_settings";
|
|
|
|
|
|
|
|
public static TestTaskSettings fromMap(Map<String, Object> map) {
|
|
public static TestTaskSettings fromMap(Map<String, Object> map) {
|
|
|
- Integer temperature = MapParsingUtils.removeAsType(map, "temperature", Integer.class);
|
|
|
|
|
|
|
+ Integer temperature = (Integer) map.remove("temperature");
|
|
|
return new TestTaskSettings(temperature);
|
|
return new TestTaskSettings(temperature);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -316,10 +303,10 @@ public class TestInferenceServicePlugin extends Plugin implements InferenceServi
|
|
|
public static TestSecretSettings fromMap(Map<String, Object> map) {
|
|
public static TestSecretSettings fromMap(Map<String, Object> map) {
|
|
|
ValidationException validationException = new ValidationException();
|
|
ValidationException validationException = new ValidationException();
|
|
|
|
|
|
|
|
- String apiKey = MapParsingUtils.removeAsType(map, "api_key", String.class);
|
|
|
|
|
|
|
+ String apiKey = (String) map.remove("api_key");
|
|
|
|
|
|
|
|
if (apiKey == null) {
|
|
if (apiKey == null) {
|
|
|
- validationException.addValidationError(MapParsingUtils.missingSettingErrorMsg("api_key", ModelSecrets.SECRET_SETTINGS));
|
|
|
|
|
|
|
+ validationException.addValidationError("missing api_key");
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
if (validationException.validationErrors().isEmpty() == false) {
|
|
if (validationException.validationErrors().isEmpty() == false) {
|
|
@@ -356,4 +343,49 @@ public class TestInferenceServicePlugin extends Plugin implements InferenceServi
|
|
|
return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests
|
|
return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+ private static class TestResults implements InferenceResults {
|
|
|
|
|
+
|
|
|
|
|
+ private String result;
|
|
|
|
|
+
|
|
|
|
|
+ public TestResults(String result) {
|
|
|
|
|
+ this.result = result;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ @Override
|
|
|
|
|
+ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
|
|
|
|
+ builder.field("result", result);
|
|
|
|
|
+ return builder;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ @Override
|
|
|
|
|
+ public String getWriteableName() {
|
|
|
|
|
+ return "test_result";
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ @Override
|
|
|
|
|
+ public void writeTo(StreamOutput out) throws IOException {
|
|
|
|
|
+ out.writeString(result);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ @Override
|
|
|
|
|
+ public String getResultsField() {
|
|
|
|
|
+ return "result";
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ @Override
|
|
|
|
|
+ public Map<String, Object> asMap() {
|
|
|
|
|
+ return Map.of("result", result);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ @Override
|
|
|
|
|
+ public Map<String, Object> asMap(String outputField) {
|
|
|
|
|
+ return Map.of(outputField, result);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ @Override
|
|
|
|
|
+ public Object predictedValue() {
|
|
|
|
|
+ return result;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|