Browse Source

Add Minimal Service Settings to the Model Registry (#120560) (#120946)

This commit introduces minimal service settings in the model registry, accessible without querying the inference index.
These settings are now available for the default models exposed by the inference service.

The ability to access settings without an inference index query is needed for the semantic text field, as it would benefit from eager validation of configuration during field creation.
This is not feasible currently because retrieving service settings relies on an asynchronous call to the inference index.

### Follow-Up Plans:
1. Extend this capability to include minimal service settings for all newly added models, making them accessible via the cluster state.
2. Update the semantic text field to eagerly retrieve service settings directly from the model registry.

Co-authored-by: Jonathan Buttner <56361221+jonathan-buttner@users.noreply.github.com>
Jim Ferenczi 8 months ago
parent
commit
7e03356faf
15 changed files with 439 additions and 254 deletions
  1. 1 1
      server/src/main/java/org/elasticsearch/inference/InferenceService.java
  2. 181 0
      server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java
  3. 43 0
      server/src/test/java/org/elasticsearch/inference/MinimalServiceSettingsTests.java
  4. 30 14
      x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java
  5. 2 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java
  6. 6 135
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java
  7. 33 14
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java
  8. 44 21
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java
  9. 4 3
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java
  10. 5 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettings.java
  11. 5 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallInternalServiceSettings.java
  12. 34 29
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java
  13. 10 25
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java
  14. 4 3
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java
  15. 37 8
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java

+ 1 - 1
server/src/main/java/org/elasticsearch/inference/InferenceService.java

@@ -219,7 +219,7 @@ public interface InferenceService extends Closeable {
         return supportedStreamingTasks().contains(taskType);
     }
 
-    record DefaultConfigId(String inferenceId, TaskType taskType, InferenceService service) {};
+    record DefaultConfigId(String inferenceId, MinimalServiceSettings settings, InferenceService service) {};
 
     /**
      * Get the Ids and task type of any default configurations provided by this service

+ 181 - 0
server/src/main/java/org/elasticsearch/inference/MinimalServiceSettings.java

@@ -0,0 +1,181 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.inference;
+
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
+import org.elasticsearch.xcontent.ConstructingObjectParser;
+import org.elasticsearch.xcontent.ParseField;
+import org.elasticsearch.xcontent.ToXContentObject;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentParser;
+
+import java.io.IOException;
+import java.util.Objects;
+
+import static org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.ElementType;
+import static org.elasticsearch.inference.TaskType.COMPLETION;
+import static org.elasticsearch.inference.TaskType.RERANK;
+import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING;
+import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING;
+
+/**
+ * Defines the base settings required to configure an inference endpoint.
+ *
+ * These settings are immutable and describe the input and output types that the endpoint will handle.
+ * They capture the essential properties of an inference model, ensuring the endpoint is correctly configured.
+ *
+ * Key properties include:
+ * <ul>
+ *   <li>{@code taskType} - Specifies the type of task the model performs, such as classification or text embeddings.</li>
+ *   <li>{@code dimensions}, {@code similarity}, and {@code elementType} - These settings are applicable only when
+ *       the {@code taskType} is {@link TaskType#TEXT_EMBEDDING}. They define the structure and behavior of embeddings.</li>
+ * </ul>
+ *
+ * @param taskType the type of task the inference model performs.
+ * @param dimensions the number of dimensions for the embeddings, applicable only for {@link TaskType#TEXT_EMBEDDING} (nullable).
+ * @param similarity the similarity measure used for embeddings, applicable only for {@link TaskType#TEXT_EMBEDDING} (nullable).
+ * @param elementType the type of elements in the embeddings, applicable only for {@link TaskType#TEXT_EMBEDDING} (nullable).
+ */
+public record MinimalServiceSettings(
+    TaskType taskType,
+    @Nullable Integer dimensions,
+    @Nullable SimilarityMeasure similarity,
+    @Nullable ElementType elementType
+) implements ToXContentObject {
+
+    public static final String TASK_TYPE_FIELD = "task_type";
+    static final String DIMENSIONS_FIELD = "dimensions";
+    static final String SIMILARITY_FIELD = "similarity";
+    static final String ELEMENT_TYPE_FIELD = "element_type";
+
+    private static final ConstructingObjectParser<MinimalServiceSettings, Void> PARSER = new ConstructingObjectParser<>(
+        "model_settings",
+        true,
+        args -> {
+            TaskType taskType = TaskType.fromString((String) args[0]);
+            Integer dimensions = (Integer) args[1];
+            SimilarityMeasure similarity = args[2] == null ? null : SimilarityMeasure.fromString((String) args[2]);
+            DenseVectorFieldMapper.ElementType elementType = args[3] == null
+                ? null
+                : DenseVectorFieldMapper.ElementType.fromString((String) args[3]);
+            return new MinimalServiceSettings(taskType, dimensions, similarity, elementType);
+        }
+    );
+
+    static {
+        PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(TASK_TYPE_FIELD));
+        PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField(DIMENSIONS_FIELD));
+        PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(SIMILARITY_FIELD));
+        PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ELEMENT_TYPE_FIELD));
+    }
+
+    public static MinimalServiceSettings parse(XContentParser parser) throws IOException {
+        return PARSER.parse(parser, null);
+    }
+
+    public static MinimalServiceSettings textEmbedding(int dimensions, SimilarityMeasure similarity, ElementType elementType) {
+        return new MinimalServiceSettings(TEXT_EMBEDDING, dimensions, similarity, elementType);
+    }
+
+    public static MinimalServiceSettings sparseEmbedding() {
+        return new MinimalServiceSettings(SPARSE_EMBEDDING, null, null, null);
+    }
+
+    public static MinimalServiceSettings rerank() {
+        return new MinimalServiceSettings(RERANK, null, null, null);
+    }
+
+    public static MinimalServiceSettings completion() {
+        return new MinimalServiceSettings(COMPLETION, null, null, null);
+    }
+
+    public MinimalServiceSettings(Model model) {
+        this(
+            model.getTaskType(),
+            model.getServiceSettings().dimensions(),
+            model.getServiceSettings().similarity(),
+            model.getServiceSettings().elementType()
+        );
+    }
+
+    public MinimalServiceSettings(
+        TaskType taskType,
+        @Nullable Integer dimensions,
+        @Nullable SimilarityMeasure similarity,
+        @Nullable ElementType elementType
+    ) {
+        this.taskType = Objects.requireNonNull(taskType, "task type must not be null");
+        this.dimensions = dimensions;
+        this.similarity = similarity;
+        this.elementType = elementType;
+        validate();
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(TASK_TYPE_FIELD, taskType.toString());
+        if (dimensions != null) {
+            builder.field(DIMENSIONS_FIELD, dimensions);
+        }
+        if (similarity != null) {
+            builder.field(SIMILARITY_FIELD, similarity);
+        }
+        if (elementType != null) {
+            builder.field(ELEMENT_TYPE_FIELD, elementType);
+        }
+        return builder.endObject();
+    }
+
+    @Override
+    public String toString() {
+        final StringBuilder sb = new StringBuilder();
+        sb.append("task_type=").append(taskType);
+        if (dimensions != null) {
+            sb.append(", dimensions=").append(dimensions);
+        }
+        if (similarity != null) {
+            sb.append(", similarity=").append(similarity);
+        }
+        if (elementType != null) {
+            sb.append(", element_type=").append(elementType);
+        }
+        return sb.toString();
+    }
+
+    private void validate() {
+        switch (taskType) {
+            case TEXT_EMBEDDING:
+                validateFieldPresent(DIMENSIONS_FIELD, dimensions);
+                validateFieldPresent(SIMILARITY_FIELD, similarity);
+                validateFieldPresent(ELEMENT_TYPE_FIELD, elementType);
+                break;
+
+            default:
+                validateFieldNotPresent(DIMENSIONS_FIELD, dimensions);
+                validateFieldNotPresent(SIMILARITY_FIELD, similarity);
+                validateFieldNotPresent(ELEMENT_TYPE_FIELD, elementType);
+                break;
+        }
+    }
+
+    private void validateFieldPresent(String field, Object fieldValue) {
+        if (fieldValue == null) {
+            throw new IllegalArgumentException("required [" + field + "] field is missing for task_type [" + taskType.name() + "]");
+        }
+    }
+
+    private void validateFieldNotPresent(String field, Object fieldValue) {
+        if (fieldValue != null) {
+            throw new IllegalArgumentException("[" + field + "] is not allowed for task_type [" + taskType.name() + "]");
+        }
+    }
+}

