瀏覽代碼

[ML] Custom service add support for input_type, top_n, and return_documents (#129441)

* Making progress on different request parameters

* Working tests

* Adding custom service validator for rerank

* Fixing embedding bug

* Adding transport version check

* Fixing tests

* Fixing license header

* Fixing writeTo

* Moving file and removing commented code

* Fixing test

* Fixing tests

* Refactoring and tests

* Fixing test
Jonathan Buttner 4 月之前
父節點
當前提交
d9b34d43a5
共有 38 個文件被更改,包括 1372 次插入125 次删除
  1. 1 0
      server/src/main/java/module-info.java
  2. 2 0
      server/src/main/java/org/elasticsearch/TransportVersions.java
  3. 11 0
      server/src/main/java/org/elasticsearch/inference/InferenceService.java
  4. 77 0
      server/src/main/java/org/elasticsearch/inference/InputType.java
  5. 6 4
      server/src/main/java/org/elasticsearch/inference/validation/ServiceIntegrationValidator.java
  6. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java
  7. 12 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
  8. 12 12
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java
  9. 37 9
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java
  10. 40 5
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java
  11. 121 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/InputTypeTranslator.java
  12. 40 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CompletionParameters.java
  13. 2 19
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequest.java
  14. 50 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParameters.java
  15. 38 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/RequestParameters.java
  16. 54 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/RerankParameters.java
  17. 1 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidator.java
  18. 68 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/CustomServiceIntegrationValidator.java
  19. 1 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java
  20. 24 8
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java
  21. 1 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidator.java
  22. 1 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidator.java
  23. 1 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java
  24. 1 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java
  25. 131 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java
  26. 0 28
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java
  27. 59 2
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java
  28. 3 2
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java
  29. 269 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/InputTypeTranslatorTests.java
  30. 35 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CompletionParametersTests.java
  31. 102 12
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java
  32. 39 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParametersTests.java
  33. 33 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/RerankParametersTests.java
  34. 27 21
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java
  35. 1 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidatorTests.java
  36. 69 2
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java
  37. 1 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidatorTests.java
  38. 1 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java

+ 1 - 0
server/src/main/java/module-info.java

@@ -475,6 +475,7 @@ module org.elasticsearch.server {
             org.elasticsearch.serverless.apifiltering;
     exports org.elasticsearch.lucene.spatial;
     exports org.elasticsearch.inference.configuration;
+    exports org.elasticsearch.inference.validation;
     exports org.elasticsearch.monitor.metrics;
     exports org.elasticsearch.plugins.internal.rewriter to org.elasticsearch.inference;
     exports org.elasticsearch.lucene.util.automaton;

+ 2 - 0
server/src/main/java/org/elasticsearch/TransportVersions.java

@@ -200,6 +200,7 @@ public class TransportVersions {
     public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19 = def(8_841_0_52);
     public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19 = def(8_841_0_53);
     public static final TransportVersion STREAMS_LOGS_SUPPORT_8_19 = def(8_841_0_54);
+    public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_INPUT_TYPE_8_19 = def(8_841_0_55);
 
     public static final TransportVersion V_9_0_0 = def(9_000_0_09);
     public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
@@ -308,6 +309,7 @@ public class TransportVersions {
     public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING = def(9_102_0_00);
     public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE = def(9_103_0_00);
     public static final TransportVersion STREAMS_LOGS_SUPPORT = def(9_104_0_00);
+    public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_INPUT_TYPE = def(9_105_0_00);
 
     /*
      * STOP! READ THIS FIRST! No, really,

+ 11 - 0
server/src/main/java/org/elasticsearch/inference/InferenceService.java

@@ -14,6 +14,7 @@ import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.client.internal.Client;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.validation.ServiceIntegrationValidator;
 
 import java.io.Closeable;
 import java.util.EnumSet;
@@ -248,4 +249,14 @@ public interface InferenceService extends Closeable {
      * after ensuring the node's internals are set up (for example if this ensures the internal ES client is ready for use).
      */
     default void onNodeStarted() {}
+
+    /**
+     * Get the service integration validator for the given task type.
+     * This allows services to provide custom validation logic.
+     * @param taskType The task type
+     * @return The service integration validator or null if the default should be used
+     */
+    default ServiceIntegrationValidator getServiceIntegrationValidator(TaskType taskType) {
+        return null;
+    }
 }

+ 77 - 0
server/src/main/java/org/elasticsearch/inference/InputType.java

@@ -10,8 +10,12 @@
 package org.elasticsearch.inference;
 
 import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.ValidationException;
 
+import java.util.EnumSet;
+import java.util.HashMap;
 import java.util.Locale;
+import java.util.Map;
 
 import static org.elasticsearch.core.Strings.format;
 
@@ -29,6 +33,13 @@ public enum InputType {
     INTERNAL_SEARCH,
     INTERNAL_INGEST;
 
+    private static final EnumSet<InputType> SUPPORTED_REQUEST_VALUES = EnumSet.of(
+        InputType.CLASSIFICATION,
+        InputType.CLUSTERING,
+        InputType.INGEST,
+        InputType.SEARCH
+    );
+
     @Override
     public String toString() {
         return name().toLowerCase(Locale.ROOT);
@@ -57,4 +68,70 @@ public enum InputType {
     public static String invalidInputTypeMessage(InputType inputType) {
         return Strings.format("received invalid input type value [%s]", inputType.toString());
     }
+
+    /**
+     * Ensures that a map used for translating input types is valid. The keys of the map are the external representation,
+     * and the values correspond to the values in this class.
+     * Throws a {@link ValidationException} if any value is not a valid InputType.
+     *
+     * @param inputTypeTranslation the map of input type translations to validate
+     * @param validationException  a ValidationException to which errors will be added
+     */
+    public static Map<InputType, String> validateInputTypeTranslationValues(
+        Map<String, Object> inputTypeTranslation,
+        ValidationException validationException
+    ) {
+        if (inputTypeTranslation == null || inputTypeTranslation.isEmpty()) {
+            return Map.of();
+        }
+
+        var translationMap = new HashMap<InputType, String>();
+
+        for (var entry : inputTypeTranslation.entrySet()) {
+            var key = entry.getKey();
+            var value = entry.getValue();
+
+            if (value instanceof String == false || Strings.isNullOrEmpty((String) value)) {
+                validationException.addValidationError(
+                    Strings.format(
+                        "Input type translation value for key [%s] must be a String that is "
+                            + "not null and not empty, received: [%s], type: [%s].",
+                        key,
+                        value,
+                        value == null ? "null" : value.getClass().getSimpleName()
+                    )
+                );
+
+                throw validationException;
+            }
+
+            try {
+                var inputTypeKey = InputType.fromStringValidateSupportedRequestValue(key);
+                translationMap.put(inputTypeKey, (String) value);
+            } catch (Exception e) {
+                validationException.addValidationError(
+                    Strings.format(
+                        "Invalid input type translation for key: [%s], is not a valid value. Must be one of %s",
+                        key,
+                        SUPPORTED_REQUEST_VALUES
+                    )
+                );
+
+                throw validationException;
+            }
+        }
+
+        return translationMap;
+    }
+
+    private static InputType fromStringValidateSupportedRequestValue(String name) {
+        var inputType = fromRestString(name);
+        if (SUPPORTED_REQUEST_VALUES.contains(inputType) == false) {
+            throw new IllegalArgumentException(
+                format("Unrecognized input_type [%s], must be one of %s", inputType, SUPPORTED_REQUEST_VALUES)
+            );
+        }
+
+        return inputType;
+    }
 }

+ 6 - 4
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ServiceIntegrationValidator.java → server/src/main/java/org/elasticsearch/inference/validation/ServiceIntegrationValidator.java

@@ -1,11 +1,13 @@
 /*
  * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
+ * or more contributor license agreements. Licensed under the "Elastic License
+ * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
+ * Public License v 1"; you may not use this file except in compliance with, at
+ * your election, the "Elastic License 2.0", the "GNU Affero General Public
+ * License v3.0 only", or the "Server Side Public License, v 1".
  */
 
-package org.elasticsearch.xpack.inference.services.validation;
+package org.elasticsearch.inference.validation;
 
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.core.TimeValue;

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java

@@ -216,7 +216,7 @@ public class TransportPutInferenceModelAction extends TransportMasterNodeAction<
             if (skipValidationAndStart) {
                 storeModelListener.onResponse(model);
             } else {
-                ModelValidatorBuilder.buildModelValidator(model.getTaskType(), service instanceof ElasticsearchInternalService)
+                ModelValidatorBuilder.buildModelValidator(model.getTaskType(), service)
                     .validate(service, model, timeout, storeModelListener);
             }
         });

+ 12 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

@@ -399,6 +399,18 @@ public final class ServiceUtils {
         return requiredField;
     }
 
+    public static String extractOptionalEmptyString(Map<String, Object> map, String settingName, ValidationException validationException) {
+        int initialValidationErrorCount = validationException.validationErrors().size();
+        String optionalField = ServiceUtils.removeAsType(map, settingName, String.class, validationException);
+
+        if (validationException.validationErrors().size() > initialValidationErrorCount) {
+            // new validation error occurred
+            return null;
+        }
+
+        return optionalField;
+    }
+
     public static String extractOptionalString(
         Map<String, Object> map,
         String settingName,

+ 12 - 12
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java

@@ -23,10 +23,13 @@ import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
 import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest;
 import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
 import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
+import org.elasticsearch.xpack.inference.services.custom.request.CompletionParameters;
 import org.elasticsearch.xpack.inference.services.custom.request.CustomRequest;
+import org.elasticsearch.xpack.inference.services.custom.request.EmbeddingParameters;
+import org.elasticsearch.xpack.inference.services.custom.request.RequestParameters;
+import org.elasticsearch.xpack.inference.services.custom.request.RerankParameters;
 import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseEntity;
 
-import java.util.List;
 import java.util.Objects;
 import java.util.function.Supplier;
 
@@ -65,19 +68,16 @@ public class CustomRequestManager extends BaseRequestManager {
         Supplier<Boolean> hasRequestCompletedFunction,
         ActionListener<InferenceServiceResults> listener
     ) {
-        String query;
-        List<String> input;
+        RequestParameters requestParameters;
         if (inferenceInputs instanceof QueryAndDocsInputs) {
-            QueryAndDocsInputs queryAndDocsInputs = QueryAndDocsInputs.of(inferenceInputs);
-            query = queryAndDocsInputs.getQuery();
-            input = queryAndDocsInputs.getChunks();
+            requestParameters = RerankParameters.of(QueryAndDocsInputs.of(inferenceInputs));
         } else if (inferenceInputs instanceof ChatCompletionInput chatInputs) {
-            query = null;
-            input = chatInputs.getInputs();
+            requestParameters = CompletionParameters.of(chatInputs);
         } else if (inferenceInputs instanceof EmbeddingsInput) {
-            EmbeddingsInput embeddingsInput = EmbeddingsInput.of(inferenceInputs);
-            query = null;
-            input = embeddingsInput.getStringInputs();
+            requestParameters = EmbeddingParameters.of(
+                EmbeddingsInput.of(inferenceInputs),
+                model.getServiceSettings().getInputTypeTranslator()
+            );
         } else {
             listener.onFailure(
                 new ElasticsearchStatusException(
@@ -89,7 +89,7 @@ public class CustomRequestManager extends BaseRequestManager {
         }
 
         try {
-            var request = new CustomRequest(query, input, model);
+            var request = new CustomRequest(requestParameters, model);
             execute(new ExecutableInferenceRequest(requestSender, logger, request, handler, hasRequestCompletedFunction, listener));
         } catch (Exception e) {
             // Intentionally not logging this exception because it could contain sensitive information from the CustomRequest construction

+ 37 - 9
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java

@@ -27,19 +27,26 @@ import org.elasticsearch.inference.ModelSecrets;
 import org.elasticsearch.inference.SettingsConfiguration;
 import org.elasticsearch.inference.SimilarityMeasure;
 import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.inference.validation.ServiceIntegrationValidator;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
 import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
 import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
+import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
 import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
 import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
+import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
 import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
 import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
 import org.elasticsearch.xpack.inference.services.SenderService;
 import org.elasticsearch.xpack.inference.services.ServiceComponents;
-import org.elasticsearch.xpack.inference.services.ServiceUtils;
+import org.elasticsearch.xpack.inference.services.custom.request.CompletionParameters;
 import org.elasticsearch.xpack.inference.services.custom.request.CustomRequest;
+import org.elasticsearch.xpack.inference.services.custom.request.EmbeddingParameters;
+import org.elasticsearch.xpack.inference.services.custom.request.RequestParameters;
+import org.elasticsearch.xpack.inference.services.custom.request.RerankParameters;
+import org.elasticsearch.xpack.inference.services.validation.CustomServiceIntegrationValidator;
 
 import java.util.EnumSet;
 import java.util.HashMap;
@@ -115,13 +122,8 @@ public class CustomService extends SenderService {
      * This does some initial validation with mock inputs to determine if any templates are missing a field to fill them.
      */
     private static void validateConfiguration(CustomModel model) {
-        String query = null;
-        if (model.getTaskType() == TaskType.RERANK) {
-            query = "test query";
-        }
-
         try {
-            new CustomRequest(query, List.of("test input"), model).createHttpRequest();
+            new CustomRequest(createParameters(model), model).createHttpRequest();
         } catch (IllegalStateException e) {
             var validationException = new ValidationException();
             validationException.addValidationError(Strings.format("Failed to validate model configuration: %s", e.getMessage()));
@@ -129,6 +131,20 @@ public class CustomService extends SenderService {
         }
     }
 
+    private static RequestParameters createParameters(CustomModel model) {
+        return switch (model.getTaskType()) {
+            case RERANK -> RerankParameters.of(new QueryAndDocsInputs("test query", List.of("test input")));
+            case COMPLETION -> CompletionParameters.of(new ChatCompletionInput(List.of("test input")));
+            case TEXT_EMBEDDING, SPARSE_EMBEDDING -> EmbeddingParameters.of(
+                new EmbeddingsInput(List.of("test input"), null, null),
+                model.getServiceSettings().getInputTypeTranslator()
+            );
+            default -> throw new IllegalStateException(
+                Strings.format("Unsupported task type [%s] for custom service", model.getTaskType())
+            );
+        };
+    }
+
     private static ChunkingSettings extractChunkingSettings(Map<String, Object> config, TaskType taskType) {
         if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
             return ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
@@ -257,7 +273,8 @@ public class CustomService extends SenderService {
 
     @Override
     protected void validateInputType(InputType inputType, Model model, ValidationException validationException) {
-        ServiceUtils.validateInputTypeIsUnspecifiedOrInternal(inputType, validationException);
+        // The custom service doesn't do any validation for the input type because if the input type is supported a default
+        // must be supplied within the service settings.
     }
 
     @Override
@@ -327,7 +344,9 @@ public class CustomService extends SenderService {
             serviceSettings.getQueryParameters(),
             serviceSettings.getRequestContentString(),
             serviceSettings.getResponseJsonParser(),
-            serviceSettings.rateLimitSettings()
+            serviceSettings.rateLimitSettings(),
+            serviceSettings.getBatchSize(),
+            serviceSettings.getInputTypeTranslator()
         );
     }
 
@@ -353,4 +372,13 @@ public class CustomService extends SenderService {
             }
         );
     }
+
+    @Override
+    public ServiceIntegrationValidator getServiceIntegrationValidator(TaskType taskType) {
+        if (taskType == TaskType.RERANK) {
+            return new CustomServiceIntegrationValidator();
+        }
+
+        return null;
+    }
 }

+ 40 - 5
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java

@@ -110,6 +110,7 @@ public class CustomServiceSettings extends FilteredXContentObject implements Ser
             context
         );
 
+        var inputTypeTranslator = InputTypeTranslator.fromMap(map, validationException, CustomService.NAME);
         var batchSize = extractOptionalPositiveInteger(map, BATCH_SIZE, ModelConfigurations.SERVICE_SETTINGS, validationException);
 
         if (responseParserMap == null || jsonParserMap == null) {
@@ -131,7 +132,8 @@ public class CustomServiceSettings extends FilteredXContentObject implements Ser
             requestContentString,
             responseJsonParser,
             rateLimitSettings,
-            batchSize
+            batchSize,
+            inputTypeTranslator
         );
     }
 
@@ -203,6 +205,7 @@ public class CustomServiceSettings extends FilteredXContentObject implements Ser
     private final CustomResponseParser responseJsonParser;
     private final RateLimitSettings rateLimitSettings;
     private final int batchSize;
+    private final InputTypeTranslator inputTypeTranslator;
 
     public CustomServiceSettings(
         TextEmbeddingSettings textEmbeddingSettings,
@@ -213,7 +216,17 @@ public class CustomServiceSettings extends FilteredXContentObject implements Ser
         CustomResponseParser responseJsonParser,
         @Nullable RateLimitSettings rateLimitSettings
     ) {
-        this(textEmbeddingSettings, url, headers, queryParameters, requestContentString, responseJsonParser, rateLimitSettings, null);
+        this(
+            textEmbeddingSettings,
+            url,
+            headers,
+            queryParameters,
+            requestContentString,
+            responseJsonParser,
+            rateLimitSettings,
+            null,
+            InputTypeTranslator.EMPTY_TRANSLATOR
+        );
     }
 
     public CustomServiceSettings(
@@ -224,7 +237,8 @@ public class CustomServiceSettings extends FilteredXContentObject implements Ser
         String requestContentString,
         CustomResponseParser responseJsonParser,
         @Nullable RateLimitSettings rateLimitSettings,
-        @Nullable Integer batchSize
+        @Nullable Integer batchSize,
+        InputTypeTranslator inputTypeTranslator
     ) {
         this.textEmbeddingSettings = Objects.requireNonNull(textEmbeddingSettings);
         this.url = Objects.requireNonNull(url);
@@ -234,6 +248,7 @@ public class CustomServiceSettings extends FilteredXContentObject implements Ser
         this.responseJsonParser = Objects.requireNonNull(responseJsonParser);
         this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
         this.batchSize = Objects.requireNonNullElse(batchSize, DEFAULT_EMBEDDING_BATCH_SIZE);
+        this.inputTypeTranslator = Objects.requireNonNull(inputTypeTranslator);
     }
 
     public CustomServiceSettings(StreamInput in) throws IOException {
@@ -258,6 +273,13 @@ public class CustomServiceSettings extends FilteredXContentObject implements Ser
         } else {
             batchSize = DEFAULT_EMBEDDING_BATCH_SIZE;
         }
+
+        if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_INPUT_TYPE)
+            || in.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_INPUT_TYPE_8_19)) {
+            inputTypeTranslator = new InputTypeTranslator(in);
+        } else {
+            inputTypeTranslator = InputTypeTranslator.EMPTY_TRANSLATOR;
+        }
     }
 
     @Override
@@ -305,6 +327,10 @@ public class CustomServiceSettings extends FilteredXContentObject implements Ser
         return responseJsonParser;
     }
 
+    public InputTypeTranslator getInputTypeTranslator() {
+        return inputTypeTranslator;
+    }
+
     public int getBatchSize() {
         return batchSize;
     }
@@ -352,6 +378,8 @@ public class CustomServiceSettings extends FilteredXContentObject implements Ser
         }
         builder.endObject();
 
+        inputTypeTranslator.toXContent(builder, params);
+
         rateLimitSettings.toXContent(builder, params);
 
         builder.field(BATCH_SIZE, batchSize);
@@ -390,6 +418,11 @@ public class CustomServiceSettings extends FilteredXContentObject implements Ser
             || out.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19)) {
             out.writeVInt(batchSize);
         }
