|
@@ -15,16 +15,12 @@ import org.elasticsearch.common.io.stream.StreamInput;
|
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
|
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
|
|
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
|
|
import org.elasticsearch.inference.ChunkingOptions;
|
|
import org.elasticsearch.inference.ChunkingOptions;
|
|
-import org.elasticsearch.inference.InferenceService;
|
|
|
|
import org.elasticsearch.inference.InferenceServiceExtension;
|
|
import org.elasticsearch.inference.InferenceServiceExtension;
|
|
import org.elasticsearch.inference.InferenceServiceResults;
|
|
import org.elasticsearch.inference.InferenceServiceResults;
|
|
import org.elasticsearch.inference.InputType;
|
|
import org.elasticsearch.inference.InputType;
|
|
import org.elasticsearch.inference.Model;
|
|
import org.elasticsearch.inference.Model;
|
|
import org.elasticsearch.inference.ModelConfigurations;
|
|
import org.elasticsearch.inference.ModelConfigurations;
|
|
-import org.elasticsearch.inference.ModelSecrets;
|
|
|
|
-import org.elasticsearch.inference.SecretSettings;
|
|
|
|
import org.elasticsearch.inference.ServiceSettings;
|
|
import org.elasticsearch.inference.ServiceSettings;
|
|
-import org.elasticsearch.inference.TaskSettings;
|
|
|
|
import org.elasticsearch.inference.TaskType;
|
|
import org.elasticsearch.inference.TaskType;
|
|
import org.elasticsearch.rest.RestStatus;
|
|
import org.elasticsearch.rest.RestStatus;
|
|
import org.elasticsearch.xcontent.ToXContentObject;
|
|
import org.elasticsearch.xcontent.ToXContentObject;
|
|
@@ -40,13 +36,13 @@ import java.util.List;
|
|
import java.util.Map;
|
|
import java.util.Map;
|
|
import java.util.Set;
|
|
import java.util.Set;
|
|
|
|
|
|
-public class TestInferenceServiceExtension implements InferenceServiceExtension {
|
|
|
|
|
|
+public class TestSparseInferenceServiceExtension implements InferenceServiceExtension {
|
|
@Override
|
|
@Override
|
|
public List<Factory> getInferenceServiceFactories() {
|
|
public List<Factory> getInferenceServiceFactories() {
|
|
return List.of(TestInferenceService::new);
|
|
return List.of(TestInferenceService::new);
|
|
}
|
|
}
|
|
|
|
|
|
- public static class TestInferenceService implements InferenceService {
|
|
|
|
|
|
+ public static class TestInferenceService extends AbstractTestInferenceService {
|
|
private static final String NAME = "test_service";
|
|
private static final String NAME = "test_service";
|
|
|
|
|
|
public TestInferenceService(InferenceServiceExtension.InferenceServiceFactoryContext context) {}
|
|
public TestInferenceService(InferenceServiceExtension.InferenceServiceFactoryContext context) {}
|
|
@@ -56,31 +52,13 @@ public class TestInferenceServiceExtension implements InferenceServiceExtension
|
|
return NAME;
|
|
return NAME;
|
|
}
|
|
}
|
|
|
|
|
|
- @Override
|
|
|
|
- public TransportVersion getMinimalSupportedVersion() {
|
|
|
|
- return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- @SuppressWarnings("unchecked")
|
|
|
|
- private static Map<String, Object> getTaskSettingsMap(Map<String, Object> settings) {
|
|
|
|
- Map<String, Object> taskSettingsMap;
|
|
|
|
- // task settings are optional
|
|
|
|
- if (settings.containsKey(ModelConfigurations.TASK_SETTINGS)) {
|
|
|
|
- taskSettingsMap = (Map<String, Object>) settings.remove(ModelConfigurations.TASK_SETTINGS);
|
|
|
|
- } else {
|
|
|
|
- taskSettingsMap = Map.of();
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- return taskSettingsMap;
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
@Override
|
|
@Override
|
|
@SuppressWarnings("unchecked")
|
|
@SuppressWarnings("unchecked")
|
|
public void parseRequestConfig(
|
|
public void parseRequestConfig(
|
|
String modelId,
|
|
String modelId,
|
|
TaskType taskType,
|
|
TaskType taskType,
|
|
Map<String, Object> config,
|
|
Map<String, Object> config,
|
|
- Set<String> platfromArchitectures,
|
|
|
|
|
|
+ Set<String> platformArchitectures,
|
|
ActionListener<Model> parsedModelListener
|
|
ActionListener<Model> parsedModelListener
|
|
) {
|
|
) {
|
|
var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);
|
|
var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);
|
|
@@ -93,39 +71,6 @@ public class TestInferenceServiceExtension implements InferenceServiceExtension
|
|
parsedModelListener.onResponse(new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings));
|
|
parsedModelListener.onResponse(new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings));
|
|
}
|
|
}
|
|
|
|
|
|
- @Override
|
|
|
|
- @SuppressWarnings("unchecked")
|
|
|
|
- public TestServiceModel parsePersistedConfigWithSecrets(
|
|
|
|
- String modelId,
|
|
|
|
- TaskType taskType,
|
|
|
|
- Map<String, Object> config,
|
|
|
|
- Map<String, Object> secrets
|
|
|
|
- ) {
|
|
|
|
- var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);
|
|
|
|
- var secretSettingsMap = (Map<String, Object>) secrets.remove(ModelSecrets.SECRET_SETTINGS);
|
|
|
|
-
|
|
|
|
- var serviceSettings = TestServiceSettings.fromMap(serviceSettingsMap);
|
|
|
|
- var secretSettings = TestSecretSettings.fromMap(secretSettingsMap);
|
|
|
|
-
|
|
|
|
- var taskSettingsMap = getTaskSettingsMap(config);
|
|
|
|
- var taskSettings = TestTaskSettings.fromMap(taskSettingsMap);
|
|
|
|
-
|
|
|
|
- return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings);
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- @Override
|
|
|
|
- @SuppressWarnings("unchecked")
|
|
|
|
- public Model parsePersistedConfig(String modelId, TaskType taskType, Map<String, Object> config) {
|
|
|
|
- var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);
|
|
|
|
-
|
|
|
|
- var serviceSettings = TestServiceSettings.fromMap(serviceSettingsMap);
|
|
|
|
-
|
|
|
|
- var taskSettingsMap = getTaskSettingsMap(config);
|
|
|
|
- var taskSettings = TestTaskSettings.fromMap(taskSettingsMap);
|
|
|
|
-
|
|
|
|
- return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, null);
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
@Override
|
|
@Override
|
|
public void infer(
|
|
public void infer(
|
|
Model model,
|
|
Model model,
|
|
@@ -189,42 +134,10 @@ public class TestInferenceServiceExtension implements InferenceServiceExtension
|
|
return List.of(new ChunkedSparseEmbeddingResults(chunks));
|
|
return List.of(new ChunkedSparseEmbeddingResults(chunks));
|
|
}
|
|
}
|
|
|
|
|
|
- @Override
|
|
|
|
- public void start(Model model, ActionListener<Boolean> listener) {
|
|
|
|
- listener.onResponse(true);
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- @Override
|
|
|
|
- public void close() throws IOException {}
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- public static class TestServiceModel extends Model {
|
|
|
|
-
|
|
|
|
- public TestServiceModel(
|
|
|
|
- String modelId,
|
|
|
|
- TaskType taskType,
|
|
|
|
- String service,
|
|
|
|
- TestServiceSettings serviceSettings,
|
|
|
|
- TestTaskSettings taskSettings,
|
|
|
|
- TestSecretSettings secretSettings
|
|
|
|
- ) {
|
|
|
|
- super(new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secretSettings));
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- @Override
|
|
|
|
- public TestServiceSettings getServiceSettings() {
|
|
|
|
- return (TestServiceSettings) super.getServiceSettings();
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- @Override
|
|
|
|
- public TestTaskSettings getTaskSettings() {
|
|
|
|
- return (TestTaskSettings) super.getTaskSettings();
|
|
|
|
|
|
+ protected ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceSettingsMap) {
|
|
|
|
+ return TestServiceSettings.fromMap(serviceSettingsMap);
|
|
}
|
|
}
|
|
|
|
|
|
- @Override
|
|
|
|
- public TestSecretSettings getSecretSettings() {
|
|
|
|
- return (TestSecretSettings) super.getSecretSettings();
|
|
|
|
- }
|
|
|
|
}
|
|
}
|
|
|
|
|
|
public record TestServiceSettings(String model, String hiddenField, boolean shouldReturnHiddenField) implements ServiceSettings {
|
|
public record TestServiceSettings(String model, String hiddenField, boolean shouldReturnHiddenField) implements ServiceSettings {
|
|
@@ -300,91 +213,4 @@ public class TestInferenceServiceExtension implements InferenceServiceExtension
|
|
};
|
|
};
|
|
}
|
|
}
|
|
}
|
|
}
|
|
-
|
|
|
|
- public record TestTaskSettings(Integer temperature) implements TaskSettings {
|
|
|
|
-
|
|
|
|
- static final String NAME = "test_task_settings";
|
|
|
|
-
|
|
|
|
- public static TestTaskSettings fromMap(Map<String, Object> map) {
|
|
|
|
- Integer temperature = (Integer) map.remove("temperature");
|
|
|
|
- return new TestTaskSettings(temperature);
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- public TestTaskSettings(StreamInput in) throws IOException {
|
|
|
|
- this(in.readOptionalVInt());
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- @Override
|
|
|
|
- public void writeTo(StreamOutput out) throws IOException {
|
|
|
|
- out.writeOptionalVInt(temperature);
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- @Override
|
|
|
|
- public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
|
|
|
- builder.startObject();
|
|
|
|
- if (temperature != null) {
|
|
|
|
- builder.field("temperature", temperature);
|
|
|
|
- }
|
|
|
|
- builder.endObject();
|
|
|
|
- return builder;
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- @Override
|
|
|
|
- public String getWriteableName() {
|
|
|
|
- return NAME;
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- @Override
|
|
|
|
- public TransportVersion getMinimalSupportedVersion() {
|
|
|
|
- return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- public record TestSecretSettings(String apiKey) implements SecretSettings {
|
|
|
|
-
|
|
|
|
- static final String NAME = "test_secret_settings";
|
|
|
|
-
|
|
|
|
- public static TestSecretSettings fromMap(Map<String, Object> map) {
|
|
|
|
- ValidationException validationException = new ValidationException();
|
|
|
|
-
|
|
|
|
- String apiKey = (String) map.remove("api_key");
|
|
|
|
-
|
|
|
|
- if (apiKey == null) {
|
|
|
|
- validationException.addValidationError("missing api_key");
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- if (validationException.validationErrors().isEmpty() == false) {
|
|
|
|
- throw validationException;
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- return new TestSecretSettings(apiKey);
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- public TestSecretSettings(StreamInput in) throws IOException {
|
|
|
|
- this(in.readString());
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- @Override
|
|
|
|
- public void writeTo(StreamOutput out) throws IOException {
|
|
|
|
- out.writeString(apiKey);
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- @Override
|
|
|
|
- public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
|
|
|
- builder.startObject();
|
|
|
|
- builder.field("api_key", apiKey);
|
|
|
|
- builder.endObject();
|
|
|
|
- return builder;
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- @Override
|
|
|
|
- public String getWriteableName() {
|
|
|
|
- return NAME;
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- @Override
|
|
|
|
- public TransportVersion getMinimalSupportedVersion() {
|
|
|
|
- return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests
|
|
|
|
- }
|
|
|
|
- }
|
|
|
|
}
|
|
}
|