+ 43 - 0
server/src/test/java/org/elasticsearch/inference/MinimalServiceSettingsTests.java

@@ -0,0 +1,43 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
+ */
+
+package org.elasticsearch.inference;
+
+import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
+import org.elasticsearch.test.AbstractXContentTestCase;
+import org.elasticsearch.xcontent.XContentParser;
+
+import java.io.IOException;
+
+public class MinimalServiceSettingsTests extends AbstractXContentTestCase<MinimalServiceSettings> {
+    @Override
+    protected MinimalServiceSettings createTestInstance() {
+        TaskType taskType = randomFrom(TaskType.values());
+        Integer dimensions = null;
+        SimilarityMeasure similarity = null;
+        DenseVectorFieldMapper.ElementType elementType = null;
+
+        if (taskType == TaskType.TEXT_EMBEDDING) {
+            dimensions = randomIntBetween(2, 1024);
+            similarity = randomFrom(SimilarityMeasure.values());
+            elementType = randomFrom(DenseVectorFieldMapper.ElementType.values());
+        }
+        return new MinimalServiceSettings(taskType, dimensions, similarity, elementType);
+    }
+
+    @Override
+    protected MinimalServiceSettings doParseInstance(XContentParser parser) throws IOException {
+        return MinimalServiceSettings.parse(parser);
+    }
+
+    @Override
+    protected boolean supportsUnknownFields() {
+        return false;
+    }
+}

+ 30 - 14
x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java

@@ -15,13 +15,16 @@ import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.index.IndexNotFoundException;
+import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
 import org.elasticsearch.inference.InferenceService;
 import org.elasticsearch.inference.InferenceServiceExtension;
+import org.elasticsearch.inference.MinimalServiceSettings;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.ModelConfigurations;
 import org.elasticsearch.inference.ModelSecrets;
 import org.elasticsearch.inference.SecretSettings;
 import org.elasticsearch.inference.ServiceSettings;
+import org.elasticsearch.inference.SimilarityMeasure;
 import org.elasticsearch.inference.TaskSettings;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.inference.UnparsedModel;
@@ -34,6 +37,7 @@ import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
 import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
 import org.elasticsearch.xpack.inference.registry.ModelRegistry;
+import org.elasticsearch.xpack.inference.registry.ModelRegistryTests;
 import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalModel;
 import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;
 import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettingsTests;
