|
@@ -20,6 +20,7 @@ import org.elasticsearch.inference.SecretSettings;
|
|
|
import org.elasticsearch.inference.ServiceSettings;
|
|
|
import org.elasticsearch.inference.TaskSettings;
|
|
|
import org.elasticsearch.inference.TaskType;
|
|
|
+import org.elasticsearch.inference.UnparsedModel;
|
|
|
import org.elasticsearch.plugins.Plugin;
|
|
|
import org.elasticsearch.reindex.ReindexPlugin;
|
|
|
import org.elasticsearch.test.ESSingleNodeTestCase;
|
|
@@ -38,6 +39,7 @@ import java.io.IOException;
|
|
|
import java.util.ArrayList;
|
|
|
import java.util.Collection;
|
|
|
import java.util.Comparator;
|
|
|
+import java.util.HashMap;
|
|
|
import java.util.List;
|
|
|
import java.util.Map;
|
|
|
import java.util.concurrent.CountDownLatch;
|
|
@@ -110,7 +112,7 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
|
|
|
assertThat(putModelHolder.get(), is(true));
|
|
|
|
|
|
// now get the model
|
|
|
- AtomicReference<ModelRegistry.UnparsedModel> modelHolder = new AtomicReference<>();
|
|
|
+ AtomicReference<UnparsedModel> modelHolder = new AtomicReference<>();
|
|
|
blockingCall(listener -> modelRegistry.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder);
|
|
|
assertThat(exceptionHolder.get(), is(nullValue()));
|
|
|
assertThat(modelHolder.get(), not(nullValue()));
|
|
@@ -168,7 +170,7 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
|
|
|
|
|
|
// get should fail
|
|
|
deleteResponseHolder.set(false);
|
|
|
- AtomicReference<ModelRegistry.UnparsedModel> modelHolder = new AtomicReference<>();
|
|
|
+ AtomicReference<UnparsedModel> modelHolder = new AtomicReference<>();
|
|
|
blockingCall(listener -> modelRegistry.getModelWithSecrets("model1", listener), modelHolder, exceptionHolder);
|
|
|
|
|
|
assertThat(exceptionHolder.get(), not(nullValue()));
|
|
@@ -194,7 +196,7 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
|
|
|
}
|
|
|
|
|
|
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
|
|
|
- AtomicReference<List<ModelRegistry.UnparsedModel>> modelHolder = new AtomicReference<>();
|
|
|
+ AtomicReference<List<UnparsedModel>> modelHolder = new AtomicReference<>();
|
|
|
blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.SPARSE_EMBEDDING, listener), modelHolder, exceptionHolder);
|
|
|
assertThat(modelHolder.get(), hasSize(3));
|
|
|
var sparseIds = sparseAndTextEmbeddingModels.stream()
|
|
@@ -235,8 +237,9 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
|
|
|
assertNull(exceptionHolder.get());
|
|
|
}
|
|
|
|
|
|
- AtomicReference<List<ModelRegistry.UnparsedModel>> modelHolder = new AtomicReference<>();
|
|
|
+ AtomicReference<List<UnparsedModel>> modelHolder = new AtomicReference<>();
|
|
|
blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder);
|
|
|
+ assertNull(exceptionHolder.get());
|
|
|
assertThat(modelHolder.get(), hasSize(modelCount));
|
|
|
var getAllModels = modelHolder.get();
|
|
|
|
|
@@ -264,15 +267,213 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
|
|
|
assertThat(putModelHolder.get(), is(true));
|
|
|
assertNull(exceptionHolder.get());
|
|
|
|
|
|
- AtomicReference<ModelRegistry.UnparsedModel> modelHolder = new AtomicReference<>();
|
|
|
+ AtomicReference<UnparsedModel> modelHolder = new AtomicReference<>();
|
|
|
blockingCall(listener -> modelRegistry.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder);
|
|
|
assertThat(modelHolder.get().secrets().keySet(), hasSize(1));
|
|
|
var secretSettings = (Map<String, Object>) modelHolder.get().secrets().get("secret_settings");
|
|
|
assertThat(secretSettings.get("secret"), equalTo(secret));
|
|
|
+ assertReturnModelIsModifiable(modelHolder.get());
|
|
|
|
|
|
// get model without secrets
|
|
|
blockingCall(listener -> modelRegistry.getModel(inferenceEntityId, listener), modelHolder, exceptionHolder);
|
|
|
assertThat(modelHolder.get().secrets().keySet(), empty());
|
|
|
+ assertReturnModelIsModifiable(modelHolder.get());
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testGetAllModels_WithDefaults() throws Exception {
|
|
|
+ var service = "foo";
|
|
|
+ var secret = "abc";
|
|
|
+ int configuredModelCount = 10;
|
|
|
+ int defaultModelCount = 2;
|
|
|
+ int totalModelCount = 12;
|
|
|
+
|
|
|
+ var defaultConfigs = new HashMap<String, UnparsedModel>();
|
|
|
+ for (int i = 0; i < defaultModelCount; i++) {
|
|
|
+ var id = "default-" + i;
|
|
|
+ defaultConfigs.put(id, createUnparsedConfig(id, randomFrom(TaskType.values()), service, secret));
|
|
|
+ }
|
|
|
+ defaultConfigs.values().forEach(modelRegistry::addDefaultConfiguration);
|
|
|
+
|
|
|
+ AtomicReference<Boolean> putModelHolder = new AtomicReference<>();
|
|
|
+ AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
|
|
|
+
|
|
|
+ var createdModels = new HashMap<String, Model>();
|
|
|
+ for (int i = 0; i < configuredModelCount; i++) {
|
|
|
+ var id = randomAlphaOfLength(5) + i;
|
|
|
+ var model = createModel(id, randomFrom(TaskType.values()), service);
|
|
|
+ createdModels.put(id, model);
|
|
|
+ blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder);
|
|
|
+ assertThat(putModelHolder.get(), is(true));
|
|
|
+ assertNull(exceptionHolder.get());
|
|
|
+ }
|
|
|
+
|
|
|
+ AtomicReference<List<UnparsedModel>> modelHolder = new AtomicReference<>();
|
|
|
+ blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder);
|
|
|
+ assertNull(exceptionHolder.get());
|
|
|
+ assertThat(modelHolder.get(), hasSize(totalModelCount));
|
|
|
+ var getAllModels = modelHolder.get();
|
|
|
+ assertReturnModelIsModifiable(modelHolder.get().get(0));
|
|
|
+
|
|
|
+ // sort in the same order as the returned models
|
|
|
+ var ids = new ArrayList<>(defaultConfigs.keySet().stream().toList());
|
|
|
+ ids.addAll(createdModels.keySet().stream().toList());
|
|
|
+ ids.sort(String::compareTo);
|
|
|
+ for (int i = 0; i < totalModelCount; i++) {
|
|
|
+ var id = ids.get(i);
|
|
|
+ assertEquals(id, getAllModels.get(i).inferenceEntityId());
|
|
|
+ if (id.startsWith("default")) {
|
|
|
+ assertEquals(defaultConfigs.get(id).taskType(), getAllModels.get(i).taskType());
|
|
|
+ assertEquals(defaultConfigs.get(id).service(), getAllModels.get(i).service());
|
|
|
+ } else {
|
|
|
+ assertEquals(createdModels.get(id).getTaskType(), getAllModels.get(i).taskType());
|
|
|
+ assertEquals(createdModels.get(id).getConfigurations().getService(), getAllModels.get(i).service());
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testGetAllModels_OnlyDefaults() throws Exception {
|
|
|
+ var service = "foo";
|
|
|
+ var secret = "abc";
|
|
|
+ int defaultModelCount = 2;
|
|
|
+
|
|
|
+ var defaultConfigs = new HashMap<String, UnparsedModel>();
|
|
|
+ for (int i = 0; i < defaultModelCount; i++) {
|
|
|
+ var id = "default-" + i;
|
|
|
+ defaultConfigs.put(id, createUnparsedConfig(id, randomFrom(TaskType.values()), service, secret));
|
|
|
+ }
|
|
|
+ defaultConfigs.values().forEach(modelRegistry::addDefaultConfiguration);
|
|
|
+
|
|
|
+ AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
|
|
|
+ AtomicReference<List<UnparsedModel>> modelHolder = new AtomicReference<>();
|
|
|
+ blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder);
|
|
|
+ assertNull(exceptionHolder.get());
|
|
|
+ assertThat(modelHolder.get(), hasSize(2));
|
|
|
+ var getAllModels = modelHolder.get();
|
|
|
+ assertReturnModelIsModifiable(modelHolder.get().get(0));
|
|
|
+
|
|
|
+ // sort in the same order as the returned models
|
|
|
+ var ids = new ArrayList<>(defaultConfigs.keySet().stream().toList());
|
|
|
+ ids.sort(String::compareTo);
|
|
|
+ for (int i = 0; i < defaultModelCount; i++) {
|
|
|
+ var id = ids.get(i);
|
|
|
+ assertEquals(id, getAllModels.get(i).inferenceEntityId());
|
|
|
+ assertEquals(defaultConfigs.get(id).taskType(), getAllModels.get(i).taskType());
|
|
|
+ assertEquals(defaultConfigs.get(id).service(), getAllModels.get(i).service());
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testGet_WithDefaults() throws InterruptedException {
|
|
|
+ var service = "foo";
|
|
|
+ var secret = "abc";
|
|
|
+
|
|
|
+ var defaultSparse = createUnparsedConfig("default-sparse", TaskType.SPARSE_EMBEDDING, service, secret);
|
|
|
+ var defaultText = createUnparsedConfig("default-text", TaskType.TEXT_EMBEDDING, service, secret);
|
|
|
+
|
|
|
+ modelRegistry.addDefaultConfiguration(defaultSparse);
|
|
|
+ modelRegistry.addDefaultConfiguration(defaultText);
|
|
|
+
|
|
|
+ AtomicReference<Boolean> putModelHolder = new AtomicReference<>();
|
|
|
+ AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
|
|
|
+
|
|
|
+ var configured1 = createModel(randomAlphaOfLength(5) + 1, randomFrom(TaskType.values()), service);
|
|
|
+ var configured2 = createModel(randomAlphaOfLength(5) + 1, randomFrom(TaskType.values()), service);
|
|
|
+ blockingCall(listener -> modelRegistry.storeModel(configured1, listener), putModelHolder, exceptionHolder);
|
|
|
+ assertThat(putModelHolder.get(), is(true));
|
|
|
+ blockingCall(listener -> modelRegistry.storeModel(configured2, listener), putModelHolder, exceptionHolder);
|
|
|
+ assertThat(putModelHolder.get(), is(true));
|
|
|
+ assertNull(exceptionHolder.get());
|
|
|
+
|
|
|
+ AtomicReference<UnparsedModel> modelHolder = new AtomicReference<>();
|
|
|
+ blockingCall(listener -> modelRegistry.getModel("default-sparse", listener), modelHolder, exceptionHolder);
|
|
|
+ assertEquals("default-sparse", modelHolder.get().inferenceEntityId());
|
|
|
+ assertEquals(TaskType.SPARSE_EMBEDDING, modelHolder.get().taskType());
|
|
|
+ assertReturnModelIsModifiable(modelHolder.get());
|
|
|
+
|
|
|
+ blockingCall(listener -> modelRegistry.getModel("default-text", listener), modelHolder, exceptionHolder);
|
|
|
+ assertEquals("default-text", modelHolder.get().inferenceEntityId());
|
|
|
+ assertEquals(TaskType.TEXT_EMBEDDING, modelHolder.get().taskType());
|
|
|
+
|
|
|
+ blockingCall(listener -> modelRegistry.getModel(configured1.getInferenceEntityId(), listener), modelHolder, exceptionHolder);
|
|
|
+ assertEquals(configured1.getInferenceEntityId(), modelHolder.get().inferenceEntityId());
|
|
|
+ assertEquals(configured1.getTaskType(), modelHolder.get().taskType());
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testGetByTaskType_WithDefaults() throws Exception {
|
|
|
+ var service = "foo";
|
|
|
+ var secret = "abc";
|
|
|
+
|
|
|
+ var defaultSparse = createUnparsedConfig("default-sparse", TaskType.SPARSE_EMBEDDING, service, secret);
|
|
|
+ var defaultText = createUnparsedConfig("default-text", TaskType.TEXT_EMBEDDING, service, secret);
|
|
|
+ var defaultChat = createUnparsedConfig("default-chat", TaskType.COMPLETION, service, secret);
|
|
|
+
|
|
|
+ modelRegistry.addDefaultConfiguration(defaultSparse);
|
|
|
+ modelRegistry.addDefaultConfiguration(defaultText);
|
|
|
+ modelRegistry.addDefaultConfiguration(defaultChat);
|
|
|
+
|
|
|
+ AtomicReference<Boolean> putModelHolder = new AtomicReference<>();
|
|
|
+ AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
|
|
|
+
|
|
|
+ var configuredSparse = createModel("configured-sparse", TaskType.SPARSE_EMBEDDING, service);
|
|
|
+ var configuredText = createModel("configured-text", TaskType.TEXT_EMBEDDING, service);
|
|
|
+ var configuredRerank = createModel("configured-rerank", TaskType.RERANK, service);
|
|
|
+ blockingCall(listener -> modelRegistry.storeModel(configuredSparse, listener), putModelHolder, exceptionHolder);
|
|
|
+ assertThat(putModelHolder.get(), is(true));
|
|
|
+ blockingCall(listener -> modelRegistry.storeModel(configuredText, listener), putModelHolder, exceptionHolder);
|
|
|
+ assertThat(putModelHolder.get(), is(true));
|
|
|
+ blockingCall(listener -> modelRegistry.storeModel(configuredRerank, listener), putModelHolder, exceptionHolder);
|
|
|
+ assertThat(putModelHolder.get(), is(true));
|
|
|
+ assertNull(exceptionHolder.get());
|
|
|
+
|
|
|
+ AtomicReference<List<UnparsedModel>> modelHolder = new AtomicReference<>();
|
|
|
+ blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.SPARSE_EMBEDDING, listener), modelHolder, exceptionHolder);
|
|
|
+ if (exceptionHolder.get() != null) {
|
|
|
+ throw exceptionHolder.get();
|
|
|
+ }
|
|
|
+ assertNull(exceptionHolder.get());
|
|
|
+ assertThat(modelHolder.get(), hasSize(2));
|
|
|
+ assertEquals("configured-sparse", modelHolder.get().get(0).inferenceEntityId());
|
|
|
+ assertEquals("default-sparse", modelHolder.get().get(1).inferenceEntityId());
|
|
|
+
|
|
|
+ blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.TEXT_EMBEDDING, listener), modelHolder, exceptionHolder);
|
|
|
+ assertThat(modelHolder.get(), hasSize(2));
|
|
|
+ assertEquals("configured-text", modelHolder.get().get(0).inferenceEntityId());
|
|
|
+ assertEquals("default-text", modelHolder.get().get(1).inferenceEntityId());
|
|
|
+ assertReturnModelIsModifiable(modelHolder.get().get(0));
|
|
|
+
|
|
|
+ blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.RERANK, listener), modelHolder, exceptionHolder);
|
|
|
+ assertThat(modelHolder.get(), hasSize(1));
|
|
|
+ assertEquals("configured-rerank", modelHolder.get().get(0).inferenceEntityId());
|
|
|
+
|
|
|
+ blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.COMPLETION, listener), modelHolder, exceptionHolder);
|
|
|
+ assertThat(modelHolder.get(), hasSize(1));
|
|
|
+ assertEquals("default-chat", modelHolder.get().get(0).inferenceEntityId());
|
|
|
+ assertReturnModelIsModifiable(modelHolder.get().get(0));
|
|
|
+ }
|
|
|
+
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
+ private void assertReturnModelIsModifiable(UnparsedModel unparsedModel) {
|
|
|
+ var settings = unparsedModel.settings();
|
|
|
+ if (settings != null) {
|
|
|
+ var serviceSettings = (Map<String, Object>) settings.get("service_settings");
|
|
|
+ if (serviceSettings != null && serviceSettings.size() > 0) {
|
|
|
+ var itr = serviceSettings.entrySet().iterator();
|
|
|
+ itr.next();
|
|
|
+ itr.remove();
|
|
|
+ }
|
|
|
+
|
|
|
+ var taskSettings = (Map<String, Object>) settings.get("task_settings");
|
|
|
+ if (taskSettings != null && taskSettings.size() > 0) {
|
|
|
+ var itr = taskSettings.entrySet().iterator();
|
|
|
+ itr.next();
|
|
|
+ itr.remove();
|
|
|
+ }
|
|
|
+
|
|
|
+ if (unparsedModel.secrets() != null && unparsedModel.secrets().size() > 0) {
|
|
|
+ var itr = unparsedModel.secrets().entrySet().iterator();
|
|
|
+ itr.next();
|
|
|
+ itr.remove();
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
private Model buildElserModelConfig(String inferenceEntityId, TaskType taskType) {
|
|
@@ -327,6 +528,10 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
|
|
|
);
|
|
|
}
|
|
|
|
|
|
+ public static UnparsedModel createUnparsedConfig(String inferenceEntityId, TaskType taskType, String service, String secret) {
|
|
|
+ return new UnparsedModel(inferenceEntityId, taskType, service, Map.of("a", "b"), Map.of("secret", secret));
|
|
|
+ }
|
|
|
+
|
|
|
private static class TestModelOfAnyKind extends ModelConfigurations {
|
|
|
|
|
|
record TestModelServiceSettings() implements ServiceSettings {
|