|
@@ -11,7 +11,10 @@ import org.elasticsearch.ElasticsearchStatusException;
|
|
|
import org.elasticsearch.TransportVersion;
|
|
|
import org.elasticsearch.action.ActionListener;
|
|
|
import org.elasticsearch.client.internal.Client;
|
|
|
+import org.elasticsearch.cluster.service.ClusterService;
|
|
|
import org.elasticsearch.common.io.stream.StreamOutput;
|
|
|
+import org.elasticsearch.common.settings.Settings;
|
|
|
+import org.elasticsearch.inference.InferenceService;
|
|
|
import org.elasticsearch.inference.InferenceServiceExtension;
|
|
|
import org.elasticsearch.inference.Model;
|
|
|
import org.elasticsearch.inference.ModelConfigurations;
|
|
@@ -46,6 +49,7 @@ import java.util.Map;
|
|
|
import java.util.concurrent.CountDownLatch;
|
|
|
import java.util.concurrent.atomic.AtomicReference;
|
|
|
import java.util.function.Consumer;
|
|
|
+import java.util.function.Function;
|
|
|
import java.util.stream.Collectors;
|
|
|
|
|
|
import static org.hamcrest.CoreMatchers.equalTo;
|
|
@@ -56,6 +60,8 @@ import static org.hamcrest.Matchers.hasSize;
|
|
|
import static org.hamcrest.Matchers.instanceOf;
|
|
|
import static org.hamcrest.Matchers.not;
|
|
|
import static org.hamcrest.Matchers.nullValue;
|
|
|
+import static org.mockito.ArgumentMatchers.any;
|
|
|
+import static org.mockito.Mockito.doAnswer;
|
|
|
import static org.mockito.Mockito.mock;
|
|
|
|
|
|
public class ModelRegistryIT extends ESSingleNodeTestCase {
|
|
@@ -121,7 +127,12 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
|
|
|
assertEquals(model.getConfigurations().getService(), modelHolder.get().service());
|
|
|
|
|
|
var elserService = new ElasticsearchInternalService(
|
|
|
- new InferenceServiceExtension.InferenceServiceFactoryContext(mock(Client.class), mock(ThreadPool.class))
|
|
|
+ new InferenceServiceExtension.InferenceServiceFactoryContext(
|
|
|
+ mock(Client.class),
|
|
|
+ mock(ThreadPool.class),
|
|
|
+ mock(ClusterService.class),
|
|
|
+ Settings.EMPTY
|
|
|
+ )
|
|
|
);
|
|
|
ElasticsearchInternalModel roundTripModel = (ElasticsearchInternalModel) elserService.parsePersistedConfigWithSecrets(
|
|
|
modelHolder.get().inferenceEntityId(),
|
|
@@ -282,18 +293,30 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
|
|
|
}
|
|
|
|
|
|
public void testGetAllModels_WithDefaults() throws Exception {
|
|
|
- var service = "foo";
|
|
|
- var secret = "abc";
|
|
|
+ var serviceName = "foo";
|
|
|
int configuredModelCount = 10;
|
|
|
int defaultModelCount = 2;
|
|
|
int totalModelCount = 12;
|
|
|
|
|
|
- var defaultConfigs = new HashMap<String, UnparsedModel>();
|
|
|
+ var service = mock(InferenceService.class);
|
|
|
+
|
|
|
+ var defaultConfigs = new ArrayList<Model>();
|
|
|
+ var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
|
|
|
for (int i = 0; i < defaultModelCount; i++) {
|
|
|
var id = "default-" + i;
|
|
|
- defaultConfigs.put(id, createUnparsedConfig(id, randomFrom(TaskType.values()), service, secret));
|
|
|
+ var taskType = randomFrom(TaskType.values());
|
|
|
+ defaultConfigs.add(createModel(id, taskType, serviceName));
|
|
|
+ defaultIds.add(new InferenceService.DefaultConfigId(id, taskType, service));
|
|
|
}
|
|
|
- defaultConfigs.values().forEach(modelRegistry::addDefaultConfiguration);
|
|
|
+
|
|
|
+ doAnswer(invocation -> {
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
+ var listener = (ActionListener<List<Model>>) invocation.getArguments()[0];
|
|
|
+ listener.onResponse(defaultConfigs);
|
|
|
+ return Void.TYPE;
|
|
|
+ }).when(service).defaultConfigs(any());
|
|
|
+
|
|
|
+ defaultIds.forEach(modelRegistry::addDefaultIds);
|
|
|
|
|
|
AtomicReference<Boolean> putModelHolder = new AtomicReference<>();
|
|
|
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
|
|
@@ -301,7 +324,7 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
|
|
|
var createdModels = new HashMap<String, Model>();
|
|
|
for (int i = 0; i < configuredModelCount; i++) {
|
|
|
var id = randomAlphaOfLength(5) + i;
|
|
|
- var model = createModel(id, randomFrom(TaskType.values()), service);
|
|
|
+ var model = createModel(id, randomFrom(TaskType.values()), serviceName);
|
|
|
createdModels.put(id, model);
|
|
|
blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder);
|
|
|
assertThat(putModelHolder.get(), is(true));
|
|
@@ -315,16 +338,22 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
|
|
|
var getAllModels = modelHolder.get();
|
|
|
assertReturnModelIsModifiable(modelHolder.get().get(0));
|
|
|
|
|
|
+ // same result but configs should have been persisted this time
|
|
|
+ blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder);
|
|
|
+ assertNull(exceptionHolder.get());
|
|
|
+ assertThat(modelHolder.get(), hasSize(totalModelCount));
|
|
|
+
|
|
|
// sort in the same order as the returned models
|
|
|
- var ids = new ArrayList<>(defaultConfigs.keySet().stream().toList());
|
|
|
+ var ids = new ArrayList<>(defaultIds.stream().map(InferenceService.DefaultConfigId::inferenceId).toList());
|
|
|
ids.addAll(createdModels.keySet().stream().toList());
|
|
|
ids.sort(String::compareTo);
|
|
|
+ var configsById = defaultConfigs.stream().collect(Collectors.toMap(Model::getInferenceEntityId, Function.identity()));
|
|
|
for (int i = 0; i < totalModelCount; i++) {
|
|
|
var id = ids.get(i);
|
|
|
assertEquals(id, getAllModels.get(i).inferenceEntityId());
|
|
|
if (id.startsWith("default")) {
|
|
|
- assertEquals(defaultConfigs.get(id).taskType(), getAllModels.get(i).taskType());
|
|
|
- assertEquals(defaultConfigs.get(id).service(), getAllModels.get(i).service());
|
|
|
+ assertEquals(configsById.get(id).getTaskType(), getAllModels.get(i).taskType());
|
|
|
+ assertEquals(configsById.get(id).getConfigurations().getService(), getAllModels.get(i).service());
|
|
|
} else {
|
|
|
assertEquals(createdModels.get(id).getTaskType(), getAllModels.get(i).taskType());
|
|
|
assertEquals(createdModels.get(id).getConfigurations().getService(), getAllModels.get(i).service());
|
|
@@ -333,16 +362,27 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
|
|
|
}
|
|
|
|
|
|
public void testGetAllModels_OnlyDefaults() throws Exception {
|
|
|
- var service = "foo";
|
|
|
- var secret = "abc";
|
|
|
int defaultModelCount = 2;
|
|
|
+ var serviceName = "foo";
|
|
|
+ var service = mock(InferenceService.class);
|
|
|
|
|
|
- var defaultConfigs = new HashMap<String, UnparsedModel>();
|
|
|
+ var defaultConfigs = new ArrayList<Model>();
|
|
|
+ var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
|
|
|
for (int i = 0; i < defaultModelCount; i++) {
|
|
|
var id = "default-" + i;
|
|
|
- defaultConfigs.put(id, createUnparsedConfig(id, randomFrom(TaskType.values()), service, secret));
|
|
|
+ var taskType = randomFrom(TaskType.values());
|
|
|
+ defaultConfigs.add(createModel(id, taskType, serviceName));
|
|
|
+ defaultIds.add(new InferenceService.DefaultConfigId(id, taskType, service));
|
|
|
}
|
|
|
- defaultConfigs.values().forEach(modelRegistry::addDefaultConfiguration);
|
|
|
+
|
|
|
+ doAnswer(invocation -> {
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
+ var listener = (ActionListener<List<Model>>) invocation.getArguments()[0];
|
|
|
+ listener.onResponse(defaultConfigs);
|
|
|
+ return Void.TYPE;
|
|
|
+ }).when(service).defaultConfigs(any());
|
|
|
+
|
|
|
+ defaultIds.forEach(modelRegistry::addDefaultIds);
|
|
|
|
|
|
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
|
|
|
AtomicReference<List<UnparsedModel>> modelHolder = new AtomicReference<>();
|
|
@@ -353,31 +393,42 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
|
|
|
assertReturnModelIsModifiable(modelHolder.get().get(0));
|
|
|
|
|
|
// sort in the same order as the returned models
|
|
|
- var ids = new ArrayList<>(defaultConfigs.keySet().stream().toList());
|
|
|
+ var configsById = defaultConfigs.stream().collect(Collectors.toMap(Model::getInferenceEntityId, Function.identity()));
|
|
|
+ var ids = new ArrayList<>(configsById.keySet().stream().toList());
|
|
|
ids.sort(String::compareTo);
|
|
|
for (int i = 0; i < defaultModelCount; i++) {
|
|
|
var id = ids.get(i);
|
|
|
assertEquals(id, getAllModels.get(i).inferenceEntityId());
|
|
|
- assertEquals(defaultConfigs.get(id).taskType(), getAllModels.get(i).taskType());
|
|
|
- assertEquals(defaultConfigs.get(id).service(), getAllModels.get(i).service());
|
|
|
+ assertEquals(configsById.get(id).getTaskType(), getAllModels.get(i).taskType());
|
|
|
+ assertEquals(configsById.get(id).getConfigurations().getService(), getAllModels.get(i).service());
|
|
|
}
|
|
|
}
|
|
|
|
|
|
public void testGet_WithDefaults() throws InterruptedException {
|
|
|
- var service = "foo";
|
|
|
- var secret = "abc";
|
|
|
+ var serviceName = "foo";
|
|
|
+ var service = mock(InferenceService.class);
|
|
|
+
|
|
|
+ var defaultConfigs = new ArrayList<Model>();
|
|
|
+ var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
|
|
|
|
|
|
- var defaultSparse = createUnparsedConfig("default-sparse", TaskType.SPARSE_EMBEDDING, service, secret);
|
|
|
- var defaultText = createUnparsedConfig("default-text", TaskType.TEXT_EMBEDDING, service, secret);
|
|
|
+ defaultConfigs.add(createModel("default-sparse", TaskType.SPARSE_EMBEDDING, serviceName));
|
|
|
+ defaultConfigs.add(createModel("default-text", TaskType.TEXT_EMBEDDING, serviceName));
|
|
|
+ defaultIds.add(new InferenceService.DefaultConfigId("default-sparse", TaskType.SPARSE_EMBEDDING, service));
|
|
|
+ defaultIds.add(new InferenceService.DefaultConfigId("default-text", TaskType.TEXT_EMBEDDING, service));
|
|
|
|
|
|
- modelRegistry.addDefaultConfiguration(defaultSparse);
|
|
|
- modelRegistry.addDefaultConfiguration(defaultText);
|
|
|
+ doAnswer(invocation -> {
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
+ var listener = (ActionListener<List<Model>>) invocation.getArguments()[0];
|
|
|
+ listener.onResponse(defaultConfigs);
|
|
|
+ return Void.TYPE;
|
|
|
+ }).when(service).defaultConfigs(any());
|
|
|
+ defaultIds.forEach(modelRegistry::addDefaultIds);
|
|
|
|
|
|
AtomicReference<Boolean> putModelHolder = new AtomicReference<>();
|
|
|
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
|
|
|
|
|
|
- var configured1 = createModel(randomAlphaOfLength(5) + 1, randomFrom(TaskType.values()), service);
|
|
|
- var configured2 = createModel(randomAlphaOfLength(5) + 1, randomFrom(TaskType.values()), service);
|
|
|
+ var configured1 = createModel(randomAlphaOfLength(5) + 1, randomFrom(TaskType.values()), serviceName);
|
|
|
+ var configured2 = createModel(randomAlphaOfLength(5) + 1, randomFrom(TaskType.values()), serviceName);
|
|
|
blockingCall(listener -> modelRegistry.storeModel(configured1, listener), putModelHolder, exceptionHolder);
|
|
|
assertThat(putModelHolder.get(), is(true));
|
|
|
blockingCall(listener -> modelRegistry.storeModel(configured2, listener), putModelHolder, exceptionHolder);
|
|
@@ -386,6 +437,7 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
|
|
|
|
|
|
AtomicReference<UnparsedModel> modelHolder = new AtomicReference<>();
|
|
|
blockingCall(listener -> modelRegistry.getModel("default-sparse", listener), modelHolder, exceptionHolder);
|
|
|
+ assertNull(exceptionHolder.get());
|
|
|
assertEquals("default-sparse", modelHolder.get().inferenceEntityId());
|
|
|
assertEquals(TaskType.SPARSE_EMBEDDING, modelHolder.get().taskType());
|
|
|
assertReturnModelIsModifiable(modelHolder.get());
|
|
@@ -400,23 +452,32 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
|
|
|
}
|
|
|
|
|
|
public void testGetByTaskType_WithDefaults() throws Exception {
|
|
|
- var service = "foo";
|
|
|
- var secret = "abc";
|
|
|
-
|
|
|
- var defaultSparse = createUnparsedConfig("default-sparse", TaskType.SPARSE_EMBEDDING, service, secret);
|
|
|
- var defaultText = createUnparsedConfig("default-text", TaskType.TEXT_EMBEDDING, service, secret);
|
|
|
- var defaultChat = createUnparsedConfig("default-chat", TaskType.COMPLETION, service, secret);
|
|
|
-
|
|
|
- modelRegistry.addDefaultConfiguration(defaultSparse);
|
|
|
- modelRegistry.addDefaultConfiguration(defaultText);
|
|
|
- modelRegistry.addDefaultConfiguration(defaultChat);
|
|
|
+ var serviceName = "foo";
|
|
|
+
|
|
|
+ var defaultSparse = createModel("default-sparse", TaskType.SPARSE_EMBEDDING, serviceName);
|
|
|
+ var defaultText = createModel("default-text", TaskType.TEXT_EMBEDDING, serviceName);
|
|
|
+ var defaultChat = createModel("default-chat", TaskType.COMPLETION, serviceName);
|
|
|
+
|
|
|
+ var service = mock(InferenceService.class);
|
|
|
+ var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
|
|
|
+ defaultIds.add(new InferenceService.DefaultConfigId("default-sparse", TaskType.SPARSE_EMBEDDING, service));
|
|
|
+ defaultIds.add(new InferenceService.DefaultConfigId("default-text", TaskType.TEXT_EMBEDDING, service));
|
|
|
+ defaultIds.add(new InferenceService.DefaultConfigId("default-chat", TaskType.COMPLETION, service));
|
|
|
+
|
|
|
+ doAnswer(invocation -> {
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
+ var listener = (ActionListener<List<Model>>) invocation.getArguments()[0];
|
|
|
+ listener.onResponse(List.of(defaultSparse, defaultChat, defaultText));
|
|
|
+ return Void.TYPE;
|
|
|
+ }).when(service).defaultConfigs(any());
|
|
|
+ defaultIds.forEach(modelRegistry::addDefaultIds);
|
|
|
|
|
|
AtomicReference<Boolean> putModelHolder = new AtomicReference<>();
|
|
|
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
|
|
|
|
|
|
- var configuredSparse = createModel("configured-sparse", TaskType.SPARSE_EMBEDDING, service);
|
|
|
- var configuredText = createModel("configured-text", TaskType.TEXT_EMBEDDING, service);
|
|
|
- var configuredRerank = createModel("configured-rerank", TaskType.RERANK, service);
|
|
|
+ var configuredSparse = createModel("configured-sparse", TaskType.SPARSE_EMBEDDING, serviceName);
|
|
|
+ var configuredText = createModel("configured-text", TaskType.TEXT_EMBEDDING, serviceName);
|
|
|
+ var configuredRerank = createModel("configured-rerank", TaskType.RERANK, serviceName);
|
|
|
blockingCall(listener -> modelRegistry.storeModel(configuredSparse, listener), putModelHolder, exceptionHolder);
|
|
|
assertThat(putModelHolder.get(), is(true));
|
|
|
blockingCall(listener -> modelRegistry.storeModel(configuredText, listener), putModelHolder, exceptionHolder);
|
|
@@ -530,10 +591,6 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
|
|
|
);
|
|
|
}
|
|
|
|
|
|
- public static UnparsedModel createUnparsedConfig(String inferenceEntityId, TaskType taskType, String service, String secret) {
|
|
|
- return new UnparsedModel(inferenceEntityId, taskType, service, Map.of("a", "b"), Map.of("secret", secret));
|
|
|
- }
|
|
|
-
|
|
|
private static class TestModelOfAnyKind extends ModelConfigurations {
|
|
|
|
|
|
record TestModelServiceSettings() implements ServiceSettings {
|