@@ -305,9 +309,9 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
         var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
         for (int i = 0; i < defaultModelCount; i++) {
             var id = "default-" + i;
-            var taskType = randomFrom(TaskType.values());
-            defaultConfigs.add(createModel(id, taskType, serviceName));
-            defaultIds.add(new InferenceService.DefaultConfigId(id, taskType, service));
+            var modelSettings = ModelRegistryTests.randomMinimalServiceSettings();
+            defaultConfigs.add(createModel(id, modelSettings.taskType(), serviceName));
+            defaultIds.add(new InferenceService.DefaultConfigId(id, modelSettings, service));
         }
 
         doAnswer(invocation -> {
@@ -371,9 +375,9 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
         var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
         for (int i = 0; i < defaultModelCount; i++) {
             var id = "default-" + i;
-            var taskType = randomFrom(TaskType.values());
-            defaultConfigs.add(createModel(id, taskType, serviceName));
-            defaultIds.add(new InferenceService.DefaultConfigId(id, taskType, service));
+            var modelSettings = ModelRegistryTests.randomMinimalServiceSettings();
+            defaultConfigs.add(createModel(id, modelSettings.taskType(), serviceName));
+            defaultIds.add(new InferenceService.DefaultConfigId(id, modelSettings, service));
         }
 
         doAnswer(invocation -> {
@@ -414,9 +418,9 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
         var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
         for (int i = 0; i < defaultModelCount; i++) {
             var id = "default-" + i;
-            var taskType = randomFrom(TaskType.values());
-            defaultConfigs.add(createModel(id, taskType, serviceName));
-            defaultIds.add(new InferenceService.DefaultConfigId(id, taskType, service));
+            var modelSettings = ModelRegistryTests.randomMinimalServiceSettings();
+            defaultConfigs.add(createModel(id, modelSettings.taskType(), serviceName));
+            defaultIds.add(new InferenceService.DefaultConfigId(id, modelSettings, service));
         }
 
         doAnswer(invocation -> {
@@ -452,8 +456,14 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
 
         defaultConfigs.add(createModel("default-sparse", TaskType.SPARSE_EMBEDDING, serviceName));
         defaultConfigs.add(createModel("default-text", TaskType.TEXT_EMBEDDING, serviceName));
-        defaultIds.add(new InferenceService.DefaultConfigId("default-sparse", TaskType.SPARSE_EMBEDDING, service));
-        defaultIds.add(new InferenceService.DefaultConfigId("default-text", TaskType.TEXT_EMBEDDING, service));
+        defaultIds.add(new InferenceService.DefaultConfigId("default-sparse", MinimalServiceSettings.sparseEmbedding(), service));
+        defaultIds.add(
+            new InferenceService.DefaultConfigId(
+                "default-text",
+                MinimalServiceSettings.textEmbedding(384, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT),
+                service
+            )
+        );
 
         doAnswer(invocation -> {
             @SuppressWarnings("unchecked")
@@ -499,9 +509,15 @@ public class ModelRegistryIT extends ESSingleNodeTestCase {
 
         var service = mock(InferenceService.class);
         var defaultIds = new ArrayList<InferenceService.DefaultConfigId>();
-        defaultIds.add(new InferenceService.DefaultConfigId("default-sparse", TaskType.SPARSE_EMBEDDING, service));
-        defaultIds.add(new InferenceService.DefaultConfigId("default-text", TaskType.TEXT_EMBEDDING, service));
-        defaultIds.add(new InferenceService.DefaultConfigId("default-chat", TaskType.COMPLETION, service));
+        defaultIds.add(new InferenceService.DefaultConfigId("default-sparse", MinimalServiceSettings.sparseEmbedding(), service));
+        defaultIds.add(
+            new InferenceService.DefaultConfigId(
+                "default-text",
+                MinimalServiceSettings.textEmbedding(384, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT),
+                service
+            )
+        );
+        defaultIds.add(new InferenceService.DefaultConfigId("default-chat", MinimalServiceSettings.completion(), service));
 
         doAnswer(invocation -> {
             @SuppressWarnings("unchecked")

+ 2 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java

@@ -35,6 +35,7 @@ import org.elasticsearch.inference.ChunkedInference;
 import org.elasticsearch.inference.InferenceService;
 import org.elasticsearch.inference.InferenceServiceRegistry;
 import org.elasticsearch.inference.InputType;
+import org.elasticsearch.inference.MinimalServiceSettings;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.UnparsedModel;
 import org.elasticsearch.rest.RestStatus;
@@ -438,7 +439,7 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
                     useLegacyFormat ? inputs : null,
                     new SemanticTextField.InferenceResult(
                         inferenceFieldMetadata.getInferenceId(),
-                        model != null ? new SemanticTextField.ModelSettings(model) : null,
+                        model != null ? new MinimalServiceSettings(model) : null,
                         chunkMap
                     ),
                     indexRequest.getContentType()

+ 6 - 135
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java

@@ -14,11 +14,8 @@ import org.elasticsearch.common.xcontent.XContentParserUtils;
 import org.elasticsearch.common.xcontent.support.XContentMapValues;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.index.IndexVersions;
-import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
 import org.elasticsearch.inference.ChunkedInference;
-import org.elasticsearch.inference.Model;
-import org.elasticsearch.inference.SimilarityMeasure;
-import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.inference.MinimalServiceSettings;
 import org.elasticsearch.xcontent.ConstructingObjectParser;
 import org.elasticsearch.xcontent.DeprecationHandler;
 import org.elasticsearch.xcontent.NamedXContentRegistry;
@@ -38,10 +35,7 @@ import java.util.Iterator;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
-import java.util.Objects;
 
-import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING;
-import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING;
 import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
 import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
 
@@ -75,117 +69,13 @@ public record SemanticTextField(
     static final String CHUNKED_START_OFFSET_FIELD = "start_offset";
     static final String CHUNKED_END_OFFSET_FIELD = "end_offset";
     static final String MODEL_SETTINGS_FIELD = "model_settings";
-    static final String TASK_TYPE_FIELD = "task_type";
-    static final String DIMENSIONS_FIELD = "dimensions";
-    static final String SIMILARITY_FIELD = "similarity";
-    static final String ELEMENT_TYPE_FIELD = "element_type";
 
-    public record InferenceResult(String inferenceId, ModelSettings modelSettings, Map<String, List<Chunk>> chunks) {}
+    public record InferenceResult(String inferenceId, MinimalServiceSettings modelSettings, Map<String, List<Chunk>> chunks) {}
 
     public record Chunk(@Nullable String text, int startOffset, int endOffset, BytesReference rawEmbeddings) {}
 
     public record Offset(String sourceFieldName, int startOffset, int endOffset) {}
 
-    public record ModelSettings(
-        TaskType taskType,
-        Integer dimensions,
-        SimilarityMeasure similarity,
-        DenseVectorFieldMapper.ElementType elementType
-    ) implements ToXContentObject {
-        public ModelSettings(Model model) {
-            this(
-                model.getTaskType(),
-                model.getServiceSettings().dimensions(),
-                model.getServiceSettings().similarity(),
-                model.getServiceSettings().elementType()
-            );
-        }
-
-        public ModelSettings(
-            TaskType taskType,
-            Integer dimensions,
-            SimilarityMeasure similarity,
-            DenseVectorFieldMapper.ElementType elementType
-        ) {
-            this.taskType = Objects.requireNonNull(taskType, "task type must not be null");
-            this.dimensions = dimensions;
-            this.similarity = similarity;
-            this.elementType = elementType;
-            validate();
-        }
-
-        @Override
-        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-            builder.startObject();
-            builder.field(TASK_TYPE_FIELD, taskType.toString());
-            if (dimensions != null) {
-                builder.field(DIMENSIONS_FIELD, dimensions);
-            }
-            if (similarity != null) {
-                builder.field(SIMILARITY_FIELD, similarity);
-            }
-            if (elementType != null) {
-                builder.field(ELEMENT_TYPE_FIELD, elementType);
-            }
-            return builder.endObject();
-        }
-
-        @Override
-        public String toString() {
-            final StringBuilder sb = new StringBuilder();
-            sb.append("task_type=").append(taskType);
-            if (dimensions != null) {
-                sb.append(", dimensions=").append(dimensions);
-            }
-            if (similarity != null) {
-                sb.append(", similarity=").append(similarity);
-            }
-            if (elementType != null) {
-                sb.append(", element_type=").append(elementType);
-            }
-            return sb.toString();
-        }
-
-        private void validate() {
-            switch (taskType) {
-                case TEXT_EMBEDDING:
-                    validateFieldPresent(DIMENSIONS_FIELD, dimensions);
-                    validateFieldPresent(SIMILARITY_FIELD, similarity);
-                    validateFieldPresent(ELEMENT_TYPE_FIELD, elementType);
-                    break;
-                case SPARSE_EMBEDDING:
-                    validateFieldNotPresent(DIMENSIONS_FIELD, dimensions);
-                    validateFieldNotPresent(SIMILARITY_FIELD, similarity);
-                    validateFieldNotPresent(ELEMENT_TYPE_FIELD, elementType);
-                    break;
-
-                default:
-                    throw new IllegalArgumentException(
-                        "Wrong ["
-                            + TASK_TYPE_FIELD
-                            + "], expected "
-                            + TEXT_EMBEDDING
-                            + " or "
-                            + SPARSE_EMBEDDING
-                            + ", got "
-                            + taskType.name()
-                    );
-            }
-        }
-
-        private void validateFieldPresent(String field, Object fieldValue) {
-            if (fieldValue == null) {
-                throw new IllegalArgumentException("required [" + field + "] field is missing for task_type [" + taskType.name() + "]");
-            }
-        }
-
-        private void validateFieldNotPresent(String field, Object fieldValue) {
-            if (fieldValue != null) {
-                throw new IllegalArgumentException("[" + field + "] is not allowed for task_type [" + taskType.name() + "]");
-            }
-        }
-    }
-
     public static String getOriginalTextFieldName(String fieldName) {
         return fieldName + "." + TEXT_FIELD;
     }
@@ -212,7 +102,7 @@ public record SemanticTextField(
         return SEMANTIC_TEXT_FIELD_PARSER.parse(parser, context);
     }
 
-    static ModelSettings parseModelSettingsFromMap(Object node) {
+    static MinimalServiceSettings parseModelSettingsFromMap(Object node) {
         if (node == null) {
             return null;
         }
@@ -224,7 +114,7 @@ public record SemanticTextField(
                 map,
                 XContentType.JSON
             );
-            return MODEL_SETTINGS_PARSER.parse(parser, null);
+            return MinimalServiceSettings.parse(parser);
         } catch (Exception exc) {
             throw new ElasticsearchException(exc);
         }
@@ -307,7 +197,7 @@ public record SemanticTextField(
     private static final ConstructingObjectParser<InferenceResult, ParserContext> INFERENCE_RESULT_PARSER = new ConstructingObjectParser<>(
         INFERENCE_FIELD,
         true,
-        args -> new InferenceResult((String) args[0], (ModelSettings) args[1], (Map<String, List<Chunk>>) args[2])
+        args -> new InferenceResult((String) args[0], (MinimalServiceSettings) args[1], (Map<String, List<Chunk>>) args[2])
     );
 
     private static final ConstructingObjectParser<Chunk, ParserContext> CHUNKS_PARSER = new ConstructingObjectParser<>(
@@ -322,20 +212,6 @@ public record SemanticTextField(
         }
     );
 
-    private static final ConstructingObjectParser<ModelSettings, Void> MODEL_SETTINGS_PARSER = new ConstructingObjectParser<>(
-        MODEL_SETTINGS_FIELD,
-        true,
-        args -> {
-            TaskType taskType = TaskType.fromString((String) args[0]);
-            Integer dimensions = (Integer) args[1];
-            SimilarityMeasure similarity = args[2] == null ? null : SimilarityMeasure.fromString((String) args[2]);
-            DenseVectorFieldMapper.ElementType elementType = args[3] == null
-                ? null
-                : DenseVectorFieldMapper.ElementType.fromString((String) args[3]);
-            return new ModelSettings(taskType, dimensions, similarity, elementType);
-        }
-    );
-
     static {
         SEMANTIC_TEXT_FIELD_PARSER.declareStringArray(optionalConstructorArg(), new ParseField(TEXT_FIELD));
         SEMANTIC_TEXT_FIELD_PARSER.declareObject(constructorArg(), INFERENCE_RESULT_PARSER, new ParseField(INFERENCE_FIELD));
@@ -343,7 +219,7 @@ public record SemanticTextField(
         INFERENCE_RESULT_PARSER.declareString(constructorArg(), new ParseField(INFERENCE_ID_FIELD));
         INFERENCE_RESULT_PARSER.declareObjectOrNull(
             constructorArg(),
-            (p, c) -> MODEL_SETTINGS_PARSER.parse(p, null),
+            (p, c) -> MinimalServiceSettings.parse(p),
             null,
             new ParseField(MODEL_SETTINGS_FIELD)
         );
@@ -362,11 +238,6 @@ public record SemanticTextField(
             b.copyCurrentStructure(p);
             return BytesReference.bytes(b);
         }, new ParseField(CHUNKED_EMBEDDINGS_FIELD), ObjectParser.ValueType.OBJECT_ARRAY);
-
-        MODEL_SETTINGS_PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(TASK_TYPE_FIELD));
-        MODEL_SETTINGS_PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), new ParseField(DIMENSIONS_FIELD));
-        MODEL_SETTINGS_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(SIMILARITY_FIELD));
-        MODEL_SETTINGS_PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ELEMENT_TYPE_FIELD));
     }
 
     private static Map<String, List<Chunk>> parseChunksMap(XContentParser parser, ParserContext context) throws IOException {

+ 33 - 14
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java

@@ -59,6 +59,7 @@ import org.elasticsearch.index.query.NestedQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.SearchExecutionContext;
 import org.elasticsearch.inference.InferenceResults;
+import org.elasticsearch.inference.MinimalServiceSettings;
 import org.elasticsearch.inference.SimilarityMeasure;
 import org.elasticsearch.search.fetch.StoredFieldsSpec;
 import org.elasticsearch.search.lookup.Source;
@@ -87,6 +88,8 @@ import java.util.Set;
 import java.util.function.BiConsumer;
 import java.util.function.Function;
 
+import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING;
+import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING;
 import static org.elasticsearch.search.SearchService.DEFAULT_SIZE;
 import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD;
 import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_OFFSET_FIELD;
@@ -162,7 +165,7 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie
             }
         });
 
-        private final Parameter<SemanticTextField.ModelSettings> modelSettings = new Parameter<>(
+        private final Parameter<MinimalServiceSettings> modelSettings = new Parameter<>(
             MODEL_SETTINGS_FIELD,
             true,
             () -> null,
@@ -209,7 +212,7 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie
             return this;
         }
 
-        public Builder setModelSettings(SemanticTextField.ModelSettings value) {
+        public Builder setModelSettings(MinimalServiceSettings value) {
             this.modelSettings.setValue(value);
             return this;
         }
@@ -240,6 +243,9 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie
             if (useLegacyFormat && multiFieldsBuilder.hasMultiFields()) {
                 throw new IllegalArgumentException(CONTENT_TYPE + " field [" + leafName() + "] does not support multi-fields");
             }
+            if (modelSettings.get() != null) {
+                validateServiceSettings(modelSettings.get());
+            }
             final String fullName = context.buildFullName(leafName());
 
             if (context.isInNestedContext()) {
@@ -263,9 +269,26 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie
             );
         }
 
+        private void validateServiceSettings(MinimalServiceSettings settings) {
+            switch (settings.taskType()) {
+                case SPARSE_EMBEDDING, TEXT_EMBEDDING -> {
+                }
+                default -> throw new IllegalArgumentException(
+                    "Wrong ["
+                        + MinimalServiceSettings.TASK_TYPE_FIELD
+                        + "], expected "
+                        + TEXT_EMBEDDING
+                        + " or "
+                        + SPARSE_EMBEDDING
+                        + ", got "
+                        + settings.taskType().name()
+                );
+            }
+        }
+
         /**
          * As necessary, copy settings from this builder to the passed-in mapper.
-         * Used to preserve {@link SemanticTextField.ModelSettings} when updating a semantic text mapping to one where the model settings
+         * Used to preserve {@link MinimalServiceSettings} when updating a semantic text mapping to one where the model settings
          * are not specified.
          *
          * @param mapper The mapper
@@ -524,7 +547,7 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie
     public static class SemanticTextFieldType extends SimpleMappedFieldType {
         private final String inferenceId;
         private final String searchInferenceId;
-        private final SemanticTextField.ModelSettings modelSettings;
+        private final MinimalServiceSettings modelSettings;
         private final ObjectMapper inferenceField;
         private final boolean useLegacyFormat;
 
@@ -532,7 +555,7 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie
             String name,
             String inferenceId,
             String searchInferenceId,
-            SemanticTextField.ModelSettings modelSettings,
+            MinimalServiceSettings modelSettings,
             ObjectMapper inferenceField,
             boolean useLegacyFormat,
             Map<String, String> meta
@@ -567,7 +590,7 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie
             return searchInferenceId == null ? inferenceId : searchInferenceId;
         }
 
-        public SemanticTextField.ModelSettings getModelSettings() {
+        public MinimalServiceSettings getModelSettings() {
             return modelSettings;
         }
 
@@ -881,7 +904,7 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie
         MapperBuilderContext context,
         IndexVersion indexVersionCreated,
         boolean useLegacyFormat,
-        @Nullable SemanticTextField.ModelSettings modelSettings,
+        @Nullable MinimalServiceSettings modelSettings,
         Function<Query, BitSetProducer> bitSetProducer,
         IndexSettings indexSettings
     ) {
@@ -893,7 +916,7 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie
     private static NestedObjectMapper.Builder createChunksField(
         IndexVersion indexVersionCreated,
         boolean useLegacyFormat,
-        @Nullable SemanticTextField.ModelSettings modelSettings,
+        @Nullable MinimalServiceSettings modelSettings,
         Function<Query, BitSetProducer> bitSetProducer,
         IndexSettings indexSettings
     ) {
@@ -918,7 +941,7 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie
 
     private static Mapper.Builder createEmbeddingsField(
         IndexVersion indexVersionCreated,
-        SemanticTextField.ModelSettings modelSettings,
+        MinimalServiceSettings modelSettings,
         boolean useLegacyFormat
     ) {
         return switch (modelSettings.taskType()) {
@@ -949,11 +972,7 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie
         };
     }
 
-    private static boolean canMergeModelSettings(
-        SemanticTextField.ModelSettings previous,
-        SemanticTextField.ModelSettings current,
-        Conflicts conflicts
-    ) {
+    private static boolean canMergeModelSettings(MinimalServiceSettings previous, MinimalServiceSettings current, Conflicts conflicts) {
         if (Objects.equals(previous, current)) {
             return true;
         }

+ 44 - 21
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java

@@ -36,6 +36,7 @@ import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.index.reindex.DeleteByQueryAction;
 import org.elasticsearch.index.reindex.DeleteByQueryRequest;
 import org.elasticsearch.inference.InferenceService;
+import org.elasticsearch.inference.MinimalServiceSettings;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.ModelConfigurations;
 import org.elasticsearch.inference.TaskType;
@@ -56,6 +57,7 @@ import org.elasticsearch.xpack.inference.services.ServiceUtils;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.Comparator;
 import java.util.HashMap;
@@ -103,33 +105,33 @@ public class ModelRegistry {
     private static final Logger logger = LogManager.getLogger(ModelRegistry.class);
 
     private final OriginSettingClient client;
-    private final List<InferenceService.DefaultConfigId> defaultConfigIds;
+    private final Map<String, InferenceService.DefaultConfigId> defaultConfigIds;
 
     private final Set<String> preventDeletionLock = Collections.newSetFromMap(new ConcurrentHashMap<>());
 
     public ModelRegistry(Client client) {
         this.client = new OriginSettingClient(client, ClientHelper.INFERENCE_ORIGIN);
-        defaultConfigIds = new ArrayList<>();
+        defaultConfigIds = new HashMap<>();
     }
 
     /**
      * Set the default inference ids provided by the services
-     * @param defaultConfigIds The defaults
+     * @param defaultConfigId The default
      */
-    public void addDefaultIds(InferenceService.DefaultConfigId defaultConfigIds) {
-        var matched = idMatchedDefault(defaultConfigIds.inferenceId(), this.defaultConfigIds);
-        if (matched.isPresent()) {
+    public synchronized void addDefaultIds(InferenceService.DefaultConfigId defaultConfigId) {
+        var config = defaultConfigIds.get(defaultConfigId.inferenceId());
+        if (config != null) {
             throw new IllegalStateException(
                 "Cannot add default endpoint to the inference endpoint registry with duplicate inference id ["
-                    + defaultConfigIds.inferenceId()
+                    + defaultConfigId.inferenceId()
                     + "] declared by service ["
-                    + defaultConfigIds.service().name()
+                    + defaultConfigId.service().name()
                     + "]. The inference Id is already use by ["
-                    + matched.get().service().name()
+                    + config.service().name()
                     + "] service."
             );
         }
-        this.defaultConfigIds.add(defaultConfigIds);
+        defaultConfigIds.put(defaultConfigId.inferenceId(), defaultConfigId);
     }
 
     /**
@@ -141,9 +143,9 @@ public class ModelRegistry {
         ActionListener<SearchResponse> searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> {
             // There should be a hit for the configurations
             if (searchResponse.getHits().getHits().length == 0) {
-                var maybeDefault = idMatchedDefault(inferenceEntityId, defaultConfigIds);
-                if (maybeDefault.isPresent()) {
-                    getDefaultConfig(true, maybeDefault.get(), listener);
+                var maybeDefault = defaultConfigIds.get(inferenceEntityId);
+                if (maybeDefault != null) {
+                    getDefaultConfig(true, maybeDefault, listener);
                 } else {
                     delegate.onFailure(inferenceNotFoundException(inferenceEntityId));
                 }
@@ -172,9 +174,9 @@ public class ModelRegistry {
         ActionListener<SearchResponse> searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> {
             // There should be a hit for the configurations
             if (searchResponse.getHits().getHits().length == 0) {
-                var maybeDefault = idMatchedDefault(inferenceEntityId, defaultConfigIds);
-                if (maybeDefault.isPresent()) {
-                    getDefaultConfig(true, maybeDefault.get(), listener);
+                var maybeDefault = defaultConfigIds.get(inferenceEntityId);
+                if (maybeDefault != null) {
+                    getDefaultConfig(true, maybeDefault, listener);
                 } else {
                     delegate.onFailure(inferenceNotFoundException(inferenceEntityId));
                 }
@@ -196,6 +198,27 @@ public class ModelRegistry {
         client.search(modelSearch, searchListener);
     }
 
+    /**
+     * Retrieves the {@link MinimalServiceSettings} associated with the specified {@code inferenceEntityId}.
+     *
+     * If the {@code inferenceEntityId} is not found, the method behaves as follows:
+     * <ul>
+     *   <li>Returns {@code null} if the id might exist but its configuration is not available locally.</li>
+     *   <li>Throws a {@link ResourceNotFoundException} if it is certain that the id does not exist in the cluster.</li>
+     * </ul>
+     *
+     * @param inferenceEntityId the unique identifier for the inference entity.
+     * @return the {@link MinimalServiceSettings} associated with the provided ID, or {@code null} if unavailable locally.
+     * @throws ResourceNotFoundException if the specified id is guaranteed to not exist in the cluster.
+     */
+    public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId) throws ResourceNotFoundException {
+        var config = defaultConfigIds.get(inferenceEntityId);
+        if (config != null) {
+            return config.settings();
+        }
+        return null;
+    }
+
     private ResourceNotFoundException inferenceNotFoundException(String inferenceEntityId) {
         return new ResourceNotFoundException("Inference endpoint not found [{}]", inferenceEntityId);
     }
@@ -209,7 +232,7 @@ public class ModelRegistry {
     public void getModelsByTaskType(TaskType taskType, ActionListener<List<UnparsedModel>> listener) {
         ActionListener<SearchResponse> searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> {
             var modelConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList();
-            var defaultConfigsForTaskType = taskTypeMatchedDefaults(taskType, defaultConfigIds);
+            var defaultConfigsForTaskType = taskTypeMatchedDefaults(taskType, defaultConfigIds.values());
             addAllDefaultConfigsIfMissing(true, modelConfigs, defaultConfigsForTaskType, delegate);
         });
 
@@ -240,7 +263,7 @@ public class ModelRegistry {
     public void getAllModels(boolean persistDefaultEndpoints, ActionListener<List<UnparsedModel>> listener) {
         ActionListener<SearchResponse> searchListener = listener.delegateFailureAndWrap((delegate, searchResponse) -> {
             var foundConfigs = parseHitsAsModels(searchResponse.getHits()).stream().map(ModelRegistry::unparsedModelFromMap).toList();
-            addAllDefaultConfigsIfMissing(persistDefaultEndpoints, foundConfigs, defaultConfigIds, delegate);
+            addAllDefaultConfigsIfMissing(persistDefaultEndpoints, foundConfigs, defaultConfigIds.values(), delegate);
         });
 
         // In theory the index should only contain model config documents
@@ -261,7 +284,7 @@ public class ModelRegistry {
     private void addAllDefaultConfigsIfMissing(
         boolean persistDefaultEndpoints,
         List<UnparsedModel> foundConfigs,
-        List<InferenceService.DefaultConfigId> matchedDefaults,
+        Collection<InferenceService.DefaultConfigId> matchedDefaults,
         ActionListener<List<UnparsedModel>> listener
     ) {
         var foundIds = foundConfigs.stream().map(UnparsedModel::inferenceEntityId).collect(Collectors.toSet());
@@ -671,10 +694,10 @@ public class ModelRegistry {
 
     static List<InferenceService.DefaultConfigId> taskTypeMatchedDefaults(
         TaskType taskType,
-        List<InferenceService.DefaultConfigId> defaultConfigIds
+        Collection<InferenceService.DefaultConfigId> defaultConfigIds
     ) {
         return defaultConfigIds.stream()
-            .filter(defaultConfigId -> defaultConfigId.taskType().equals(taskType))
+            .filter(defaultConfigId -> defaultConfigId.settings().taskType().equals(taskType))
             .collect(Collectors.toList());
     }
 }

+ 4 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

@@ -25,6 +25,7 @@ import org.elasticsearch.inference.InferenceServiceConfiguration;
 import org.elasticsearch.inference.InferenceServiceExtension;
 import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.inference.InputType;
+import org.elasticsearch.inference.MinimalServiceSettings;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.ModelConfigurations;
 import org.elasticsearch.inference.SettingsConfiguration;
@@ -828,9 +829,9 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
 
     public List<DefaultConfigId> defaultConfigIds() {
         return List.of(
-            new DefaultConfigId(DEFAULT_ELSER_ID, TaskType.SPARSE_EMBEDDING, this),
-            new DefaultConfigId(DEFAULT_E5_ID, TaskType.TEXT_EMBEDDING, this),
-            new DefaultConfigId(DEFAULT_RERANK_ID, TaskType.RERANK, this)
+            new DefaultConfigId(DEFAULT_ELSER_ID, ElserInternalServiceSettings.minimalServiceSettings(), this),
+            new DefaultConfigId(DEFAULT_E5_ID, MultilingualE5SmallInternalServiceSettings.minimalServiceSettings(), this),
+            new DefaultConfigId(DEFAULT_RERANK_ID, MinimalServiceSettings.rerank(), this)
         );
     }
 

+ 5 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElserInternalServiceSettings.java

@@ -11,6 +11,7 @@ import org.elasticsearch.TransportVersion;
 import org.elasticsearch.TransportVersions;
 import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.inference.MinimalServiceSettings;
 import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
 
 import java.io.IOException;
@@ -21,6 +22,10 @@ public class ElserInternalServiceSettings extends ElasticsearchInternalServiceSe
 
     public static final String NAME = "elser_mlnode_service_settings";
 
+    public static MinimalServiceSettings minimalServiceSettings() {
+        return MinimalServiceSettings.sparseEmbedding();
+    }
+
     public static Builder fromRequestMap(Map<String, Object> map) {
         ValidationException validationException = new ValidationException();
         var baseSettings = ElasticsearchInternalServiceSettings.fromMap(map, validationException);

+ 5 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallInternalServiceSettings.java

@@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference.services.elasticsearch;
 import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
+import org.elasticsearch.inference.MinimalServiceSettings;
 import org.elasticsearch.inference.SimilarityMeasure;
 import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
 
@@ -24,6 +25,10 @@ public class MultilingualE5SmallInternalServiceSettings extends ElasticsearchInt
     static final int DIMENSIONS = 384;
     static final SimilarityMeasure SIMILARITY = SimilarityMeasure.COSINE;
 
+    public static MinimalServiceSettings minimalServiceSettings() {
+        return MinimalServiceSettings.textEmbedding(DIMENSIONS, SIMILARITY, DenseVectorFieldMapper.ElementType.FLOAT);
+    }
+
     public MultilingualE5SmallInternalServiceSettings(ElasticsearchInternalServiceSettings other) {
         super(other);
     }

+ 34 - 29
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java

@@ -53,6 +53,7 @@ import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper;
 import org.elasticsearch.index.mapper.vectors.XFeatureField;
 import org.elasticsearch.index.query.SearchExecutionContext;
 import org.elasticsearch.index.search.ESToParentBlockJoinQuery;
+import org.elasticsearch.inference.MinimalServiceSettings;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.SimilarityMeasure;
 import org.elasticsearch.inference.TaskType;
@@ -292,6 +293,28 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
         }
     }
 
+    public void testInvalidTaskTypes() {
+        for (var taskType : TaskType.values()) {
+            if (taskType == TaskType.TEXT_EMBEDDING || taskType == TaskType.SPARSE_EMBEDDING) {
+                continue;
+            }
+            Exception e = expectThrows(
+                MapperParsingException.class,
+                () -> createMapperService(
+                    fieldMapping(
+                        b -> b.field("type", "semantic_text")
+                            .field(INFERENCE_ID_FIELD, "test1")
+                            .startObject("model_settings")
+                            .field("task_type", taskType)
+                            .endObject()
+                    ),
+                    useLegacyFormat
+                )
+            );
+            assertThat(e.getMessage(), containsString("Failed to parse mapping: Wrong [task_type]"));
+        }
+    }
+
     public void testMultiFieldsSupport() throws IOException {
         if (useLegacyFormat) {
             Exception e = expectThrows(MapperParsingException.class, () -> createMapperService(fieldMapping(b -> {
@@ -392,7 +415,7 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
             MapperService mapperService = mapperServiceForFieldWithModelSettings(
                 fieldName,
                 inferenceId,
-                new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null, null)
+                new MinimalServiceSettings(TaskType.SPARSE_EMBEDDING, null, null, null)
             );
             assertSemanticTextField(mapperService, fieldName, true);
             assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId);
@@ -403,7 +426,7 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
                 fieldName,
                 inferenceId,
                 searchInferenceId,
-                new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null, null)
+                new MinimalServiceSettings(TaskType.SPARSE_EMBEDDING, null, null, null)
             );
             assertSemanticTextField(mapperService, fieldName, true);
             assertInferenceEndpoints(mapperService, fieldName, inferenceId, searchInferenceId);
@@ -523,7 +546,7 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
             mapperService = mapperServiceForFieldWithModelSettings(
                 fieldName,
                 inferenceId,
-                new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null, null)
+                new MinimalServiceSettings(TaskType.SPARSE_EMBEDDING, null, null, null)
             );
             assertSemanticTextField(mapperService, fieldName, true);
             assertInferenceEndpoints(mapperService, fieldName, inferenceId, inferenceId);
@@ -742,7 +765,7 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
                         useLegacyFormat,
                         b -> b.startObject("field")
                             .startObject(INFERENCE_FIELD)
-                            .field(MODEL_SETTINGS_FIELD, new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null, null))
+                            .field(MODEL_SETTINGS_FIELD, new MinimalServiceSettings(TaskType.SPARSE_EMBEDDING, null, null, null))
                             .field(CHUNKS_FIELD, useLegacyFormat ? List.of() : Map.of())
                             .endObject()
                             .endObject()
@@ -804,24 +827,14 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
         MapperService floatMapperService = mapperServiceForFieldWithModelSettings(
             fieldName,
             inferenceId,
-            new SemanticTextField.ModelSettings(
-                TaskType.TEXT_EMBEDDING,
-                1024,
-                SimilarityMeasure.COSINE,
-                DenseVectorFieldMapper.ElementType.FLOAT
-            )
+            new MinimalServiceSettings(TaskType.TEXT_EMBEDDING, 1024, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT)
         );
         assertMapperService.accept(floatMapperService, DenseVectorFieldMapper.ElementType.FLOAT);
 
         MapperService byteMapperService = mapperServiceForFieldWithModelSettings(
             fieldName,
             inferenceId,
-            new SemanticTextField.ModelSettings(
-                TaskType.TEXT_EMBEDDING,
-                1024,
-                SimilarityMeasure.COSINE,
-                DenseVectorFieldMapper.ElementType.BYTE
-            )
+            new MinimalServiceSettings(TaskType.TEXT_EMBEDDING, 1024, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.BYTE)
         );
         assertMapperService.accept(byteMapperService, DenseVectorFieldMapper.ElementType.BYTE);
     }
@@ -855,11 +868,8 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
         assertThat(ex.getMessage(), containsString("[model_settings] must be set for field [field] when chunks are provided"));
     }
 
-    private MapperService mapperServiceForFieldWithModelSettings(
-        String fieldName,
-        String inferenceId,
-        SemanticTextField.ModelSettings modelSettings
-    ) throws IOException {
+    private MapperService mapperServiceForFieldWithModelSettings(String fieldName, String inferenceId, MinimalServiceSettings modelSettings)
+        throws IOException {
         return mapperServiceForFieldWithModelSettings(fieldName, inferenceId, null, modelSettings);
     }
 
@@ -867,7 +877,7 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
         String fieldName,
         String inferenceId,
         String searchInferenceId,
-        SemanticTextField.ModelSettings modelSettings
+        MinimalServiceSettings modelSettings
     ) throws IOException {
         String mappingParams = "type=semantic_text,inference_id=" + inferenceId;
         if (searchInferenceId != null) {
@@ -914,7 +924,7 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
         MapperService mapperService = mapperServiceForFieldWithModelSettings(
             fieldName,
             inferenceId,
-            new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null, null)
+            new MinimalServiceSettings(TaskType.SPARSE_EMBEDDING, null, null, null)
         );
 
         Mapper mapper = mapperService.mappingLookup().getMapper(fieldName);
@@ -931,12 +941,7 @@ public class SemanticTextFieldMapperTests extends MapperTestCase {
         MapperService mapperService = mapperServiceForFieldWithModelSettings(
             fieldName,
             inferenceId,
-            new SemanticTextField.ModelSettings(
-                TaskType.TEXT_EMBEDDING,
-                1024,
-                SimilarityMeasure.COSINE,
-                DenseVectorFieldMapper.ElementType.FLOAT
-            )
+            new MinimalServiceSettings(TaskType.TEXT_EMBEDDING, 1024, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT)
         );
 
         Mapper mapper = mapperService.mappingLookup().getMapper(fieldName);

+ 10 - 25
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java

@@ -13,6 +13,7 @@ import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.xcontent.XContentHelper;
 import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
 import org.elasticsearch.inference.ChunkedInference;
+import org.elasticsearch.inference.MinimalServiceSettings;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.SimilarityMeasure;
 import org.elasticsearch.inference.TaskType;
@@ -66,7 +67,7 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTex
         assertThat(newInstance.originalValues(), equalTo(expectedInstance.originalValues()));
         assertThat(newInstance.inference().modelSettings(), equalTo(expectedInstance.inference().modelSettings()));
         assertThat(newInstance.inference().chunks().size(), equalTo(expectedInstance.inference().chunks().size()));
-        SemanticTextField.ModelSettings modelSettings = newInstance.inference().modelSettings();
+        MinimalServiceSettings modelSettings = newInstance.inference().modelSettings();
         for (var entry : newInstance.inference().chunks().entrySet()) {
             var expectedChunks = expectedInstance.inference().chunks().get(entry.getKey());
             assertNotNull(expectedChunks);
@@ -133,53 +134,37 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTex
 
     public void testModelSettingsValidation() {
         NullPointerException npe = expectThrows(NullPointerException.class, () -> {
-            new SemanticTextField.ModelSettings(null, 10, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT);
+            new MinimalServiceSettings(null, 10, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT);
         });
         assertThat(npe.getMessage(), equalTo("task type must not be null"));
 
         IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> {
-            new SemanticTextField.ModelSettings(
-                TaskType.COMPLETION,
-                10,
-                SimilarityMeasure.COSINE,
-                DenseVectorFieldMapper.ElementType.FLOAT
-            );
+            new MinimalServiceSettings(TaskType.SPARSE_EMBEDDING, 10, null, null);
         });
-        assertThat(ex.getMessage(), containsString("Wrong [task_type]"));
-
-        ex = expectThrows(
-            IllegalArgumentException.class,
-            () -> { new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, 10, null, null); }
-        );
         assertThat(ex.getMessage(), containsString("[dimensions] is not allowed"));
 
         ex = expectThrows(IllegalArgumentException.class, () -> {
-            new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, SimilarityMeasure.COSINE, null);
+            new MinimalServiceSettings(TaskType.SPARSE_EMBEDDING, null, SimilarityMeasure.COSINE, null);
         });
         assertThat(ex.getMessage(), containsString("[similarity] is not allowed"));
 
         ex = expectThrows(IllegalArgumentException.class, () -> {
-            new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null, DenseVectorFieldMapper.ElementType.FLOAT);
+            new MinimalServiceSettings(TaskType.SPARSE_EMBEDDING, null, null, DenseVectorFieldMapper.ElementType.FLOAT);
         });
         assertThat(ex.getMessage(), containsString("[element_type] is not allowed"));
 
         ex = expectThrows(IllegalArgumentException.class, () -> {
-            new SemanticTextField.ModelSettings(
-                TaskType.TEXT_EMBEDDING,
-                null,
-                SimilarityMeasure.COSINE,
-                DenseVectorFieldMapper.ElementType.FLOAT
-            );
+            new MinimalServiceSettings(TaskType.TEXT_EMBEDDING, null, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT);
         });
         assertThat(ex.getMessage(), containsString("required [dimensions] field is missing"));
 
         ex = expectThrows(IllegalArgumentException.class, () -> {
-            new SemanticTextField.ModelSettings(TaskType.TEXT_EMBEDDING, 10, null, DenseVectorFieldMapper.ElementType.FLOAT);
+            new MinimalServiceSettings(TaskType.TEXT_EMBEDDING, 10, null, DenseVectorFieldMapper.ElementType.FLOAT);
         });
         assertThat(ex.getMessage(), containsString("required [similarity] field is missing"));
 
         ex = expectThrows(IllegalArgumentException.class, () -> {
-            new SemanticTextField.ModelSettings(TaskType.TEXT_EMBEDDING, 10, SimilarityMeasure.COSINE, null);
+            new MinimalServiceSettings(TaskType.TEXT_EMBEDDING, 10, SimilarityMeasure.COSINE, null);
         });
         assertThat(ex.getMessage(), containsString("required [element_type] field is missing"));
     }
@@ -285,7 +270,7 @@ public class SemanticTextFieldTests extends AbstractXContentTestCase<SemanticTex
             useLegacyFormat ? inputs : null,
             new SemanticTextField.InferenceResult(
                 model.getInferenceEntityId(),
-                new SemanticTextField.ModelSettings(model),
+                new MinimalServiceSettings(model),
                 Map.of(fieldName, chunks)
             ),
             contentType

+ 4 - 3
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java

@@ -40,6 +40,7 @@ import org.elasticsearch.index.query.QueryRewriteContext;
 import org.elasticsearch.index.query.SearchExecutionContext;
 import org.elasticsearch.index.search.ESToParentBlockJoinQuery;
 import org.elasticsearch.inference.InputType;
+import org.elasticsearch.inference.MinimalServiceSettings;
 import org.elasticsearch.inference.SimilarityMeasure;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.plugins.Plugin;
@@ -351,10 +352,10 @@ public class SemanticQueryBuilderTests extends AbstractQueryTestCase<SemanticQue
         DenseVectorFieldMapper.ElementType denseVectorElementType,
         boolean useLegacyFormat
     ) throws IOException {
-        SemanticTextField.ModelSettings modelSettings = switch (inferenceResultType) {
+        var modelSettings = switch (inferenceResultType) {
             case NONE -> null;
-            case SPARSE_EMBEDDING -> new SemanticTextField.ModelSettings(TaskType.SPARSE_EMBEDDING, null, null, null);
-            case TEXT_EMBEDDING -> new SemanticTextField.ModelSettings(
+            case SPARSE_EMBEDDING -> new MinimalServiceSettings(TaskType.SPARSE_EMBEDDING, null, null, null);
+            case TEXT_EMBEDDING -> new MinimalServiceSettings(
                 TaskType.TEXT_EMBEDDING,
                 TEXT_EMBEDDING_DIMENSION_COUNT,
                 SimilarityMeasure.COSINE,

+ 37 - 8
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java

@@ -22,7 +22,10 @@ import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.index.engine.VersionConflictEngineException;
+import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
 import org.elasticsearch.inference.InferenceService;
+import org.elasticsearch.inference.MinimalServiceSettings;
+import org.elasticsearch.inference.SimilarityMeasure;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.inference.UnparsedModel;
 import org.elasticsearch.search.SearchHit;
@@ -294,8 +297,12 @@ public class ModelRegistryTests extends ESTestCase {
 
     public void testIdMatchedDefault() {
         var defaultConfigIds = new ArrayList<InferenceService.DefaultConfigId>();
-        defaultConfigIds.add(new InferenceService.DefaultConfigId("foo", TaskType.SPARSE_EMBEDDING, mock(InferenceService.class)));
-        defaultConfigIds.add(new InferenceService.DefaultConfigId("bar", TaskType.SPARSE_EMBEDDING, mock(InferenceService.class)));
+        defaultConfigIds.add(
+            new InferenceService.DefaultConfigId("foo", MinimalServiceSettings.sparseEmbedding(), mock(InferenceService.class))
+        );
+        defaultConfigIds.add(
+            new InferenceService.DefaultConfigId("bar", MinimalServiceSettings.sparseEmbedding(), mock(InferenceService.class))
+        );
 
         var matched = ModelRegistry.idMatchedDefault("bar", defaultConfigIds);
         assertEquals(defaultConfigIds.get(1), matched.get());
@@ -305,10 +312,20 @@ public class ModelRegistryTests extends ESTestCase {
 
     public void testTaskTypeMatchedDefaults() {
         var defaultConfigIds = new ArrayList<InferenceService.DefaultConfigId>();
-        defaultConfigIds.add(new InferenceService.DefaultConfigId("s1", TaskType.SPARSE_EMBEDDING, mock(InferenceService.class)));
-        defaultConfigIds.add(new InferenceService.DefaultConfigId("s2", TaskType.SPARSE_EMBEDDING, mock(InferenceService.class)));
-        defaultConfigIds.add(new InferenceService.DefaultConfigId("d1", TaskType.TEXT_EMBEDDING, mock(InferenceService.class)));
-        defaultConfigIds.add(new InferenceService.DefaultConfigId("c1", TaskType.COMPLETION, mock(InferenceService.class)));
+        defaultConfigIds.add(
+            new InferenceService.DefaultConfigId("s1", MinimalServiceSettings.sparseEmbedding(), mock(InferenceService.class))
+        );
+        defaultConfigIds.add(
+            new InferenceService.DefaultConfigId("s2", MinimalServiceSettings.sparseEmbedding(), mock(InferenceService.class))
+        );
+        defaultConfigIds.add(
+            new InferenceService.DefaultConfigId(
+                "d1",
+                MinimalServiceSettings.textEmbedding(384, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT),
+                mock(InferenceService.class)
+            )
+        );
+        defaultConfigIds.add(new InferenceService.DefaultConfigId("c1", MinimalServiceSettings.completion(), mock(InferenceService.class)));
 
         var matched = ModelRegistry.taskTypeMatchedDefaults(TaskType.SPARSE_EMBEDDING, defaultConfigIds);
         assertThat(matched, contains(defaultConfigIds.get(0), defaultConfigIds.get(1)));
@@ -328,10 +345,10 @@ public class ModelRegistryTests extends ESTestCase {
         var mockServiceB = mock(InferenceService.class);
         when(mockServiceB.name()).thenReturn("service-b");
 
-        registry.addDefaultIds(new InferenceService.DefaultConfigId(id, randomFrom(TaskType.values()), mockServiceA));
+        registry.addDefaultIds(new InferenceService.DefaultConfigId(id, randomMinimalServiceSettings(), mockServiceA));
         var ise = expectThrows(
             IllegalStateException.class,
-            () -> registry.addDefaultIds(new InferenceService.DefaultConfigId(id, randomFrom(TaskType.values()), mockServiceB))
+            () -> registry.addDefaultIds(new InferenceService.DefaultConfigId(id, randomMinimalServiceSettings(), mockServiceB))
         );
         assertThat(
             ise.getMessage(),
@@ -385,4 +402,16 @@ public class ModelRegistryTests extends ESTestCase {
 
         return searchResponse;
     }
+
+    public static MinimalServiceSettings randomMinimalServiceSettings() {
+        TaskType type = randomFrom(TaskType.values());
+        if (type == TaskType.TEXT_EMBEDDING) {
+            return MinimalServiceSettings.textEmbedding(
+                randomIntBetween(2, 384),
+                randomFrom(SimilarityMeasure.values()),
+                randomFrom(DenseVectorFieldMapper.ElementType.values())
+            );
+        }
+        return new MinimalServiceSettings(type, null, null, null);
+    }
 }