Browse Source

[8.16][ML] Pick best model variant for the default elser endpoint (#114758)

* [ML] Pick best model variant for the default elser endpoint (#114690)

# Conflicts:
#	x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java
#	x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java
#	x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml

* fix test

* fix test
David Kyle 1 year ago
parent
commit
c2cec39b37
23 changed files with 448 additions and 370 deletions
  1. 14 4
      server/src/main/java/org/elasticsearch/inference/InferenceService.java
  2. 3 1
      server/src/main/java/org/elasticsearch/inference/InferenceServiceExtension.java
  3. 8 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningField.java
  4. 100 43
      x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java
  5. 11 3
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java
  6. 8 7
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java
  7. 119 104
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java
  8. 35 19
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java
  9. 57 67
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java
  10. 32 55
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java
  11. 39 21
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java
  12. 3 2
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/AutoscalingIT.java
  13. 3 2
      x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TooManyJobsIT.java
  14. 1 9
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java
  15. 2 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportMlInfoAction.java
  16. 2 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.java
  17. 3 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/task/AbstractJobPersistentTasksExecutor.java
  18. 1 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/NativeMemoryCalculator.java
  19. 1 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsActionTests.java
  20. 2 2
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterServiceTests.java
  21. 3 3
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/task/OpenJobPersistentTasksExecutorTests.java
  22. 1 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/NativeMemoryCalculatorTests.java
  23. 0 20
      x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml

+ 14 - 4
server/src/main/java/org/elasticsearch/inference/InferenceService.java

@@ -192,12 +192,22 @@ public interface InferenceService extends Closeable {
         return supportedStreamingTasks().contains(taskType);
     }
 
+    record DefaultConfigId(String inferenceId, TaskType taskType, InferenceService service) {};
+
     /**
-     * A service can define default configurations that can be
-     * used out of the box without creating an endpoint first.
-     * @return Default configurations provided by this service
+     * Get the Ids and task type of any default configurations provided by this service
+     * @return Defaults
      */
-    default List<UnparsedModel> defaultConfigs() {
+    default List<DefaultConfigId> defaultConfigIds() {
         return List.of();
     }
+
+    /**
+     * Call the listener with the default model configurations defined by
+     * the service
+     * @param defaultsListener The listener
+     */
+    default void defaultConfigs(ActionListener<List<Model>> defaultsListener) {
+        defaultsListener.onResponse(List.of());
+    }
 }

+ 3 - 1
server/src/main/java/org/elasticsearch/inference/InferenceServiceExtension.java

@@ -10,6 +10,8 @@
 package org.elasticsearch.inference;
 
 import org.elasticsearch.client.internal.Client;
+import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.threadpool.ThreadPool;
 
 import java.util.List;
@@ -21,7 +23,7 @@ public interface InferenceServiceExtension {
 
     List<Factory> getInferenceServiceFactories();
 
-    record InferenceServiceFactoryContext(Client client, ThreadPool threadPool) {}
+    record InferenceServiceFactoryContext(Client client, ThreadPool threadPool, ClusterService clusterService, Settings settings) {}
 
     interface Factory {
         /**

+ 8 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningField.java

@@ -37,6 +37,14 @@ public final class MachineLearningField {
         Setting.Property.NodeScope
     );
 
+    public static final Setting<Integer> MAX_LAZY_ML_NODES = Setting.intSetting(
+        "xpack.ml.max_lazy_ml_nodes",
+        0,
+        0,
+        Setting.Property.OperatorDynamic,
+        Setting.Property.NodeScope
+    );
+
     /**
      * This boolean value indicates if `max_machine_memory_percent` should be ignored and an automatic calculation is used instead.
      *

+ 100 - 43
x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java

@@ -11,7 +11,10 @@ import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.TransportVersion;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.client.internal.Client;
+import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.inference.InferenceService;
 import org.elasticsearch.inference.InferenceServiceExtension;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.ModelConfigurations;
@@ -46,6 +49,7 @@ import java.util.Map;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Consumer;
+import java.util.function.Function;
 import java.util.stream.Collectors;
 
 import static org.hamcrest.CoreMatchers.equalTo;
@@ -56,6 +60,8 @@ import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.not;
 import static org.hamcrest.Matchers.nullValue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
 
 public class ModelRegistryIT extends ESSingleNodeTestCase {
@@ -121,7 +127,12 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
         assertEquals(model.getConfigurations().getService(), modelHolder.get().service());
 
         var elserService = new ElasticsearchInternalService(
-            new InferenceServiceExtension.InferenceServiceFactoryContext(mock(Client.class), mock(ThreadPool.class))
+            new InferenceServiceExtension.InferenceServiceFactoryContext(
+                mock(Client.class),
+                mock(ThreadPool.class),
+                mock(ClusterService.class),
+                Settings.EMPTY
+            )
         );
         ElasticsearchInternalModel roundTripModel = (ElasticsearchInternalModel) elserService.parsePersistedConfigWithSecrets(
             modelHolder.get().inferenceEntityId(),
@@ -282,18 +293,30 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
     }
 
     public void testGetAllModels_WithDefaults() throws Exception {
-        var service = "foo";
-        var secret = "abc";
+        var serviceName = "foo";
         int configuredModelCount = 10;
         int defaultModelCount = 2;
         int totalModelCount = 12;
 
-        var defaultConfigs = new HashMap<String, UnparsedModel>();
+        var service = mock(InferenceService.class);
+
+        var defaultConfigs = new ArrayList<Model>();
+        var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
         for (int i = 0; i < defaultModelCount; i++) {
             var id = "default-" + i;
-            defaultConfigs.put(id, createUnparsedConfig(id, randomFrom(TaskType.values()), service, secret));
+            var taskType = randomFrom(TaskType.values());
+            defaultConfigs.add(createModel(id, taskType, serviceName));
+            defaultIds.add(new InferenceService.DefaultConfigId(id, taskType, service));
         }
-        defaultConfigs.values().forEach(modelRegistry::addDefaultConfiguration);
+
+        doAnswer(invocation -> {
+            @SuppressWarnings("unchecked")
+            var listener = (ActionListener<List<Model>>) invocation.getArguments()[0];
+            listener.onResponse(defaultConfigs);
+            return Void.TYPE;
+        }).when(service).defaultConfigs(any());
+
+        defaultIds.forEach(modelRegistry::addDefaultIds);
 
         AtomicReference<Boolean> putModelHolder = new AtomicReference<>();
         AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
@@ -301,7 +324,7 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
         var createdModels = new HashMap<String, Model>();
         for (int i = 0; i < configuredModelCount; i++) {
             var id = randomAlphaOfLength(5) + i;
-            var model = createModel(id, randomFrom(TaskType.values()), service);
+            var model = createModel(id, randomFrom(TaskType.values()), serviceName);
             createdModels.put(id, model);
             blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder);
             assertThat(putModelHolder.get(), is(true));
@@ -315,16 +338,22 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
         var getAllModels = modelHolder.get();
         assertReturnModelIsModifiable(modelHolder.get().get(0));
 
+        // same result but configs should have been persisted this time
+        blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder);
+        assertNull(exceptionHolder.get());
+        assertThat(modelHolder.get(), hasSize(totalModelCount));
+
         // sort in the same order as the returned models
-        var ids = new ArrayList<>(defaultConfigs.keySet().stream().toList());
+        var ids = new ArrayList<>(defaultIds.stream().map(InferenceService.DefaultConfigId::inferenceId).toList());
         ids.addAll(createdModels.keySet().stream().toList());
         ids.sort(String::compareTo);
+        var configsById = defaultConfigs.stream().collect(Collectors.toMap(Model::getInferenceEntityId, Function.identity()));
         for (int i = 0; i < totalModelCount; i++) {
             var id = ids.get(i);
             assertEquals(id, getAllModels.get(i).inferenceEntityId());
             if (id.startsWith("default")) {
-                assertEquals(defaultConfigs.get(id).taskType(), getAllModels.get(i).taskType());
-                assertEquals(defaultConfigs.get(id).service(), getAllModels.get(i).service());
+                assertEquals(configsById.get(id).getTaskType(), getAllModels.get(i).taskType());
+                assertEquals(configsById.get(id).getConfigurations().getService(), getAllModels.get(i).service());
             } else {
                 assertEquals(createdModels.get(id).getTaskType(), getAllModels.get(i).taskType());
                 assertEquals(createdModels.get(id).getConfigurations().getService(), getAllModels.get(i).service());
@@ -333,16 +362,27 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
     }
 
     public void testGetAllModels_OnlyDefaults() throws Exception {
-        var service = "foo";
-        var secret = "abc";
         int defaultModelCount = 2;
+        var serviceName = "foo";
+        var service = mock(InferenceService.class);
 
-        var defaultConfigs = new HashMap<String, UnparsedModel>();
+        var defaultConfigs = new ArrayList<Model>();
+        var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
         for (int i = 0; i < defaultModelCount; i++) {
             var id = "default-" + i;
-            defaultConfigs.put(id, createUnparsedConfig(id, randomFrom(TaskType.values()), service, secret));
+            var taskType = randomFrom(TaskType.values());
+            defaultConfigs.add(createModel(id, taskType, serviceName));
+            defaultIds.add(new InferenceService.DefaultConfigId(id, taskType, service));
         }
-        defaultConfigs.values().forEach(modelRegistry::addDefaultConfiguration);
+
+        doAnswer(invocation -> {
+            @SuppressWarnings("unchecked")
+            var listener = (ActionListener<List<Model>>) invocation.getArguments()[0];
+            listener.onResponse(defaultConfigs);
+            return Void.TYPE;
+        }).when(service).defaultConfigs(any());
+
+        defaultIds.forEach(modelRegistry::addDefaultIds);
 
         AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
         AtomicReference<List<UnparsedModel>> modelHolder = new AtomicReference<>();
@@ -353,31 +393,42 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
         assertReturnModelIsModifiable(modelHolder.get().get(0));
 
         // sort in the same order as the returned models
-        var ids = new ArrayList<>(defaultConfigs.keySet().stream().toList());
+        var configsById = defaultConfigs.stream().collect(Collectors.toMap(Model::getInferenceEntityId, Function.identity()));
+        var ids = new ArrayList<>(configsById.keySet().stream().toList());
         ids.sort(String::compareTo);
         for (int i = 0; i < defaultModelCount; i++) {
             var id = ids.get(i);
             assertEquals(id, getAllModels.get(i).inferenceEntityId());
-            assertEquals(defaultConfigs.get(id).taskType(), getAllModels.get(i).taskType());
-            assertEquals(defaultConfigs.get(id).service(), getAllModels.get(i).service());
+            assertEquals(configsById.get(id).getTaskType(), getAllModels.get(i).taskType());
+            assertEquals(configsById.get(id).getConfigurations().getService(), getAllModels.get(i).service());
         }
     }
 
     public void testGet_WithDefaults() throws InterruptedException {
-        var service = "foo";
-        var secret = "abc";
+        var serviceName = "foo";
+        var service = mock(InferenceService.class);
+
+        var defaultConfigs = new ArrayList<Model>();
+        var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
 
-        var defaultSparse = createUnparsedConfig("default-sparse", TaskType.SPARSE_EMBEDDING, service, secret);
-        var defaultText = createUnparsedConfig("default-text", TaskType.TEXT_EMBEDDING, service, secret);
+        defaultConfigs.add(createModel("default-sparse", TaskType.SPARSE_EMBEDDING, serviceName));
+        defaultConfigs.add(createModel("default-text", TaskType.TEXT_EMBEDDING, serviceName));
+        defaultIds.add(new InferenceService.DefaultConfigId("default-sparse", TaskType.SPARSE_EMBEDDING, service));
+        defaultIds.add(new InferenceService.DefaultConfigId("default-text", TaskType.TEXT_EMBEDDING, service));
 
-        modelRegistry.addDefaultConfiguration(defaultSparse);
-        modelRegistry.addDefaultConfiguration(defaultText);
+        doAnswer(invocation -> {
+            @SuppressWarnings("unchecked")
+            var listener = (ActionListener<List<Model>>) invocation.getArguments()[0];
+            listener.onResponse(defaultConfigs);
+            return Void.TYPE;
+        }).when(service).defaultConfigs(any());
+        defaultIds.forEach(modelRegistry::addDefaultIds);
 
         AtomicReference<Boolean> putModelHolder = new AtomicReference<>();
         AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
 
-        var configured1 = createModel(randomAlphaOfLength(5) + 1, randomFrom(TaskType.values()), service);
-        var configured2 = createModel(randomAlphaOfLength(5) + 1, randomFrom(TaskType.values()), service);
+        var configured1 = createModel(randomAlphaOfLength(5) + 1, randomFrom(TaskType.values()), serviceName);
+        var configured2 = createModel(randomAlphaOfLength(5) + 1, randomFrom(TaskType.values()), serviceName);
         blockingCall(listener -> modelRegistry.storeModel(configured1, listener), putModelHolder, exceptionHolder);
         assertThat(putModelHolder.get(), is(true));
         blockingCall(listener -> modelRegistry.storeModel(configured2, listener), putModelHolder, exceptionHolder);
@@ -386,6 +437,7 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
 
         AtomicReference<UnparsedModel> modelHolder = new AtomicReference<>();
         blockingCall(listener -> modelRegistry.getModel("default-sparse", listener), modelHolder, exceptionHolder);
+        assertNull(exceptionHolder.get());
         assertEquals("default-sparse", modelHolder.get().inferenceEntityId());
         assertEquals(TaskType.SPARSE_EMBEDDING, modelHolder.get().taskType());
         assertReturnModelIsModifiable(modelHolder.get());
@@ -400,23 +452,32 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
     }
 
     public void testGetByTaskType_WithDefaults() throws Exception {
-        var service = "foo";
-        var secret = "abc";
-
-        var defaultSparse = createUnparsedConfig("default-sparse", TaskType.SPARSE_EMBEDDING, service, secret);
-        var defaultText = createUnparsedConfig("default-text", TaskType.TEXT_EMBEDDING, service, secret);
-        var defaultChat = createUnparsedConfig("default-chat", TaskType.COMPLETION, service, secret);
-
-        modelRegistry.addDefaultConfiguration(defaultSparse);
-        modelRegistry.addDefaultConfiguration(defaultText);
-        modelRegistry.addDefaultConfiguration(defaultChat);
+        var serviceName = "foo";
+
+        var defaultSparse = createModel("default-sparse", TaskType.SPARSE_EMBEDDING, serviceName);
+        var defaultText = createModel("default-text", TaskType.TEXT_EMBEDDING, serviceName);
+        var defaultChat = createModel("default-chat", TaskType.COMPLETION, serviceName);
+
+        var service = mock(InferenceService.class);
+        var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
+        defaultIds.add(new InferenceService.DefaultConfigId("default-sparse", TaskType.SPARSE_EMBEDDING, service));
+        defaultIds.add(new InferenceService.DefaultConfigId("default-text", TaskType.TEXT_EMBEDDING, service));
+        defaultIds.add(new InferenceService.DefaultConfigId("default-chat", TaskType.COMPLETION, service));
+
+        doAnswer(invocation -> {
+            @SuppressWarnings("unchecked")
+            var listener = (ActionListener<List<Model>>) invocation.getArguments()[0];
+            listener.onResponse(List.of(defaultSparse, defaultChat, defaultText));
+            return Void.TYPE;
+        }).when(service).defaultConfigs(any());
+        defaultIds.forEach(modelRegistry::addDefaultIds);
 
         AtomicReference<Boolean> putModelHolder = new AtomicReference<>();
         AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
 
-        var configuredSparse = createModel("configured-sparse", TaskType.SPARSE_EMBEDDING, service);
-        var configuredText = createModel("configured-text", TaskType.TEXT_EMBEDDING, service);
-        var configuredRerank = createModel("configured-rerank", TaskType.RERANK, service);
+        var configuredSparse = createModel("configured-sparse", TaskType.SPARSE_EMBEDDING, serviceName);
+        var configuredText = createModel("configured-text", TaskType.TEXT_EMBEDDING, serviceName);
+        var configuredRerank = createModel("configured-rerank", TaskType.RERANK, serviceName);
         blockingCall(listener -> modelRegistry.storeModel(configuredSparse, listener), putModelHolder, exceptionHolder);
         assertThat(putModelHolder.get(), is(true));
         blockingCall(listener -> modelRegistry.storeModel(configuredText, listener), putModelHolder, exceptionHolder);
@@ -530,10 +591,6 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
         );
     }
 
-    public static UnparsedModel createUnparsedConfig(String inferenceEntityId, TaskType taskType, String service, String secret) {
-        return new UnparsedModel(inferenceEntityId, taskType, service, Map.of("a", "b"), Map.of("secret", secret));
-    }
-
     private static class TestModelOfAnyKind extends ModelConfigurations {
 
         record TestModelServiceSettings() implements ServiceSettings {

+ 11 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

@@ -212,13 +212,21 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP
             );
         }
 
-        var factoryContext = new InferenceServiceExtension.InferenceServiceFactoryContext(services.client(), services.threadPool());
+        var factoryContext = new InferenceServiceExtension.InferenceServiceFactoryContext(
+            services.client(),
+            services.threadPool(),
+            services.clusterService(),
+            settings
+        );
+
         // This must be done after the HttpRequestSenderFactory is created so that the services can get the
         // reference correctly
         var registry = new InferenceServiceRegistry(inferenceServices, factoryContext);
         registry.init(services.client());
-        for (var service : registry.getServices().values()) {
-            service.defaultConfigs().forEach(modelRegistry::addDefaultConfiguration);
+        if (DefaultElserFeatureFlag.isEnabled()) {
+            for (var service : registry.getServices().values()) {
+                service.defaultConfigIds().forEach(modelRegistry::addDefaultIds);
+            }
         }
         inferenceServiceRegistry.set(registry);
 

+ 8 - 7
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java

@@ -35,7 +35,7 @@ public class SentenceBoundaryChunkingSettings implements ChunkingSettings {
         ChunkingSettingsOptions.SENTENCE_OVERLAP.toString()
     );
 
-    private static int DEFAULT_OVERLAP = 0;
+    private static int DEFAULT_OVERLAP = 1;
 
     protected final int maxChunkSize;
     protected int sentenceOverlap = DEFAULT_OVERLAP;
@@ -69,17 +69,18 @@ public class SentenceBoundaryChunkingSettings implements ChunkingSettings {
             validationException
         );
 
-        Integer sentenceOverlap = ServiceUtils.extractOptionalPositiveInteger(
+        Integer sentenceOverlap = ServiceUtils.removeAsType(
             map,
             ChunkingSettingsOptions.SENTENCE_OVERLAP.toString(),
-            ModelConfigurations.CHUNKING_SETTINGS,
+            Integer.class,
             validationException
         );
-
-        if (sentenceOverlap != null && sentenceOverlap > 1) {
+        if (sentenceOverlap == null) {
+            sentenceOverlap = DEFAULT_OVERLAP;
+        } else if (sentenceOverlap > 1 || sentenceOverlap < 0) {
             validationException.addValidationError(
-                ChunkingSettingsOptions.SENTENCE_OVERLAP.toString() + "[" + sentenceOverlap + "] must be either 0 or 1"
-            ); // todo better
+                ChunkingSettingsOptions.SENTENCE_OVERLAP + "[" + sentenceOverlap + "] must be either 0 or 1"
+            );
         }
 
         if (validationException.validationErrors().isEmpty() == false) {

+ 119 - 104
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java

@@ -23,15 +23,19 @@ import org.elasticsearch.action.bulk.BulkResponse;
 import org.elasticsearch.action.index.IndexRequest;
 import org.elasticsearch.action.search.SearchRequest;
 import org.elasticsearch.action.search.SearchResponse;
+import org.elasticsearch.action.support.GroupedActionListener;
 import org.elasticsearch.action.support.SubscribableListener;
 import org.elasticsearch.action.support.WriteRequest;
 import org.elasticsearch.client.internal.Client;
 import org.elasticsearch.client.internal.OriginSettingClient;
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.common.xcontent.XContentHelper;
 import org.elasticsearch.index.engine.VersionConflictEngineException;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.index.reindex.DeleteByQueryAction;
 import org.elasticsearch.index.reindex.DeleteByQueryRequest;
+import org.elasticsearch.inference.InferenceService;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.ModelConfigurations;
 import org.elasticsearch.inference.TaskType;
@@ -57,6 +61,7 @@ import java.util.Comparator;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.function.Function;
@@ -87,29 +92,33 @@ public class ModelRegistry {
     private static final Logger logger = LogManager.getLogger(ModelRegistry.class);
 
     private final OriginSettingClient client;
-    private Map<String, UnparsedModel> defaultConfigs;
+    private final List<InferenceService.DefaultConfigId> defaultConfigIds;
 
     private final Set<String> preventDeletionLock = Collections.newSetFromMap(new ConcurrentHashMap<>());
 
     public ModelRegistry(Client client) {
         this.client = new OriginSettingClient(client, ClientHelper.INFERENCE_ORIGIN);
-        this.defaultConfigs = new HashMap<>();
+        defaultConfigIds = new ArrayList<>();
     }
 
-    public void addDefaultConfiguration(UnparsedModel serviceDefaultConfig) {
-        if (defaultConfigs.containsKey(serviceDefaultConfig.inferenceEntityId())) {
+    /**
+     * Set the default inference ids provided by the services
+     * @param defaultConfigIds The defaults
+     */
+    public void addDefaultIds(InferenceService.DefaultConfigId defaultConfigIds) {
+        var matched = idMatchedDefault(defaultConfigIds.inferenceId(), this.defaultConfigIds);
+        if (matched.isPresent()) {
             throw new IllegalStateException(
                 "Cannot add default endpoint to the inference endpoint registry with duplicate inference id ["
-                    + serviceDefaultConfig.inferenceEntityId()
+                    + defaultConfigIds.inferenceId()
                     + "] declared by service ["
-                    + serviceDefaultConfig.service()
+                    + defaultConfigIds.service().name()
                     + "]. The inference Id is already use by ["
-                    + defaultConfigs.get(serviceDefaultConfig.inferenceEntityId()).service()
+                    + matched.get().service().name()
                     + "] service."
             );
         }
-
-        defaultConfigs.put(serviceDefaultConfig.inferenceEntityId(), serviceDefaultConfig);
+        this.defaultConfigIds.add(defaultConfigIds);
     }
 
     /**
@@ -118,15 +127,15 @@ public class ModelRegistry {
      * @param listener Model listener
      */
     public void getModelWithSecrets(String inferenceEntityId, ActionListener<UnparsedModel> listener) {
-        if (defaultConfigs.containsKey(inferenceEntityId)) {
-            listener.onResponse(deepCopyDefaultConfig(defaultConfigs.get(inferenceEntityId)));
-            return;
-        }
-
         ActionListener<SearchResponse> searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> {
-            // There should be a hit for the configurations and secrets
+            // There should be a hit for the configurations
             if (searchResponse.getHits().getHits().length == 0) {
-                delegate.onFailure(inferenceNotFoundException(inferenceEntityId));
+                var maybeDefault = idMatchedDefault(inferenceEntityId, defaultConfigIds);
+                if (maybeDefault.isPresent()) {
+                    getDefaultConfig(maybeDefault.get(), listener);
+                } else {
+                    delegate.onFailure(inferenceNotFoundException(inferenceEntityId));
+                }
                 return;
             }
 
@@ -149,15 +158,15 @@ public class ModelRegistry {
      * @param listener Model listener
      */
     public void getModel(String inferenceEntityId, ActionListener<UnparsedModel> listener) {
-        if (defaultConfigs.containsKey(inferenceEntityId)) {
-            listener.onResponse(deepCopyDefaultConfig(defaultConfigs.get(inferenceEntityId)));
-            return;
-        }
-
         ActionListener<SearchResponse> searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> {
-            // There should be a hit for the configurations and secrets
+            // There should be a hit for the configurations
             if (searchResponse.getHits().getHits().length == 0) {
-                delegate.onFailure(inferenceNotFoundException(inferenceEntityId));
+                var maybeDefault = idMatchedDefault(inferenceEntityId, defaultConfigIds);
+                if (maybeDefault.isPresent()) {
+                    getDefaultConfig(maybeDefault.get(), listener);
+                } else {
+                    delegate.onFailure(inferenceNotFoundException(inferenceEntityId));
+                }
                 return;
             }
 
@@ -188,29 +197,9 @@ public class ModelRegistry {
      */
     public void getModelsByTaskType(TaskType taskType, ActionListener<List<UnparsedModel>> listener) {
         ActionListener<SearchResponse> searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> {
-            var defaultConfigsForTaskType = defaultConfigs.values()
-                .stream()
-                .filter(m -> m.taskType() == taskType)
-                .map(ModelRegistry::deepCopyDefaultConfig)
-                .toList();
-
-            // Not an error if no models of this task_type
-            if (searchResponse.getHits().getHits().length == 0 && defaultConfigsForTaskType.isEmpty()) {
-                delegate.onResponse(List.of());
-                return;
-            }
-
             var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList();
-
-            if (defaultConfigsForTaskType.isEmpty() == false) {
-                var allConfigs = new ArrayList<UnparsedModel>();
-                allConfigs.addAll(modelConfigs);
-                allConfigs.addAll(defaultConfigsForTaskType);
-                allConfigs.sort(Comparator.comparing(UnparsedModel::inferenceEntityId));
-                delegate.onResponse(allConfigs);
-            } else {
-                delegate.onResponse(modelConfigs);
-            }
+            var defaultConfigsForTaskType = taskTypeMatchedDefaults(taskType, defaultConfigIds);
+            addAllDefaultConfigsIfMissing(modelConfigs, defaultConfigsForTaskType, delegate);
         });
 
         QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.termsQuery(TASK_TYPE_FIELD, taskType.toString()));
@@ -232,19 +221,8 @@ public class ModelRegistry {
      */
     public void getAllModels(ActionListener<List<UnparsedModel>> listener) {
         ActionListener<SearchResponse> searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> {
-            var defaults = defaultConfigs.values().stream().map(ModelRegistry::deepCopyDefaultConfig).toList();
-
-            if (searchResponse.getHits().getHits().length == 0 && defaults.isEmpty()) {
-                delegate.onResponse(List.of());
-                return;
-            }
-
             var foundConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList();
-            var allConfigs = new ArrayList<UnparsedModel>();
-            allConfigs.addAll(foundConfigs);
-            allConfigs.addAll(defaults);
-            allConfigs.sort(Comparator.comparing(UnparsedModel::inferenceEntityId));
-            delegate.onResponse(allConfigs);
+            addAllDefaultConfigsIfMissing(foundConfigs, defaultConfigIds, delegate);
         });
 
         // In theory the index should only contain model config documents
@@ -262,6 +240,67 @@ public class ModelRegistry {
         client.search(modelSearch, searchListener);
     }
 
+    private void addAllDefaultConfigsIfMissing(
+        List<UnparsedModel> foundConfigs,
+        List<InferenceService.DefaultConfigId> matchedDefaults,
+        ActionListener<List<UnparsedModel>> listener
+    ) {
+        var foundIds = foundConfigs.stream().map(UnparsedModel::inferenceEntityId).collect(Collectors.toSet());
+        var missing = matchedDefaults.stream().filter(d -> foundIds.contains(d.inferenceId()) == false).toList();
+
+        if (missing.isEmpty()) {
+            listener.onResponse(foundConfigs);
+        } else {
+            var groupedListener = new GroupedActionListener<UnparsedModel>(
+                missing.size(),
+                listener.delegateFailure((delegate, listOfModels) -> {
+                    var allConfigs = new ArrayList<UnparsedModel>();
+                    allConfigs.addAll(foundConfigs);
+                    allConfigs.addAll(listOfModels);
+                    allConfigs.sort(Comparator.comparing(UnparsedModel::inferenceEntityId));
+                    delegate.onResponse(allConfigs);
+                })
+            );
+
+            for (var required : missing) {
+                getDefaultConfig(required, groupedListener);
+            }
+        }
+    }
+
+    private void getDefaultConfig(InferenceService.DefaultConfigId defaultConfig, ActionListener<UnparsedModel> listener) {
+        defaultConfig.service().defaultConfigs(listener.delegateFailureAndWrap((delegate, models) -> {
+            boolean foundModel = false;
+            for (var m : models) {
+                if (m.getInferenceEntityId().equals(defaultConfig.inferenceId())) {
+                    foundModel = true;
+                    storeDefaultEndpoint(m, () -> listener.onResponse(modelToUnparsedModel(m)));
+                    break;
+                }
+            }
+
+            if (foundModel == false) {
+                listener.onFailure(
+                    new IllegalStateException("Configuration not found for default inference id [" + defaultConfig.inferenceId() + "]")
+                );
+            }
+        }));
+    }
+
+    public void storeDefaultEndpoint(Model preconfigured, Runnable runAfter) {
+        var responseListener = ActionListener.<Boolean>wrap(success -> {
+            logger.debug("Added default inference endpoint [{}]", preconfigured.getInferenceEntityId());
+        }, exception -> {
+            if (exception instanceof ResourceAlreadyExistsException) {
+                logger.debug("Default inference id [{}] already exists", preconfigured.getInferenceEntityId());
+            } else {
+                logger.error("Failed to store default inference id [" + preconfigured.getInferenceEntityId() + "]", exception);
+            }
+        });
+
+        storeModel(preconfigured, ActionListener.runAfter(responseListener, runAfter));
+    }
+
     private ArrayList<ModelConfigMap> parseHitsAsModels(SearchHits hits) {
         var modelConfigs = new ArrayList<ModelConfigMap>();
         for (var hit : hits) {
@@ -578,60 +617,36 @@ public class ModelRegistry {
         }
     }
 
-    private QueryBuilder documentIdQuery(String inferenceEntityId) {
-        return QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(Model.documentId(inferenceEntityId)));
-    }
-
-    static UnparsedModel deepCopyDefaultConfig(UnparsedModel other) {
-        // Because the default config uses immutable maps
-        return new UnparsedModel(
-            other.inferenceEntityId(),
-            other.taskType(),
-            other.service(),
-            copySettingsMap(other.settings()),
-            copySecretsMap(other.secrets())
-        );
-    }
-
-    @SuppressWarnings("unchecked")
-    static Map<String, Object> copySettingsMap(Map<String, Object> other) {
-        var result = new HashMap<String, Object>();
-
-        var serviceSettings = (Map<String, Object>) other.get(ModelConfigurations.SERVICE_SETTINGS);
-        if (serviceSettings != null) {
-            var copiedServiceSettings = copyMap1LevelDeep(serviceSettings);
-            result.put(ModelConfigurations.SERVICE_SETTINGS, copiedServiceSettings);
-        }
+    private static UnparsedModel modelToUnparsedModel(Model model) {
+        try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
+            model.getConfigurations()
+                .toXContent(builder, new ToXContent.MapParams(Map.of(ModelConfigurations.USE_ID_FOR_INDEX, Boolean.TRUE.toString())));
 
-        var taskSettings = (Map<String, Object>) other.get(ModelConfigurations.TASK_SETTINGS);
-        if (taskSettings != null) {
-            var copiedTaskSettings = copyMap1LevelDeep(taskSettings);
-            result.put(ModelConfigurations.TASK_SETTINGS, copiedTaskSettings);
-        }
+            var modelConfigMap = XContentHelper.convertToMap(BytesReference.bytes(builder), false, builder.contentType()).v2();
+            return unparsedModelFromMap(new ModelConfigMap(modelConfigMap, new HashMap<>()));
 
-        var chunkSettings = (Map<String, Object>) other.get(ModelConfigurations.CHUNKING_SETTINGS);
-        if (chunkSettings != null) {
-            var copiedChunkSettings = copyMap1LevelDeep(chunkSettings);
-            result.put(ModelConfigurations.CHUNKING_SETTINGS, copiedChunkSettings);
+        } catch (IOException ex) {
+            throw new ElasticsearchException("[{}] Error serializing inference endpoint configuration", model.getInferenceEntityId(), ex);
         }
+    }
 
-        return result;
+    private QueryBuilder documentIdQuery(String inferenceEntityId) {
+        return QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(Model.documentId(inferenceEntityId)));
     }
 
-    static Map<String, Object> copySecretsMap(Map<String, Object> other) {
-        return copyMap1LevelDeep(other);
+    static Optional<InferenceService.DefaultConfigId> idMatchedDefault(
+        String inferenceId,
+        List<InferenceService.DefaultConfigId> defaultConfigIds
+    ) {
+        return defaultConfigIds.stream().filter(defaultConfigId -> defaultConfigId.inferenceId().equals(inferenceId)).findFirst();
     }
 
-    @SuppressWarnings("unchecked")
-    static Map<String, Object> copyMap1LevelDeep(Map<String, Object> other) {
-        var result = new HashMap<String, Object>();
-        for (var entry : other.entrySet()) {
-            if (entry.getValue() instanceof Map<?, ?>) {
-                result.put(entry.getKey(), new HashMap<>((Map<String, Object>) entry.getValue()));
-            } else {
-                result.put(entry.getKey(), entry.getValue());
-            }
-        }
-        return result;
+    static List<InferenceService.DefaultConfigId> taskTypeMatchedDefaults(
+        TaskType taskType,
+        List<InferenceService.DefaultConfigId> defaultConfigIds
+    ) {
+        return defaultConfigIds.stream()
+            .filter(defaultConfigId -> defaultConfigId.taskType().equals(taskType))
+            .collect(Collectors.toList());
     }
 }

+ 35 - 19
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java

@@ -15,6 +15,7 @@ import org.elasticsearch.ResourceNotFoundException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.SubscribableListener;
 import org.elasticsearch.client.internal.OriginSettingClient;
+import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.inference.InferenceService;
 import org.elasticsearch.inference.InferenceServiceExtension;
@@ -22,6 +23,7 @@ import org.elasticsearch.inference.InputType;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.xpack.core.ClientHelper;
+import org.elasticsearch.xpack.core.ml.MachineLearningField;
 import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
 import org.elasticsearch.xpack.core.ml.action.InferModelAction;
 import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
@@ -38,7 +40,6 @@ import org.elasticsearch.xpack.inference.InferencePlugin;
 import java.io.IOException;
 import java.util.EnumSet;
 import java.util.List;
-import java.util.Set;
 import java.util.concurrent.ExecutorService;
 import java.util.function.Consumer;
 
@@ -49,14 +50,21 @@ public abstract class BaseElasticsearchInternalService implements InferenceServi
 
     protected final OriginSettingClient client;
     protected final ExecutorService inferenceExecutor;
-    protected final Consumer<ActionListener<Set<String>>> platformArch;
+    protected final Consumer<ActionListener<PreferredModelVariant>> preferredModelVariantFn;
+    private final ClusterService clusterService;
+
+    public enum PreferredModelVariant {
+        LINUX_X86_OPTIMIZED,
+        PLATFORM_AGNOSTIC
+    };
 
     private static final Logger logger = LogManager.getLogger(BaseElasticsearchInternalService.class);
 
     public BaseElasticsearchInternalService(InferenceServiceExtension.InferenceServiceFactoryContext context) {
         this.client = new OriginSettingClient(context.client(), ClientHelper.INFERENCE_ORIGIN);
         this.inferenceExecutor = context.threadPool().executor(InferencePlugin.UTILITY_THREAD_POOL_NAME);
-        this.platformArch = this::platformArchitecture;
+        this.preferredModelVariantFn = this::preferredVariantFromPlatformArchitecture;
+        this.clusterService = context.clusterService();
     }
 
     // For testing.
@@ -66,11 +74,12 @@ public abstract class BaseElasticsearchInternalService implements InferenceServi
     // service package.
     public BaseElasticsearchInternalService(
         InferenceServiceExtension.InferenceServiceFactoryContext context,
-        Consumer<ActionListener<Set<String>>> platformArchFn
+        Consumer<ActionListener<PreferredModelVariant>> preferredModelVariantFn
     ) {
         this.client = new OriginSettingClient(context.client(), ClientHelper.INFERENCE_ORIGIN);
         this.inferenceExecutor = context.threadPool().executor(InferencePlugin.UTILITY_THREAD_POOL_NAME);
-        this.platformArch = platformArchFn;
+        this.preferredModelVariantFn = preferredModelVariantFn;
+        this.clusterService = context.clusterService();
     }
 
     /**
@@ -206,31 +215,36 @@ public abstract class BaseElasticsearchInternalService implements InferenceServi
     public void close() throws IOException {}
 
     public static String selectDefaultModelVariantBasedOnClusterArchitecture(
-        Set<String> modelArchitectures,
-        String linuxX86OptimisedModel,
+        PreferredModelVariant preferredModelVariant,
+        String linuxX86OptimizedModel,
         String platformAgnosticModel
     ) {
         // choose a default model version based on the cluster architecture
-        boolean homogenous = modelArchitectures.size() == 1;
-        if (homogenous && modelArchitectures.iterator().next().equals("linux-x86_64")) {
+        if (PreferredModelVariant.LINUX_X86_OPTIMIZED.equals(preferredModelVariant)) {
             // Use the hardware optimized model
-            return linuxX86OptimisedModel;
+            return linuxX86OptimizedModel;
         } else {
             // default to the platform-agnostic model
             return platformAgnosticModel;
         }
     }
 
-    private void platformArchitecture(ActionListener<Set<String>> platformArchitectureListener) {
+    private void preferredVariantFromPlatformArchitecture(ActionListener<PreferredModelVariant> preferredVariantListener) {
         // Find the cluster platform as the service may need that
         // information when creating the model
         MlPlatformArchitecturesUtil.getMlNodesArchitecturesSet(
-            platformArchitectureListener.delegateFailureAndWrap((delegate, architectures) -> {
-                if (architectures.isEmpty() && clusterIsInElasticCloud()) {
-                    // In Elastic cloud ml nodes run on Linux x86
-                    delegate.onResponse(Set.of("linux-x86_64"));
+            preferredVariantListener.delegateFailureAndWrap((delegate, architectures) -> {
+                if (architectures.isEmpty() && isClusterInElasticCloud()) {
+                    // There are no ml nodes to check the current arch.
+                    // However, in Elastic cloud ml nodes run on Linux x86
+                    delegate.onResponse(PreferredModelVariant.LINUX_X86_OPTIMIZED);
                 } else {
-                    delegate.onResponse(architectures);
+                    boolean homogenous = architectures.size() == 1;
+                    if (homogenous && architectures.iterator().next().equals("linux-x86_64")) {
+                        delegate.onResponse(PreferredModelVariant.LINUX_X86_OPTIMIZED);
+                    } else {
+                        delegate.onResponse(PreferredModelVariant.PLATFORM_AGNOSTIC);
+                    }
                 }
             }),
             client,
@@ -238,9 +252,11 @@ public abstract class BaseElasticsearchInternalService implements InferenceServi
         );
     }
 
-    static boolean clusterIsInElasticCloud() {
-        // use a heuristic to determine if in Elastic cloud.
-        return true; // TODO
+    boolean isClusterInElasticCloud() {
+        // Use the ml lazy node count as a heuristic to determine if in Elastic cloud.
+        // A value > 0 means scaling should be available for ml nodes
+        var maxMlLazyNodes = clusterService.getClusterSettings().get(MachineLearningField.MAX_LAZY_ML_NODES);
+        return maxMlLazyNodes > 0;
     }
 
     public static InferModelAction.Request buildInferenceRequest(

+ 57 - 67
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

@@ -27,13 +27,13 @@ import org.elasticsearch.inference.InputType;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.ModelConfigurations;
 import org.elasticsearch.inference.TaskType;
-import org.elasticsearch.inference.UnparsedModel;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
 import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
 import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
 import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
 import org.elasticsearch.xpack.core.ml.action.InferModelAction;
+import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
 import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
 import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
@@ -61,7 +61,6 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFrom
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
 import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.ELSER_V2_MODEL;
 import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels.ELSER_V2_MODEL_LINUX_X86;
-import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.EMBEDDING_MAX_BATCH_SIZE;
 
 public class ElasticsearchInternalService extends BaseElasticsearchInternalService {
 
@@ -88,7 +87,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
     // for testing
     ElasticsearchInternalService(
         InferenceServiceExtension.InferenceServiceFactoryContext context,
-        Consumer<ActionListener<Set<String>>> platformArch
+        Consumer<ActionListener<PreferredModelVariant>> platformArch
     ) {
         super(context, platformArch);
     }
@@ -143,13 +142,13 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
                         "Putting elasticsearch service inference endpoints (including elser service) without a model_id field is"
                             + " deprecated and will be removed in a future release. Please specify a model_id field."
                     );
-                    platformArch.accept(
+                    preferredModelVariantFn.accept(
                         modelListener.delegateFailureAndWrap(
-                            (delegate, arch) -> elserCase(
+                            (delegate, preferredModelVariant) -> elserCase(
                                 inferenceEntityId,
                                 taskType,
                                 config,
-                                arch,
+                                preferredModelVariant,
                                 serviceSettingsMap,
                                 chunkingSettings,
                                 modelListener
@@ -160,13 +159,13 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
                     throw new IllegalArgumentException("Error parsing service settings, model_id must be provided");
                 }
             } else if (MULTILINGUAL_E5_SMALL_VALID_IDS.contains(modelId)) {
-                platformArch.accept(
+                preferredModelVariantFn.accept(
                     modelListener.delegateFailureAndWrap(
-                        (delegate, arch) -> e5Case(
+                        (delegate, preferredModelVariant) -> e5Case(
                             inferenceEntityId,
                             taskType,
                             config,
-                            arch,
+                            preferredModelVariant,
                             serviceSettingsMap,
                             chunkingSettings,
                             modelListener
@@ -174,13 +173,13 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
                     )
                 );
             } else if (ElserModels.isValidModel(modelId)) {
-                platformArch.accept(
+                preferredModelVariantFn.accept(
                     modelListener.delegateFailureAndWrap(
-                        (delegate, arch) -> elserCase(
+                        (delegate, preferredModelVariant) -> elserCase(
                             inferenceEntityId,
                             taskType,
                             config,
-                            arch,
+                            preferredModelVariant,
                             serviceSettingsMap,
                             chunkingSettings,
                             modelListener
@@ -284,7 +283,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
         String inferenceEntityId,
         TaskType taskType,
         Map<String, Object> config,
-        Set<String> platformArchitectures,
+        PreferredModelVariant preferredModelVariant,
         Map<String, Object> serviceSettingsMap,
         ChunkingSettings chunkingSettings,
         ActionListener<Model> modelListener
@@ -294,12 +293,12 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
         if (esServiceSettingsBuilder.getModelId() == null) {
             esServiceSettingsBuilder.setModelId(
                 selectDefaultModelVariantBasedOnClusterArchitecture(
-                    platformArchitectures,
+                    preferredModelVariant,
                     MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86,
                     MULTILINGUAL_E5_SMALL_MODEL_ID
                 )
             );
-        } else if (modelVariantValidForArchitecture(platformArchitectures, esServiceSettingsBuilder.getModelId()) == false) {
+        } else if (modelVariantValidForArchitecture(preferredModelVariant, esServiceSettingsBuilder.getModelId()) == false) {
             throw new IllegalArgumentException(
                 "Error parsing request config, model id does not match any models available on this platform. Was ["
                     + esServiceSettingsBuilder.getModelId()
@@ -321,14 +320,14 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
         );
     }
 
-    static boolean modelVariantValidForArchitecture(Set<String> platformArchitectures, String modelId) {
+    static boolean modelVariantValidForArchitecture(PreferredModelVariant modelVariant, String modelId) {
         if (modelId.equals(MULTILINGUAL_E5_SMALL_MODEL_ID)) {
             // platform agnostic model is always compatible
             return true;
         }
         return modelId.equals(
             selectDefaultModelVariantBasedOnClusterArchitecture(
-                platformArchitectures,
+                modelVariant,
                 MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86,
                 MULTILINGUAL_E5_SMALL_MODEL_ID
             )
@@ -339,14 +338,14 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
         String inferenceEntityId,
         TaskType taskType,
         Map<String, Object> config,
-        Set<String> platformArchitectures,
+        PreferredModelVariant preferredModelVariant,
         Map<String, Object> serviceSettingsMap,
         ChunkingSettings chunkingSettings,
         ActionListener<Model> modelListener
     ) {
         var esServiceSettingsBuilder = ElasticsearchInternalServiceSettings.fromRequestMap(serviceSettingsMap);
         final String defaultModelId = selectDefaultModelVariantBasedOnClusterArchitecture(
-            platformArchitectures,
+            preferredModelVariant,
             ELSER_V2_MODEL_LINUX_X86,
             ELSER_V2_MODEL
         );
@@ -381,14 +380,6 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
             defaultModelId
         );
 
-        if (modelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic(platformArchitectures, esServiceSettingsBuilder.getModelId())) {
-            throw new IllegalArgumentException(
-                "Error parsing request config, model id does not match any models available on this platform. Was ["
-                    + esServiceSettingsBuilder.getModelId()
-                    + "]"
-            );
-        }
-
         throwIfNotEmptyMap(config, name());
         throwIfNotEmptyMap(serviceSettingsMap, name());
 
@@ -404,19 +395,6 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
         );
     }
 
-    private static boolean modelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic(
-        Set<String> platformArchitectures,
-        String modelId
-    ) {
-        return modelId.equals(
-            selectDefaultModelVariantBasedOnClusterArchitecture(
-                platformArchitectures,
-                MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86,
-                MULTILINGUAL_E5_SMALL_MODEL_ID
-            )
-        );
-    }
-
     @Override
     public Model parsePersistedConfigWithSecrets(
         String inferenceEntityId,
@@ -781,37 +759,49 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
         return new RankedDocsResults(rankings);
     }
 
+    public List<DefaultConfigId> defaultConfigIds() {
+        return List.of(new DefaultConfigId(DEFAULT_ELSER_ID, TaskType.SPARSE_EMBEDDING, this));
+    }
+
+    /**
+     * Default configurations that can be out of the box without creating an endpoint first.
+     * @param defaultsListener Config listener
+     */
     @Override
-    public List<UnparsedModel> defaultConfigs() {
-        // TODO Chunking settings
-        Map<String, Object> elserSettings = Map.of(
-            ModelConfigurations.SERVICE_SETTINGS,
-            Map.of(
-                ElasticsearchInternalServiceSettings.MODEL_ID,
-                ElserModels.ELSER_V2_MODEL,  // TODO pick model depending on platform
-                ElasticsearchInternalServiceSettings.NUM_THREADS,
+    public void defaultConfigs(ActionListener<List<Model>> defaultsListener) {
+        preferredModelVariantFn.accept(defaultsListener.delegateFailureAndWrap((delegate, preferredModelVariant) -> {
+            if (PreferredModelVariant.LINUX_X86_OPTIMIZED.equals(preferredModelVariant)) {
+                defaultsListener.onResponse(defaultConfigsLinuxOptimized());
+            } else {
+                defaultsListener.onResponse(defaultConfigsPlatfromAgnostic());
+            }
+        }));
+    }
+
+    private List<Model> defaultConfigsLinuxOptimized() {
+        return defaultConfigs(true);
+    }
+
+    private List<Model> defaultConfigsPlatfromAgnostic() {
+        return defaultConfigs(false);
+    }
+
+    private List<Model> defaultConfigs(boolean useLinuxOptimizedModel) {
+        var defaultElser = new ElserInternalModel(
+            DEFAULT_ELSER_ID,
+            TaskType.SPARSE_EMBEDDING,
+            NAME,
+            new ElserInternalServiceSettings(
+                null,
                 1,
-                ElasticsearchInternalServiceSettings.ADAPTIVE_ALLOCATIONS,
-                Map.of(
-                    "enabled",
-                    Boolean.TRUE,
-                    "min_number_of_allocations",
-                    1,
-                    "max_number_of_allocations",
-                    8   // no max?
-                )
-            )
+                useLinuxOptimizedModel ? ELSER_V2_MODEL_LINUX_X86 : ELSER_V2_MODEL,
+                new AdaptiveAllocationsSettings(Boolean.TRUE, 1, 8)
+            ),
+            ElserMlNodeTaskSettings.DEFAULT,
+            null // default chunking settings
         );
 
-        return List.of(
-            new UnparsedModel(
-                DEFAULT_ELSER_ID,
-                TaskType.SPARSE_EMBEDDING,
-                NAME,
-                elserSettings,
-                Map.of() // no secrets
-            )
-        );
+        return List.of(defaultElser);
     }
 
     @Override

+ 32 - 55
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java

@@ -22,6 +22,7 @@ import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.index.engine.VersionConflictEngineException;
+import org.elasticsearch.inference.InferenceService;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.inference.UnparsedModel;
 import org.elasticsearch.search.SearchHit;
@@ -35,16 +36,16 @@ import org.junit.After;
 import org.junit.Before;
 
 import java.nio.ByteBuffer;
+import java.util.ArrayList;
 import java.util.Map;
 import java.util.concurrent.TimeUnit;
 
 import static org.elasticsearch.core.Strings.format;
+import static org.hamcrest.Matchers.contains;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.empty;
 import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.is;
-import static org.hamcrest.Matchers.not;
-import static org.hamcrest.Matchers.sameInstance;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
@@ -292,58 +293,30 @@ public class ModelRegistryTests extends ESTestCase {
         );
     }
 
-    @SuppressWarnings("unchecked")
-    public void testDeepCopyDefaultConfig() {
-        {
-            var toCopy = new UnparsedModel("tocopy", randomFrom(TaskType.values()), "service-a", Map.of(), Map.of());
-            var copied = ModelRegistry.deepCopyDefaultConfig(toCopy);
-            assertThat(copied, not(sameInstance(toCopy)));
-            assertThat(copied.taskType(), is(toCopy.taskType()));
-            assertThat(copied.service(), is(toCopy.service()));
-            assertThat(copied.secrets(), not(sameInstance(toCopy.secrets())));
-            assertThat(copied.secrets(), is(toCopy.secrets()));
-            // Test copied is a modifiable map
-            copied.secrets().put("foo", "bar");
-
-            assertThat(copied.settings(), not(sameInstance(toCopy.settings())));
-            assertThat(copied.settings(), is(toCopy.settings()));
-            // Test copied is a modifiable map
-            copied.settings().put("foo", "bar");
-        }
+    public void testIdMatchedDefault() {
+        var defaultConfigIds = new ArrayList<InferenceService.DefaultConfigId>();
+        defaultConfigIds.add(new InferenceService.DefaultConfigId("foo", TaskType.SPARSE_EMBEDDING, mock(InferenceService.class)));
+        defaultConfigIds.add(new InferenceService.DefaultConfigId("bar", TaskType.SPARSE_EMBEDDING, mock(InferenceService.class)));
 
-        {
-            Map<String, Object> secretsMap = Map.of("secret", "value");
-            Map<String, Object> chunking = Map.of("strategy", "word");
-            Map<String, Object> task = Map.of("user", "name");
-            Map<String, Object> service = Map.of("num_threads", 1, "adaptive_allocations", Map.of("enabled", true));
-            Map<String, Object> settings = Map.of("chunking_settings", chunking, "service_settings", service, "task_settings", task);
-
-            var toCopy = new UnparsedModel("tocopy", randomFrom(TaskType.values()), "service-a", settings, secretsMap);
-            var copied = ModelRegistry.deepCopyDefaultConfig(toCopy);
-            assertThat(copied, not(sameInstance(toCopy)));
-
-            assertThat(copied.secrets(), not(sameInstance(toCopy.secrets())));
-            assertThat(copied.secrets(), is(toCopy.secrets()));
-            // Test copied is a modifiable map
-            copied.secrets().remove("secret");
-
-            assertThat(copied.settings(), not(sameInstance(toCopy.settings())));
-            assertThat(copied.settings(), is(toCopy.settings()));
-            // Test copied is a modifiable map
-            var chunkOut = (Map<String, Object>) copied.settings().get("chunking_settings");
-            assertThat(chunkOut, is(chunking));
-            chunkOut.remove("strategy");
-
-            var taskOut = (Map<String, Object>) copied.settings().get("task_settings");
-            assertThat(taskOut, is(task));
-            taskOut.remove("user");
-
-            var serviceOut = (Map<String, Object>) copied.settings().get("service_settings");
-            assertThat(serviceOut, is(service));
-            var adaptiveOut = (Map<String, Object>) serviceOut.remove("adaptive_allocations");
-            assertThat(adaptiveOut, is(Map.of("enabled", true)));
-            adaptiveOut.remove("enabled");
-        }
+        var matched = ModelRegistry.idMatchedDefault("bar", defaultConfigIds);
+        assertEquals(defaultConfigIds.get(1), matched.get());
+        matched = ModelRegistry.idMatchedDefault("baz", defaultConfigIds);
+        assertFalse(matched.isPresent());
+    }
+
+    public void testTaskTypeMatchedDefaults() {
+        var defaultConfigIds = new ArrayList<InferenceService.DefaultConfigId>();
+        defaultConfigIds.add(new InferenceService.DefaultConfigId("s1", TaskType.SPARSE_EMBEDDING, mock(InferenceService.class)));
+        defaultConfigIds.add(new InferenceService.DefaultConfigId("s2", TaskType.SPARSE_EMBEDDING, mock(InferenceService.class)));
+        defaultConfigIds.add(new InferenceService.DefaultConfigId("d1", TaskType.TEXT_EMBEDDING, mock(InferenceService.class)));
+        defaultConfigIds.add(new InferenceService.DefaultConfigId("c1", TaskType.COMPLETION, mock(InferenceService.class)));
+
+        var matched = ModelRegistry.taskTypeMatchedDefaults(TaskType.SPARSE_EMBEDDING, defaultConfigIds);
+        assertThat(matched, contains(defaultConfigIds.get(0), defaultConfigIds.get(1)));
+        matched = ModelRegistry.taskTypeMatchedDefaults(TaskType.TEXT_EMBEDDING, defaultConfigIds);
+        assertThat(matched, contains(defaultConfigIds.get(2)));
+        matched = ModelRegistry.taskTypeMatchedDefaults(TaskType.RERANK, defaultConfigIds);
+        assertThat(matched, empty());
     }
 
     public void testDuplicateDefaultIds() {
@@ -351,11 +324,15 @@ public class ModelRegistryTests extends ESTestCase {
         var registry = new ModelRegistry(client);
 
         var id = "my-inference";
+        var mockServiceA = mock(InferenceService.class);
+        when(mockServiceA.name()).thenReturn("service-a");
+        var mockServiceB = mock(InferenceService.class);
+        when(mockServiceB.name()).thenReturn("service-b");
 
-        registry.addDefaultConfiguration(new UnparsedModel(id, randomFrom(TaskType.values()), "service-a", Map.of(), Map.of()));
+        registry.addDefaultIds(new InferenceService.DefaultConfigId(id, randomFrom(TaskType.values()), mockServiceA));
         var ise = expectThrows(
             IllegalStateException.class,
-            () -> registry.addDefaultConfiguration(new UnparsedModel(id, randomFrom(TaskType.values()), "service-b", Map.of(), Map.of()))
+            () -> registry.addDefaultIds(new InferenceService.DefaultConfigId(id, randomFrom(TaskType.values()), mockServiceB))
         );
         assertThat(
             ise.getMessage(),

+ 39 - 21
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java

@@ -14,7 +14,9 @@ import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.PlainActionFuture;
 import org.elasticsearch.client.internal.Client;
+import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.logging.DeprecationLogger;
+import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
@@ -37,6 +39,7 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction;
 import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
 import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults;
 import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.ml.MachineLearningField;
 import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
 import org.elasticsearch.xpack.core.ml.action.InferModelAction;
 import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
@@ -170,7 +173,7 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
 
     public void testParseRequestConfig_E5() {
         {
-            var service = createService(mock(Client.class), Set.of("Aarch64"));
+            var service = createService(mock(Client.class), BaseElasticsearchInternalService.PreferredModelVariant.PLATFORM_AGNOSTIC);
             var settings = new HashMap<String, Object>();
             settings.put(
                 ModelConfigurations.SERVICE_SETTINGS,
@@ -197,7 +200,7 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
         }
 
         {
-            var service = createService(mock(Client.class), Set.of("linux-x86_64"));
+            var service = createService(mock(Client.class), BaseElasticsearchInternalService.PreferredModelVariant.LINUX_X86_OPTIMIZED);
             var settings = new HashMap<String, Object>();
             settings.put(
                 ModelConfigurations.SERVICE_SETTINGS,
@@ -230,7 +233,7 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
 
         // Invalid service settings
         {
-            var service = createService(mock(Client.class), Set.of("Aarch64"));
+            var service = createService(mock(Client.class), BaseElasticsearchInternalService.PreferredModelVariant.PLATFORM_AGNOSTIC);
             var settings = new HashMap<String, Object>();
             settings.put(
                 ModelConfigurations.SERVICE_SETTINGS,
@@ -257,7 +260,7 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
         }
 
         {
-            var service = createService(mock(Client.class), Set.of("Aarch64"));
+            var service = createService(mock(Client.class), BaseElasticsearchInternalService.PreferredModelVariant.PLATFORM_AGNOSTIC);
             var settings = new HashMap<String, Object>();
             settings.put(
                 ModelConfigurations.SERVICE_SETTINGS,
@@ -285,7 +288,7 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
         }
 
         {
-            var service = createService(mock(Client.class), Set.of("Aarch64"));
+            var service = createService(mock(Client.class), BaseElasticsearchInternalService.PreferredModelVariant.PLATFORM_AGNOSTIC);
             var settings = new HashMap<String, Object>();
             settings.put(
                 ModelConfigurations.SERVICE_SETTINGS,
@@ -1377,26 +1380,33 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
 
     public void testModelVariantDoesNotMatchArchitecturesAndIsNotPlatformAgnostic() {
         {
-            var architectures = Set.of("Aarch64");
             assertFalse(
-                ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86)
+                ElasticsearchInternalService.modelVariantValidForArchitecture(
+                    BaseElasticsearchInternalService.PreferredModelVariant.PLATFORM_AGNOSTIC,
+                    MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86
+                )
             );
 
-            assertTrue(ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID));
-        }
-        {
-            var architectures = Set.of("linux-x86_64");
             assertTrue(
-                ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86)
+                ElasticsearchInternalService.modelVariantValidForArchitecture(
+                    BaseElasticsearchInternalService.PreferredModelVariant.PLATFORM_AGNOSTIC,
+                    MULTILINGUAL_E5_SMALL_MODEL_ID
+                )
             );
-            assertTrue(ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID));
         }
         {
-            var architectures = Set.of("linux-x86_64", "Aarch64");
-            assertFalse(
-                ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86)
+            assertTrue(
+                ElasticsearchInternalService.modelVariantValidForArchitecture(
+                    BaseElasticsearchInternalService.PreferredModelVariant.LINUX_X86_OPTIMIZED,
+                    MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86
+                )
+            );
+            assertTrue(
+                ElasticsearchInternalService.modelVariantValidForArchitecture(
+                    BaseElasticsearchInternalService.PreferredModelVariant.LINUX_X86_OPTIMIZED,
+                    MULTILINGUAL_E5_SMALL_MODEL_ID
+                )
             );
-            assertTrue(ElasticsearchInternalService.modelVariantValidForArchitecture(architectures, MULTILINGUAL_E5_SMALL_MODEL_ID));
         }
     }
 
@@ -1427,12 +1437,20 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
     }
 
     private ElasticsearchInternalService createService(Client client) {
-        var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client, threadPool);
+        var cs = mock(ClusterService.class);
+        var cSettings = new ClusterSettings(Settings.EMPTY, Set.of(MachineLearningField.MAX_LAZY_ML_NODES));
+        when(cs.getClusterSettings()).thenReturn(cSettings);
+        var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client, threadPool, cs, Settings.EMPTY);
         return new ElasticsearchInternalService(context);
     }
 
-    private ElasticsearchInternalService createService(Client client, Set<String> architectures) {
-        var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client, threadPool);
-        return new ElasticsearchInternalService(context, l -> l.onResponse(architectures));
+    private ElasticsearchInternalService createService(Client client, BaseElasticsearchInternalService.PreferredModelVariant modelVariant) {
+        var context = new InferenceServiceExtension.InferenceServiceFactoryContext(
+            client,
+            threadPool,
+            mock(ClusterService.class),
+            Settings.EMPTY
+        );
+        return new ElasticsearchInternalService(context, l -> l.onResponse(modelVariant));
     }
 }

+ 3 - 2
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/AutoscalingIT.java

@@ -16,6 +16,7 @@ import org.elasticsearch.xpack.autoscaling.action.GetAutoscalingCapacityAction;
 import org.elasticsearch.xpack.autoscaling.action.PutAutoscalingPolicyAction;
 import org.elasticsearch.xpack.autoscaling.capacity.AutoscalingDeciderResult;
 import org.elasticsearch.xpack.autoscaling.capacity.AutoscalingDeciderResults;
+import org.elasticsearch.xpack.core.ml.MachineLearningField;
 import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
 import org.elasticsearch.xpack.core.ml.action.PutTrainedModelDefinitionPartAction;
 import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction;
@@ -62,14 +63,14 @@ public class AutoscalingIT extends MlNativeAutodetectIntegTestCase {
     @Before
     public void putSettings() {
         updateClusterSettings(
-            Settings.builder().put(MachineLearning.MAX_LAZY_ML_NODES.getKey(), 100).put("logger.org.elasticsearch.xpack.ml", "DEBUG")
+            Settings.builder().put(MachineLearningField.MAX_LAZY_ML_NODES.getKey(), 100).put("logger.org.elasticsearch.xpack.ml", "DEBUG")
         );
     }
 
     @After
     public void removeSettings() {
         updateClusterSettings(
-            Settings.builder().putNull(MachineLearning.MAX_LAZY_ML_NODES.getKey()).putNull("logger.org.elasticsearch.xpack.ml")
+            Settings.builder().putNull(MachineLearningField.MAX_LAZY_ML_NODES.getKey()).putNull("logger.org.elasticsearch.xpack.ml")
         );
         cleanUp();
     }

+ 3 - 2
x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TooManyJobsIT.java

@@ -16,6 +16,7 @@ import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
 import org.elasticsearch.transport.TransportService;
+import org.elasticsearch.xpack.core.ml.MachineLearningField;
 import org.elasticsearch.xpack.core.ml.MlTasks;
 import org.elasticsearch.xpack.core.ml.action.CloseJobAction;
 import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction;
@@ -84,9 +85,9 @@ public class TooManyJobsIT extends BaseMlIntegTestCase {
         logger.info("Started [{}] nodes", numNodes);
         ensureStableCluster(numNodes);
         ensureTemplatesArePresent();
-        logger.info("[{}] is [{}]", MachineLearning.MAX_LAZY_ML_NODES.getKey(), maxNumberOfLazyNodes);
+        logger.info("[{}] is [{}]", MachineLearningField.MAX_LAZY_ML_NODES.getKey(), maxNumberOfLazyNodes);
         // Set our lazy node number
-        updateClusterSettings(Settings.builder().put(MachineLearning.MAX_LAZY_ML_NODES.getKey(), maxNumberOfLazyNodes));
+        updateClusterSettings(Settings.builder().put(MachineLearningField.MAX_LAZY_ML_NODES.getKey(), maxNumberOfLazyNodes));
         // create and open first job, which succeeds:
         Job.Builder job = createJob("lazy-node-validation-job-1", ByteSizeValue.ofMb(2));
         PutJobAction.Request putJobRequest = new PutJobAction.Request(job);

+ 1 - 9
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

@@ -649,14 +649,6 @@ public class MachineLearning extends Plugin
         Property.NodeScope
     );
 
-    public static final Setting<Integer> MAX_LAZY_ML_NODES = Setting.intSetting(
-        "xpack.ml.max_lazy_ml_nodes",
-        0,
-        0,
-        Property.OperatorDynamic,
-        Property.NodeScope
-    );
-
     // Before 8.0.0 this needs to match the max allowed value for xpack.ml.max_open_jobs,
     // as the current node could be running in a cluster where some nodes are still using
     // that setting. From 8.0.0 onwards we have the flexibility to increase it...
@@ -810,7 +802,7 @@ public class MachineLearning extends Plugin
             PROCESS_CONNECT_TIMEOUT,
             CONCURRENT_JOB_ALLOCATIONS,
             MachineLearningField.MAX_MODEL_MEMORY_LIMIT,
-            MAX_LAZY_ML_NODES,
+            MachineLearningField.MAX_LAZY_ML_NODES,
             MAX_MACHINE_MEMORY_PERCENT,
             AutodetectBuilder.MAX_ANOMALY_RECORDS_SETTING_DYNAMIC,
             MAX_OPEN_JOBS_PER_NODE,

+ 2 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportMlInfoAction.java

@@ -22,6 +22,7 @@ import org.elasticsearch.injection.guice.Inject;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.xcontent.NamedXContentRegistry;
+import org.elasticsearch.xpack.core.ml.MachineLearningField;
 import org.elasticsearch.xpack.core.ml.MlMetadata;
 import org.elasticsearch.xpack.core.ml.action.MlInfoAction;
 import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig;
@@ -162,7 +163,7 @@ public class TransportMlInfoAction extends HandledTransportAction<MlInfoAction.R
                 clusterSettings.get(MachineLearning.ALLOCATED_PROCESSORS_SCALE)
             );
             if (totalMlProcessors.count() > 0) {
-                int potentialExtraProcessors = Math.max(0, clusterSettings.get(MachineLearning.MAX_LAZY_ML_NODES) - mlNodes.size())
+                int potentialExtraProcessors = Math.max(0, clusterSettings.get(MachineLearningField.MAX_LAZY_ML_NODES) - mlNodes.size())
                     * singleNodeProcessors.roundUp();
                 limits.put("total_ml_processors", totalMlProcessors.roundUp() + potentialExtraProcessors);
             }

+ 2 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.java

@@ -113,7 +113,7 @@ public class TrainedModelAssignmentClusterService implements ClusterStateListene
         this.maxMemoryPercentage = MachineLearning.MAX_MACHINE_MEMORY_PERCENT.get(settings);
         this.useAuto = MachineLearningField.USE_AUTO_MACHINE_MEMORY_PERCENT.get(settings);
         this.maxOpenJobs = MachineLearning.MAX_OPEN_JOBS_PER_NODE.get(settings);
-        this.maxLazyMLNodes = MachineLearning.MAX_LAZY_ML_NODES.get(settings);
+        this.maxLazyMLNodes = MachineLearningField.MAX_LAZY_ML_NODES.get(settings);
         this.maxMLNodeSize = MachineLearning.MAX_ML_NODE_SIZE.get(settings).getBytes();
         this.allocatedProcessorsScale = MachineLearning.ALLOCATED_PROCESSORS_SCALE.get(settings);
         this.client = client;
@@ -125,7 +125,7 @@ public class TrainedModelAssignmentClusterService implements ClusterStateListene
             clusterService.getClusterSettings()
                 .addSettingsUpdateConsumer(MachineLearningField.USE_AUTO_MACHINE_MEMORY_PERCENT, this::setUseAuto);
             clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.MAX_OPEN_JOBS_PER_NODE, this::setMaxOpenJobs);
-            clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.MAX_LAZY_ML_NODES, this::setMaxLazyMLNodes);
+            clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearningField.MAX_LAZY_ML_NODES, this::setMaxLazyMLNodes);
             clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.MAX_ML_NODE_SIZE, this::setMaxMLNodeSize);
             clusterService.getClusterSettings()
                 .addSettingsUpdateConsumer(MachineLearning.ALLOCATED_PROCESSORS_SCALE, this::setAllocatedProcessorsScale);

+ 3 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/task/AbstractJobPersistentTasksExecutor.java

@@ -24,6 +24,7 @@ import org.elasticsearch.persistent.PersistentTaskParams;
 import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
 import org.elasticsearch.persistent.PersistentTasksExecutor;
 import org.elasticsearch.xpack.core.common.notifications.AbstractAuditor;
+import org.elasticsearch.xpack.core.ml.MachineLearningField;
 import org.elasticsearch.xpack.core.ml.MlMetadata;
 import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 import org.elasticsearch.xpack.ml.MachineLearning;
@@ -103,7 +104,7 @@ public abstract class AbstractJobPersistentTasksExecutor<Params extends Persiste
         this.expressionResolver = Objects.requireNonNull(expressionResolver);
         this.maxConcurrentJobAllocations = MachineLearning.CONCURRENT_JOB_ALLOCATIONS.get(settings);
         this.maxMachineMemoryPercent = MachineLearning.MAX_MACHINE_MEMORY_PERCENT.get(settings);
-        this.maxLazyMLNodes = MachineLearning.MAX_LAZY_ML_NODES.get(settings);
+        this.maxLazyMLNodes = MachineLearningField.MAX_LAZY_ML_NODES.get(settings);
         this.maxOpenJobs = MAX_OPEN_JOBS_PER_NODE.get(settings);
         this.useAutoMemoryPercentage = USE_AUTO_MACHINE_MEMORY_PERCENT.get(settings);
         this.maxNodeMemory = MAX_ML_NODE_SIZE.get(settings).getBytes();
@@ -111,7 +112,7 @@ public abstract class AbstractJobPersistentTasksExecutor<Params extends Persiste
             .addSettingsUpdateConsumer(MachineLearning.CONCURRENT_JOB_ALLOCATIONS, this::setMaxConcurrentJobAllocations);
         clusterService.getClusterSettings()
             .addSettingsUpdateConsumer(MachineLearning.MAX_MACHINE_MEMORY_PERCENT, this::setMaxMachineMemoryPercent);
-        clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.MAX_LAZY_ML_NODES, this::setMaxLazyMLNodes);
+        clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearningField.MAX_LAZY_ML_NODES, this::setMaxLazyMLNodes);
         clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_OPEN_JOBS_PER_NODE, this::setMaxOpenJobs);
         clusterService.getClusterSettings().addSettingsUpdateConsumer(USE_AUTO_MACHINE_MEMORY_PERCENT, this::setUseAutoMemoryPercentage);
         clusterService.getClusterSettings().addSettingsUpdateConsumer(MAX_ML_NODE_SIZE, this::setMaxNodeSize);

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/utils/NativeMemoryCalculator.java

@@ -22,10 +22,10 @@ import org.elasticsearch.xpack.ml.MachineLearning;
 
 import java.util.OptionalLong;
 
+import static org.elasticsearch.xpack.core.ml.MachineLearningField.MAX_LAZY_ML_NODES;
 import static org.elasticsearch.xpack.core.ml.MachineLearningField.USE_AUTO_MACHINE_MEMORY_PERCENT;
 import static org.elasticsearch.xpack.ml.MachineLearning.MACHINE_MEMORY_NODE_ATTR;
 import static org.elasticsearch.xpack.ml.MachineLearning.MAX_JVM_SIZE_NODE_ATTR;
-import static org.elasticsearch.xpack.ml.MachineLearning.MAX_LAZY_ML_NODES;
 import static org.elasticsearch.xpack.ml.MachineLearning.MAX_MACHINE_MEMORY_PERCENT;
 import static org.elasticsearch.xpack.ml.MachineLearning.MAX_ML_NODE_SIZE;
 

+ 1 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportStartDataFrameAnalyticsActionTests.java

@@ -130,7 +130,7 @@ public class TransportStartDataFrameAnalyticsActionTests extends ESTestCase {
                 MachineLearning.MAX_MACHINE_MEMORY_PERCENT,
                 MachineLearningField.USE_AUTO_MACHINE_MEMORY_PERCENT,
                 MachineLearning.MAX_ML_NODE_SIZE,
-                MachineLearning.MAX_LAZY_ML_NODES,
+                MachineLearningField.MAX_LAZY_ML_NODES,
                 MachineLearning.MAX_OPEN_JOBS_PER_NODE
             )
         );

+ 2 - 2
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterServiceTests.java

@@ -127,7 +127,7 @@ public class TrainedModelAssignmentClusterServiceTests extends ESTestCase {
                 MachineLearning.MAX_MACHINE_MEMORY_PERCENT,
                 MachineLearningField.USE_AUTO_MACHINE_MEMORY_PERCENT,
                 MachineLearning.MAX_OPEN_JOBS_PER_NODE,
-                MachineLearning.MAX_LAZY_ML_NODES,
+                MachineLearningField.MAX_LAZY_ML_NODES,
                 MachineLearning.MAX_ML_NODE_SIZE,
                 MachineLearning.ALLOCATED_PROCESSORS_SCALE
             )
@@ -2079,7 +2079,7 @@ public class TrainedModelAssignmentClusterServiceTests extends ESTestCase {
 
     private TrainedModelAssignmentClusterService createClusterService(int maxLazyNodes) {
         return new TrainedModelAssignmentClusterService(
-            Settings.builder().put(MachineLearning.MAX_LAZY_ML_NODES.getKey(), maxLazyNodes).build(),
+            Settings.builder().put(MachineLearningField.MAX_LAZY_ML_NODES.getKey(), maxLazyNodes).build(),
             clusterService,
             threadPool,
             nodeLoadDetector,

+ 3 - 3
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/task/OpenJobPersistentTasksExecutorTests.java

@@ -105,7 +105,7 @@ public class OpenJobPersistentTasksExecutorTests extends ESTestCase {
                     ClusterApplierService.CLUSTER_SERVICE_SLOW_TASK_THREAD_DUMP_TIMEOUT_SETTING,
                     MachineLearning.CONCURRENT_JOB_ALLOCATIONS,
                     MachineLearning.MAX_MACHINE_MEMORY_PERCENT,
-                    MachineLearning.MAX_LAZY_ML_NODES,
+                    MachineLearningField.MAX_LAZY_ML_NODES,
                     MachineLearning.MAX_ML_NODE_SIZE,
                     MachineLearning.MAX_OPEN_JOBS_PER_NODE,
                     MachineLearningField.USE_AUTO_MACHINE_MEMORY_PERCENT
@@ -155,7 +155,7 @@ public class OpenJobPersistentTasksExecutorTests extends ESTestCase {
 
     // An index being unavailable should take precedence over waiting for a lazy node
     public void testGetAssignment_GivenUnavailableIndicesWithLazyNode() {
-        Settings settings = Settings.builder().put(MachineLearning.MAX_LAZY_ML_NODES.getKey(), 1).build();
+        Settings settings = Settings.builder().put(MachineLearningField.MAX_LAZY_ML_NODES.getKey(), 1).build();
 
         ClusterState.Builder csBuilder = ClusterState.builder(new ClusterName("_name"));
         Metadata.Builder metadata = Metadata.builder();
@@ -177,7 +177,7 @@ public class OpenJobPersistentTasksExecutorTests extends ESTestCase {
     }
 
     public void testGetAssignment_GivenLazyJobAndNoGlobalLazyNodes() {
-        Settings settings = Settings.builder().put(MachineLearning.MAX_LAZY_ML_NODES.getKey(), 0).build();
+        Settings settings = Settings.builder().put(MachineLearningField.MAX_LAZY_ML_NODES.getKey(), 0).build();
         ClusterState.Builder csBuilder = ClusterState.builder(new ClusterName("_name"));
         Metadata.Builder metadata = Metadata.builder();
         RoutingTable.Builder routingTable = RoutingTable.builder();

+ 1 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/NativeMemoryCalculatorTests.java

@@ -36,11 +36,11 @@ import java.util.Map;
 import java.util.Set;
 import java.util.function.BiConsumer;
 
+import static org.elasticsearch.xpack.core.ml.MachineLearningField.MAX_LAZY_ML_NODES;
 import static org.elasticsearch.xpack.core.ml.MachineLearningField.MAX_MODEL_MEMORY_LIMIT;
 import static org.elasticsearch.xpack.core.ml.MachineLearningField.USE_AUTO_MACHINE_MEMORY_PERCENT;
 import static org.elasticsearch.xpack.ml.MachineLearning.MACHINE_MEMORY_NODE_ATTR;
 import static org.elasticsearch.xpack.ml.MachineLearning.MAX_JVM_SIZE_NODE_ATTR;
-import static org.elasticsearch.xpack.ml.MachineLearning.MAX_LAZY_ML_NODES;
 import static org.elasticsearch.xpack.ml.MachineLearning.MAX_MACHINE_MEMORY_PERCENT;
 import static org.elasticsearch.xpack.ml.MachineLearning.MAX_ML_NODE_SIZE;
 import static org.elasticsearch.xpack.ml.autoscaling.MlAutoscalingDeciderServiceTests.AUTO_NODE_TIERS_NO_MONITORING;

+ 0 - 20
x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml

@@ -39,23 +39,3 @@
           }
   - match: { error.reason: "Unknown task_type [bad]" }
 
----
-"Test get all":
-  - do:
-      inference.get:
-        inference_id: "*"
-  - length: { endpoints: 1}
-  - match: { endpoints.0.inference_id: ".elser-2" }
-
-  - do:
-      inference.get:
-        inference_id: _all
-  - length: { endpoints: 1}
-  - match: { endpoints.0.inference_id: ".elser-2" }
-
-  - do:
-      inference.get:
-        inference_id: ""
-  - length: { endpoints: 1}
-  - match: { endpoints.0.inference_id: ".elser-2" }
-