+
+        if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_INPUT_TYPE)
+            || out.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_INPUT_TYPE_8_19)) {
+            inputTypeTranslator.writeTo(out);
+        }
     }
 
     @Override
@@ -404,7 +437,8 @@ public class CustomServiceSettings extends FilteredXContentObject implements Ser
             && Objects.equals(requestContentString, that.requestContentString)
             && Objects.equals(responseJsonParser, that.responseJsonParser)
             && Objects.equals(rateLimitSettings, that.rateLimitSettings)
-            && Objects.equals(batchSize, that.batchSize);
+            && Objects.equals(batchSize, that.batchSize)
+            && Objects.equals(inputTypeTranslator, that.inputTypeTranslator);
     }
 
     @Override
@@ -417,7 +451,8 @@ public class CustomServiceSettings extends FilteredXContentObject implements Ser
             requestContentString,
             responseJsonParser,
             rateLimitSettings,
-            batchSize
+            batchSize,
+            inputTypeTranslator
         );
     }
 

+ 121 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/InputTypeTranslator.java

@@ -0,0 +1,121 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom;
+
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.xcontent.ToXContentFragment;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+import java.util.TreeMap;
+
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEmptyString;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMap;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
+
+public class InputTypeTranslator implements ToXContentFragment, Writeable {
+    public static final String INPUT_TYPE_TRANSLATOR = "input_type";
+    public static final String TRANSLATION = "translation";
+    public static final String DEFAULT = "default";
+    public static final InputTypeTranslator EMPTY_TRANSLATOR = new InputTypeTranslator(null, null);
+
+    public static InputTypeTranslator fromMap(Map<String, Object> map, ValidationException validationException, String serviceName) {
+        if (map == null || map.isEmpty()) {
+            return EMPTY_TRANSLATOR;
+        }
+
+        Map<String, Object> inputTypeTranslation = Objects.requireNonNullElse(
+            extractOptionalMap(map, INPUT_TYPE_TRANSLATOR, ModelConfigurations.SERVICE_SETTINGS, validationException),
+            new HashMap<>(Map.of())
+        );
+
+        Map<String, Object> translationMap = extractOptionalMap(
+            inputTypeTranslation,
+            TRANSLATION,
+            INPUT_TYPE_TRANSLATOR,
+            validationException
+        );
+
+        var validatedTranslation = InputType.validateInputTypeTranslationValues(translationMap, validationException);
+
+        var defaultValue = extractOptionalEmptyString(inputTypeTranslation, DEFAULT, validationException);
+
+        throwIfNotEmptyMap(inputTypeTranslation, INPUT_TYPE_TRANSLATOR, "input_type_translator");
+
+        if (validationException.validationErrors().isEmpty() == false) {
+            throw validationException;
+        }
+
+        return new InputTypeTranslator(validatedTranslation, defaultValue);
+    }
+
+    private final Map<InputType, String> inputTypeTranslation;
+    private final String defaultValue;
+
+    public InputTypeTranslator(@Nullable Map<InputType, String> inputTypeTranslation, @Nullable String defaultValue) {
+        this.inputTypeTranslation = Objects.requireNonNullElse(inputTypeTranslation, Map.of());
+        this.defaultValue = Objects.requireNonNullElse(defaultValue, "");
+    }
+
+    public InputTypeTranslator(StreamInput in) throws IOException {
+        this.inputTypeTranslation = in.readImmutableMap(keyReader -> keyReader.readEnum(InputType.class), StreamInput::readString);
+        this.defaultValue = in.readString();
+    }
+
+    public Map<InputType, String> getTranslation() {
+        return inputTypeTranslation;
+    }
+
+    public String getDefaultValue() {
+        return defaultValue;
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        var sortedMap = new TreeMap<>(inputTypeTranslation);
+
+        builder.startObject(INPUT_TYPE_TRANSLATOR);
+        {
+            builder.startObject(TRANSLATION);
+            for (var entry : sortedMap.entrySet()) {
+                builder.field(entry.getKey().toString(), entry.getValue());
+            }
+            builder.endObject();
+            builder.field(DEFAULT, defaultValue);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeMap(inputTypeTranslation, StreamOutput::writeEnum, StreamOutput::writeString);
+        out.writeString(defaultValue);
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (o == null || getClass() != o.getClass()) return false;
+        InputTypeTranslator that = (InputTypeTranslator) o;
+        return Objects.equals(inputTypeTranslation, that.inputTypeTranslation) && Objects.equals(defaultValue, that.defaultValue);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(inputTypeTranslation, defaultValue);
+    }
+}

+ 40 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CompletionParameters.java

@@ -0,0 +1,40 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom.request;
+
+import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
+
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.common.JsonUtils.toJson;
+
+public class CompletionParameters extends RequestParameters {
+
+    public static CompletionParameters of(ChatCompletionInput completionInput) {
+        return new CompletionParameters(Objects.requireNonNull(completionInput));
+    }
+
+    private CompletionParameters(ChatCompletionInput completionInput) {
+        super(completionInput.getInputs());
+    }
+
+    @Override
+    public Map<String, String> jsonParameters() {
+        String jsonRep;
+
+        if (inputs.isEmpty() == false) {
+            jsonRep = toJson(inputs.get(0), INPUT);
+        } else {
+            jsonRep = toJson("", INPUT);
+        }
+
+        return Map.of(INPUT, jsonRep);
+    }
+
+}

+ 2 - 19
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequest.java

@@ -14,7 +14,6 @@ import org.apache.http.client.utils.URIBuilder;
 import org.apache.http.entity.StringEntity;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.settings.SecureString;
-import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xpack.inference.common.ValidatingSubstitutor;
 import org.elasticsearch.xpack.inference.external.request.HttpRequest;
@@ -26,7 +25,6 @@ import java.net.URI;
 import java.net.URISyntaxException;
 import java.nio.charset.StandardCharsets;
 import java.util.HashMap;
-import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 
@@ -35,15 +33,13 @@ import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSet
 import static org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings.URL;
 
 public class CustomRequest implements Request {
-    private static final String QUERY = "query";
-    private static final String INPUT = "input";
 
     private final URI uri;
     private final ValidatingSubstitutor jsonPlaceholderReplacer;
     private final ValidatingSubstitutor stringPlaceholderReplacer;
     private final CustomModel model;
 
-    public CustomRequest(String query, List<String> input, CustomModel model) {
+    public CustomRequest(RequestParameters requestParams, CustomModel model) {
         this.model = Objects.requireNonNull(model);
 
         var stringOnlyParams = new HashMap<String, String>();
@@ -54,11 +50,7 @@ public class CustomRequest implements Request {
         addJsonStringParams(jsonParams, model.getSecretSettings().getSecretParameters());
         addJsonStringParams(jsonParams, model.getTaskSettings().getParameters());
 
-        if (query != null) {
-            jsonParams.put(QUERY, toJson(query, QUERY));
-        }
-
-        addInputJsonParam(jsonParams, input, model.getTaskType());
+        jsonParams.putAll(requestParams.jsonParameters());
 
         jsonPlaceholderReplacer = new ValidatingSubstitutor(jsonParams, "${", "}");
         stringPlaceholderReplacer = new ValidatingSubstitutor(stringOnlyParams, "${", "}");
@@ -81,14 +73,6 @@ public class CustomRequest implements Request {
         }
     }
 
-    private static void addInputJsonParam(Map<String, String> jsonParams, List<String> input, TaskType taskType) {
-        if (taskType == TaskType.COMPLETION && input.isEmpty() == false) {
-            jsonParams.put(INPUT, toJson(input.get(0), INPUT));
-        } else {
-            jsonParams.put(INPUT, toJson(input, INPUT));
-        }
-    }
-
     private URI buildUri() {
         var replacedUrl = stringPlaceholderReplacer.replace(model.getServiceSettings().getUrl(), URL);
 
@@ -104,7 +88,6 @@ public class CustomRequest implements Request {
         } catch (URISyntaxException e) {
             throw new IllegalStateException(Strings.format("Failed to build URI, error: %s", e.getMessage()), e);
         }
-
     }
 
     @Override

+ 50 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParameters.java

@@ -0,0 +1,50 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom.request;
+
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
+import org.elasticsearch.xpack.inference.services.custom.InputTypeTranslator;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.common.JsonUtils.toJson;
+
+public class EmbeddingParameters extends RequestParameters {
+    private static final String INPUT_TYPE = "input_type";
+
+    public static EmbeddingParameters of(EmbeddingsInput embeddingsInput, InputTypeTranslator inputTypeTranslator) {
+        return new EmbeddingParameters(Objects.requireNonNull(embeddingsInput), Objects.requireNonNull(inputTypeTranslator));
+    }
+
+    private final InputType inputType;
+    private final InputTypeTranslator translator;
+
+    private EmbeddingParameters(EmbeddingsInput embeddingsInput, InputTypeTranslator translator) {
+        super(embeddingsInput.getStringInputs());
+        this.inputType = embeddingsInput.getInputType();
+        this.translator = translator;
+    }
+
+    @Override
+    protected Map<String, String> taskTypeParameters() {
+        var additionalParameters = new HashMap<String, String>();
+
+        if (inputType != null && translator.getTranslation().containsKey(inputType)) {
+            var inputTypeTranslation = translator.getTranslation().get(inputType);
+
+            additionalParameters.put(INPUT_TYPE, toJson(inputTypeTranslation, INPUT_TYPE));
+        } else {
+            additionalParameters.put(INPUT_TYPE, toJson(translator.getDefaultValue(), INPUT_TYPE));
+        }
+
+        return additionalParameters;
+    }
+}

+ 38 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/RequestParameters.java

@@ -0,0 +1,38 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom.request;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.common.JsonUtils.toJson;
+
+public abstract class RequestParameters {
+
+    public static final String INPUT = "input";
+
+    protected final List<String> inputs;
+
+    public RequestParameters(List<String> inputs) {
+        this.inputs = Objects.requireNonNull(inputs);
+    }
+
+    Map<String, String> jsonParameters() {
+        var additionalParameters = taskTypeParameters();
+        var totalParameters = new HashMap<>(additionalParameters);
+        totalParameters.put(INPUT, toJson(inputs, INPUT));
+
+        return totalParameters;
+    }
+
+    protected Map<String, String> taskTypeParameters() {
+        return Map.of();
+    }
+}

+ 54 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/request/RerankParameters.java

@@ -0,0 +1,54 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom.request;
+
+import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.common.JsonUtils.toJson;
+
+public class RerankParameters extends RequestParameters {
+    private static final String QUERY = "query";
+
+    public static RerankParameters of(QueryAndDocsInputs queryAndDocsInputs) {
+        Objects.requireNonNull(queryAndDocsInputs);
+
+        return new RerankParameters(queryAndDocsInputs);
+    }
+
+    private final QueryAndDocsInputs queryAndDocsInputs;
+
+    private RerankParameters(QueryAndDocsInputs queryAndDocsInputs) {
+        super(queryAndDocsInputs.getChunks());
+        this.queryAndDocsInputs = queryAndDocsInputs;
+    }
+
+    @Override
+    protected Map<String, String> taskTypeParameters() {
+        var additionalParameters = new HashMap<String, String>();
+        additionalParameters.put(QUERY, toJson(queryAndDocsInputs.getQuery(), QUERY));
+        if (queryAndDocsInputs.getTopN() != null) {
+            additionalParameters.put(
+                InferenceAction.Request.TOP_N.getPreferredName(),
+                toJson(queryAndDocsInputs.getTopN(), InferenceAction.Request.TOP_N.getPreferredName())
+            );
+        }
+
+        if (queryAndDocsInputs.getReturnDocuments() != null) {
+            additionalParameters.put(
+                InferenceAction.Request.RETURN_DOCUMENTS.getPreferredName(),
+                toJson(queryAndDocsInputs.getReturnDocuments(), InferenceAction.Request.RETURN_DOCUMENTS.getPreferredName())
+            );
+        }
+        return additionalParameters;
+    }
+}

+ 1 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidator.java

@@ -11,6 +11,7 @@ import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.inference.InferenceService;
 import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.validation.ServiceIntegrationValidator;
 
 public class ChatCompletionModelValidator implements ModelValidator {
 

+ 68 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/CustomServiceIntegrationValidator.java

@@ -0,0 +1,68 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.validation;
+
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.InferenceService;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.inference.validation.ServiceIntegrationValidator;
+import org.elasticsearch.rest.RestStatus;
+
+import java.util.List;
+import java.util.Map;
+
+/**
+ * This class is slightly different from the SimpleServiceIntegrationValidator in that in sends the topN and return documents in the
+ * request. This is necessary because the custom service may require those template to be replaced when building the request. Otherwise,
+ * the request will fail to be constructed because it'll have a template that wasn't replaced.
+ */
+public class CustomServiceIntegrationValidator implements ServiceIntegrationValidator {
+    private static final List<String> TEST_INPUT = List.of("how big");
+    private static final String QUERY = "test query";
+
+    @Override
+    public void validate(InferenceService service, Model model, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
+        service.infer(
+            model,
+            model.getTaskType().equals(TaskType.RERANK) ? QUERY : null,
+            true,
+            1,
+            TEST_INPUT,
+            false,
+            Map.of(),
+            InputType.INTERNAL_INGEST,
+            timeout,
+            ActionListener.wrap(r -> {
+                if (r != null) {
+                    listener.onResponse(r);
+                } else {
+                    listener.onFailure(
+                        new ElasticsearchStatusException(
+                            "Could not complete custom service inference endpoint creation as"
+                                + " validation call to service returned null response.",
+                            RestStatus.BAD_REQUEST
+                        )
+                    );
+                }
+            },
+                e -> listener.onFailure(
+                    new ElasticsearchStatusException(
+                        "Could not complete custom service inference endpoint creation as validation call to service threw an exception.",
+                        RestStatus.BAD_REQUEST,
+                        e
+                    )
+                )
+            )
+        );
+    }
+}

+ 1 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ElasticsearchInternalServiceModelValidator.java

@@ -15,6 +15,7 @@ import org.elasticsearch.inference.InferenceService;
 import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.inference.validation.ServiceIntegrationValidator;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
 import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;

+ 24 - 8
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilder.java

@@ -8,34 +8,50 @@
 package org.elasticsearch.xpack.inference.services.validation;
 
 import org.elasticsearch.core.Strings;
+import org.elasticsearch.inference.InferenceService;
 import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.inference.validation.ServiceIntegrationValidator;
+import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;
+
+import java.util.Objects;
 
 public class ModelValidatorBuilder {
-    public static ModelValidator buildModelValidator(TaskType taskType, boolean isElasticsearchInternalService) {
-        if (isElasticsearchInternalService) {
+    public static ModelValidator buildModelValidator(TaskType taskType, InferenceService service) {
+        if (service instanceof ElasticsearchInternalService) {
             return new ElasticsearchInternalServiceModelValidator(new SimpleServiceIntegrationValidator());
         } else {
-            return buildModelValidatorForTaskType(taskType);
+            return buildModelValidatorForTaskType(taskType, service);
         }
     }
 
-    private static ModelValidator buildModelValidatorForTaskType(TaskType taskType) {
+    private static ModelValidator buildModelValidatorForTaskType(TaskType taskType, InferenceService service) {
         if (taskType == null) {
             throw new IllegalArgumentException("Task type can't be null");
         }
 
+        ServiceIntegrationValidator validatorFromService = null;
+        if (service != null) {
+            validatorFromService = service.getServiceIntegrationValidator(taskType);
+        }
+
         switch (taskType) {
             case TEXT_EMBEDDING -> {
-                return new TextEmbeddingModelValidator(new SimpleServiceIntegrationValidator());
+                return new TextEmbeddingModelValidator(
+                    Objects.requireNonNullElse(validatorFromService, new SimpleServiceIntegrationValidator())
+                );
             }
             case COMPLETION -> {
-                return new ChatCompletionModelValidator(new SimpleServiceIntegrationValidator());
+                return new ChatCompletionModelValidator(
+                    Objects.requireNonNullElse(validatorFromService, new SimpleServiceIntegrationValidator())
+                );
             }
             case CHAT_COMPLETION -> {
-                return new ChatCompletionModelValidator(new SimpleChatCompletionServiceIntegrationValidator());
+                return new ChatCompletionModelValidator(
+                    Objects.requireNonNullElse(validatorFromService, new SimpleChatCompletionServiceIntegrationValidator())
+                );
             }
             case SPARSE_EMBEDDING, RERANK, ANY -> {
-                return new SimpleModelValidator(new SimpleServiceIntegrationValidator());
+                return new SimpleModelValidator(Objects.requireNonNullElse(validatorFromService, new SimpleServiceIntegrationValidator()));
             }
             default -> throw new IllegalArgumentException(Strings.format("Can't validate inference model for task type %s", taskType));
         }

+ 1 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleChatCompletionServiceIntegrationValidator.java

@@ -14,6 +14,7 @@ import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.inference.InferenceService;
 import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.validation.ServiceIntegrationValidator;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
 

+ 1 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidator.java

@@ -11,6 +11,7 @@ import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.inference.InferenceService;
 import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.validation.ServiceIntegrationValidator;
 
 public class SimpleModelValidator implements ModelValidator {
 

+ 1 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java

@@ -16,6 +16,7 @@ import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.inference.InputType;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.inference.validation.ServiceIntegrationValidator;
 import org.elasticsearch.rest.RestStatus;
 
 import java.util.List;

+ 1 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidator.java

@@ -14,6 +14,7 @@ import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.inference.InferenceService;
 import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.validation.ServiceIntegrationValidator;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
 import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;

+ 131 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/InputTypeTests.java

@@ -7,10 +7,13 @@
 
 package org.elasticsearch.xpack.inference;
 
+import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.inference.InputType;
 import org.elasticsearch.test.ESTestCase;
 
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 
 import static org.elasticsearch.core.Strings.format;
 import static org.hamcrest.CoreMatchers.is;
@@ -80,4 +83,132 @@ public class InputTypeTests extends ESTestCase {
 
         assertThat(thrownException.getMessage(), is("No enum constant org.elasticsearch.inference.InputType.FOO"));
     }
+
+    public void testValidateInputTypeTranslationValues() {
+        assertThat(
+            InputType.validateInputTypeTranslationValues(
+                Map.of(
+                    InputType.INGEST.toString(),
+                    "ingest_value",
+                    InputType.SEARCH.toString(),
+                    "search_value",
+                    InputType.CLASSIFICATION.toString(),
+                    "classification_value",
+                    InputType.CLUSTERING.toString(),
+                    "clustering_value"
+                ),
+                new ValidationException()
+            ),
+            is(
+                Map.of(
+                    InputType.INGEST,
+                    "ingest_value",
+                    InputType.SEARCH,
+                    "search_value",
+                    InputType.CLASSIFICATION,
+                    "classification_value",
+                    InputType.CLUSTERING,
+                    "clustering_value"
+                )
+            )
+        );
+    }
+
+    public void testValidateInputTypeTranslationValues_ReturnsEmptyMap_WhenTranslationIsNull() {
+        assertThat(InputType.validateInputTypeTranslationValues(null, new ValidationException()), is(Map.of()));
+    }
+
+    public void testValidateInputTypeTranslationValues_ReturnsEmptyMap_WhenTranslationIsAnEmptyMap() {
+        assertThat(InputType.validateInputTypeTranslationValues(Map.of(), new ValidationException()), is(Map.of()));
+    }
+
+    public void testValidateInputTypeTranslationValues_ThrowsAnException_WhenInputTypeIsUnspecified() {
+        var exception = expectThrows(
+            ValidationException.class,
+            () -> InputType.validateInputTypeTranslationValues(
+                Map.of(InputType.INGEST.toString(), "ingest_value", InputType.UNSPECIFIED.toString(), "unspecified_value"),
+                new ValidationException()
+            )
+        );
+
+        assertThat(
+            exception.getMessage(),
+            is(
+                "Validation Failed: 1: Invalid input type translation for key: [unspecified], is not a valid value. Must be "
+                    + "one of [ingest, search, classification, clustering];"
+            )
+        );
+    }
+
+    public void testValidateInputTypeTranslationValues_ThrowsAnException_WhenInputTypeIsInternal() {
+        var exception = expectThrows(
+            ValidationException.class,
+            () -> InputType.validateInputTypeTranslationValues(
+                Map.of(InputType.INGEST.toString(), "ingest_value", InputType.INTERNAL_INGEST.toString(), "internal_ingest_value"),
+                new ValidationException()
+            )
+        );
+
+        assertThat(
+            exception.getMessage(),
+            is(
+                "Validation Failed: 1: Invalid input type translation for key: [internal_ingest], is not a valid value. Must be "
+                    + "one of [ingest, search, classification, clustering];"
+            )
+        );
+    }
+
+    public void testValidateInputTypeTranslationValues_ThrowsAnException_WhenValueIsNull() {
+        var translation = new HashMap<String, Object>();
+        translation.put(InputType.INGEST.toString(), null);
+
+        var exception = expectThrows(
+            ValidationException.class,
+            () -> InputType.validateInputTypeTranslationValues(translation, new ValidationException())
+        );
+
+        assertThat(
+            exception.getMessage(),
+            is(
+                "Validation Failed: 1: Input type translation value for key [ingest] must be a String that "
+                    + "is not null and not empty, received: [null], type: [null].;"
+            )
+        );
+    }
+
+    public void testValidateInputTypeTranslationValues_ThrowsAnException_WhenValueIsAnEmptyString() {
+        var translation = new HashMap<String, Object>();
+        translation.put(InputType.INGEST.toString(), "");
+
+        var exception = expectThrows(
+            ValidationException.class,
+            () -> InputType.validateInputTypeTranslationValues(translation, new ValidationException())
+        );
+
+        assertThat(
+            exception.getMessage(),
+            is(
+                "Validation Failed: 1: Input type translation value for key [ingest] must be a String that "
+                    + "is not null and not empty, received: [], type: [String].;"
+            )
+        );
+    }
+
+    public void testValidateInputTypeTranslationValues_ThrowsAnException_WhenValueIsNotAString() {
+        var translation = new HashMap<String, Object>();
+        translation.put(InputType.INGEST.toString(), 1);
+
+        var exception = expectThrows(
+            ValidationException.class,
+            () -> InputType.validateInputTypeTranslationValues(translation, new ValidationException())
+        );
+
+        assertThat(
+            exception.getMessage(),
+            is(
+                "Validation Failed: 1: Input type translation value for key [ingest] must be a String that "
+                    + "is not null and not empty, received: [1], type: [Integer].;"
+            )
+        );
+    }
 }

+ 0 - 28
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/AbstractInferenceServiceTests.java

@@ -9,7 +9,6 @@ package org.elasticsearch.xpack.inference.services;
 
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.support.PlainActionFuture;
-import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.Strings;
@@ -460,33 +459,6 @@ public abstract class AbstractInferenceServiceTests extends ESTestCase {
         }
     }
 
-    public void testInfer_ThrowsErrorWhenInputTypeIsSpecified() throws IOException {
-        try (var service = testConfiguration.commonConfig.createService(threadPool, clientManager)) {
-            var listener = new PlainActionFuture<InferenceServiceResults>();
-
-            var exception = expectThrows(
-                ValidationException.class,
-                () -> service.infer(
-                    getInvalidModel("id", "service"),
-                    null,
-                    null,
-                    null,
-                    List.of(""),
-                    false,
-                    new HashMap<>(),
-                    InputType.INGEST,
-                    InferenceAction.Request.DEFAULT_TIMEOUT,
-                    listener
-                )
-            );
-
-            assertThat(
-                exception.getMessage(),
-                is("Validation Failed: 1: Invalid input_type [ingest]. The input_type option is not supported by this service;")
-            );
-        }
-    }
-
     public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException {
         Assume.assumeTrue(testConfiguration.updateModelConfiguration.isEnabled());
 

+ 59 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettingsTests.java

@@ -15,6 +15,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.xcontent.XContentHelper;
 import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
+import org.elasticsearch.inference.InputType;
 import org.elasticsearch.inference.SimilarityMeasure;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.xcontent.XContentBuilder;
@@ -157,7 +158,8 @@ public class CustomServiceSettingsTests extends AbstractBWCWireSerializationTest
                     requestContentString,
                     responseParser,
                     new RateLimitSettings(10_000),
-                    11
+                    11,
+                    InputTypeTranslator.EMPTY_TRANSLATOR
                 )
             )
         );
@@ -580,6 +582,56 @@ public class CustomServiceSettingsTests extends AbstractBWCWireSerializationTest
                         "text_embeddings": "$.result.embeddings[*].embedding"
                     }
                 },
+                "input_type": {
+                    "translation": {},
+                    "default": ""
+                },
+                "rate_limit": {
+                    "requests_per_minute": 10000
+                },
+                "batch_size": 10
+            }
+            """);
+
+        assertThat(xContentResult, is(expected));
+    }
+
+    public void testXContent_WithInputTypeTranslationValues() throws IOException {
+        var entity = new CustomServiceSettings(
+            CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS,
+            "http://www.abc.com",
+            Map.of("key", "value"),
+            null,
+            "string",
+            new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"),
+            null,
+            null,
+            new InputTypeTranslator(Map.of(InputType.SEARCH, "do_search", InputType.INGEST, "do_ingest"), "a_default")
+        );
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        entity.toXContent(builder, null);
+        String xContentResult = Strings.toString(builder);
+
+        var expected = XContentHelper.stripWhitespace("""
+            {
+                "url": "http://www.abc.com",
+                "headers": {
+                    "key": "value"
+                },
+                "request": "string",
+                "response": {
+                    "json_parser": {
+                        "text_embeddings": "$.result.embeddings[*].embedding"
+                    }
+                },
+                "input_type": {
+                    "translation": {
+                        "ingest": "do_ingest",
+                        "search": "do_search"
+                    },
+                    "default": "a_default"
+                },
                 "rate_limit": {
                     "requests_per_minute": 10000
                 },
@@ -599,7 +651,8 @@ public class CustomServiceSettingsTests extends AbstractBWCWireSerializationTest
             "string",
             new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"),
             null,
-            11
+            11,
+            InputTypeTranslator.EMPTY_TRANSLATOR
         );
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
@@ -618,6 +671,10 @@ public class CustomServiceSettingsTests extends AbstractBWCWireSerializationTest
                         "text_embeddings": "$.result.embeddings[*].embedding"
                     }
                 },
+                "input_type": {
+                    "translation": {},
+                    "default": ""
+                },
                 "rate_limit": {
                     "requests_per_minute": 10000
                 },

+ 3 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java

@@ -147,7 +147,7 @@ public class CustomServiceTests extends AbstractInferenceServiceTests {
         assertThat(customModel.getServiceSettings().getResponseJsonParser(), instanceOf(CompletionResponseParser.class));
     }
 
-    private static SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) {
+    public static SenderService createService(ThreadPool threadPool, HttpClientManager clientManager) {
         var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
         return new CustomService(senderFactory, createWithEmptySettings(threadPool));
     }
@@ -278,7 +278,8 @@ public class CustomServiceTests extends AbstractInferenceServiceTests {
                 "{\"input\":${input}}",
                 parser,
                 new RateLimitSettings(10_000),
-                batchSize
+                batchSize,
+                InputTypeTranslator.EMPTY_TRANSLATOR
             ),
             new CustomTaskSettings(Map.of("key", "test_value")),
             new CustomSecretSettings(Map.of("test_key", new SecureString("test_value".toCharArray()))),

+ 269 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/InputTypeTranslatorTests.java

@@ -0,0 +1,269 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom;
+
+import org.elasticsearch.TransportVersion;
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.elasticsearch.core.Tuple.tuple;
+import static org.hamcrest.Matchers.is;
+
+public class InputTypeTranslatorTests extends AbstractBWCWireSerializationTestCase<InputTypeTranslator> {
+    public static InputTypeTranslator createRandom() {
+        Map<InputType, String> translation = randomBoolean()
+            ? randomMap(
+                0,
+                5,
+                () -> tuple(
+                    randomFrom(List.of(InputType.CLASSIFICATION, InputType.CLUSTERING, InputType.INGEST, InputType.SEARCH)),
+                    randomAlphaOfLength(5)
+                )
+            )
+            : Map.of();
+        return new InputTypeTranslator(translation, randomAlphaOfLength(5));
+    }
+
+    public void testFromMap() {
+        var settings = new HashMap<String, Object>(
+            Map.of(
+                InputTypeTranslator.INPUT_TYPE_TRANSLATOR,
+                new HashMap<>(
+                    Map.of(
+                        InputTypeTranslator.TRANSLATION,
+                        new HashMap<>(
+                            Map.of(
+                                "CLASSIFICATION",
+                                "test_value",
+                                "CLUSTERING",
+                                "test_value_2",
+                                "INGEST",
+                                "test_value_3",
+                                "SEARCH",
+                                "test_value_4"
+                            )
+                        ),
+                        InputTypeTranslator.DEFAULT,
+                        "default_value"
+                    )
+                )
+            )
+        );
+
+        assertThat(
+            InputTypeTranslator.fromMap(settings, new ValidationException(), "name"),
+            is(
+                new InputTypeTranslator(
+                    Map.of(
+                        InputType.CLASSIFICATION,
+                        "test_value",
+                        InputType.CLUSTERING,
+                        "test_value_2",
+                        InputType.INGEST,
+                        "test_value_3",
+                        InputType.SEARCH,
+                        "test_value_4"
+                    ),
+                    "default_value"
+                )
+            )
+        );
+    }
+
+    public void testFromMap_Null_EmptyMap_Returns_EmptySettings() {
+        assertThat(InputTypeTranslator.fromMap(null, null, null), is(InputTypeTranslator.EMPTY_TRANSLATOR));
+        assertThat(InputTypeTranslator.fromMap(Map.of(), null, null), is(InputTypeTranslator.EMPTY_TRANSLATOR));
+    }
+
+    public void testFromMap_Throws_IfValueIsNotAString() {
+        var settings = new HashMap<String, Object>(
+            Map.of(
+                InputTypeTranslator.INPUT_TYPE_TRANSLATOR,
+                new HashMap<>(Map.of(InputTypeTranslator.TRANSLATION, new HashMap<>(Map.of("CLASSIFICATION", 12345))))
+            )
+        );
+
+        var exception = expectThrows(
+            ValidationException.class,
+            () -> InputTypeTranslator.fromMap(settings, new ValidationException(), "name")
+        );
+
+        assertThat(
+            exception.getMessage(),
+            is(
+                "Validation Failed: 1: Input type translation value for key [CLASSIFICATION] "
+                    + "must be a String that is not null and not empty, received: [12345], type: [Integer].;"
+            )
+        );
+    }
+
+    public void testFromMap_Throws_IfValueIsEmptyString() {
+        var settings = new HashMap<String, Object>(
+            Map.of(
+                InputTypeTranslator.INPUT_TYPE_TRANSLATOR,
+                new HashMap<>(Map.of(InputTypeTranslator.TRANSLATION, new HashMap<>(Map.of("CLASSIFICATION", ""))))
+            )
+        );
+
+        var exception = expectThrows(
+            ValidationException.class,
+            () -> InputTypeTranslator.fromMap(settings, new ValidationException(), "name")
+        );
+
+        assertThat(
+            exception.getMessage(),
+            is(
+                "Validation Failed: 1: Input type translation value for key [CLASSIFICATION] "
+                    + "must be a String that is not null and not empty, received: [], type: [String].;"
+            )
+        );
+    }
+
+    public void testFromMap_DoesNotThrow_ForAnEmptyDefaultValue() {
+        var settings = new HashMap<String, Object>(
+            Map.of(
+                InputTypeTranslator.INPUT_TYPE_TRANSLATOR,
+                new HashMap<>(
+                    Map.of(
+                        InputTypeTranslator.TRANSLATION,
+                        new HashMap<>(Map.of("CLASSIFICATION", "value")),
+                        InputTypeTranslator.DEFAULT,
+                        ""
+                    )
+                )
+            )
+        );
+
+        var translator = InputTypeTranslator.fromMap(settings, new ValidationException(), "name");
+
+        assertThat(translator, is(new InputTypeTranslator(Map.of(InputType.CLASSIFICATION, "value"), "")));
+    }
+
+    public void testFromMap_Throws_IfKeyIsInvalid() {
+        var settings = new HashMap<String, Object>(
+            Map.of(
+                InputTypeTranslator.INPUT_TYPE_TRANSLATOR,
+                new HashMap<>(
+                    Map.of(
+                        InputTypeTranslator.TRANSLATION,
+                        new HashMap<>(Map.of("CLASSIFICATION", "test_value", "invalid_key", "another_value"))
+                    )
+                )
+            )
+        );
+
+        var exception = expectThrows(
+            ValidationException.class,
+            () -> InputTypeTranslator.fromMap(settings, new ValidationException(), "name")
+        );
+
+        assertThat(
+            exception.getMessage(),
+            is(
+                "Validation Failed: 1: Invalid input type translation for key: [invalid_key]"
+                    + ", is not a valid value. Must be one of [ingest, search, classification, clustering];"
+            )
+        );
+    }
+
+    public void testFromMap_DefaultsToEmptyMap_WhenField_DoesNotExist() {
+        var map = new HashMap<String, Object>(Map.of("key", new HashMap<>(Map.of("test_key", "test_value"))));
+
+        assertThat(InputTypeTranslator.fromMap(map, new ValidationException(), "name"), is(new InputTypeTranslator(Map.of(), null)));
+    }
+
+    public void testXContent() throws IOException {
+        var entity = new InputTypeTranslator(Map.of(InputType.CLASSIFICATION, "test_value"), "default");
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+
+        builder.startObject();
+        entity.toXContent(builder, null);
+        builder.endObject();
+
+        String xContentResult = Strings.toString(builder);
+
+        var expected = XContentHelper.stripWhitespace("""
+            {
+                "input_type": {
+                    "translation": {
+                        "classification": "test_value"
+                    },
+                    "default": "default"
+                }
+            }
+            """);
+
+        assertThat(xContentResult, is(expected));
+    }
+
+    public void testXContent_EmptyTranslator() throws IOException {
+        var entity = new InputTypeTranslator(Map.of(), null);
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+
+        builder.startObject();
+        entity.toXContent(builder, null);
+        builder.endObject();
+
+        String xContentResult = Strings.toString(builder);
+
+        var expected = XContentHelper.stripWhitespace("""
+            {
+                "input_type": {
+                    "translation": {},
+                    "default": ""
+                }
+            }
+            """);
+
+        assertThat(xContentResult, is(expected));
+    }
+
+    @Override
+    protected Writeable.Reader<InputTypeTranslator> instanceReader() {
+        return InputTypeTranslator::new;
+    }
+
+    @Override
+    protected InputTypeTranslator createTestInstance() {
+        return createRandom();
+    }
+
+    @Override
+    protected InputTypeTranslator mutateInstance(InputTypeTranslator instance) throws IOException {
+        return randomValueOtherThan(instance, InputTypeTranslatorTests::createRandom);
+    }
+
+    public static Map<String, Object> getTaskSettingsMap(@Nullable Map<String, Object> parameters) {
+        var map = new HashMap<String, Object>();
+        if (parameters != null) {
+            map.put(CustomTaskSettings.PARAMETERS, parameters);
+        }
+
+        return map;
+    }
+
+    @Override
+    protected InputTypeTranslator mutateInstanceForVersion(InputTypeTranslator instance, TransportVersion version) {
+        return instance;
+    }
+}

+ 35 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CompletionParametersTests.java

@@ -0,0 +1,35 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom.request;
+
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
+
+import java.util.List;
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.services.custom.request.RequestParameters.INPUT;
+import static org.hamcrest.Matchers.is;
+
+public class CompletionParametersTests extends ESTestCase {
+
+    public void testJsonParameters_SingleValue() {
+        var parameters = CompletionParameters.of(new ChatCompletionInput(List.of("hello")));
+        assertThat(parameters.jsonParameters(), is(Map.of(INPUT, "\"hello\"")));
+    }
+
+    public void testJsonParameters_RetrievesFirstEntryFromList() {
+        var parameters = CompletionParameters.of(new ChatCompletionInput(List.of("hello", "hi")));
+        assertThat(parameters.jsonParameters(), is(Map.of(INPUT, "\"hello\"")));
+    }
+
+    public void testJsonParameters_EmptyList() {
+        var parameters = CompletionParameters.of(new ChatCompletionInput(List.of()));
+        assertThat(parameters.jsonParameters(), is(Map.of(INPUT, "\"\"")));
+    }
+}

+ 102 - 12
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java

@@ -14,14 +14,18 @@ import org.elasticsearch.common.io.Streams;
 import org.elasticsearch.common.settings.SecureString;
 import org.elasticsearch.common.xcontent.XContentHelper;
 import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
+import org.elasticsearch.inference.InputType;
 import org.elasticsearch.inference.SimilarityMeasure;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
+import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
 import org.elasticsearch.xpack.inference.services.custom.CustomModelTests;
 import org.elasticsearch.xpack.inference.services.custom.CustomSecretSettings;
 import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings;
 import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings;
+import org.elasticsearch.xpack.inference.services.custom.InputTypeTranslator;
 import org.elasticsearch.xpack.inference.services.custom.QueryParameters;
 import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
 import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
@@ -46,7 +50,8 @@ public class CustomRequestTests extends ESTestCase {
         Map<String, String> headers = Map.of(HttpHeaders.AUTHORIZATION, Strings.format("${api_key}"));
         var requestContentString = """
             {
-                "input": ${input}
+                "input": ${input},
+                "input_type": ${input_type}
             }
             """;
 
@@ -62,7 +67,9 @@ public class CustomRequestTests extends ESTestCase {
             new QueryParameters(List.of(new QueryParameters.Parameter("key", "value"), new QueryParameters.Parameter("key", "value2"))),
             requestContentString,
             new TextEmbeddingResponseParser("$.result.embeddings"),
-            new RateLimitSettings(10_000)
+            new RateLimitSettings(10_000),
+            null,
+            new InputTypeTranslator(Map.of(InputType.INGEST, "value"), "default")
         );
 
         var model = CustomModelTests.createModel(
@@ -73,7 +80,13 @@ public class CustomRequestTests extends ESTestCase {
             new CustomSecretSettings(Map.of("api_key", new SecureString("my-secret-key".toCharArray())))
         );
 
-        var request = new CustomRequest(null, List.of("abc", "123"), model);
+        var request = new CustomRequest(
+            EmbeddingParameters.of(
+                new EmbeddingsInput(List.of("abc", "123"), null, null),
+                model.getServiceSettings().getInputTypeTranslator()
+            ),
+            model
+        );
         var httpRequest = request.createHttpRequest();
         assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
 
@@ -84,18 +97,20 @@ public class CustomRequestTests extends ESTestCase {
 
         var expectedBody = XContentHelper.stripWhitespace("""
             {
-              "input": ["abc", "123"]
+              "input": ["abc", "123"],
+              "input_type": "default"
             }
             """);
 
         assertThat(convertToString(httpPost.getEntity().getContent()), is(expectedBody));
     }
 
-    public void testCreateRequest_QueryParametersAreEscaped_AndEncoded() {
+    public void testCreateRequest_QueryParametersAreEscaped_AndEncoded() throws IOException {
         var inferenceId = "inferenceId";
         var requestContentString = """
             {
-                "input": ${input}
+                "input": ${input},
+                "input_type": ${input_type}
             }
             """;
 
@@ -115,7 +130,9 @@ public class CustomRequestTests extends ESTestCase {
             ),
             requestContentString,
             new TextEmbeddingResponseParser("$.result.embeddings"),
-            new RateLimitSettings(10_000)
+            new RateLimitSettings(10_000),
+            null,
+            new InputTypeTranslator(Map.of(InputType.INGEST, "value"), "default")
         );
 
         var model = CustomModelTests.createModel(
@@ -126,7 +143,13 @@ public class CustomRequestTests extends ESTestCase {
             new CustomSecretSettings(Map.of("api_key", new SecureString("my-secret-key".toCharArray())))
         );
 
-        var request = new CustomRequest(null, List.of("abc", "123"), model);
+        var request = new CustomRequest(
+            EmbeddingParameters.of(
+                new EmbeddingsInput(List.of("abc", "123"), null, InputType.INGEST),
+                model.getServiceSettings().getInputTypeTranslator()
+            ),
+            model
+        );
         var httpRequest = request.createHttpRequest();
         assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
 
@@ -136,6 +159,14 @@ public class CustomRequestTests extends ESTestCase {
             // To visually verify that this is correct, input the query parameters into here: https://www.urldecoder.org/
             is("http://www.elastic.co?key=+%3C%3E%23%25%2B%7B%7D%7C%5C%5E%7E%5B%5D%60%3B%2F%3F%3A%40%3D%26%24&key=%CE%A3+%F0%9F%98%80")
         );
+
+        var expectedBody = XContentHelper.stripWhitespace("""
+            {
+              "input": ["abc", "123"],
+              "input_type": "value"
+            }
+            """);
+        assertThat(convertToString(httpPost.getEntity().getContent()), is(expectedBody));
     }
 
     public void testCreateRequest_SecretsInTheJsonBody_AreEncodedCorrectly() throws IOException {
@@ -173,7 +204,13 @@ public class CustomRequestTests extends ESTestCase {
             new CustomSecretSettings(Map.of("api_key", new SecureString("my-secret-key".toCharArray())))
         );
 
-        var request = new CustomRequest(null, List.of("abc", "123"), model);
+        var request = new CustomRequest(
+            EmbeddingParameters.of(
+                new EmbeddingsInput(List.of("abc", "123"), null, InputType.SEARCH),
+                model.getServiceSettings().getInputTypeTranslator()
+            ),
+            model
+        );
         var httpRequest = request.createHttpRequest();
         assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
 
@@ -220,7 +257,7 @@ public class CustomRequestTests extends ESTestCase {
             new CustomSecretSettings(Map.of("api_key", new SecureString("my-secret-key".toCharArray())))
         );
 
-        var request = new CustomRequest("query string", List.of("abc", "123"), model);
+        var request = new CustomRequest(RerankParameters.of(new QueryAndDocsInputs("query string", List.of("abc", "123"))), model);
         var httpRequest = request.createHttpRequest();
         assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
 
@@ -236,6 +273,56 @@ public class CustomRequestTests extends ESTestCase {
         assertThat(convertToString(httpPost.getEntity().getContent()), is(expectedBody));
     }
 
+    public void testCreateRequest_HandlesQuery_WithReturnDocsAndTopN() throws IOException {
+        var inferenceId = "inference_id";
+        var requestContentString = """
+            {
+                "input": ${input},
+                "query": ${query},
+                "return_documents": ${return_documents},
+                "top_n": ${top_n}
+            }
+            """;
+
+        var serviceSettings = new CustomServiceSettings(
+            CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS,
+            "http://www.elastic.co",
+            null,
+            null,
+            requestContentString,
+            new RerankResponseParser("$.result.score"),
+            new RateLimitSettings(10_000)
+        );
+
+        var model = CustomModelTests.createModel(
+            inferenceId,
+            TaskType.RERANK,
+            serviceSettings,
+            new CustomTaskSettings(Map.of()),
+            new CustomSecretSettings(Map.of("api_key", new SecureString("my-secret-key".toCharArray())))
+        );
+
+        var request = new CustomRequest(
+            RerankParameters.of(new QueryAndDocsInputs("query string", List.of("abc", "123"), false, 2, false)),
+            model
+        );
+        var httpRequest = request.createHttpRequest();
+        assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
+
+        var httpPost = (HttpPost) httpRequest.httpRequestBase();
+
+        var expectedBody = XContentHelper.stripWhitespace("""
+            {
+              "input": ["abc", "123"],
+              "query": "query string",
+              "return_documents": false,
+              "top_n": 2
+            }
+            """);
+
+        assertThat(convertToString(httpPost.getEntity().getContent()), is(expectedBody));
+    }
+
     public void testCreateRequest_IgnoresNonStringFields_ForStringParams() throws IOException {
         var inferenceId = "inference_id";
         var requestContentString = """
@@ -262,7 +349,7 @@ public class CustomRequestTests extends ESTestCase {
             new CustomSecretSettings(Map.of("api_key", new SecureString("my-secret-key".toCharArray())))
         );
 
-        var request = new CustomRequest(null, List.of("abc", "123"), model);
+        var request = new CustomRequest(RerankParameters.of(new QueryAndDocsInputs("query string", List.of("abc", "123"))), model);
         var exception = expectThrows(IllegalStateException.class, request::createHttpRequest);
         assertThat(
             exception.getMessage(),
@@ -299,7 +386,10 @@ public class CustomRequestTests extends ESTestCase {
             new CustomSecretSettings(Map.of("api_key", new SecureString("my-secret-key".toCharArray())))
         );
 
-        var exception = expectThrows(IllegalStateException.class, () -> new CustomRequest(null, List.of("abc", "123"), model));
+        var exception = expectThrows(
+            IllegalStateException.class,
+            () -> new CustomRequest(RerankParameters.of(new QueryAndDocsInputs("query string", List.of("abc", "123"))), model)
+        );
         assertThat(exception.getMessage(), is("Failed to build URI, error: Illegal character in path at index 0: ^"));
     }
 

+ 39 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/EmbeddingParametersTests.java

@@ -0,0 +1,39 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom.request;
+
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
+import org.elasticsearch.xpack.inference.services.custom.InputTypeTranslator;
+
+import java.util.List;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.is;
+
+public class EmbeddingParametersTests extends ESTestCase {
+
+    public void testTaskTypeParameters_UsesDefaultValue() {
+        var parameters = EmbeddingParameters.of(
+            new EmbeddingsInput(List.of("input"), null, InputType.INGEST),
+            new InputTypeTranslator(Map.of(), "default")
+        );
+
+        assertThat(parameters.taskTypeParameters(), is(Map.of("input_type", "\"default\"")));
+    }
+
+    public void testTaskTypeParameters_UsesMappedValue() {
+        var parameters = EmbeddingParameters.of(
+            new EmbeddingsInput(List.of("input"), null, InputType.INGEST),
+            new InputTypeTranslator(Map.of(InputType.INGEST, "ingest_value"), "default")
+        );
+
+        assertThat(parameters.taskTypeParameters(), is(Map.of("input_type", "\"ingest_value\"")));
+    }
+}

+ 33 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/RerankParametersTests.java

@@ -0,0 +1,33 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.custom.request;
+
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
+
+import java.util.List;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.is;
+
+public class RerankParametersTests extends ESTestCase {
+
+    public void testTaskTypeParameters() {
+        var queryAndDocsInputs = new QueryAndDocsInputs("query_value", List.of("doc1", "doc2"), true, 5, false);
+        var parameters = RerankParameters.of(queryAndDocsInputs);
+
+        assertThat(parameters.taskTypeParameters(), is(Map.of("query", "\"query_value\"", "top_n", "5", "return_documents", "true")));
+    }
+
+    public void testTaskTypeParameters_WithoutOptionalFields() {
+        var queryAndDocsInputs = new QueryAndDocsInputs("query_value", List.of("doc1", "doc2"));
+        var parameters = RerankParameters.of(queryAndDocsInputs);
+
+        assertThat(parameters.taskTypeParameters(), is(Map.of("query", "\"query_value\"")));
+    }
+}

+ 27 - 21
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/response/CustomResponseEntityTests.java

@@ -17,8 +17,14 @@ import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
 import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
 import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
 import org.elasticsearch.xpack.inference.external.http.HttpResult;
+import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
+import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
+import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
 import org.elasticsearch.xpack.inference.services.custom.CustomModelTests;
+import org.elasticsearch.xpack.inference.services.custom.request.CompletionParameters;
 import org.elasticsearch.xpack.inference.services.custom.request.CustomRequest;
+import org.elasticsearch.xpack.inference.services.custom.request.EmbeddingParameters;
+import org.elasticsearch.xpack.inference.services.custom.request.RerankParameters;
 
 import java.io.IOException;
 import java.nio.charset.StandardCharsets;
@@ -53,10 +59,13 @@ public class CustomResponseEntityTests extends ESTestCase {
             }
             """;
 
+        var model = CustomModelTests.getTestModel(
+            TaskType.TEXT_EMBEDDING,
+            new TextEmbeddingResponseParser("$.result.embeddings[*].embedding")
+        );
         var request = new CustomRequest(
-            null,
-            List.of("abc"),
-            CustomModelTests.getTestModel(TaskType.TEXT_EMBEDDING, new TextEmbeddingResponseParser("$.result.embeddings[*].embedding"))
+            EmbeddingParameters.of(new EmbeddingsInput(List.of("abc"), null, null), model.getServiceSettings().getInputTypeTranslator()),
+            model
         );
         InferenceServiceResults results = CustomResponseEntity.fromResponse(
             request,
@@ -98,17 +107,18 @@ public class CustomResponseEntityTests extends ESTestCase {
             }
             """;
 
-        var request = new CustomRequest(
-            null,
-            List.of("abc"),
-            CustomModelTests.getTestModel(
-                TaskType.SPARSE_EMBEDDING,
-                new SparseEmbeddingResponseParser(
-                    "$.result.sparse_embeddings[*].embedding[*].tokenId",
-                    "$.result.sparse_embeddings[*].embedding[*].weight"
-                )
+        var model = CustomModelTests.getTestModel(
+            TaskType.SPARSE_EMBEDDING,
+            new SparseEmbeddingResponseParser(
+                "$.result.sparse_embeddings[*].embedding[*].tokenId",
+                "$.result.sparse_embeddings[*].embedding[*].weight"
             )
         );
+        var request = new CustomRequest(
+            EmbeddingParameters.of(new EmbeddingsInput(List.of("abc"), null, null), model.getServiceSettings().getInputTypeTranslator()),
+            model
+
+        );
 
         InferenceServiceResults results = CustomResponseEntity.fromResponse(
             request,
@@ -152,14 +162,11 @@ public class CustomResponseEntityTests extends ESTestCase {
             }
             """;
 
-        var request = new CustomRequest(
-            null,
-            List.of("abc"),
-            CustomModelTests.getTestModel(
-                TaskType.RERANK,
-                new RerankResponseParser("$.result.scores[*].score", "$.result.scores[*].index", null)
-            )
+        var model = CustomModelTests.getTestModel(
+            TaskType.RERANK,
+            new RerankResponseParser("$.result.scores[*].score", "$.result.scores[*].index", null)
         );
+        var request = new CustomRequest(RerankParameters.of(new QueryAndDocsInputs("query", List.of("doc1", "doc2"))), model);
 
         InferenceServiceResults results = CustomResponseEntity.fromResponse(
             request,
@@ -193,8 +200,7 @@ public class CustomResponseEntityTests extends ESTestCase {
             """;
 
         var request = new CustomRequest(
-            null,
-            List.of("abc"),
+            CompletionParameters.of(new ChatCompletionInput(List.of("abc"))),
             CustomModelTests.getTestModel(TaskType.COMPLETION, new CompletionResponseParser("$.result.text"))
         );
 

+ 1 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ChatCompletionModelValidatorTests.java

@@ -13,6 +13,7 @@ import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.inference.InferenceService;
 import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.validation.ServiceIntegrationValidator;
 import org.elasticsearch.test.ESTestCase;
 import org.junit.Before;
 import org.mockito.Mock;

+ 69 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/ModelValidatorBuilderTests.java

@@ -7,21 +7,88 @@
 
 package org.elasticsearch.xpack.inference.services.validation;
 
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.inference.InferenceService;
+import org.elasticsearch.inference.InputType;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
+import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+import org.elasticsearch.xpack.inference.services.custom.CustomModelTests;
+import org.elasticsearch.xpack.inference.services.custom.CustomServiceTests;
+import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
+import org.junit.After;
+import org.junit.Before;
 
+import java.util.List;
 import java.util.Map;
 
+import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
+import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
 import static org.hamcrest.Matchers.isA;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
 
 public class ModelValidatorBuilderTests extends ESTestCase {
+
+    private ThreadPool threadPool;
+    private HttpClientManager clientManager;
+
+    @Override
+    @Before
+    public void setUp() throws Exception {
+        super.setUp();
+        threadPool = createThreadPool(inferenceUtilityPool());
+        clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
+    }
+
+    @Override
+    @After
+    public void tearDown() throws Exception {
+        super.tearDown();
+        clientManager.close();
+        terminate(threadPool);
+    }
+
+    public void testCustomServiceValidator() {
+        var service = CustomServiceTests.createService(threadPool, clientManager);
+        var validator = ModelValidatorBuilder.buildModelValidator(TaskType.RERANK, service);
+        var mockService = mock(InferenceService.class);
+        validator.validate(
+            mockService,
+            CustomModelTests.getTestModel(TaskType.RERANK, new RerankResponseParser("score")),
+            null,
+            ActionListener.noop()
+        );
+
+        verify(mockService, times(1)).infer(
+            any(),
+            eq("test query"),
+            eq(true),
+            eq(1),
+            eq(List.of("how big")),
+            eq(false),
+            eq(Map.of()),
+            eq(InputType.INTERNAL_INGEST),
+            any(),
+            any()
+        );
+        verifyNoMoreInteractions(mockService);
+    }
+
     public void testBuildModelValidator_NullTaskType() {
-        assertThrows(IllegalArgumentException.class, () -> { ModelValidatorBuilder.buildModelValidator(null, false); });
+        assertThrows(IllegalArgumentException.class, () -> { ModelValidatorBuilder.buildModelValidator(null, null); });
     }
 
     public void testBuildModelValidator_ValidTaskType() {
         taskTypeToModelValidatorClassMap().forEach((taskType, modelValidatorClass) -> {
-            assertThat(ModelValidatorBuilder.buildModelValidator(taskType, false), isA(modelValidatorClass));
+            assertThat(ModelValidatorBuilder.buildModelValidator(taskType, null), isA(modelValidatorClass));
         });
     }
 

+ 1 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleModelValidatorTests.java

@@ -13,6 +13,7 @@ import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.inference.InferenceService;
 import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.validation.ServiceIntegrationValidator;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
 import org.junit.Before;

+ 1 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/TextEmbeddingModelValidatorTests.java

@@ -14,6 +14,7 @@ import org.elasticsearch.inference.InferenceService;
 import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.ServiceSettings;
+import org.elasticsearch.inference.validation.ServiceIntegrationValidator;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
 import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;