Browse Source

Adds support for `input_type` field to Vertex inference service (#116431) (#116673)

* Adding input type to google vertex ai service

* Update docs/changelog/116431.yaml

* PR feedback - backwards compatibility

* Fix lint error

(cherry picked from commit 7039a1dc8c886e23fda47a4b38cbab72746ac8cf)
Ying Mao 11 months ago
parent
commit
2ec5299460
18 changed files with 697 additions and 106 deletions
  1. 5 0
      docs/changelog/116431.yaml
  2. 1 0
      server/src/main/java/org/elasticsearch/TransportVersions.java
  3. 4 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googlevertexai/GoogleVertexAiActionCreator.java
  4. 2 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googlevertexai/GoogleVertexAiActionVisitor.java
  5. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiEmbeddingsRequest.java
  6. 33 4
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiEmbeddingsRequestEntity.java
  7. 17 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiModel.java
  8. 2 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java
  9. 43 8
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModel.java
  10. 22 5
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsRequestTaskSettings.java
  11. 88 11
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsTaskSettings.java
  12. 2 7
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/rerank/GoogleVertexAiRerankModel.java
  13. 86 10
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiEmbeddingsRequestEntityTests.java
  14. 30 6
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiEmbeddingsRequestTests.java
  15. 64 26
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java
  16. 101 3
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModelTests.java
  17. 43 2
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsRequestTaskSettingsTests.java
  18. 153 17
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsTaskSettingsTests.java

+ 5 - 0
docs/changelog/116431.yaml

@@ -0,0 +1,5 @@
+pr: 116431
+summary: Adds support for `input_type` field to Vertex inference service
+area: Machine Learning
+type: enhancement
+issues: []

+ 1 - 0
server/src/main/java/org/elasticsearch/TransportVersions.java

@@ -193,6 +193,7 @@ public class TransportVersions {
     public static final TransportVersion ROLE_MONITOR_STATS = def(8_787_00_0);
     public static final TransportVersion DATA_STREAM_INDEX_VERSION_DEPRECATION_CHECK = def(8_788_00_0);
     public static final TransportVersion ADD_COMPATIBILITY_VERSIONS_TO_NODE_INFO = def(8_789_00_0);
+    public static final TransportVersion VERTEX_AI_INPUT_TYPE_ADDED = def(8_790_00_0);
 
     /*
      * STOP! READ THIS FIRST! No, really,

+ 4 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googlevertexai/GoogleVertexAiActionCreator.java

@@ -7,6 +7,7 @@
 
 package org.elasticsearch.xpack.inference.external.action.googlevertexai;
 
+import org.elasticsearch.inference.InputType;
 import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
 import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
 import org.elasticsearch.xpack.inference.external.http.sender.GoogleVertexAiEmbeddingsRequestManager;
@@ -33,9 +34,10 @@ public class GoogleVertexAiActionCreator implements GoogleVertexAiActionVisitor
     }
 
     @Override
-    public ExecutableAction create(GoogleVertexAiEmbeddingsModel model, Map<String, Object> taskSettings) {
+    public ExecutableAction create(GoogleVertexAiEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType) {
+        var overriddenModel = GoogleVertexAiEmbeddingsModel.of(model, taskSettings, inputType);
         var requestManager = new GoogleVertexAiEmbeddingsRequestManager(
-            model,
+            overriddenModel,
             serviceComponents.truncator(),
             serviceComponents.threadPool()
         );

+ 2 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googlevertexai/GoogleVertexAiActionVisitor.java

@@ -7,6 +7,7 @@
 
 package org.elasticsearch.xpack.inference.external.action.googlevertexai;
 
+import org.elasticsearch.inference.InputType;
 import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
 import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
 import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;
@@ -15,7 +16,7 @@ import java.util.Map;
 
 public interface GoogleVertexAiActionVisitor {
 
-    ExecutableAction create(GoogleVertexAiEmbeddingsModel model, Map<String, Object> taskSettings);
+    ExecutableAction create(GoogleVertexAiEmbeddingsModel model, Map<String, Object> taskSettings, InputType inputType);
 
     ExecutableAction create(GoogleVertexAiRerankModel model, Map<String, Object> taskSettings);
 }

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiEmbeddingsRequest.java

@@ -40,7 +40,7 @@ public class GoogleVertexAiEmbeddingsRequest implements GoogleVertexAiRequest {
         HttpPost httpPost = new HttpPost(model.uri());
 
         ByteArrayEntity byteEntity = new ByteArrayEntity(
-            Strings.toString(new GoogleVertexAiEmbeddingsRequestEntity(truncationResult.input(), model.getTaskSettings().autoTruncate()))
+            Strings.toString(new GoogleVertexAiEmbeddingsRequestEntity(truncationResult.input(), model.getTaskSettings()))
                 .getBytes(StandardCharsets.UTF_8)
         );
 

+ 33 - 4
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiEmbeddingsRequestEntity.java

@@ -7,23 +7,35 @@
 
 package org.elasticsearch.xpack.inference.external.request.googlevertexai;
 
-import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.InputType;
 import org.elasticsearch.xcontent.ToXContentObject;
 import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings;
 
 import java.io.IOException;
 import java.util.List;
 import java.util.Objects;
 
-public record GoogleVertexAiEmbeddingsRequestEntity(List<String> inputs, @Nullable Boolean autoTruncation) implements ToXContentObject {
+import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings.invalidInputTypeMessage;
+
+public record GoogleVertexAiEmbeddingsRequestEntity(List<String> inputs, GoogleVertexAiEmbeddingsTaskSettings taskSettings)
+    implements
+        ToXContentObject {
 
     private static final String INSTANCES_FIELD = "instances";
     private static final String CONTENT_FIELD = "content";
     private static final String PARAMETERS_FIELD = "parameters";
     private static final String AUTO_TRUNCATE_FIELD = "autoTruncate";
+    private static final String TASK_TYPE_FIELD = "task_type";
+
+    private static final String CLASSIFICATION_TASK_TYPE = "CLASSIFICATION";
+    private static final String CLUSTERING_TASK_TYPE = "CLUSTERING";
+    private static final String RETRIEVAL_DOCUMENT_TASK_TYPE = "RETRIEVAL_DOCUMENT";
+    private static final String RETRIEVAL_QUERY_TASK_TYPE = "RETRIEVAL_QUERY";
 
     public GoogleVertexAiEmbeddingsRequestEntity {
         Objects.requireNonNull(inputs);
+        Objects.requireNonNull(taskSettings);
     }
 
     @Override
@@ -35,16 +47,20 @@ public record GoogleVertexAiEmbeddingsRequestEntity(List<String> inputs, @Nullab
             builder.startObject();
             {
                 builder.field(CONTENT_FIELD, input);
+
+                if (taskSettings.getInputType() != null) {
+                    builder.field(TASK_TYPE_FIELD, convertToString(taskSettings.getInputType()));
+                }
             }
             builder.endObject();
         }
 
         builder.endArray();
 
-        if (autoTruncation != null) {
+        if (taskSettings.autoTruncate() != null) {
             builder.startObject(PARAMETERS_FIELD);
             {
-                builder.field(AUTO_TRUNCATE_FIELD, autoTruncation);
+                builder.field(AUTO_TRUNCATE_FIELD, taskSettings.autoTruncate());
             }
             builder.endObject();
         }
@@ -52,4 +68,17 @@ public record GoogleVertexAiEmbeddingsRequestEntity(List<String> inputs, @Nullab
 
         return builder;
     }
+
+    static String convertToString(InputType inputType) {
+        return switch (inputType) {
+            case INGEST -> RETRIEVAL_DOCUMENT_TASK_TYPE;
+            case SEARCH -> RETRIEVAL_QUERY_TASK_TYPE;
+            case CLASSIFICATION -> CLASSIFICATION_TASK_TYPE;
+            case CLUSTERING -> CLUSTERING_TASK_TYPE;
+            default -> {
+                assert false : invalidInputTypeMessage(inputType);
+                yield null;
+            }
+        };
+    }
 }

+ 17 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiModel.java

@@ -7,13 +7,16 @@
 
 package org.elasticsearch.xpack.inference.services.googlevertexai;
 
+import org.elasticsearch.inference.InputType;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.ModelConfigurations;
 import org.elasticsearch.inference.ModelSecrets;
 import org.elasticsearch.inference.ServiceSettings;
+import org.elasticsearch.inference.TaskSettings;
 import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
 import org.elasticsearch.xpack.inference.external.action.googlevertexai.GoogleVertexAiActionVisitor;
 
+import java.net.URI;
 import java.util.Map;
 import java.util.Objects;
 
@@ -21,6 +24,8 @@ public abstract class GoogleVertexAiModel extends Model {
 
     private final GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings;
 
+    protected URI uri;
+
     public GoogleVertexAiModel(
         ModelConfigurations configurations,
         ModelSecrets secrets,
@@ -34,13 +39,24 @@ public abstract class GoogleVertexAiModel extends Model {
     public GoogleVertexAiModel(GoogleVertexAiModel model, ServiceSettings serviceSettings) {
         super(model, serviceSettings);
 
+        uri = model.uri();
+        rateLimitServiceSettings = model.rateLimitServiceSettings();
+    }
+
+    public GoogleVertexAiModel(GoogleVertexAiModel model, TaskSettings taskSettings) {
+        super(model, taskSettings);
+
+        uri = model.uri();
         rateLimitServiceSettings = model.rateLimitServiceSettings();
     }
 
-    public abstract ExecutableAction accept(GoogleVertexAiActionVisitor creator, Map<String, Object> taskSettings);
+    public abstract ExecutableAction accept(GoogleVertexAiActionVisitor creator, Map<String, Object> taskSettings, InputType inputType);
 
     public GoogleVertexAiRateLimitServiceSettings rateLimitServiceSettings() {
         return rateLimitServiceSettings;
     }
 
+    public URI uri() {
+        return uri;
+    }
 }

+ 2 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java

@@ -210,7 +210,7 @@ public class GoogleVertexAiService extends SenderService {
 
         var actionCreator = new GoogleVertexAiActionCreator(getSender(), getServiceComponents());
 
-        var action = googleVertexAiModel.accept(actionCreator, taskSettings);
+        var action = googleVertexAiModel.accept(actionCreator, taskSettings, inputType);
         action.execute(inputs, timeout, listener);
     }
 
@@ -235,7 +235,7 @@ public class GoogleVertexAiService extends SenderService {
         ).batchRequestsWithListeners(listener);
 
         for (var request : batchedRequests) {
-            var action = googleVertexAiModel.accept(actionCreator, taskSettings);
+            var action = googleVertexAiModel.accept(actionCreator, taskSettings, inputType);
             action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());
         }
     }

+ 43 - 8
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModel.java

@@ -11,12 +11,14 @@ import org.apache.http.client.utils.URIBuilder;
 import org.elasticsearch.common.util.LazyInitializable;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.inference.ChunkingSettings;
+import org.elasticsearch.inference.InputType;
 import org.elasticsearch.inference.ModelConfigurations;
 import org.elasticsearch.inference.ModelSecrets;
 import org.elasticsearch.inference.SettingsConfiguration;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType;
 import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
+import org.elasticsearch.inference.configuration.SettingsConfigurationSelectOption;
 import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
 import org.elasticsearch.xpack.inference.external.action.googlevertexai.GoogleVertexAiActionVisitor;
 import org.elasticsearch.xpack.inference.external.request.googlevertexai.GoogleVertexAiUtils;
@@ -29,13 +31,25 @@ import java.net.URISyntaxException;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.stream.Stream;
 
 import static org.elasticsearch.core.Strings.format;
 import static org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings.AUTO_TRUNCATE;
+import static org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings.INPUT_TYPE;
 
 public class GoogleVertexAiEmbeddingsModel extends GoogleVertexAiModel {
 
-    private URI uri;
+    public static GoogleVertexAiEmbeddingsModel of(
+        GoogleVertexAiEmbeddingsModel model,
+        Map<String, Object> taskSettings,
+        InputType inputType
+    ) {
+        var requestTaskSettings = GoogleVertexAiEmbeddingsRequestTaskSettings.fromMap(taskSettings);
+        return new GoogleVertexAiEmbeddingsModel(
+            model,
+            GoogleVertexAiEmbeddingsTaskSettings.of(model.getTaskSettings(), requestTaskSettings, inputType)
+        );
+    }
 
     public GoogleVertexAiEmbeddingsModel(
         String inferenceEntityId,
@@ -62,6 +76,10 @@ public class GoogleVertexAiEmbeddingsModel extends GoogleVertexAiModel {
         super(model, serviceSettings);
     }
 
+    public GoogleVertexAiEmbeddingsModel(GoogleVertexAiEmbeddingsModel model, GoogleVertexAiEmbeddingsTaskSettings taskSettings) {
+        super(model, taskSettings);
+    }
+
     // Should only be used directly for testing
     GoogleVertexAiEmbeddingsModel(
         String inferenceEntityId,
@@ -126,13 +144,9 @@ public class GoogleVertexAiEmbeddingsModel extends GoogleVertexAiModel {
         return (GoogleVertexAiEmbeddingsRateLimitServiceSettings) super.rateLimitServiceSettings();
     }
 
-    public URI uri() {
-        return uri;
-    }
-
     @Override
-    public ExecutableAction accept(GoogleVertexAiActionVisitor visitor, Map<String, Object> taskSettings) {
-        return visitor.create(this, taskSettings);
+    public ExecutableAction accept(GoogleVertexAiActionVisitor visitor, Map<String, Object> taskSettings, InputType inputType) {
+        return visitor.create(this, taskSettings, inputType);
     }
 
     public static URI buildUri(String location, String projectId, String modelId) throws URISyntaxException {
@@ -161,11 +175,32 @@ public class GoogleVertexAiEmbeddingsModel extends GoogleVertexAiModel {
             new LazyInitializable<>(() -> {
                 var configurationMap = new HashMap<String, SettingsConfiguration>();
 
+                configurationMap.put(
+                    INPUT_TYPE,
+                    new SettingsConfiguration.Builder().setDisplay(SettingsConfigurationDisplayType.DROPDOWN)
+                        .setLabel("Input Type")
+                        .setOrder(1)
+                        .setRequired(false)
+                        .setSensitive(false)
+                        .setTooltip("Specifies the type of input passed to the model.")
+                        .setType(SettingsConfigurationFieldType.STRING)
+                        .setOptions(
+                            Stream.of(
+                                InputType.CLASSIFICATION.toString(),
+                                InputType.CLUSTERING.toString(),
+                                InputType.INGEST.toString(),
+                                InputType.SEARCH.toString()
+                            ).map(v -> new SettingsConfigurationSelectOption.Builder().setLabelAndValue(v).build()).toList()
+                        )
+                        .setValue("")
+                        .build()
+                );
+
                 configurationMap.put(
                     AUTO_TRUNCATE,
                     new SettingsConfiguration.Builder().setDisplay(SettingsConfigurationDisplayType.TOGGLE)
                         .setLabel("Auto Truncate")
-                        .setOrder(1)
+                        .setOrder(2)
                         .setRequired(false)
                         .setSensitive(false)
                         .setTooltip("Specifies if the API truncates inputs longer than the maximum token length automatically.")

+ 22 - 5
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsRequestTaskSettings.java

@@ -9,29 +9,46 @@ package org.elasticsearch.xpack.inference.services.googlevertexai.embeddings;
 
 import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.inference.ModelConfigurations;
 
 import java.util.Map;
 
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
+import static org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings.INPUT_TYPE;
+import static org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings.VALID_REQUEST_VALUES;
 
-public record GoogleVertexAiEmbeddingsRequestTaskSettings(@Nullable Boolean autoTruncate) {
+public record GoogleVertexAiEmbeddingsRequestTaskSettings(@Nullable Boolean autoTruncate, @Nullable InputType inputType) {
 
-    public static final GoogleVertexAiEmbeddingsRequestTaskSettings EMPTY_SETTINGS = new GoogleVertexAiEmbeddingsRequestTaskSettings(null);
+    public static final GoogleVertexAiEmbeddingsRequestTaskSettings EMPTY_SETTINGS = new GoogleVertexAiEmbeddingsRequestTaskSettings(
+        null,
+        null
+    );
 
     public static GoogleVertexAiEmbeddingsRequestTaskSettings fromMap(Map<String, Object> map) {
-        if (map.isEmpty()) {
-            return GoogleVertexAiEmbeddingsRequestTaskSettings.EMPTY_SETTINGS;
+        if (map == null || map.isEmpty()) {
+            return EMPTY_SETTINGS;
         }
 
         ValidationException validationException = new ValidationException();
 
+        InputType inputType = extractOptionalEnum(
+            map,
+            INPUT_TYPE,
+            ModelConfigurations.TASK_SETTINGS,
+            InputType::fromString,
+            VALID_REQUEST_VALUES,
+            validationException
+        );
+
         Boolean autoTruncate = extractOptionalBoolean(map, GoogleVertexAiEmbeddingsTaskSettings.AUTO_TRUNCATE, validationException);
 
         if (validationException.validationErrors().isEmpty() == false) {
             throw validationException;
         }
 
-        return new GoogleVertexAiEmbeddingsRequestTaskSettings(autoTruncate);
+        return new GoogleVertexAiEmbeddingsRequestTaskSettings(autoTruncate, inputType);
     }
 
 }

+ 88 - 11
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsTaskSettings.java

@@ -9,19 +9,24 @@ package org.elasticsearch.xpack.inference.services.googlevertexai.embeddings;
 
 import org.elasticsearch.TransportVersion;
 import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.inference.ModelConfigurations;
 import org.elasticsearch.inference.TaskSettings;
 import org.elasticsearch.xcontent.XContentBuilder;
 
 import java.io.IOException;
+import java.util.EnumSet;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Objects;
 
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
 
 public class GoogleVertexAiEmbeddingsTaskSettings implements TaskSettings {
 
@@ -29,48 +34,108 @@ public class GoogleVertexAiEmbeddingsTaskSettings implements TaskSettings {
 
     public static final String AUTO_TRUNCATE = "auto_truncate";
 
-    public static final GoogleVertexAiEmbeddingsTaskSettings EMPTY_SETTINGS = new GoogleVertexAiEmbeddingsTaskSettings(
-        Boolean.valueOf(null)
+    public static final String INPUT_TYPE = "input_type";
+
+    static final EnumSet<InputType> VALID_REQUEST_VALUES = EnumSet.of(
+        InputType.INGEST,
+        InputType.SEARCH,
+        InputType.CLASSIFICATION,
+        InputType.CLUSTERING
     );
 
+    public static final GoogleVertexAiEmbeddingsTaskSettings EMPTY_SETTINGS = new GoogleVertexAiEmbeddingsTaskSettings(null, null);
+
     public static GoogleVertexAiEmbeddingsTaskSettings fromMap(Map<String, Object> map) {
+        if (map == null || map.isEmpty()) {
+            return EMPTY_SETTINGS;
+        }
+
         ValidationException validationException = new ValidationException();
 
+        InputType inputType = extractOptionalEnum(
+            map,
+            INPUT_TYPE,
+            ModelConfigurations.TASK_SETTINGS,
+            InputType::fromString,
+            VALID_REQUEST_VALUES,
+            validationException
+        );
+
         Boolean autoTruncate = extractOptionalBoolean(map, AUTO_TRUNCATE, validationException);
         if (validationException.validationErrors().isEmpty() == false) {
             throw validationException;
         }
 
-        return new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate);
+        return new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate, inputType);
     }
 
     public static GoogleVertexAiEmbeddingsTaskSettings of(
         GoogleVertexAiEmbeddingsTaskSettings originalSettings,
-        GoogleVertexAiEmbeddingsRequestTaskSettings requestSettings
+        GoogleVertexAiEmbeddingsRequestTaskSettings requestSettings,
+        InputType requestInputType
     ) {
+        var inputTypeToUse = getValidInputType(originalSettings, requestSettings, requestInputType);
         var autoTruncate = requestSettings.autoTruncate() == null ? originalSettings.autoTruncate : requestSettings.autoTruncate();
-        return new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate);
+        return new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate, inputTypeToUse);
+    }
+
+    private static InputType getValidInputType(
+        GoogleVertexAiEmbeddingsTaskSettings originalSettings,
+        GoogleVertexAiEmbeddingsRequestTaskSettings requestTaskSettings,
+        InputType requestInputType
+    ) {
+        InputType inputTypeToUse = originalSettings.inputType;
+
+        if (VALID_REQUEST_VALUES.contains(requestInputType)) {
+            inputTypeToUse = requestInputType;
+        } else if (requestTaskSettings.inputType() != null) {
+            inputTypeToUse = requestTaskSettings.inputType();
+        }
+
+        return inputTypeToUse;
     }
 
+    private final InputType inputType;
     private final Boolean autoTruncate;
 
-    public GoogleVertexAiEmbeddingsTaskSettings(@Nullable Boolean autoTruncate) {
+    public GoogleVertexAiEmbeddingsTaskSettings(@Nullable Boolean autoTruncate, @Nullable InputType inputType) {
+        validateInputType(inputType);
+        this.inputType = inputType;
         this.autoTruncate = autoTruncate;
     }
 
     public GoogleVertexAiEmbeddingsTaskSettings(StreamInput in) throws IOException {
         this.autoTruncate = in.readOptionalBoolean();
+
+        var inputType = (in.getTransportVersion().onOrAfter(TransportVersions.VERTEX_AI_INPUT_TYPE_ADDED))
+            ? in.readOptionalEnum(InputType.class)
+            : null;
+
+        validateInputType(inputType);
+        this.inputType = inputType;
+    }
+
+    private static void validateInputType(InputType inputType) {
+        if (inputType == null) {
+            return;
+        }
+
+        assert VALID_REQUEST_VALUES.contains(inputType) : invalidInputTypeMessage(inputType);
     }
 
     @Override
     public boolean isEmpty() {
-        return autoTruncate == null;
+        return inputType == null && autoTruncate == null;
     }
 
     public Boolean autoTruncate() {
         return autoTruncate;
     }
 
+    public InputType getInputType() {
+        return inputType;
+    }
+
     @Override
     public String getWriteableName() {
         return NAME;
@@ -84,11 +149,19 @@ public class GoogleVertexAiEmbeddingsTaskSettings implements TaskSettings {
     @Override
     public void writeTo(StreamOutput out) throws IOException {
         out.writeOptionalBoolean(this.autoTruncate);
+
+        if (out.getTransportVersion().onOrAfter(TransportVersions.VERTEX_AI_INPUT_TYPE_ADDED)) {
+            out.writeOptionalEnum(this.inputType);
+        }
     }
 
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
+        if (inputType != null) {
+            builder.field(INPUT_TYPE, inputType);
+        }
+
         if (autoTruncate != null) {
             builder.field(AUTO_TRUNCATE, autoTruncate);
         }
@@ -101,19 +174,23 @@ public class GoogleVertexAiEmbeddingsTaskSettings implements TaskSettings {
         if (this == object) return true;
         if (object == null || getClass() != object.getClass()) return false;
         GoogleVertexAiEmbeddingsTaskSettings that = (GoogleVertexAiEmbeddingsTaskSettings) object;
-        return Objects.equals(autoTruncate, that.autoTruncate);
+        return Objects.equals(inputType, that.inputType) && Objects.equals(autoTruncate, that.autoTruncate);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(autoTruncate);
+        return Objects.hash(autoTruncate, inputType);
+    }
+
+    public static String invalidInputTypeMessage(InputType inputType) {
+        return Strings.format("received invalid input type value [%s]", inputType.toString());
     }
 
     @Override
     public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
-        GoogleVertexAiEmbeddingsRequestTaskSettings requestSettings = GoogleVertexAiEmbeddingsRequestTaskSettings.fromMap(
+        GoogleVertexAiEmbeddingsRequestTaskSettings updatedSettings = GoogleVertexAiEmbeddingsRequestTaskSettings.fromMap(
             new HashMap<>(newSettings)
         );
-        return of(this, requestSettings);
+        return of(this, updatedSettings, updatedSettings.inputType() != null ? updatedSettings.inputType() : this.inputType);
     }
 }

+ 2 - 7
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/rerank/GoogleVertexAiRerankModel.java

@@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference.services.googlevertexai.rerank;
 import org.apache.http.client.utils.URIBuilder;
 import org.elasticsearch.common.util.LazyInitializable;
 import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.InputType;
 import org.elasticsearch.inference.ModelConfigurations;
 import org.elasticsearch.inference.ModelSecrets;
 import org.elasticsearch.inference.SettingsConfiguration;
@@ -34,8 +35,6 @@ import static org.elasticsearch.xpack.inference.services.googlevertexai.rerank.G
 
 public class GoogleVertexAiRerankModel extends GoogleVertexAiModel {
 
-    private URI uri;
-
     public GoogleVertexAiRerankModel(
         String inferenceEntityId,
         TaskType taskType,
@@ -122,12 +121,8 @@ public class GoogleVertexAiRerankModel extends GoogleVertexAiModel {
         return (GoogleDiscoveryEngineRateLimitServiceSettings) super.rateLimitServiceSettings();
     }
 
-    public URI uri() {
-        return uri;
-    }
-
     @Override
-    public ExecutableAction accept(GoogleVertexAiActionVisitor visitor, Map<String, Object> taskSettings) {
+    public ExecutableAction accept(GoogleVertexAiActionVisitor visitor, Map<String, Object> taskSettings, InputType inputType) {
         return visitor.create(this, taskSettings);
     }
 

+ 86 - 10
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiEmbeddingsRequestEntityTests.java

@@ -8,10 +8,12 @@
 package org.elasticsearch.xpack.inference.external.request.googlevertexai;
 
 import org.elasticsearch.common.Strings;
+import org.elasticsearch.inference.InputType;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentFactory;
 import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings;
 
 import java.io.IOException;
 import java.util.List;
@@ -20,8 +22,11 @@ import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhi
 
 public class GoogleVertexAiEmbeddingsRequestEntityTests extends ESTestCase {
 
-    public void testToXContent_SingleEmbeddingRequest_WritesAutoTruncationIfDefined() throws IOException {
-        var entity = new GoogleVertexAiEmbeddingsRequestEntity(List.of("abc"), true);
+    public void testToXContent_SingleEmbeddingRequest_WritesAllFields() throws IOException {
+        var entity = new GoogleVertexAiEmbeddingsRequestEntity(
+            List.of("abc"),
+            new GoogleVertexAiEmbeddingsTaskSettings(true, InputType.SEARCH)
+        );
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
         entity.toXContent(builder, null);
@@ -31,7 +36,8 @@ public class GoogleVertexAiEmbeddingsRequestEntityTests extends ESTestCase {
             {
                 "instances": [
                     {
-                        "content": "abc"
+                        "content": "abc",
+                        "task_type": "RETRIEVAL_QUERY"
                     }
                 ],
                 "parameters": {
@@ -42,7 +48,10 @@ public class GoogleVertexAiEmbeddingsRequestEntityTests extends ESTestCase {
     }
 
     public void testToXContent_SingleEmbeddingRequest_DoesNotWriteAutoTruncationIfNotDefined() throws IOException {
-        var entity = new GoogleVertexAiEmbeddingsRequestEntity(List.of("abc"), null);
+        var entity = new GoogleVertexAiEmbeddingsRequestEntity(
+            List.of("abc"),
+            new GoogleVertexAiEmbeddingsTaskSettings(null, InputType.INGEST)
+        );
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
         entity.toXContent(builder, null);
@@ -52,15 +61,16 @@ public class GoogleVertexAiEmbeddingsRequestEntityTests extends ESTestCase {
             {
                 "instances": [
                     {
-                        "content": "abc"
+                        "content": "abc",
+                        "task_type": "RETRIEVAL_DOCUMENT"
                     }
                 ]
             }
             """));
     }
 
-    public void testToXContent_MultipleEmbeddingsRequest_WritesAutoTruncationIfDefined() throws IOException {
-        var entity = new GoogleVertexAiEmbeddingsRequestEntity(List.of("abc", "def"), true);
+    public void testToXContent_SingleEmbeddingRequest_DoesNotWriteInputTypeIfNotDefined() throws IOException {
+        var entity = new GoogleVertexAiEmbeddingsRequestEntity(List.of("abc"), new GoogleVertexAiEmbeddingsTaskSettings(false, null));
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
         entity.toXContent(builder, null);
@@ -71,9 +81,35 @@ public class GoogleVertexAiEmbeddingsRequestEntityTests extends ESTestCase {
                 "instances": [
                     {
                         "content": "abc"
+                    }
+                ],
+                "parameters": {
+                    "autoTruncate": false
+                }
+            }
+            """));
+    }
+
+    public void testToXContent_MultipleEmbeddingsRequest_WritesAllFields() throws IOException {
+        var entity = new GoogleVertexAiEmbeddingsRequestEntity(
+            List.of("abc", "def"),
+            new GoogleVertexAiEmbeddingsTaskSettings(true, InputType.CLUSTERING)
+        );
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        entity.toXContent(builder, null);
+        String xContentResult = Strings.toString(builder);
+
+        assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
+            {
+                "instances": [
+                    {
+                        "content": "abc",
+                        "task_type": "CLUSTERING"
                     },
                     {
-                        "content": "def"
+                        "content": "def",
+                        "task_type": "CLUSTERING"
                     }
                 ],
                 "parameters": {
@@ -83,8 +119,8 @@ public class GoogleVertexAiEmbeddingsRequestEntityTests extends ESTestCase {
             """));
     }
 
-    public void testToXContent_MultipleEmbeddingsRequest_DoesNotWriteAutoTruncationIfNotDefined() throws IOException {
-        var entity = new GoogleVertexAiEmbeddingsRequestEntity(List.of("abc", "def"), null);
+    public void testToXContent_MultipleEmbeddingsRequest_DoesNotWriteInputTypeIfNotDefined() throws IOException {
+        var entity = new GoogleVertexAiEmbeddingsRequestEntity(List.of("abc", "def"), new GoogleVertexAiEmbeddingsTaskSettings(true, null));
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
         entity.toXContent(builder, null);
@@ -99,8 +135,48 @@ public class GoogleVertexAiEmbeddingsRequestEntityTests extends ESTestCase {
                     {
                         "content": "def"
                     }
+                ],
+                "parameters": {
+                    "autoTruncate": true
+                }
+            }
+            """));
+    }
+
+    public void testToXContent_MultipleEmbeddingsRequest_DoesNotWriteAutoTruncationIfNotDefined() throws IOException {
+        var entity = new GoogleVertexAiEmbeddingsRequestEntity(
+            List.of("abc", "def"),
+            new GoogleVertexAiEmbeddingsTaskSettings(null, InputType.CLASSIFICATION)
+        );
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        entity.toXContent(builder, null);
+        String xContentResult = Strings.toString(builder);
+
+        assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
+            {
+                "instances": [
+                    {
+                        "content": "abc",
+                        "task_type": "CLASSIFICATION"
+                    },
+                    {
+                        "content": "def",
+                        "task_type": "CLASSIFICATION"
+                    }
                 ]
             }
             """));
     }
+
+    public void testToXContent_ThrowsIfInputIsNull() {
+        expectThrows(
+            NullPointerException.class,
+            () -> new GoogleVertexAiEmbeddingsRequestEntity(null, new GoogleVertexAiEmbeddingsTaskSettings(null, InputType.CLASSIFICATION))
+        );
+    }
+
+    public void testToXContent_ThrowsIfTaskSettingsIsNull() {
+        expectThrows(NullPointerException.class, () -> new GoogleVertexAiEmbeddingsRequestEntity(List.of("abc", "def"), null));
+    }
 }

+ 30 - 6
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googlevertexai/GoogleVertexAiEmbeddingsRequestTests.java

@@ -10,6 +10,7 @@ package org.elasticsearch.xpack.inference.external.request.googlevertexai;
 import org.apache.http.HttpHeaders;
 import org.apache.http.client.methods.HttpPost;
 import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.InputType;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xpack.inference.common.Truncator;
@@ -31,11 +32,11 @@ public class GoogleVertexAiEmbeddingsRequestTests extends ESTestCase {
 
     private static final String AUTH_HEADER_VALUE = "foo";
 
-    public void testCreateRequest_WithoutDimensionsSet_And_WithoutAutoTruncateSet() throws IOException {
+    public void testCreateRequest_WithoutDimensionsSet_And_WithoutAutoTruncateSet_And_WithoutInputTypeSet() throws IOException {
         var model = "model";
         var input = "input";
 
-        var request = createRequest(model, input, null);
+        var request = createRequest(model, input, null, null);
         var httpRequest = request.createHttpRequest();
 
         assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
@@ -54,7 +55,7 @@ public class GoogleVertexAiEmbeddingsRequestTests extends ESTestCase {
         var input = "input";
         var autoTruncate = true;
 
-        var request = createRequest(model, input, autoTruncate);
+        var request = createRequest(model, input, autoTruncate, null);
         var httpRequest = request.createHttpRequest();
 
         assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
@@ -68,11 +69,29 @@ public class GoogleVertexAiEmbeddingsRequestTests extends ESTestCase {
         assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input")), "parameters", Map.of("autoTruncate", true))));
     }
 
+    public void testCreateRequest_WithInputTypeSet() throws IOException {
+        var model = "model";
+        var input = "input";
+
+        var request = createRequest(model, input, null, InputType.SEARCH);
+        var httpRequest = request.createHttpRequest();
+
+        assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
+        var httpPost = (HttpPost) httpRequest.httpRequestBase();
+
+        assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
+        assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is(AUTH_HEADER_VALUE));
+
+        var requestMap = entityAsMap(httpPost.getEntity().getContent());
+        assertThat(requestMap, aMapWithSize(1));
+        assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "input", "task_type", "RETRIEVAL_QUERY")))));
+    }
+
     public void testTruncate_ReducesInputTextSizeByHalf() throws IOException {
         var model = "model";
         var input = "abcd";
 
-        var request = createRequest(model, input, null);
+        var request = createRequest(model, input, null, null);
         var truncatedRequest = request.truncate();
         var httpRequest = truncatedRequest.createHttpRequest();
 
@@ -87,8 +106,13 @@ public class GoogleVertexAiEmbeddingsRequestTests extends ESTestCase {
         assertThat(requestMap, is(Map.of("instances", List.of(Map.of("content", "ab")))));
     }
 
-    private static GoogleVertexAiEmbeddingsRequest createRequest(String modelId, String input, @Nullable Boolean autoTruncate) {
-        var embeddingsModel = GoogleVertexAiEmbeddingsModelTests.createModel(modelId, autoTruncate);
+    private static GoogleVertexAiEmbeddingsRequest createRequest(
+        String modelId,
+        String input,
+        @Nullable Boolean autoTruncate,
+        @Nullable InputType inputType
+    ) {
+        var embeddingsModel = GoogleVertexAiEmbeddingsModelTests.createModel(modelId, autoTruncate, inputType);
 
         return new GoogleVertexAiEmbeddingsWithoutAuthRequest(
             TruncatorTests.createTruncator(),

+ 64 - 26
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiServiceTests.java

@@ -13,8 +13,10 @@ import org.elasticsearch.common.bytes.BytesArray;
 import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.xcontent.XContentHelper;
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.inference.ChunkingSettings;
 import org.elasticsearch.inference.InferenceServiceConfiguration;
+import org.elasticsearch.inference.InputType;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.ModelConfigurations;
 import org.elasticsearch.inference.TaskType;
@@ -109,7 +111,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
                             projectId
                         )
                     ),
-                    new HashMap<>(Map.of()),
+                    getTaskSettingsMap(true, InputType.INGEST),
                     getSecretSettingsMap(serviceAccountJson)
                 ),
                 modelListener
@@ -154,7 +156,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
                             projectId
                         )
                     ),
-                    new HashMap<>(Map.of()),
+                    getTaskSettingsMap(true, InputType.INGEST),
                     createRandomChunkingSettingsMap(),
                     getSecretSettingsMap(serviceAccountJson)
                 ),
@@ -200,7 +202,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
                             projectId
                         )
                     ),
-                    new HashMap<>(Map.of()),
+                    getTaskSettingsMap(false, InputType.SEARCH),
                     getSecretSettingsMap(serviceAccountJson)
                 ),
                 modelListener
@@ -281,7 +283,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
                         "project"
                     )
                 ),
-                getTaskSettingsMap(true),
+                getTaskSettingsMap(true, InputType.SEARCH),
                 getSecretSettingsMap("{}")
             );
             config.put("extra_key", "value");
@@ -308,7 +310,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
             );
             serviceSettings.put("extra_key", "value");
 
-            var config = getRequestConfigMap(serviceSettings, getTaskSettingsMap(true), getSecretSettingsMap("{}"));
+            var config = getRequestConfigMap(serviceSettings, getTaskSettingsMap(true, InputType.CLUSTERING), getSecretSettingsMap("{}"));
 
             var failureListener = getModelListenerForException(
                 ElasticsearchStatusException.class,
@@ -362,7 +364,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
                         "project"
                     )
                 ),
-                getTaskSettingsMap(true),
+                getTaskSettingsMap(true, null),
                 secretSettings
             );
 
@@ -399,7 +401,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
                         true
                     )
                 ),
-                getTaskSettingsMap(autoTruncate),
+                getTaskSettingsMap(autoTruncate, InputType.SEARCH),
                 getSecretSettingsMap(serviceAccountJson)
             );
 
@@ -417,7 +419,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
             assertThat(embeddingsModel.getServiceSettings().location(), is(location));
             assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId));
             assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE));
-            assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate)));
+            assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate, InputType.SEARCH)));
             assertThat(embeddingsModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson));
         }
     }
@@ -447,7 +449,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
                         true
                     )
                 ),
-                getTaskSettingsMap(autoTruncate),
+                getTaskSettingsMap(autoTruncate, null),
                 createRandomChunkingSettingsMap(),
                 getSecretSettingsMap(serviceAccountJson)
             );
@@ -466,7 +468,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
             assertThat(embeddingsModel.getServiceSettings().location(), is(location));
             assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId));
             assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE));
-            assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate)));
+            assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate, null)));
             assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
             assertThat(embeddingsModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson));
         }
@@ -497,7 +499,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
                         true
                     )
                 ),
-                getTaskSettingsMap(autoTruncate),
+                getTaskSettingsMap(autoTruncate, null),
                 getSecretSettingsMap(serviceAccountJson)
             );
 
@@ -515,7 +517,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
             assertThat(embeddingsModel.getServiceSettings().location(), is(location));
             assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId));
             assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE));
-            assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate)));
+            assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate, null)));
             assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
             assertThat(embeddingsModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson));
         }
@@ -573,7 +575,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
                         true
                     )
                 ),
-                getTaskSettingsMap(autoTruncate),
+                getTaskSettingsMap(autoTruncate, InputType.INGEST),
                 getSecretSettingsMap(serviceAccountJson)
             );
             persistedConfig.config().put("extra_key", "value");
@@ -592,7 +594,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
             assertThat(embeddingsModel.getServiceSettings().location(), is(location));
             assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId));
             assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE));
-            assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate)));
+            assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate, InputType.INGEST)));
             assertThat(embeddingsModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson));
         }
     }
@@ -625,7 +627,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
                         true
                     )
                 ),
-                getTaskSettingsMap(autoTruncate),
+                getTaskSettingsMap(autoTruncate, null),
                 secretSettingsMap
             );
 
@@ -643,7 +645,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
             assertThat(embeddingsModel.getServiceSettings().location(), is(location));
             assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId));
             assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE));
-            assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate)));
+            assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate, null)));
             assertThat(embeddingsModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson));
         }
     }
@@ -676,7 +678,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
 
             var persistedConfig = getPersistedConfigMap(
                 serviceSettingsMap,
-                getTaskSettingsMap(autoTruncate),
+                getTaskSettingsMap(autoTruncate, InputType.CLUSTERING),
                 getSecretSettingsMap(serviceAccountJson)
             );
 
@@ -694,7 +696,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
             assertThat(embeddingsModel.getServiceSettings().location(), is(location));
             assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId));
             assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE));
-            assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate)));
+            assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate, InputType.CLUSTERING)));
             assertThat(embeddingsModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson));
         }
     }
@@ -711,7 +713,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
             """;
 
         try (var service = createGoogleVertexAiService()) {
-            var taskSettings = getTaskSettingsMap(autoTruncate);
+            var taskSettings = getTaskSettingsMap(autoTruncate, InputType.SEARCH);
             taskSettings.put("extra_key", "value");
 
             var persistedConfig = getPersistedConfigMap(
@@ -745,7 +747,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
             assertThat(embeddingsModel.getServiceSettings().location(), is(location));
             assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId));
             assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE));
-            assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate)));
+            assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate, InputType.SEARCH)));
             assertThat(embeddingsModel.getSecretSettings().serviceAccountJson().toString(), is(serviceAccountJson));
         }
     }
@@ -770,7 +772,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
                         true
                     )
                 ),
-                getTaskSettingsMap(autoTruncate),
+                getTaskSettingsMap(autoTruncate, null),
                 createRandomChunkingSettingsMap()
             );
 
@@ -783,7 +785,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
             assertThat(embeddingsModel.getServiceSettings().location(), is(location));
             assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId));
             assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE));
-            assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate)));
+            assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate, null)));
             assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
         }
     }
@@ -808,7 +810,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
                         true
                     )
                 ),
-                getTaskSettingsMap(autoTruncate)
+                getTaskSettingsMap(autoTruncate, null)
             );
 
             var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
@@ -820,7 +822,7 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
             assertThat(embeddingsModel.getServiceSettings().location(), is(location));
             assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId));
             assertThat(embeddingsModel.getServiceSettings().dimensionsSetByUser(), is(Boolean.TRUE));
-            assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate)));
+            assertThat(embeddingsModel.getTaskSettings(), is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate, null)));
             assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
         }
     }
@@ -838,12 +840,44 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
                                 {
                                     "task_type": "text_embedding",
                                     "configuration": {
+                                         "input_type": {
+                                             "default_value": null,
+                                             "depends_on": [],
+                                             "display": "dropdown",
+                                             "label": "Input Type",
+                                             "options": [
+                                                 {
+                                                     "label": "classification",
+                                                     "value": "classification"
+                                                 },
+                                                 {
+                                                     "label": "clustering",
+                                                     "value": "clustering"
+                                                 },
+                                                 {
+                                                     "label": "ingest",
+                                                     "value": "ingest"
+                                                 },
+                                                 {
+                                                     "label": "search",
+                                                     "value": "search"
+                                                 }
+                                             ],
+                                             "order": 1,
+                                             "required": false,
+                                             "sensitive": false,
+                                             "tooltip": "Specifies the type of input passed to the model.",
+                                             "type": "str",
+                                             "ui_restrictions": [],
+                                             "validations": [],
+                                             "value": ""
+                                        },
                                         "auto_truncate": {
                                             "default_value": null,
                                             "depends_on": [],
                                             "display": "toggle",
                                             "label": "Auto Truncate",
-                                            "order": 1,
+                                            "order": 2,
                                             "required": false,
                                             "sensitive": false,
                                             "tooltip": "Specifies if the API truncates inputs longer than the maximum token length automatically.",
@@ -1005,11 +1039,15 @@ public class GoogleVertexAiServiceTests extends ESTestCase {
         });
     }
 
-    private static Map<String, Object> getTaskSettingsMap(Boolean autoTruncate) {
+    private static Map<String, Object> getTaskSettingsMap(Boolean autoTruncate, @Nullable InputType inputType) {
         var taskSettings = new HashMap<String, Object>();
 
         taskSettings.put(GoogleVertexAiEmbeddingsTaskSettings.AUTO_TRUNCATE, autoTruncate);
 
+        if (inputType != null) {
+            taskSettings.put(GoogleVertexAiEmbeddingsTaskSettings.INPUT_TYPE, inputType.toString());
+        }
+
         return taskSettings;
     }
 

+ 101 - 3
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsModelTests.java

@@ -10,14 +10,18 @@ package org.elasticsearch.xpack.inference.services.googlevertexai.embeddings;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.settings.SecureString;
 import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.InputType;
 import org.elasticsearch.inference.SimilarityMeasure;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings;
+import org.hamcrest.MatcherAssert;
 
 import java.net.URI;
 import java.net.URISyntaxException;
+import java.util.Map;
 
+import static org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettingsTests.getTaskSettingsMap;
 import static org.hamcrest.Matchers.is;
 
 public class GoogleVertexAiEmbeddingsModelTests extends ESTestCase {
@@ -45,6 +49,75 @@ public class GoogleVertexAiEmbeddingsModelTests extends ESTestCase {
         );
     }
 
+    public void testOverrideWith_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreEmpty_AndInputTypeIsInvalid() {
+        var model = createModel("model", Boolean.FALSE, InputType.SEARCH);
+        var overriddenModel = GoogleVertexAiEmbeddingsModel.of(model, Map.of(), InputType.UNSPECIFIED);
+
+        MatcherAssert.assertThat(overriddenModel, is(model));
+    }
+
+    public void testOverrideWith_DoesNotOverrideAndModelRemainsEqual_WhenSettingsAreNull_AndInputTypeIsInvalid() {
+        var model = createModel("model", Boolean.FALSE, InputType.SEARCH);
+        var overriddenModel = GoogleVertexAiEmbeddingsModel.of(model, null, InputType.UNSPECIFIED);
+
+        MatcherAssert.assertThat(overriddenModel, is(model));
+    }
+
+    public void testOverrideWith_SetsInputTypeToOverride_WhenFieldIsNullInModelTaskSettings_AndNullInRequestTaskSettings() {
+        var model = createModel("model", Boolean.FALSE, null);
+        var overriddenModel = GoogleVertexAiEmbeddingsModel.of(model, getTaskSettingsMap(null, null), InputType.SEARCH);
+
+        var expectedModel = createModel("model", Boolean.FALSE, InputType.SEARCH);
+        MatcherAssert.assertThat(overriddenModel, is(expectedModel));
+    }
+
+    public void testOverrideWith_SetsInputType_FromRequest_IfValid_OverridingStoredTaskSettings() {
+        var model = createModel("model", Boolean.FALSE, InputType.INGEST);
+        var overriddenModel = GoogleVertexAiEmbeddingsModel.of(model, getTaskSettingsMap(null, null), InputType.SEARCH);
+
+        var expectedModel = createModel("model", Boolean.FALSE, InputType.SEARCH);
+        MatcherAssert.assertThat(overriddenModel, is(expectedModel));
+    }
+
+    public void testOverrideWith_SetsInputType_FromRequest_IfValid_OverridingRequestTaskSettings() {
+        var model = createModel("model", Boolean.FALSE, null);
+        var overriddenModel = GoogleVertexAiEmbeddingsModel.of(model, getTaskSettingsMap(null, InputType.CLUSTERING), InputType.SEARCH);
+
+        var expectedModel = createModel("model", Boolean.FALSE, InputType.SEARCH);
+        MatcherAssert.assertThat(overriddenModel, is(expectedModel));
+    }
+
+    public void testOverrideWith_OverridesInputType_WithRequestTaskSettingsSearch_WhenRequestInputTypeIsInvalid() {
+        var model = createModel("model", Boolean.FALSE, InputType.INGEST);
+        var overriddenModel = GoogleVertexAiEmbeddingsModel.of(model, getTaskSettingsMap(null, InputType.SEARCH), InputType.UNSPECIFIED);
+
+        var expectedModel = createModel("model", Boolean.FALSE, InputType.SEARCH);
+        MatcherAssert.assertThat(overriddenModel, is(expectedModel));
+    }
+
+    public void testOverrideWith_DoesNotSetInputType_FromRequest_IfInputTypeIsInvalid() {
+        var model = createModel("model", Boolean.FALSE, null);
+        var overriddenModel = GoogleVertexAiEmbeddingsModel.of(model, getTaskSettingsMap(null, null), InputType.UNSPECIFIED);
+
+        var expectedModel = createModel("model", Boolean.FALSE, null);
+        MatcherAssert.assertThat(overriddenModel, is(expectedModel));
+    }
+
+    public void testOverrideWith_DoesNotSetInputType_WhenRequestTaskSettingsIsNull_AndRequestInputTypeIsInvalid() {
+        var model = createModel("model", Boolean.FALSE, InputType.INGEST);
+        var overriddenModel = GoogleVertexAiEmbeddingsModel.of(model, getTaskSettingsMap(null, null), InputType.UNSPECIFIED);
+
+        var expectedModel = createModel("model", Boolean.FALSE, InputType.INGEST);
+        MatcherAssert.assertThat(overriddenModel, is(expectedModel));
+    }
+
+    public void testOverrideWith_DoesNotOverrideModelUri() {
+        var model = createModel("model", Boolean.FALSE, InputType.SEARCH);
+        var overriddenModel = GoogleVertexAiEmbeddingsModel.of(model, Map.of(), null);
+
+        MatcherAssert.assertThat(overriddenModel.uri(), is(model.uri()));
+    }
+
     public static GoogleVertexAiEmbeddingsModel createModel(
         String location,
         String projectId,
@@ -58,12 +131,37 @@ public class GoogleVertexAiEmbeddingsModelTests extends ESTestCase {
             "service",
             uri,
             new GoogleVertexAiEmbeddingsServiceSettings(location, projectId, modelId, false, null, null, null, null),
-            new GoogleVertexAiEmbeddingsTaskSettings(Boolean.FALSE),
+            new GoogleVertexAiEmbeddingsTaskSettings(Boolean.FALSE, null),
             new GoogleVertexAiSecretSettings(new SecureString(serviceAccountJson.toCharArray()))
         );
     }
 
-    public static GoogleVertexAiEmbeddingsModel createModel(String modelId, @Nullable Boolean autoTruncate) {
+    public static GoogleVertexAiEmbeddingsModel createModel(String modelId, @Nullable Boolean autoTruncate, @Nullable InputType inputType) {
+        return new GoogleVertexAiEmbeddingsModel(
+            "id",
+            TaskType.TEXT_EMBEDDING,
+            "service",
+            new GoogleVertexAiEmbeddingsServiceSettings(
+                "location",
+                "projectId",
+                modelId,
+                false,
+                null,
+                null,
+                SimilarityMeasure.DOT_PRODUCT,
+                null
+            ),
+            new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate, inputType),
+            null,
+            new GoogleVertexAiSecretSettings(new SecureString("testString".toCharArray()))
+        );
+    }
+
+    public static GoogleVertexAiEmbeddingsModel createRandomizedModel(
+        String modelId,
+        @Nullable Boolean autoTruncate,
+        @Nullable InputType inputType
+    ) {
         return new GoogleVertexAiEmbeddingsModel(
             "id",
             TaskType.TEXT_EMBEDDING,
@@ -78,7 +176,7 @@ public class GoogleVertexAiEmbeddingsModelTests extends ESTestCase {
                 SimilarityMeasure.DOT_PRODUCT,
                 null
             ),
-            new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate),
+            new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate, inputType),
             null,
             new GoogleVertexAiSecretSettings(new SecureString(randomAlphaOfLength(8).toCharArray()))
         );

+ 43 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsRequestTaskSettingsTests.java

@@ -7,6 +7,8 @@
 
 package org.elasticsearch.xpack.inference.services.googlevertexai.embeddings;
 
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.inference.InputType;
 import org.elasticsearch.test.ESTestCase;
 
 import java.util.HashMap;
@@ -21,9 +23,14 @@ public class GoogleVertexAiEmbeddingsRequestTaskSettingsTests extends ESTestCase
         assertThat(requestTaskSettings, is(GoogleVertexAiEmbeddingsRequestTaskSettings.EMPTY_SETTINGS));
     }
 
+    public void testFromMap_ReturnsEmptySettings_IfMapNull() {
+        var requestTaskSettings = GoogleVertexAiEmbeddingsRequestTaskSettings.fromMap(null);
+        assertThat(requestTaskSettings, is(GoogleVertexAiEmbeddingsRequestTaskSettings.EMPTY_SETTINGS));
+    }
+
     public void testFromMap_DoesNotThrowValidationException_IfAutoTruncateIsMissing() {
         var requestTaskSettings = GoogleVertexAiEmbeddingsRequestTaskSettings.fromMap(new HashMap<>(Map.of("unrelated", true)));
-        assertThat(requestTaskSettings, is(new GoogleVertexAiEmbeddingsRequestTaskSettings(null)));
+        assertThat(requestTaskSettings, is(new GoogleVertexAiEmbeddingsRequestTaskSettings(null, null)));
     }
 
     public void testFromMap_ExtractsAutoTruncate() {
@@ -31,6 +38,40 @@ public class GoogleVertexAiEmbeddingsRequestTaskSettingsTests extends ESTestCase
         var requestTaskSettings = GoogleVertexAiEmbeddingsRequestTaskSettings.fromMap(
             new HashMap<>(Map.of(GoogleVertexAiEmbeddingsTaskSettings.AUTO_TRUNCATE, autoTruncate))
         );
-        assertThat(requestTaskSettings, is(new GoogleVertexAiEmbeddingsRequestTaskSettings(autoTruncate)));
+        assertThat(requestTaskSettings, is(new GoogleVertexAiEmbeddingsRequestTaskSettings(autoTruncate, null)));
+    }
+
+    public void testFromMap_ThrowsValidationException_IfAutoTruncateIsInvalidValue() {
+        expectThrows(
+            ValidationException.class,
+            () -> GoogleVertexAiEmbeddingsRequestTaskSettings.fromMap(
+                new HashMap<>(Map.of(GoogleVertexAiEmbeddingsTaskSettings.AUTO_TRUNCATE, "invalid"))
+            )
+        );
+    }
+
+    public void testFromMap_ExtractsInputType() {
+        var requestTaskSettings = GoogleVertexAiEmbeddingsRequestTaskSettings.fromMap(
+            new HashMap<>(Map.of(GoogleVertexAiEmbeddingsTaskSettings.INPUT_TYPE, InputType.INGEST.toString()))
+        );
+        assertThat(requestTaskSettings, is(new GoogleVertexAiEmbeddingsRequestTaskSettings(null, InputType.INGEST)));
+    }
+
+    public void testFromMap_ThrowsValidationException_IfInputTypeIsInvalidValue() {
+        expectThrows(
+            ValidationException.class,
+            () -> GoogleVertexAiEmbeddingsRequestTaskSettings.fromMap(
+                new HashMap<>(Map.of(GoogleVertexAiEmbeddingsTaskSettings.INPUT_TYPE, "abc"))
+            )
+        );
+    }
+
+    public void testFromMap_ThrowsValidationException_IfInputTypeIsUnspecified() {
+        expectThrows(
+            ValidationException.class,
+            () -> GoogleVertexAiEmbeddingsRequestTaskSettings.fromMap(
+                new HashMap<>(Map.of(GoogleVertexAiEmbeddingsTaskSettings.INPUT_TYPE, InputType.UNSPECIFIED.toString()))
+            )
+        );
     }
 }

+ 153 - 17
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/embeddings/GoogleVertexAiEmbeddingsTaskSettingsTests.java

@@ -8,21 +8,30 @@
 package org.elasticsearch.xpack.inference.services.googlevertexai.embeddings;
 
 import org.elasticsearch.TransportVersion;
+import org.elasticsearch.TransportVersions;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.InputType;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentFactory;
 import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
+import org.hamcrest.MatcherAssert;
 
 import java.io.IOException;
+import java.util.Arrays;
 import java.util.Collections;
+import java.util.EnumSet;
 import java.util.HashMap;
+import java.util.Locale;
 import java.util.Map;
 
+import static org.elasticsearch.xpack.inference.InputTypeTests.randomWithoutUnspecified;
 import static org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings.AUTO_TRUNCATE;
+import static org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings.INPUT_TYPE;
+import static org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings.VALID_REQUEST_VALUES;
 import static org.hamcrest.Matchers.is;
 
 public class GoogleVertexAiEmbeddingsTaskSettingsTests extends AbstractBWCWireSerializationTestCase<GoogleVertexAiEmbeddingsTaskSettings> {
@@ -39,6 +48,9 @@ public class GoogleVertexAiEmbeddingsTaskSettingsTests extends AbstractBWCWireSe
         if (newSettings.autoTruncate() != null) {
             newSettingsMap.put(GoogleVertexAiEmbeddingsTaskSettings.AUTO_TRUNCATE, newSettings.autoTruncate());
         }
+        if (newSettings.getInputType() != null) {
+            newSettingsMap.put(GoogleVertexAiEmbeddingsTaskSettings.INPUT_TYPE, newSettings.getInputType().toString());
+        }
         GoogleVertexAiEmbeddingsTaskSettings updatedSettings = (GoogleVertexAiEmbeddingsTaskSettings) initialSettings.updatedTaskSettings(
             Collections.unmodifiableMap(newSettingsMap)
         );
@@ -47,56 +59,144 @@ public class GoogleVertexAiEmbeddingsTaskSettingsTests extends AbstractBWCWireSe
         } else {
             assertEquals(newSettings.autoTruncate(), updatedSettings.autoTruncate());
         }
+        if (newSettings.getInputType() == null) {
+            assertEquals(initialSettings.getInputType(), updatedSettings.getInputType());
+        } else {
+            assertEquals(newSettings.getInputType(), updatedSettings.getInputType());
+        }
+    }
+
+    public void testFromMap_CreatesEmptySettings_WhenAllFieldsAreNull() {
+        MatcherAssert.assertThat(
+            GoogleVertexAiEmbeddingsTaskSettings.fromMap(new HashMap<>()),
+            is(new GoogleVertexAiEmbeddingsTaskSettings(null, null))
+        );
+        assertNull(GoogleVertexAiEmbeddingsTaskSettings.fromMap(new HashMap<>()).autoTruncate());
+        assertNull(GoogleVertexAiEmbeddingsTaskSettings.fromMap(new HashMap<>()).getInputType());
+    }
+
+    public void testFromMap_CreatesEmptySettings_WhenMapIsNull() {
+        MatcherAssert.assertThat(
+            GoogleVertexAiEmbeddingsTaskSettings.fromMap(null),
+            is(new GoogleVertexAiEmbeddingsTaskSettings(null, null))
+        );
+        assertNull(GoogleVertexAiEmbeddingsTaskSettings.fromMap(null).autoTruncate());
+        assertNull(GoogleVertexAiEmbeddingsTaskSettings.fromMap(null).getInputType());
     }
 
     public void testFromMap_AutoTruncateIsSet() {
         var autoTruncate = true;
-        var taskSettingsMap = getTaskSettingsMap(autoTruncate);
+        var taskSettingsMap = getTaskSettingsMap(autoTruncate, null);
         var taskSettings = GoogleVertexAiEmbeddingsTaskSettings.fromMap(taskSettingsMap);
 
-        assertThat(taskSettings, is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate)));
+        assertThat(taskSettings, is(new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate, null)));
     }
 
     public void testFromMap_ThrowsValidationException_IfAutoTruncateIsInvalidValue() {
-        var taskSettings = getTaskSettingsMap("invalid");
+        var taskSettings = getTaskSettingsMap("invalid", null);
 
         expectThrows(ValidationException.class, () -> GoogleVertexAiEmbeddingsTaskSettings.fromMap(taskSettings));
     }
 
     public void testFromMap_AutoTruncateIsNull() {
-        var taskSettingsMap = getTaskSettingsMap(null);
+        var taskSettingsMap = getTaskSettingsMap(null, null);
         var taskSettings = GoogleVertexAiEmbeddingsTaskSettings.fromMap(taskSettingsMap);
         // needed, because of constructors being ambiguous otherwise
         Boolean nullBoolean = null;
 
-        assertThat(taskSettings, is(new GoogleVertexAiEmbeddingsTaskSettings(nullBoolean)));
+        assertThat(taskSettings, is(new GoogleVertexAiEmbeddingsTaskSettings(nullBoolean, null)));
     }
 
-    public void testFromMap_DoesNotThrow_WithEmptyMap() {
-        assertNull(GoogleVertexAiEmbeddingsTaskSettings.fromMap(new HashMap<>()).autoTruncate());
+    public void testFromMap_ReturnsFailure_WhenInputTypeIsInvalid() {
+        var exception = expectThrows(
+            ValidationException.class,
+            () -> GoogleVertexAiEmbeddingsTaskSettings.fromMap(
+                new HashMap<>(Map.of(GoogleVertexAiEmbeddingsTaskSettings.INPUT_TYPE, "abc"))
+            )
+        );
+
+        assertThat(
+            exception.getMessage(),
+            is(
+                Strings.format(
+                    "Validation Failed: 1: [task_settings] Invalid value [abc] received. [input_type] must be one of [%s];",
+                    getValidValuesSortedAndCombined(VALID_REQUEST_VALUES)
+                )
+            )
+        );
+    }
+
+    public void testFromMap_ReturnsFailure_WhenInputTypeIsUnspecified() {
+        var exception = expectThrows(
+            ValidationException.class,
+            () -> GoogleVertexAiEmbeddingsTaskSettings.fromMap(
+                new HashMap<>(Map.of(GoogleVertexAiEmbeddingsTaskSettings.INPUT_TYPE, InputType.UNSPECIFIED.toString()))
+            )
+        );
+
+        assertThat(
+            exception.getMessage(),
+            is(
+                Strings.format(
+                    "Validation Failed: 1: [task_settings] Invalid value [unspecified] received. [input_type] must be one of [%s];",
+                    getValidValuesSortedAndCombined(VALID_REQUEST_VALUES)
+                )
+            )
+        );
     }
 
     public void testOf_UseRequestSettings() {
         var originalAutoTruncate = true;
-        var originalSettings = new GoogleVertexAiEmbeddingsTaskSettings(originalAutoTruncate);
+        var originalSettings = new GoogleVertexAiEmbeddingsTaskSettings(originalAutoTruncate, null);
 
         var requestAutoTruncate = originalAutoTruncate == false;
-        var requestTaskSettings = new GoogleVertexAiEmbeddingsRequestTaskSettings(requestAutoTruncate);
+        var requestTaskSettings = new GoogleVertexAiEmbeddingsRequestTaskSettings(requestAutoTruncate, null);
 
-        assertThat(GoogleVertexAiEmbeddingsTaskSettings.of(originalSettings, requestTaskSettings).autoTruncate(), is(requestAutoTruncate));
+        assertThat(
+            GoogleVertexAiEmbeddingsTaskSettings.of(originalSettings, requestTaskSettings, null).autoTruncate(),
+            is(requestAutoTruncate)
+        );
+    }
+
+    public void testOf_UseRequestSettings_AndRequestInputType() {
+        var originalAutoTruncate = true;
+        var originalSettings = new GoogleVertexAiEmbeddingsTaskSettings(originalAutoTruncate, InputType.SEARCH);
+
+        var requestAutoTruncate = originalAutoTruncate == false;
+        var requestTaskSettings = new GoogleVertexAiEmbeddingsRequestTaskSettings(requestAutoTruncate, null);
+
+        assertThat(
+            GoogleVertexAiEmbeddingsTaskSettings.of(originalSettings, requestTaskSettings, InputType.INGEST).getInputType(),
+            is(InputType.INGEST)
+        );
     }
 
     public void testOf_UseOriginalSettings() {
         var originalAutoTruncate = true;
-        var originalSettings = new GoogleVertexAiEmbeddingsTaskSettings(originalAutoTruncate);
+        var originalSettings = new GoogleVertexAiEmbeddingsTaskSettings(originalAutoTruncate, null);
 
-        var requestTaskSettings = new GoogleVertexAiEmbeddingsRequestTaskSettings(null);
+        var requestTaskSettings = new GoogleVertexAiEmbeddingsRequestTaskSettings(null, null);
 
-        assertThat(GoogleVertexAiEmbeddingsTaskSettings.of(originalSettings, requestTaskSettings).autoTruncate(), is(originalAutoTruncate));
+        assertThat(
+            GoogleVertexAiEmbeddingsTaskSettings.of(originalSettings, requestTaskSettings, null).autoTruncate(),
+            is(originalAutoTruncate)
+        );
+    }
+
+    public void testOf_UseOriginalSettings_WithInputType() {
+        var originalAutoTruncate = true;
+        var originalSettings = new GoogleVertexAiEmbeddingsTaskSettings(originalAutoTruncate, InputType.INGEST);
+
+        var requestTaskSettings = new GoogleVertexAiEmbeddingsRequestTaskSettings(null, null);
+
+        assertThat(
+            GoogleVertexAiEmbeddingsTaskSettings.of(originalSettings, requestTaskSettings, null).autoTruncate(),
+            is(originalAutoTruncate)
+        );
     }
 
     public void testToXContent_WritesAutoTruncateIfNotNull() throws IOException {
-        var settings = GoogleVertexAiEmbeddingsTaskSettings.fromMap(getTaskSettingsMap(true));
+        var settings = GoogleVertexAiEmbeddingsTaskSettings.fromMap(getTaskSettingsMap(true, null));
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
         settings.toXContent(builder, null);
@@ -107,7 +207,7 @@ public class GoogleVertexAiEmbeddingsTaskSettingsTests extends AbstractBWCWireSe
     }
 
     public void testToXContent_DoesNotWriteAutoTruncateIfNull() throws IOException {
-        var settings = GoogleVertexAiEmbeddingsTaskSettings.fromMap(getTaskSettingsMap(null));
+        var settings = GoogleVertexAiEmbeddingsTaskSettings.fromMap(getTaskSettingsMap(null, null));
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
         settings.toXContent(builder, null);
@@ -117,6 +217,25 @@ public class GoogleVertexAiEmbeddingsTaskSettingsTests extends AbstractBWCWireSe
             {}"""));
     }
 
+    public void testToXContent_WritesInputTypeIfNotNull() throws IOException {
+        var settings = GoogleVertexAiEmbeddingsTaskSettings.fromMap(getTaskSettingsMap(true, InputType.INGEST));
+
+        XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
+        settings.toXContent(builder, null);
+        String xContentResult = Strings.toString(builder);
+
+        assertThat(xContentResult, is("""
+            {"input_type":"ingest","auto_truncate":true}"""));
+    }
+
+    public void testToXContent_ThrowsAssertionFailure_WhenInputTypeIsUnspecified() {
+        var thrownException = expectThrows(
+            AssertionError.class,
+            () -> new GoogleVertexAiEmbeddingsTaskSettings(false, InputType.UNSPECIFIED)
+        );
+        assertThat(thrownException.getMessage(), is("received invalid input type value [unspecified]"));
+    }
+
     @Override
     protected Writeable.Reader<GoogleVertexAiEmbeddingsTaskSettings> instanceReader() {
         return GoogleVertexAiEmbeddingsTaskSettings::new;
@@ -137,20 +256,37 @@ public class GoogleVertexAiEmbeddingsTaskSettingsTests extends AbstractBWCWireSe
         GoogleVertexAiEmbeddingsTaskSettings instance,
         TransportVersion version
     ) {
+        if (version.before(TransportVersions.VERTEX_AI_INPUT_TYPE_ADDED)) {
+            // default to null input type if node is on a version before input type was introduced
+            return new GoogleVertexAiEmbeddingsTaskSettings(instance.autoTruncate(), null);
+        }
         return instance;
     }
 
     private static GoogleVertexAiEmbeddingsTaskSettings createRandom() {
-        return new GoogleVertexAiEmbeddingsTaskSettings(randomFrom(new Boolean[] { null, randomBoolean() }));
+        var inputType = randomBoolean() ? randomWithoutUnspecified() : null;
+        var autoTruncate = randomFrom(new Boolean[] { null, randomBoolean() });
+        return new GoogleVertexAiEmbeddingsTaskSettings(autoTruncate, inputType);
+    }
+
+    private static <E extends Enum<E>> String getValidValuesSortedAndCombined(EnumSet<E> validValues) {
+        var validValuesAsStrings = validValues.stream().map(value -> value.toString().toLowerCase(Locale.ROOT)).toArray(String[]::new);
+        Arrays.sort(validValuesAsStrings);
+
+        return String.join(", ", validValuesAsStrings);
     }
 
-    private static Map<String, Object> getTaskSettingsMap(@Nullable Object autoTruncate) {
+    public static Map<String, Object> getTaskSettingsMap(@Nullable Object autoTruncate, @Nullable InputType inputType) {
         var map = new HashMap<String, Object>();
 
         if (autoTruncate != null) {
             map.put(AUTO_TRUNCATE, autoTruncate);
         }
 
+        if (inputType != null) {
+            map.put(INPUT_TYPE, inputType.toString());
+        }
+
         return map;
     }
 }