Browse Source

[ML] Default inference endpoint for ELSER (#114164)

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
Mike Pellegrini 1 year ago
parent
commit
5b8f9c12d2
29 changed files with 745 additions and 134 deletions
  1. 5 0
      docs/changelog/113873.yaml
  2. 6 1
      docs/reference/rest-api/usage.asciidoc
  3. 9 0
      server/src/main/java/org/elasticsearch/inference/InferenceService.java
  4. 24 0
      server/src/main/java/org/elasticsearch/inference/UnparsedModel.java
  5. 2 1
      test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java
  6. 1 0
      x-pack/plugin/build.gradle
  7. 3 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java
  8. 2 2
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationsSettings.java
  9. 1 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationSettingsTests.java
  10. 1 1
      x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CustomElandModelIT.java
  11. 70 0
      x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultElserIT.java
  12. 15 6
      x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java
  13. 5 3
      x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java
  14. 4 7
      x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockDenseInferenceServiceIT.java
  15. 6 9
      x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockSparseInferenceServiceIT.java
  16. 2 2
      x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java
  17. 210 5
      x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java
  18. 21 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/DefaultElserFeatureFlag.java
  19. 3 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java
  20. 2 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java
  21. 2 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java
  22. 20 19
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java
  23. 3 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java
  24. 126 35
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java
  25. 26 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java
  26. 83 23
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java
  27. 3 3
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java
  28. 84 6
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java
  29. 6 3
      x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml

+ 5 - 0
docs/changelog/113873.yaml

@@ -0,0 +1,5 @@
+pr: 113873
+summary: Default inference endpoint for ELSER
+area: Machine Learning
+type: enhancement
+issues: []

+ 6 - 1
docs/reference/rest-api/usage.asciidoc

@@ -206,7 +206,12 @@ GET /_xpack/usage
   "inference": {
     "available" : true,
     "enabled" : true,
-    "models" : []
+    "models" : [{
+        "service": "elasticsearch",
+        "task_type": "SPARSE_EMBEDDING",
+        "count": 1
+      }
+    ]
   },
   "logstash" : {
     "available" : true,

+ 9 - 0
server/src/main/java/org/elasticsearch/inference/InferenceService.java

@@ -191,4 +191,13 @@ public interface InferenceService extends Closeable {
     default boolean canStream(TaskType taskType) {
         return supportedStreamingTasks().contains(taskType);
     }
+
+    /**
+     * A service can define default configurations that can be
+     * used out of the box without creating an endpoint first.
+     * @return Default configurations provided by this service
+     */
+    default List<UnparsedModel> defaultConfigs() {
+        return List.of();
+    }
 }

+ 24 - 0
server/src/main/java/org/elasticsearch/inference/UnparsedModel.java

@@ -0,0 +1,24 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.inference;
+
+import java.util.Map;
+
+/**
+ * Semi parsed model where inference entity id, task type and service
+ * are known but the settings are not parsed.
+ */
+public record UnparsedModel(
+    String inferenceEntityId,
+    TaskType taskType,
+    String service,
+    Map<String, Object> settings,
+    Map<String, Object> secrets
+) {}

+ 2 - 1
test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java

@@ -19,7 +19,8 @@ public enum FeatureFlag {
     TIME_SERIES_MODE("es.index_mode_feature_flag_registered=true", Version.fromString("8.0.0"), null),
     FAILURE_STORE_ENABLED("es.failure_store_feature_flag_enabled=true", Version.fromString("8.12.0"), null),
     CHUNKING_SETTINGS_ENABLED("es.inference_chunking_settings_feature_flag_enabled=true", Version.fromString("8.16.0"), null),
-    INFERENCE_SCALE_TO_ZERO("es.inference_scale_to_zero_feature_flag_enabled=true", Version.fromString("8.16.0"), null);
+    INFERENCE_SCALE_TO_ZERO("es.inference_scale_to_zero_feature_flag_enabled=true", Version.fromString("8.16.0"), null),
+    INFERENCE_DEFAULT_ELSER("es.inference_default_elser_feature_flag_enabled=true", Version.fromString("8.16.0"), null);
 
     public final String systemProperty;
     public final Version from;

+ 1 - 0
x-pack/plugin/build.gradle

@@ -201,5 +201,6 @@ tasks.named("precommit").configure {
 
 tasks.named("yamlRestTestV7CompatTransform").configure({ task ->
   task.skipTest("security/10_forbidden/Test bulk response with invalid credentials", "warning does not exist for compatibility")
+  task.skipTest("inference/inference_crud/Test get all", "Assertions on number of inference models break due to default configs")
 })
 

+ 3 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java

@@ -237,7 +237,9 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
             if (numberOfAllocations != null) {
                 return numberOfAllocations;
             } else {
-                if (adaptiveAllocationsSettings == null || adaptiveAllocationsSettings.getMinNumberOfAllocations() == null) {
+                if (adaptiveAllocationsSettings == null
+                    || adaptiveAllocationsSettings.getMinNumberOfAllocations() == null
+                    || adaptiveAllocationsSettings.getMinNumberOfAllocations() == 0) {
                     return DEFAULT_NUM_ALLOCATIONS;
                 } else {
                     return adaptiveAllocationsSettings.getMinNumberOfAllocations();

+ 2 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationsSettings.java

@@ -147,8 +147,8 @@ public class AdaptiveAllocationsSettings implements ToXContentObject, Writeable
     public ActionRequestValidationException validate() {
         ActionRequestValidationException validationException = new ActionRequestValidationException();
         boolean hasMinNumberOfAllocations = (minNumberOfAllocations != null && minNumberOfAllocations != -1);
-        if (hasMinNumberOfAllocations && minNumberOfAllocations < 1) {
-            validationException.addValidationError("[" + MIN_NUMBER_OF_ALLOCATIONS + "] must be a positive integer or null");
+        if (hasMinNumberOfAllocations && minNumberOfAllocations < 0) {
+            validationException.addValidationError("[" + MIN_NUMBER_OF_ALLOCATIONS + "] must be a non-negative integer or null");
         }
         boolean hasMaxNumberOfAllocations = (maxNumberOfAllocations != null && maxNumberOfAllocations != -1);
         if (hasMaxNumberOfAllocations && maxNumberOfAllocations < 1) {

+ 1 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationSettingsTests.java

@@ -17,7 +17,7 @@ public class AdaptiveAllocationSettingsTests extends AbstractWireSerializingTest
     public static AdaptiveAllocationsSettings testInstance() {
         return new AdaptiveAllocationsSettings(
             randomBoolean() ? null : randomBoolean(),
-            randomBoolean() ? null : randomIntBetween(1, 2),
+            randomBoolean() ? null : randomIntBetween(0, 2),
             randomBoolean() ? null : randomIntBetween(2, 4)
         );
     }

+ 1 - 1
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CustomElandModelIT.java

@@ -85,7 +85,7 @@ public class CustomElandModelIT extends InferenceBaseRestTest {
 
         var inferenceId = "sparse-inf";
         putModel(inferenceId, inferenceConfig, TaskType.SPARSE_EMBEDDING);
-        var results = inferOnMockService(inferenceId, List.of("washing", "machine"));
+        var results = infer(inferenceId, List.of("washing", "machine"));
         deleteModel(inferenceId);
         assertNotNull(results.get("sparse_embedding"));
     }

+ 70 - 0
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultElserIT.java

@@ -0,0 +1,70 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference;
+
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.threadpool.TestThreadPool;
+import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;
+import org.hamcrest.Matchers;
+import org.junit.After;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.oneOf;
+
+public class DefaultElserIT extends InferenceBaseRestTest {
+
+    private TestThreadPool threadPool;
+
+    @Before
+    public void createThreadPool() {
+        threadPool = new TestThreadPool(DefaultElserIT.class.getSimpleName());
+    }
+
+    @After
+    public void tearDown() throws Exception {
+        threadPool.close();
+        super.tearDown();
+    }
+
+    @SuppressWarnings("unchecked")
+    public void testInferCreatesDefaultElser() throws IOException {
+        assumeTrue("Default config requires a feature flag", DefaultElserFeatureFlag.isEnabled());
+        var model = getModel(ElasticsearchInternalService.DEFAULT_ELSER_ID);
+        assertDefaultElserConfig(model);
+
+        var inputs = List.of("Hello World", "Goodnight moon");
+        var queryParams = Map.of("timeout", "120s");
+        var results = infer(ElasticsearchInternalService.DEFAULT_ELSER_ID, TaskType.SPARSE_EMBEDDING, inputs, queryParams);
+        var embeddings = (List<Map<String, Object>>) results.get("sparse_embedding");
+        assertThat(results.toString(), embeddings, hasSize(2));
+    }
+
+    @SuppressWarnings("unchecked")
+    private static void assertDefaultElserConfig(Map<String, Object> modelConfig) {
+        assertEquals(modelConfig.toString(), ElasticsearchInternalService.DEFAULT_ELSER_ID, modelConfig.get("inference_id"));
+        assertEquals(modelConfig.toString(), ElasticsearchInternalService.NAME, modelConfig.get("service"));
+        assertEquals(modelConfig.toString(), TaskType.SPARSE_EMBEDDING.toString(), modelConfig.get("task_type"));
+
+        var serviceSettings = (Map<String, Object>) modelConfig.get("service_settings");
+        assertThat(modelConfig.toString(), serviceSettings.get("model_id"), is(oneOf(".elser_model_2", ".elser_model_2_linux-x86_64")));
+        assertEquals(modelConfig.toString(), 1, serviceSettings.get("num_threads"));
+
+        var adaptiveAllocations = (Map<String, Object>) serviceSettings.get("adaptive_allocations");
+        assertThat(
+            modelConfig.toString(),
+            adaptiveAllocations,
+            Matchers.is(Map.of("enabled", true, "min_number_of_allocations", 1, "max_number_of_allocations", 8))
+        );
+    }
+}

+ 15 - 6
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

@@ -270,7 +270,7 @@ public class InferenceBaseRestTest extends ESRestTestCase {
 
     @SuppressWarnings("unchecked")
     protected Map<String, Object> getModel(String modelId) throws IOException {
-        var endpoint = Strings.format("_inference/%s", modelId);
+        var endpoint = Strings.format("_inference/%s?error_trace", modelId);
         return ((List<Map<String, Object>>) getInternal(endpoint).get("endpoints")).get(0);
     }
 
@@ -293,9 +293,9 @@ public class InferenceBaseRestTest extends ESRestTestCase {
         return entityAsMap(response);
     }
 
-    protected Map<String, Object> inferOnMockService(String modelId, List<String> input) throws IOException {
+    protected Map<String, Object> infer(String modelId, List<String> input) throws IOException {
         var endpoint = Strings.format("_inference/%s", modelId);
-        return inferOnMockServiceInternal(endpoint, input);
+        return inferInternal(endpoint, input, Map.of());
     }
 
     protected Deque<ServerSentEvent> streamInferOnMockService(String modelId, TaskType taskType, List<String> input) throws Exception {
@@ -324,14 +324,23 @@ public class InferenceBaseRestTest extends ESRestTestCase {
         return responseConsumer.events();
     }
 
-    protected Map<String, Object> inferOnMockService(String modelId, TaskType taskType, List<String> input) throws IOException {
+    protected Map<String, Object> infer(String modelId, TaskType taskType, List<String> input) throws IOException {
         var endpoint = Strings.format("_inference/%s/%s", taskType, modelId);
-        return inferOnMockServiceInternal(endpoint, input);
+        return inferInternal(endpoint, input, Map.of());
+    }
+
+    protected Map<String, Object> infer(String modelId, TaskType taskType, List<String> input, Map<String, String> queryParameters)
+        throws IOException {
+        var endpoint = Strings.format("_inference/%s/%s?error_trace", taskType, modelId);
+        return inferInternal(endpoint, input, queryParameters);
     }
 
-    private Map<String, Object> inferOnMockServiceInternal(String endpoint, List<String> input) throws IOException {
+    private Map<String, Object> inferInternal(String endpoint, List<String> input, Map<String, String> queryParameters) throws IOException {
         var request = new Request("POST", endpoint);
         request.setJsonEntity(jsonBody(input));
+        if (queryParameters.isEmpty() == false) {
+            request.addParameters(queryParameters);
+        }
         var response = client().performRequest(request);
         assertOkOrCreated(response);
         return entityAsMap(response);

+ 5 - 3
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

@@ -38,10 +38,12 @@ public class InferenceCrudIT extends InferenceBaseRestTest {
         }
 
         var getAllModels = getAllModels();
-        assertThat(getAllModels, hasSize(9));
+        int numModels = DefaultElserFeatureFlag.isEnabled() ? 10 : 9;
+        assertThat(getAllModels, hasSize(numModels));
 
         var getSparseModels = getModels("_all", TaskType.SPARSE_EMBEDDING);
-        assertThat(getSparseModels, hasSize(5));
+        int numSparseModels = DefaultElserFeatureFlag.isEnabled() ? 6 : 5;
+        assertThat(getSparseModels, hasSize(numSparseModels));
         for (var sparseModel : getSparseModels) {
             assertEquals("sparse_embedding", sparseModel.get("task_type"));
         }
@@ -99,7 +101,7 @@ public class InferenceCrudIT extends InferenceBaseRestTest {
         assertEquals(modelId, singleModel.get("inference_id"));
         assertEquals(TaskType.SPARSE_EMBEDDING.toString(), singleModel.get("task_type"));
 
-        var inference = inferOnMockService(modelId, List.of(randomAlphaOfLength(10)));
+        var inference = infer(modelId, List.of(randomAlphaOfLength(10)));
         assertNonEmptyInferenceResults(inference, 1, TaskType.SPARSE_EMBEDDING);
         deleteModel(modelId);
     }

+ 4 - 7
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockDenseInferenceServiceIT.java

@@ -28,15 +28,12 @@ public class MockDenseInferenceServiceIT extends InferenceBaseRestTest {
         }
 
         List<String> input = List.of(randomAlphaOfLength(10));
-        var inference = inferOnMockService(inferenceEntityId, input);
+        var inference = infer(inferenceEntityId, input);
         assertNonEmptyInferenceResults(inference, 1, TaskType.TEXT_EMBEDDING);
         // Same input should return the same result
-        assertEquals(inference, inferOnMockService(inferenceEntityId, input));
+        assertEquals(inference, infer(inferenceEntityId, input));
         // Different input values should not
-        assertNotEquals(
-            inference,
-            inferOnMockService(inferenceEntityId, randomValueOtherThan(input, () -> List.of(randomAlphaOfLength(10))))
-        );
+        assertNotEquals(inference, infer(inferenceEntityId, randomValueOtherThan(input, () -> List.of(randomAlphaOfLength(10)))));
     }
 
     public void testMockServiceWithMultipleInputs() throws IOException {
@@ -44,7 +41,7 @@ public class MockDenseInferenceServiceIT extends InferenceBaseRestTest {
         putModel(inferenceEntityId, mockDenseServiceModelConfig(), TaskType.TEXT_EMBEDDING);
 
         // The response is randomly generated, the input can be anything
-        var inference = inferOnMockService(
+        var inference = infer(
             inferenceEntityId,
             TaskType.TEXT_EMBEDDING,
             List.of(randomAlphaOfLength(5), randomAlphaOfLength(10), randomAlphaOfLength(15))

+ 6 - 9
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockSparseInferenceServiceIT.java

@@ -30,15 +30,12 @@ public class MockSparseInferenceServiceIT extends InferenceBaseRestTest {
         }
 
         List<String> input = List.of(randomAlphaOfLength(10));
-        var inference = inferOnMockService(inferenceEntityId, input);
+        var inference = infer(inferenceEntityId, input);
         assertNonEmptyInferenceResults(inference, 1, TaskType.SPARSE_EMBEDDING);
         // Same input should return the same result
-        assertEquals(inference, inferOnMockService(inferenceEntityId, input));
+        assertEquals(inference, infer(inferenceEntityId, input));
         // Different input values should not
-        assertNotEquals(
-            inference,
-            inferOnMockService(inferenceEntityId, randomValueOtherThan(input, () -> List.of(randomAlphaOfLength(10))))
-        );
+        assertNotEquals(inference, infer(inferenceEntityId, randomValueOtherThan(input, () -> List.of(randomAlphaOfLength(10)))));
     }
 
     public void testMockServiceWithMultipleInputs() throws IOException {
@@ -46,7 +43,7 @@ public class MockSparseInferenceServiceIT extends InferenceBaseRestTest {
         putModel(inferenceEntityId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING);
 
         // The response is randomly generated, the input can be anything
-        var inference = inferOnMockService(
+        var inference = infer(
             inferenceEntityId,
             TaskType.SPARSE_EMBEDDING,
             List.of(randomAlphaOfLength(5), randomAlphaOfLength(10), randomAlphaOfLength(15))
@@ -84,7 +81,7 @@ public class MockSparseInferenceServiceIT extends InferenceBaseRestTest {
         }
 
         // The response is randomly generated, the input can be anything
-        var inference = inferOnMockService(inferenceEntityId, List.of(randomAlphaOfLength(10)));
+        var inference = infer(inferenceEntityId, List.of(randomAlphaOfLength(10)));
         assertNonEmptyInferenceResults(inference, 1, TaskType.SPARSE_EMBEDDING);
     }
 
@@ -102,7 +99,7 @@ public class MockSparseInferenceServiceIT extends InferenceBaseRestTest {
         }
 
         // The response is randomly generated, the input can be anything
-        var inference = inferOnMockService(inferenceEntityId, List.of(randomAlphaOfLength(10)));
+        var inference = infer(inferenceEntityId, List.of(randomAlphaOfLength(10)));
         assertNonEmptyInferenceResults(inference, 1, TaskType.SPARSE_EMBEDDING);
     }
 }

+ 2 - 2
x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/TextEmbeddingCrudIT.java

@@ -38,7 +38,7 @@ public class TextEmbeddingCrudIT extends InferenceBaseRestTest {
         var models = getTrainedModel("_all");
         assertThat(models.toString(), containsString("deployment_id=" + inferenceEntityId));
 
-        Map<String, Object> results = inferOnMockService(
+        Map<String, Object> results = infer(
             inferenceEntityId,
             TaskType.TEXT_EMBEDDING,
             List.of("hello world", "this is the second document")
@@ -57,7 +57,7 @@ public class TextEmbeddingCrudIT extends InferenceBaseRestTest {
             var models = getTrainedModel("_all");
             assertThat(models.toString(), containsString("deployment_id=" + inferenceEntityId));
 
-            Map<String, Object> results = inferOnMockService(
+            Map<String, Object> results = infer(
                 inferenceEntityId,
                 TaskType.TEXT_EMBEDDING,
                 List.of("hello world", "this is the second document")

+ 210 - 5
x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java

@@ -20,6 +20,7 @@ import org.elasticsearch.inference.SecretSettings;
 import org.elasticsearch.inference.ServiceSettings;
 import org.elasticsearch.inference.TaskSettings;
 import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.inference.UnparsedModel;
 import org.elasticsearch.plugins.Plugin;
 import org.elasticsearch.reindex.ReindexPlugin;
 import org.elasticsearch.test.ESSingleNodeTestCase;
@@ -38,6 +39,7 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Comparator;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.CountDownLatch;
@@ -110,7 +112,7 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
         assertThat(putModelHolder.get(), is(true));
 
         // now get the model
-        AtomicReference<ModelRegistry.UnparsedModel> modelHolder = new AtomicReference<>();
+        AtomicReference<UnparsedModel> modelHolder = new AtomicReference<>();
         blockingCall(listener -> modelRegistry.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder);
         assertThat(exceptionHolder.get(), is(nullValue()));
         assertThat(modelHolder.get(), not(nullValue()));
@@ -168,7 +170,7 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
 
         // get should fail
         deleteResponseHolder.set(false);
-        AtomicReference<ModelRegistry.UnparsedModel> modelHolder = new AtomicReference<>();
+        AtomicReference<UnparsedModel> modelHolder = new AtomicReference<>();
         blockingCall(listener -> modelRegistry.getModelWithSecrets("model1", listener), modelHolder, exceptionHolder);
 
         assertThat(exceptionHolder.get(), not(nullValue()));
@@ -194,7 +196,7 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
         }
 
         AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
-        AtomicReference<List<ModelRegistry.UnparsedModel>> modelHolder = new AtomicReference<>();
+        AtomicReference<List<UnparsedModel>> modelHolder = new AtomicReference<>();
         blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.SPARSE_EMBEDDING, listener), modelHolder, exceptionHolder);
         assertThat(modelHolder.get(), hasSize(3));
         var sparseIds = sparseAndTextEmbeddingModels.stream()
@@ -235,8 +237,9 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
             assertNull(exceptionHolder.get());
         }
 
-        AtomicReference<List<ModelRegistry.UnparsedModel>> modelHolder = new AtomicReference<>();
+        AtomicReference<List<UnparsedModel>> modelHolder = new AtomicReference<>();
         blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder);
+        assertNull(exceptionHolder.get());
         assertThat(modelHolder.get(), hasSize(modelCount));
         var getAllModels = modelHolder.get();
 
@@ -264,15 +267,213 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
         assertThat(putModelHolder.get(), is(true));
         assertNull(exceptionHolder.get());
 
-        AtomicReference<ModelRegistry.UnparsedModel> modelHolder = new AtomicReference<>();
+        AtomicReference<UnparsedModel> modelHolder = new AtomicReference<>();
         blockingCall(listener -> modelRegistry.getModelWithSecrets(inferenceEntityId, listener), modelHolder, exceptionHolder);
         assertThat(modelHolder.get().secrets().keySet(), hasSize(1));
         var secretSettings = (Map<String, Object>) modelHolder.get().secrets().get("secret_settings");
         assertThat(secretSettings.get("secret"), equalTo(secret));
+        assertReturnModelIsModifiable(modelHolder.get());
 
         // get model without secrets
         blockingCall(listener -> modelRegistry.getModel(inferenceEntityId, listener), modelHolder, exceptionHolder);
         assertThat(modelHolder.get().secrets().keySet(), empty());
+        assertReturnModelIsModifiable(modelHolder.get());
+    }
+
+    public void testGetAllModels_WithDefaults() throws Exception {
+        var service = "foo";
+        var secret = "abc";
+        int configuredModelCount = 10;
+        int defaultModelCount = 2;
+        int totalModelCount = 12;
+
+        var defaultConfigs = new HashMap<String, UnparsedModel>();
+        for (int i = 0; i < defaultModelCount; i++) {
+            var id = "default-" + i;
+            defaultConfigs.put(id, createUnparsedConfig(id, randomFrom(TaskType.values()), service, secret));
+        }
+        defaultConfigs.values().forEach(modelRegistry::addDefaultConfiguration);
+
+        AtomicReference<Boolean> putModelHolder = new AtomicReference<>();
+        AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
+
+        var createdModels = new HashMap<String, Model>();
+        for (int i = 0; i < configuredModelCount; i++) {
+            var id = randomAlphaOfLength(5) + i;
+            var model = createModel(id, randomFrom(TaskType.values()), service);
+            createdModels.put(id, model);
+            blockingCall(listener -> modelRegistry.storeModel(model, listener), putModelHolder, exceptionHolder);
+            assertThat(putModelHolder.get(), is(true));
+            assertNull(exceptionHolder.get());
+        }
+
+        AtomicReference<List<UnparsedModel>> modelHolder = new AtomicReference<>();
+        blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder);
+        assertNull(exceptionHolder.get());
+        assertThat(modelHolder.get(), hasSize(totalModelCount));
+        var getAllModels = modelHolder.get();
+        assertReturnModelIsModifiable(modelHolder.get().get(0));
+
+        // sort in the same order as the returned models
+        var ids = new ArrayList<>(defaultConfigs.keySet().stream().toList());
+        ids.addAll(createdModels.keySet().stream().toList());
+        ids.sort(String::compareTo);
+        for (int i = 0; i < totalModelCount; i++) {
+            var id = ids.get(i);
+            assertEquals(id, getAllModels.get(i).inferenceEntityId());
+            if (id.startsWith("default")) {
+                assertEquals(defaultConfigs.get(id).taskType(), getAllModels.get(i).taskType());
+                assertEquals(defaultConfigs.get(id).service(), getAllModels.get(i).service());
+            } else {
+                assertEquals(createdModels.get(id).getTaskType(), getAllModels.get(i).taskType());
+                assertEquals(createdModels.get(id).getConfigurations().getService(), getAllModels.get(i).service());
+            }
+        }
+    }
+
+    public void testGetAllModels_OnlyDefaults() throws Exception {
+        var service = "foo";
+        var secret = "abc";
+        int defaultModelCount = 2;
+
+        var defaultConfigs = new HashMap<String, UnparsedModel>();
+        for (int i = 0; i < defaultModelCount; i++) {
+            var id = "default-" + i;
+            defaultConfigs.put(id, createUnparsedConfig(id, randomFrom(TaskType.values()), service, secret));
+        }
+        defaultConfigs.values().forEach(modelRegistry::addDefaultConfiguration);
+
+        AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
+        AtomicReference<List<UnparsedModel>> modelHolder = new AtomicReference<>();
+        blockingCall(listener -> modelRegistry.getAllModels(listener), modelHolder, exceptionHolder);
+        assertNull(exceptionHolder.get());
+        assertThat(modelHolder.get(), hasSize(2));
+        var getAllModels = modelHolder.get();
+        assertReturnModelIsModifiable(modelHolder.get().get(0));
+
+        // sort in the same order as the returned models
+        var ids = new ArrayList<>(defaultConfigs.keySet().stream().toList());
+        ids.sort(String::compareTo);
+        for (int i = 0; i < defaultModelCount; i++) {
+            var id = ids.get(i);
+            assertEquals(id, getAllModels.get(i).inferenceEntityId());
+            assertEquals(defaultConfigs.get(id).taskType(), getAllModels.get(i).taskType());
+            assertEquals(defaultConfigs.get(id).service(), getAllModels.get(i).service());
+        }
+    }
+
+    public void testGet_WithDefaults() throws InterruptedException {
+        var service = "foo";
+        var secret = "abc";
+
+        var defaultSparse = createUnparsedConfig("default-sparse", TaskType.SPARSE_EMBEDDING, service, secret);
+        var defaultText = createUnparsedConfig("default-text", TaskType.TEXT_EMBEDDING, service, secret);
+
+        modelRegistry.addDefaultConfiguration(defaultSparse);
+        modelRegistry.addDefaultConfiguration(defaultText);
+
+        AtomicReference<Boolean> putModelHolder = new AtomicReference<>();
+        AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
+
+        var configured1 = createModel(randomAlphaOfLength(5) + 1, randomFrom(TaskType.values()), service);
+        var configured2 = createModel(randomAlphaOfLength(5) + 1, randomFrom(TaskType.values()), service);
+        blockingCall(listener -> modelRegistry.storeModel(configured1, listener), putModelHolder, exceptionHolder);
+        assertThat(putModelHolder.get(), is(true));
+        blockingCall(listener -> modelRegistry.storeModel(configured2, listener), putModelHolder, exceptionHolder);
+        assertThat(putModelHolder.get(), is(true));
+        assertNull(exceptionHolder.get());
+
+        AtomicReference<UnparsedModel> modelHolder = new AtomicReference<>();
+        blockingCall(listener -> modelRegistry.getModel("default-sparse", listener), modelHolder, exceptionHolder);
+        assertEquals("default-sparse", modelHolder.get().inferenceEntityId());
+        assertEquals(TaskType.SPARSE_EMBEDDING, modelHolder.get().taskType());
+        assertReturnModelIsModifiable(modelHolder.get());
+
+        blockingCall(listener -> modelRegistry.getModel("default-text", listener), modelHolder, exceptionHolder);
+        assertEquals("default-text", modelHolder.get().inferenceEntityId());
+        assertEquals(TaskType.TEXT_EMBEDDING, modelHolder.get().taskType());
+
+        blockingCall(listener -> modelRegistry.getModel(configured1.getInferenceEntityId(), listener), modelHolder, exceptionHolder);
+        assertEquals(configured1.getInferenceEntityId(), modelHolder.get().inferenceEntityId());
+        assertEquals(configured1.getTaskType(), modelHolder.get().taskType());
+    }
+
+    public void testGetByTaskType_WithDefaults() throws Exception {
+        var service = "foo";
+        var secret = "abc";
+
+        var defaultSparse = createUnparsedConfig("default-sparse", TaskType.SPARSE_EMBEDDING, service, secret);
+        var defaultText = createUnparsedConfig("default-text", TaskType.TEXT_EMBEDDING, service, secret);
+        var defaultChat = createUnparsedConfig("default-chat", TaskType.COMPLETION, service, secret);
+
+        modelRegistry.addDefaultConfiguration(defaultSparse);
+        modelRegistry.addDefaultConfiguration(defaultText);
+        modelRegistry.addDefaultConfiguration(defaultChat);
+
+        AtomicReference<Boolean> putModelHolder = new AtomicReference<>();
+        AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
+
+        var configuredSparse = createModel("configured-sparse", TaskType.SPARSE_EMBEDDING, service);
+        var configuredText = createModel("configured-text", TaskType.TEXT_EMBEDDING, service);
+        var configuredRerank = createModel("configured-rerank", TaskType.RERANK, service);
+        blockingCall(listener -> modelRegistry.storeModel(configuredSparse, listener), putModelHolder, exceptionHolder);
+        assertThat(putModelHolder.get(), is(true));
+        blockingCall(listener -> modelRegistry.storeModel(configuredText, listener), putModelHolder, exceptionHolder);
+        assertThat(putModelHolder.get(), is(true));
+        blockingCall(listener -> modelRegistry.storeModel(configuredRerank, listener), putModelHolder, exceptionHolder);
+        assertThat(putModelHolder.get(), is(true));
+        assertNull(exceptionHolder.get());
+
+        AtomicReference<List<UnparsedModel>> modelHolder = new AtomicReference<>();
+        blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.SPARSE_EMBEDDING, listener), modelHolder, exceptionHolder);
+        if (exceptionHolder.get() != null) {
+            throw exceptionHolder.get();
+        }
+        assertNull(exceptionHolder.get());
+        assertThat(modelHolder.get(), hasSize(2));
+        assertEquals("configured-sparse", modelHolder.get().get(0).inferenceEntityId());
+        assertEquals("default-sparse", modelHolder.get().get(1).inferenceEntityId());
+
+        blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.TEXT_EMBEDDING, listener), modelHolder, exceptionHolder);
+        assertThat(modelHolder.get(), hasSize(2));
+        assertEquals("configured-text", modelHolder.get().get(0).inferenceEntityId());
+        assertEquals("default-text", modelHolder.get().get(1).inferenceEntityId());
+        assertReturnModelIsModifiable(modelHolder.get().get(0));
+
+        blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.RERANK, listener), modelHolder, exceptionHolder);
+        assertThat(modelHolder.get(), hasSize(1));
+        assertEquals("configured-rerank", modelHolder.get().get(0).inferenceEntityId());
+
+        blockingCall(listener -> modelRegistry.getModelsByTaskType(TaskType.COMPLETION, listener), modelHolder, exceptionHolder);
+        assertThat(modelHolder.get(), hasSize(1));
+        assertEquals("default-chat", modelHolder.get().get(0).inferenceEntityId());
+        assertReturnModelIsModifiable(modelHolder.get().get(0));
+    }
+
+    @SuppressWarnings("unchecked")
+    private void assertReturnModelIsModifiable(UnparsedModel unparsedModel) {
+        var settings = unparsedModel.settings();
+        if (settings != null) {
+            var serviceSettings = (Map<String, Object>) settings.get("service_settings");
+            if (serviceSettings != null && serviceSettings.size() > 0) {
+                var itr = serviceSettings.entrySet().iterator();
+                itr.next();
+                itr.remove();
+            }
+
+            var taskSettings = (Map<String, Object>) settings.get("task_settings");
+            if (taskSettings != null && taskSettings.size() > 0) {
+                var itr = taskSettings.entrySet().iterator();
+                itr.next();
+                itr.remove();
+            }
+
+            if (unparsedModel.secrets() != null && unparsedModel.secrets().size() > 0) {
+                var itr = unparsedModel.secrets().entrySet().iterator();
+                itr.next();
+                itr.remove();
+            }
+        }
     }
 
     private Model buildElserModelConfig(String inferenceEntityId, TaskType taskType) {
@@ -327,6 +528,10 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
         );
     }
 
+    public static UnparsedModel createUnparsedConfig(String inferenceEntityId, TaskType taskType, String service, String secret) {
+        return new UnparsedModel(inferenceEntityId, taskType, service, Map.of("a", "b"), Map.of("secret", secret));
+    }
+
     private static class TestModelOfAnyKind extends ModelConfigurations {
 
         record TestModelServiceSettings() implements ServiceSettings {

+ 21 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/DefaultElserFeatureFlag.java

@@ -0,0 +1,21 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference;
+
+import org.elasticsearch.common.util.FeatureFlag;
+
+public class DefaultElserFeatureFlag {
+
+    private DefaultElserFeatureFlag() {}
+
+    private static final FeatureFlag FEATURE_FLAG = new FeatureFlag("inference_default_elser");
+
+    public static boolean isEnabled() {
+        return FEATURE_FLAG.isEnabled();
+    }
+}

+ 3 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

@@ -210,6 +210,9 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP
         // reference correctly
         var registry = new InferenceServiceRegistry(inferenceServices, factoryContext);
         registry.init(services.client());
+        for (var service : registry.getServices().values()) {
+            service.defaultConfigs().forEach(modelRegistry::addDefaultConfiguration);
+        }
         inferenceServiceRegistry.set(registry);
 
         var actionFilter = new ShardBulkInferenceActionFilter(registry, modelRegistry);

+ 2 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java

@@ -24,6 +24,7 @@ import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
 import org.elasticsearch.inference.InferenceServiceRegistry;
+import org.elasticsearch.inference.UnparsedModel;
 import org.elasticsearch.injection.guice.Inject;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.tasks.Task;
@@ -91,7 +92,7 @@ public class TransportDeleteInferenceEndpointAction extends TransportMasterNodeA
         ClusterState state,
         ActionListener<DeleteInferenceEndpointAction.Response> masterListener
     ) {
-        SubscribableListener.<ModelRegistry.UnparsedModel>newForked(modelConfigListener -> {
+        SubscribableListener.<UnparsedModel>newForked(modelConfigListener -> {
             // Get the model from the registry
 
             modelRegistry.getModel(request.getInferenceEndpointId(), modelConfigListener);

+ 2 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceModelAction.java

@@ -17,6 +17,7 @@ import org.elasticsearch.common.util.concurrent.EsExecutors;
 import org.elasticsearch.inference.InferenceServiceRegistry;
 import org.elasticsearch.inference.ModelConfigurations;
 import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.inference.UnparsedModel;
 import org.elasticsearch.injection.guice.Inject;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.tasks.Task;
@@ -112,7 +113,7 @@ public class TransportGetInferenceModelAction extends HandledTransportAction<
         );
     }
 
-    private GetInferenceModelAction.Response parseModels(List<ModelRegistry.UnparsedModel> unparsedModels) {
+    private GetInferenceModelAction.Response parseModels(List<UnparsedModel> unparsedModels) {
         var parsedModels = new ArrayList<ModelConfigurations>();
 
         for (var unparsedModel : unparsedModels) {

+ 20 - 19
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java

@@ -19,6 +19,7 @@ import org.elasticsearch.inference.InferenceServiceRegistry;
 import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.inference.UnparsedModel;
 import org.elasticsearch.injection.guice.Inject;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.tasks.Task;
@@ -64,30 +65,16 @@ public class TransportInferenceAction extends HandledTransportAction<InferenceAc
     @Override
     protected void doExecute(Task task, InferenceAction.Request request, ActionListener<InferenceAction.Response> listener) {
 
-        ActionListener<ModelRegistry.UnparsedModel> getModelListener = listener.delegateFailureAndWrap((delegate, unparsedModel) -> {
+        ActionListener<UnparsedModel> getModelListener = listener.delegateFailureAndWrap((delegate, unparsedModel) -> {
             var service = serviceRegistry.getService(unparsedModel.service());
             if (service.isEmpty()) {
-                delegate.onFailure(
-                    new ElasticsearchStatusException(
-                        "Unknown service [{}] for model [{}]. ",
-                        RestStatus.INTERNAL_SERVER_ERROR,
-                        unparsedModel.service(),
-                        unparsedModel.inferenceEntityId()
-                    )
-                );
+                listener.onFailure(unknownServiceException(unparsedModel.service(), request.getInferenceEntityId()));
                 return;
             }
 
             if (request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false) {
                 // not the wildcard task type and not the model task type
-                delegate.onFailure(
-                    new ElasticsearchStatusException(
-                        "Incompatible task_type, the requested type [{}] does not match the model type [{}]",
-                        RestStatus.BAD_REQUEST,
-                        request.getTaskType(),
-                        unparsedModel.taskType()
-                    )
-                );
+                listener.onFailure(incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType()));
                 return;
             }
 
@@ -98,7 +85,6 @@ public class TransportInferenceAction extends HandledTransportAction<InferenceAc
                     unparsedModel.settings(),
                     unparsedModel.secrets()
                 );
-            inferenceStats.incrementRequestCount(model);
             inferOnService(model, request, service.get(), delegate);
         });
 
@@ -112,6 +98,7 @@ public class TransportInferenceAction extends HandledTransportAction<InferenceAc
         ActionListener<InferenceAction.Response> listener
     ) {
         if (request.isStreaming() == false || service.canStream(request.getTaskType())) {
+            inferenceStats.incrementRequestCount(model);
             service.infer(
                 model,
                 request.getQuery(),
@@ -160,5 +147,19 @@ public class TransportInferenceAction extends HandledTransportAction<InferenceAc
             });
         }
         return listener.delegateFailureAndWrap((l, inferenceResults) -> l.onResponse(new InferenceAction.Response(inferenceResults)));
-    };
+    }
+
+    private static ElasticsearchStatusException unknownServiceException(String service, String inferenceId) {
+        return new ElasticsearchStatusException("Unknown service [{}] for model [{}]. ", RestStatus.BAD_REQUEST, service, inferenceId);
+    }
+
+    private static ElasticsearchStatusException incompatibleTaskTypeException(TaskType requested, TaskType expected) {
+        return new ElasticsearchStatusException(
+            "Incompatible task_type, the requested type [{}] does not match the model type [{}]",
+            RestStatus.BAD_REQUEST,
+            requested,
+            expected
+        );
+    }
+
 }

+ 3 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java

@@ -35,6 +35,7 @@ import org.elasticsearch.inference.InferenceService;
 import org.elasticsearch.inference.InferenceServiceRegistry;
 import org.elasticsearch.inference.InputType;
 import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.UnparsedModel;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
@@ -211,9 +212,9 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
             final Releasable onFinish
         ) {
             if (inferenceProvider == null) {
-                ActionListener<ModelRegistry.UnparsedModel> modelLoadingListener = new ActionListener<>() {
+                ActionListener<UnparsedModel> modelLoadingListener = new ActionListener<>() {
                     @Override
-                    public void onResponse(ModelRegistry.UnparsedModel unparsedModel) {
+                    public void onResponse(UnparsedModel unparsedModel) {
                         var service = inferenceServiceRegistry.getService(unparsedModel.service());
                         if (service.isEmpty() == false) {
                             var provider = new InferenceProvider(

+ 126 - 35
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java

@@ -32,6 +32,7 @@ import org.elasticsearch.index.reindex.DeleteByQueryRequest;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.ModelConfigurations;
 import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.inference.UnparsedModel;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.search.SearchHit;
 import org.elasticsearch.search.SearchHits;
@@ -48,6 +49,8 @@ import org.elasticsearch.xpack.inference.services.ServiceUtils;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Comparator;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.function.Function;
@@ -58,32 +61,19 @@ import static org.elasticsearch.core.Strings.format;
 public class ModelRegistry {
     public record ModelConfigMap(Map<String, Object> config, Map<String, Object> secrets) {}
 
-    /**
-     * Semi parsed model where inference entity id, task type and service
-     * are known but the settings are not parsed.
-     */
-    public record UnparsedModel(
-        String inferenceEntityId,
-        TaskType taskType,
-        String service,
-        Map<String, Object> settings,
-        Map<String, Object> secrets
-    ) {
-
-        public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap) {
-            if (modelConfigMap.config() == null) {
-                throw new ElasticsearchStatusException("Missing config map", RestStatus.BAD_REQUEST);
-            }
-            String inferenceEntityId = ServiceUtils.removeStringOrThrowIfNull(
-                modelConfigMap.config(),
-                ModelConfigurations.INDEX_ONLY_ID_FIELD_NAME
-            );
-            String service = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.SERVICE);
-            String taskTypeStr = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), TaskType.NAME);
-            TaskType taskType = TaskType.fromString(taskTypeStr);
-
-            return new UnparsedModel(inferenceEntityId, taskType, service, modelConfigMap.config(), modelConfigMap.secrets());
+    public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap) {
+        if (modelConfigMap.config() == null) {
+            throw new ElasticsearchStatusException("Missing config map", RestStatus.BAD_REQUEST);
         }
+        String inferenceEntityId = ServiceUtils.removeStringOrThrowIfNull(
+            modelConfigMap.config(),
+            ModelConfigurations.INDEX_ONLY_ID_FIELD_NAME
+        );
+        String service = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), ModelConfigurations.SERVICE);
+        String taskTypeStr = ServiceUtils.removeStringOrThrowIfNull(modelConfigMap.config(), TaskType.NAME);
+        TaskType taskType = TaskType.fromString(taskTypeStr);
+
+        return new UnparsedModel(inferenceEntityId, taskType, service, modelConfigMap.config(), modelConfigMap.secrets());
     }
 
     private static final String TASK_TYPE_FIELD = "task_type";
@@ -91,9 +81,27 @@ public class ModelRegistry {
     private static final Logger logger = LogManager.getLogger(ModelRegistry.class);
 
     private final OriginSettingClient client;
+    private Map<String, UnparsedModel> defaultConfigs;
 
     public ModelRegistry(Client client) {
         this.client = new OriginSettingClient(client, ClientHelper.INFERENCE_ORIGIN);
+        this.defaultConfigs = new HashMap<>();
+    }
+
+    public void addDefaultConfiguration(UnparsedModel serviceDefaultConfig) {
+        if (defaultConfigs.containsKey(serviceDefaultConfig.inferenceEntityId())) {
+            throw new IllegalStateException(
+                "Cannot add default endpoint to the inference endpoint registry with duplicate inference id ["
+                    + serviceDefaultConfig.inferenceEntityId()
+                    + "] declared by service ["
+                    + serviceDefaultConfig.service()
+                    + "]. The inference Id is already use by ["
+                    + defaultConfigs.get(serviceDefaultConfig.inferenceEntityId()).service()
+                    + "] service."
+            );
+        }
+
+        defaultConfigs.put(serviceDefaultConfig.inferenceEntityId(), serviceDefaultConfig);
     }
 
     /**
@@ -102,6 +110,11 @@ public class ModelRegistry {
      * @param listener Model listener
      */
     public void getModelWithSecrets(String inferenceEntityId, ActionListener<UnparsedModel> listener) {
+        if (defaultConfigs.containsKey(inferenceEntityId)) {
+            listener.onResponse(deepCopyDefaultConfig(defaultConfigs.get(inferenceEntityId)));
+            return;
+        }
+
         ActionListener<SearchResponse> searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> {
             // There should be a hit for the configurations and secrets
             if (searchResponse.getHits().getHits().length == 0) {
@@ -109,7 +122,7 @@ public class ModelRegistry {
                 return;
             }
 
-            delegate.onResponse(UnparsedModel.unparsedModelFromMap(createModelConfigMap(searchResponse.getHits(), inferenceEntityId)));
+            delegate.onResponse(unparsedModelFromMap(createModelConfigMap(searchResponse.getHits(), inferenceEntityId)));
         });
 
         QueryBuilder queryBuilder = documentIdQuery(inferenceEntityId);
@@ -128,6 +141,11 @@ public class ModelRegistry {
      * @param listener Model listener
      */
     public void getModel(String inferenceEntityId, ActionListener<UnparsedModel> listener) {
+        if (defaultConfigs.containsKey(inferenceEntityId)) {
+            listener.onResponse(deepCopyDefaultConfig(defaultConfigs.get(inferenceEntityId)));
+            return;
+        }
+
         ActionListener<SearchResponse> searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> {
             // There should be a hit for the configurations and secrets
             if (searchResponse.getHits().getHits().length == 0) {
@@ -135,7 +153,7 @@ public class ModelRegistry {
                 return;
             }
 
-            var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(UnparsedModel::unparsedModelFromMap).toList();
+            var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList();
             assert modelConfigs.size() == 1;
             delegate.onResponse(modelConfigs.get(0));
         });
@@ -162,14 +180,29 @@ public class ModelRegistry {
      */
     public void getModelsByTaskType(TaskType taskType, ActionListener<List<UnparsedModel>> listener) {
         ActionListener<SearchResponse> searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> {
+            var defaultConfigsForTaskType = defaultConfigs.values()
+                .stream()
+                .filter(m -> m.taskType() == taskType)
+                .map(ModelRegistry::deepCopyDefaultConfig)
+                .toList();
+
             // Not an error if no models of this task_type
-            if (searchResponse.getHits().getHits().length == 0) {
+            if (searchResponse.getHits().getHits().length == 0 && defaultConfigsForTaskType.isEmpty()) {
                 delegate.onResponse(List.of());
                 return;
             }
 
-            var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(UnparsedModel::unparsedModelFromMap).toList();
-            delegate.onResponse(modelConfigs);
+            var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList();
+
+            if (defaultConfigsForTaskType.isEmpty() == false) {
+                var allConfigs = new ArrayList<UnparsedModel>();
+                allConfigs.addAll(modelConfigs);
+                allConfigs.addAll(defaultConfigsForTaskType);
+                allConfigs.sort(Comparator.comparing(UnparsedModel::inferenceEntityId));
+                delegate.onResponse(allConfigs);
+            } else {
+                delegate.onResponse(modelConfigs);
+            }
         });
 
         QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.termsQuery(TASK_TYPE_FIELD, taskType.toString()));
@@ -191,14 +224,19 @@ public class ModelRegistry {
      */
     public void getAllModels(ActionListener<List<UnparsedModel>> listener) {
         ActionListener<SearchResponse> searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> {
-            // Not an error if no models of this task_type
-            if (searchResponse.getHits().getHits().length == 0) {
+            var defaults = defaultConfigs.values().stream().map(ModelRegistry::deepCopyDefaultConfig).toList();
+
+            if (searchResponse.getHits().getHits().length == 0 && defaults.isEmpty()) {
                 delegate.onResponse(List.of());
                 return;
             }
 
-            var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(UnparsedModel::unparsedModelFromMap).toList();
-            delegate.onResponse(modelConfigs);
+            var foundConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList();
+            var allConfigs = new ArrayList<UnparsedModel>();
+            allConfigs.addAll(foundConfigs);
+            allConfigs.addAll(defaults);
+            allConfigs.sort(Comparator.comparing(UnparsedModel::inferenceEntityId));
+            delegate.onResponse(allConfigs);
         });
 
         // In theory the index should only contain model config documents
@@ -216,7 +254,7 @@ public class ModelRegistry {
         client.search(modelSearch, searchListener);
     }
 
-    private List<ModelConfigMap> parseHitsAsModels(SearchHits hits) {
+    private ArrayList<ModelConfigMap> parseHitsAsModels(SearchHits hits) {
         var modelConfigs = new ArrayList<ModelConfigMap>();
         for (var hit : hits) {
             modelConfigs.add(new ModelConfigMap(hit.getSourceAsMap(), Map.of()));
@@ -393,4 +431,57 @@ public class ModelRegistry {
     private QueryBuilder documentIdQuery(String inferenceEntityId) {
         return QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(Model.documentId(inferenceEntityId)));
     }
+
+    static UnparsedModel deepCopyDefaultConfig(UnparsedModel other) {
+        // Because the default config uses immutable maps
+        return new UnparsedModel(
+            other.inferenceEntityId(),
+            other.taskType(),
+            other.service(),
+            copySettingsMap(other.settings()),
+            copySecretsMap(other.secrets())
+        );
+    }
+
+    @SuppressWarnings("unchecked")
+    static Map<String, Object> copySettingsMap(Map<String, Object> other) {
+        var result = new HashMap<String, Object>();
+
+        var serviceSettings = (Map<String, Object>) other.get(ModelConfigurations.SERVICE_SETTINGS);
+        if (serviceSettings != null) {
+            var copiedServiceSettings = copyMap1LevelDeep(serviceSettings);
+            result.put(ModelConfigurations.SERVICE_SETTINGS, copiedServiceSettings);
+        }
+
+        var taskSettings = (Map<String, Object>) other.get(ModelConfigurations.TASK_SETTINGS);
+        if (taskSettings != null) {
+            var copiedTaskSettings = copyMap1LevelDeep(taskSettings);
+            result.put(ModelConfigurations.TASK_SETTINGS, copiedTaskSettings);
+        }
+
+        var chunkSettings = (Map<String, Object>) other.get(ModelConfigurations.CHUNKING_SETTINGS);
+        if (chunkSettings != null) {
+            var copiedChunkSettings = copyMap1LevelDeep(chunkSettings);
+            result.put(ModelConfigurations.CHUNKING_SETTINGS, copiedChunkSettings);
+        }
+
+        return result;
+    }
+
+    static Map<String, Object> copySecretsMap(Map<String, Object> other) {
+        return copyMap1LevelDeep(other);
+    }
+
+    @SuppressWarnings("unchecked")
+    static Map<String, Object> copyMap1LevelDeep(Map<String, Object> other) {
+        var result = new HashMap<String, Object>();
+        for (var entry : other.entrySet()) {
+            if (entry.getValue() instanceof Map<?, ?>) {
+                result.put(entry.getKey(), new HashMap<>((Map<String, Object>) entry.getValue()));
+            } else {
+                result.put(entry.getKey(), entry.getValue());
+            }
+        }
+        return result;
+    }
 }

+ 26 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java

@@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference.services.elasticsearch;
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.ExceptionsHelper;
 import org.elasticsearch.ResourceNotFoundException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.SubscribableListener;
@@ -31,6 +32,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
 import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil;
+import org.elasticsearch.xpack.inference.DefaultElserFeatureFlag;
 import org.elasticsearch.xpack.inference.InferencePlugin;
 
 import java.io.IOException;
@@ -80,7 +82,6 @@ public abstract class BaseElasticsearchInternalService implements InferenceServi
     @Override
     public void start(Model model, ActionListener<Boolean> finalListener) {
         if (model instanceof ElasticsearchInternalModel esModel) {
-
             if (supportedTaskTypes().contains(model.getTaskType()) == false) {
                 finalListener.onFailure(
                     new IllegalStateException(TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()))
@@ -149,7 +150,7 @@ public abstract class BaseElasticsearchInternalService implements InferenceServi
         }
     }
 
-    private void putBuiltInModel(String modelId, ActionListener<Boolean> listener) {
+    protected void putBuiltInModel(String modelId, ActionListener<Boolean> listener) {
         var input = new TrainedModelInput(List.<String>of("text_field")); // by convention text_field is used
         var config = TrainedModelConfig.builder().setInput(input).setModelId(modelId).validate(true).build();
         PutTrainedModelAction.Request putRequest = new PutTrainedModelAction.Request(config, false, true);
@@ -258,4 +259,27 @@ public abstract class BaseElasticsearchInternalService implements InferenceServi
         request.setChunked(chunk);
         return request;
     }
+
+    protected abstract boolean isDefaultId(String inferenceId);
+
+    protected void maybeStartDeployment(
+        ElasticsearchInternalModel model,
+        Exception e,
+        InferModelAction.Request request,
+        ActionListener<InferModelAction.Response> listener
+    ) {
+        if (DefaultElserFeatureFlag.isEnabled() == false) {
+            listener.onFailure(e);
+            return;
+        }
+
+        if (isDefaultId(model.getInferenceEntityId()) && ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
+            this.start(
+                model,
+                listener.delegateFailureAndWrap((l, started) -> { client.execute(InferModelAction.INSTANCE, request, listener); })
+            );
+        } else {
+            listener.onFailure(e);
+        }
+    }
 }

+ 83 - 23
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

@@ -26,6 +26,7 @@ import org.elasticsearch.inference.InputType;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.ModelConfigurations;
 import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.inference.UnparsedModel;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
 import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults;
@@ -73,6 +74,8 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
         MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86
     );
 
+    public static final String DEFAULT_ELSER_ID = ".elser-2";
+
     private static final Logger logger = LogManager.getLogger(ElasticsearchInternalService.class);
     private static final DeprecationLogger DEPRECATION_LOGGER = DeprecationLogger.getLogger(ElasticsearchInternalService.class);
 
@@ -100,6 +103,17 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
         Map<String, Object> config,
         ActionListener<Model> modelListener
     ) {
+        if (inferenceEntityId.equals(DEFAULT_ELSER_ID)) {
+            modelListener.onFailure(
+                new ElasticsearchStatusException(
+                    "[{}] is a reserved inference Id. Cannot create a new inference endpoint with a reserved Id",
+                    RestStatus.BAD_REQUEST,
+                    inferenceEntityId
+                )
+            );
+            return;
+        }
+
         try {
             Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
             Map<String, Object> taskSettingsMap = removeFromMap(config, ModelConfigurations.TASK_SETTINGS);
@@ -459,20 +473,24 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
         TimeValue timeout,
         ActionListener<InferenceServiceResults> listener
     ) {
-        var taskType = model.getConfigurations().getTaskType();
-        if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
-            inferTextEmbedding(model, input, inputType, timeout, listener);
-        } else if (TaskType.RERANK.equals(taskType)) {
-            inferRerank(model, query, input, inputType, timeout, taskSettings, listener);
-        } else if (TaskType.SPARSE_EMBEDDING.equals(taskType)) {
-            inferSparseEmbedding(model, input, inputType, timeout, listener);
+        if (model instanceof ElasticsearchInternalModel esModel) {
+            var taskType = model.getConfigurations().getTaskType();
+            if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
+                inferTextEmbedding(esModel, input, inputType, timeout, listener);
+            } else if (TaskType.RERANK.equals(taskType)) {
+                inferRerank(esModel, query, input, inputType, timeout, taskSettings, listener);
+            } else if (TaskType.SPARSE_EMBEDDING.equals(taskType)) {
+                inferSparseEmbedding(esModel, input, inputType, timeout, listener);
+            } else {
+                throw new ElasticsearchStatusException(TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST);
+            }
         } else {
-            throw new ElasticsearchStatusException(TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST);
+            listener.onFailure(notElasticsearchModelException(model));
         }
     }
 
     public void inferTextEmbedding(
-        Model model,
+        ElasticsearchInternalModel model,
         List<String> inputs,
         InputType inputType,
         TimeValue timeout,
@@ -487,17 +505,19 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
             false
         );
 
-        client.execute(
-            InferModelAction.INSTANCE,
-            request,
-            listener.delegateFailureAndWrap(
-                (l, inferenceResult) -> l.onResponse(InferenceTextEmbeddingFloatResults.of(inferenceResult.getInferenceResults()))
-            )
+        ActionListener<InferModelAction.Response> mlResultsListener = listener.delegateFailureAndWrap(
+            (l, inferenceResult) -> l.onResponse(InferenceTextEmbeddingFloatResults.of(inferenceResult.getInferenceResults()))
+        );
+
+        var maybeDeployListener = mlResultsListener.delegateResponse(
+            (l, exception) -> maybeStartDeployment(model, exception, request, mlResultsListener)
         );
+
+        client.execute(InferModelAction.INSTANCE, request, maybeDeployListener);
     }
 
     public void inferSparseEmbedding(
-        Model model,
+        ElasticsearchInternalModel model,
         List<String> inputs,
         InputType inputType,
         TimeValue timeout,
@@ -512,17 +532,19 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
             false
         );
 
-        client.execute(
-            InferModelAction.INSTANCE,
-            request,
-            listener.delegateFailureAndWrap(
-                (l, inferenceResult) -> l.onResponse(SparseEmbeddingResults.of(inferenceResult.getInferenceResults()))
-            )
+        ActionListener<InferModelAction.Response> mlResultsListener = listener.delegateFailureAndWrap(
+            (l, inferenceResult) -> l.onResponse(SparseEmbeddingResults.of(inferenceResult.getInferenceResults()))
+        );
+
+        var maybeDeployListener = mlResultsListener.delegateResponse(
+            (l, exception) -> maybeStartDeployment(model, exception, request, mlResultsListener)
         );
+
+        client.execute(InferModelAction.INSTANCE, request, maybeDeployListener);
     }
 
     public void inferRerank(
-        Model model,
+        ElasticsearchInternalModel model,
         String query,
         List<String> inputs,
         InputType inputType,
@@ -671,4 +693,42 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
         Collections.sort(rankings);
         return new RankedDocsResults(rankings);
     }
+
+    @Override
+    public List<UnparsedModel> defaultConfigs() {
+        // TODO Chunking settings
+        Map<String, Object> elserSettings = Map.of(
+            ModelConfigurations.SERVICE_SETTINGS,
+            Map.of(
+                ElasticsearchInternalServiceSettings.MODEL_ID,
+                ElserModels.ELSER_V2_MODEL,  // TODO pick model depending on platform
+                ElasticsearchInternalServiceSettings.NUM_THREADS,
+                1,
+                ElasticsearchInternalServiceSettings.ADAPTIVE_ALLOCATIONS,
+                Map.of(
+                    "enabled",
+                    Boolean.TRUE,
+                    "min_number_of_allocations",
+                    1,
+                    "max_number_of_allocations",
+                    8   // no max?
+                )
+            )
+        );
+
+        return List.of(
+            new UnparsedModel(
+                DEFAULT_ELSER_ID,
+                TaskType.SPARSE_EMBEDDING,
+                NAME,
+                elserSettings,
+                Map.of() // no secrets
+            )
+        );
+    }
+
+    @Override
+    protected boolean isDefaultId(String inferenceId) {
+        return DEFAULT_ELSER_ID.equals(inferenceId);
+    }
 }

+ 3 - 3
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java

@@ -26,6 +26,7 @@ import org.elasticsearch.inference.InferenceService;
 import org.elasticsearch.inference.InferenceServiceRegistry;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.inference.UnparsedModel;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.test.ESTestCase;
@@ -266,12 +267,11 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
         ModelRegistry modelRegistry = mock(ModelRegistry.class);
         Answer<?> unparsedModelAnswer = invocationOnMock -> {
             String id = (String) invocationOnMock.getArguments()[0];
-            ActionListener<ModelRegistry.UnparsedModel> listener = (ActionListener<ModelRegistry.UnparsedModel>) invocationOnMock
-                .getArguments()[1];
+            ActionListener<UnparsedModel> listener = (ActionListener<UnparsedModel>) invocationOnMock.getArguments()[1];
             var model = modelMap.get(id);
             if (model != null) {
                 listener.onResponse(
-                    new ModelRegistry.UnparsedModel(
+                    new UnparsedModel(
                         model.getInferenceEntityId(),
                         model.getTaskType(),
                         model.getServiceSettings().model(),

+ 84 - 6
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java

@@ -23,6 +23,7 @@ import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.index.engine.VersionConflictEngineException;
 import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.inference.UnparsedModel;
 import org.elasticsearch.search.SearchHit;
 import org.elasticsearch.search.SearchHits;
 import org.elasticsearch.search.SearchResponseUtils;
@@ -38,9 +39,12 @@ import java.util.Map;
 import java.util.concurrent.TimeUnit;
 
 import static org.elasticsearch.core.Strings.format;
+import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.empty;
 import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.not;
+import static org.hamcrest.Matchers.sameInstance;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
@@ -68,7 +72,7 @@ public class ModelRegistryTests extends ESTestCase {
 
         var registry = new ModelRegistry(client);
 
-        var listener = new PlainActionFuture<ModelRegistry.UnparsedModel>();
+        var listener = new PlainActionFuture<UnparsedModel>();
         registry.getModelWithSecrets("1", listener);
 
         ResourceNotFoundException exception = expectThrows(ResourceNotFoundException.class, () -> listener.actionGet(TIMEOUT));
@@ -82,7 +86,7 @@ public class ModelRegistryTests extends ESTestCase {
 
         var registry = new ModelRegistry(client);
 
-        var listener = new PlainActionFuture<ModelRegistry.UnparsedModel>();
+        var listener = new PlainActionFuture<UnparsedModel>();
         registry.getModelWithSecrets("1", listener);
 
         IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> listener.actionGet(TIMEOUT));
@@ -99,7 +103,7 @@ public class ModelRegistryTests extends ESTestCase {
 
         var registry = new ModelRegistry(client);
 
-        var listener = new PlainActionFuture<ModelRegistry.UnparsedModel>();
+        var listener = new PlainActionFuture<UnparsedModel>();
         registry.getModelWithSecrets("1", listener);
 
         IllegalStateException exception = expectThrows(IllegalStateException.class, () -> listener.actionGet(TIMEOUT));
@@ -116,7 +120,7 @@ public class ModelRegistryTests extends ESTestCase {
 
         var registry = new ModelRegistry(client);
 
-        var listener = new PlainActionFuture<ModelRegistry.UnparsedModel>();
+        var listener = new PlainActionFuture<UnparsedModel>();
         registry.getModelWithSecrets("1", listener);
 
         IllegalStateException exception = expectThrows(IllegalStateException.class, () -> listener.actionGet(TIMEOUT));
@@ -150,7 +154,7 @@ public class ModelRegistryTests extends ESTestCase {
 
         var registry = new ModelRegistry(client);
 
-        var listener = new PlainActionFuture<ModelRegistry.UnparsedModel>();
+        var listener = new PlainActionFuture<UnparsedModel>();
         registry.getModelWithSecrets("1", listener);
 
         var modelConfig = listener.actionGet(TIMEOUT);
@@ -179,7 +183,7 @@ public class ModelRegistryTests extends ESTestCase {
 
         var registry = new ModelRegistry(client);
 
-        var listener = new PlainActionFuture<ModelRegistry.UnparsedModel>();
+        var listener = new PlainActionFuture<UnparsedModel>();
         registry.getModel("1", listener);
 
         registry.getModel("1", listener);
@@ -288,6 +292,80 @@ public class ModelRegistryTests extends ESTestCase {
         );
     }
 
+    @SuppressWarnings("unchecked")
+    public void testDeepCopyDefaultConfig() {
+        {
+            var toCopy = new UnparsedModel("tocopy", randomFrom(TaskType.values()), "service-a", Map.of(), Map.of());
+            var copied = ModelRegistry.deepCopyDefaultConfig(toCopy);
+            assertThat(copied, not(sameInstance(toCopy)));
+            assertThat(copied.taskType(), is(toCopy.taskType()));
+            assertThat(copied.service(), is(toCopy.service()));
+            assertThat(copied.secrets(), not(sameInstance(toCopy.secrets())));
+            assertThat(copied.secrets(), is(toCopy.secrets()));
+            // Test copied is a modifiable map
+            copied.secrets().put("foo", "bar");
+
+            assertThat(copied.settings(), not(sameInstance(toCopy.settings())));
+            assertThat(copied.settings(), is(toCopy.settings()));
+            // Test copied is a modifiable map
+            copied.settings().put("foo", "bar");
+        }
+
+        {
+            Map<String, Object> secretsMap = Map.of("secret", "value");
+            Map<String, Object> chunking = Map.of("strategy", "word");
+            Map<String, Object> task = Map.of("user", "name");
+            Map<String, Object> service = Map.of("num_threads", 1, "adaptive_allocations", Map.of("enabled", true));
+            Map<String, Object> settings = Map.of("chunking_settings", chunking, "service_settings", service, "task_settings", task);
+
+            var toCopy = new UnparsedModel("tocopy", randomFrom(TaskType.values()), "service-a", settings, secretsMap);
+            var copied = ModelRegistry.deepCopyDefaultConfig(toCopy);
+            assertThat(copied, not(sameInstance(toCopy)));
+
+            assertThat(copied.secrets(), not(sameInstance(toCopy.secrets())));
+            assertThat(copied.secrets(), is(toCopy.secrets()));
+            // Test copied is a modifiable map
+            copied.secrets().remove("secret");
+
+            assertThat(copied.settings(), not(sameInstance(toCopy.settings())));
+            assertThat(copied.settings(), is(toCopy.settings()));
+            // Test copied is a modifiable map
+            var chunkOut = (Map<String, Object>) copied.settings().get("chunking_settings");
+            assertThat(chunkOut, is(chunking));
+            chunkOut.remove("strategy");
+
+            var taskOut = (Map<String, Object>) copied.settings().get("task_settings");
+            assertThat(taskOut, is(task));
+            taskOut.remove("user");
+
+            var serviceOut = (Map<String, Object>) copied.settings().get("service_settings");
+            assertThat(serviceOut, is(service));
+            var adaptiveOut = (Map<String, Object>) serviceOut.remove("adaptive_allocations");
+            assertThat(adaptiveOut, is(Map.of("enabled", true)));
+            adaptiveOut.remove("enabled");
+        }
+    }
+
+    public void testDuplicateDefaultIds() {
+        var client = mockBulkClient();
+        var registry = new ModelRegistry(client);
+
+        var id = "my-inference";
+
+        registry.addDefaultConfiguration(new UnparsedModel(id, randomFrom(TaskType.values()), "service-a", Map.of(), Map.of()));
+        var ise = expectThrows(
+            IllegalStateException.class,
+            () -> registry.addDefaultConfiguration(new UnparsedModel(id, randomFrom(TaskType.values()), "service-b", Map.of(), Map.of()))
+        );
+        assertThat(
+            ise.getMessage(),
+            containsString(
+                "Cannot add default endpoint to the inference endpoint registry with duplicate inference id [my-inference] declared by "
+                    + "service [service-b]. The inference Id is already use by [service-a] service."
+            )
+        );
+    }
+
     private Client mockBulkClient() {
         var client = mockClient();
         when(client.prepareBulk()).thenReturn(new BulkRequestBuilder(client));

+ 6 - 3
x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_crud.yml

@@ -44,15 +44,18 @@
   - do:
       inference.get:
         inference_id: "*"
-  - length: { endpoints: 0}
+  - length: { endpoints: 1}
+  - match: { endpoints.0.inference_id: ".elser-2" }
 
   - do:
       inference.get:
         inference_id: _all
-  - length: { endpoints: 0}
+  - length: { endpoints: 1}
+  - match: { endpoints.0.inference_id: ".elser-2" }
 
   - do:
       inference.get:
         inference_id: ""
-  - length: { endpoints: 0}
+  - length: { endpoints: 1}
+  - match: { endpoints.0.inference_id: ".elser-2" }