Răsfoiți Sursa

[ML] Adding custom headers support openai text embeddings (#134960)

* Adding custom headers support openai text embeddings

* Update docs/changelog/134960.yaml

* Adding headers to the service api result

* [CI] Auto commit changes from spotless

* Addressing feedback

* Adding transport version change

* [CI] Auto commit changes from spotless

* Cleaning up helpers

* [CI] Auto commit changes from spotless

---------

Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>
Jonathan Buttner 2 săptămâni în urmă
părinte
comite
9600127dd7
20 a modificat fișierele cu 554 adăugiri și 584 ștergeri
  1. 5 0
      docs/changelog/134960.yaml
  2. 1 0
      server/src/main/resources/transport/definitions/referable/inference_api_openai_embeddings_headers.csv
  3. 1 1
      server/src/main/resources/transport/upper_bounds/9.2.csv
  4. 2 3
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java
  5. 126 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiTaskSettings.java
  6. 2 3
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java
  7. 0 57
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettings.java
  8. 17 89
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionTaskSettings.java
  9. 3 4
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModel.java
  10. 30 71
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsTaskSettings.java
  11. 6 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/request/OpenAiEmbeddingsRequest.java
  12. 48 32
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
  13. 238 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiTaskSettingsTests.java
  14. 5 5
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java
  15. 2 2
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModelTests.java
  16. 0 66
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettingsTests.java
  17. 15 138
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionTaskSettingsTests.java
  18. 6 6
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModelTests.java
  19. 16 104
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsTaskSettingsTests.java
  20. 31 3
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/request/OpenAiEmbeddingsRequestTests.java

+ 5 - 0
docs/changelog/134960.yaml

@@ -0,0 +1,5 @@
+pr: 134960
+summary: Adding custom headers support openai text embeddings
+area: Machine Learning
+type: enhancement
+issues: []

+ 1 - 0
server/src/main/resources/transport/definitions/referable/inference_api_openai_embeddings_headers.csv

@@ -0,0 +1 @@
+9169000

+ 1 - 1
server/src/main/resources/transport/upper_bounds/9.2.csv

@@ -1 +1 @@
-security_stats_endpoint,9168000
+inference_api_openai_embeddings_headers,9169000

+ 2 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java

@@ -485,9 +485,8 @@ public class OpenAiService extends SenderService {
 
                 configurationMap.put(
                     HEADERS,
-                    new SettingsConfiguration.Builder(EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION)).setDescription(
-                        "Custom headers to include in the requests to OpenAI."
-                    )
+                    new SettingsConfiguration.Builder(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION, TaskType.CHAT_COMPLETION))
+                        .setDescription("Custom headers to include in the requests to OpenAI.")
                         .setLabel("Custom Headers")
                         .setRequired(false)
                         .setSensitive(false)

+ 126 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiTaskSettings.java

@@ -0,0 +1,126 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.openai;
+
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.TaskSettings;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMapRemoveNulls;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapStringValues;
+import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.HEADERS;
+import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.USER;
+
+public abstract class OpenAiTaskSettings<T extends OpenAiTaskSettings<T>> implements TaskSettings {
+    private static final Settings EMPTY_SETTINGS = new Settings(null, null);
+
+    private final Settings taskSettings;
+
+    public OpenAiTaskSettings(Map<String, Object> map) {
+        this(fromMap(map));
+    }
+
+    public record Settings(@Nullable String user, @Nullable Map<String, String> headers) {}
+
+    public static Settings createSettings(String user, Map<String, String> stringHeaders) {
+        if (user == null && stringHeaders == null) {
+            return EMPTY_SETTINGS;
+        } else {
+            return new Settings(user, stringHeaders);
+        }
+    }
+
+    private static Settings fromMap(Map<String, Object> map) {
+        if (map.isEmpty()) {
+            return EMPTY_SETTINGS;
+        }
+
+        ValidationException validationException = new ValidationException();
+
+        String user = extractOptionalString(map, USER, ModelConfigurations.TASK_SETTINGS, validationException);
+        Map<String, Object> headers = extractOptionalMapRemoveNulls(map, HEADERS, validationException);
+        var stringHeaders = validateMapStringValues(headers, HEADERS, validationException, false, null);
+
+        if (validationException.validationErrors().isEmpty() == false) {
+            throw validationException;
+        }
+
+        return createSettings(user, stringHeaders);
+    }
+
+    public OpenAiTaskSettings(@Nullable String user, @Nullable Map<String, String> headers) {
+        this(new Settings(user, headers));
+    }
+
+    protected OpenAiTaskSettings(Settings taskSettings) {
+        this.taskSettings = Objects.requireNonNull(taskSettings);
+    }
+
+    public String user() {
+        return taskSettings.user();
+    }
+
+    public Map<String, String> headers() {
+        return taskSettings.headers();
+    }
+
+    @Override
+    public boolean isEmpty() {
+        return taskSettings.user() == null && (taskSettings.headers() == null || taskSettings.headers().isEmpty());
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+
+        if (taskSettings.user() != null) {
+            builder.field(USER, taskSettings.user());
+        }
+
+        if (taskSettings.headers() != null && taskSettings.headers().isEmpty() == false) {
+            builder.field(HEADERS, taskSettings.headers());
+        }
+
+        builder.endObject();
+
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        OpenAiTaskSettings<?> that = (OpenAiTaskSettings<?>) o;
+        return Objects.equals(taskSettings, that.taskSettings);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(taskSettings);
+    }
+
+    @Override
+    public T updatedTaskSettings(Map<String, Object> newSettings) {
+        Settings updatedSettings = fromMap(new HashMap<>(newSettings));
+
+        var userToUse = updatedSettings.user() == null ? taskSettings.user() : updatedSettings.user();
+        var headersToUse = updatedSettings.headers() == null ? taskSettings.headers() : updatedSettings.headers();
+        return create(userToUse, headersToUse);
+    }
+
+    protected abstract T create(@Nullable String user, @Nullable Map<String, String> headers);
+
+}

+ 2 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java

@@ -35,8 +35,7 @@ public class OpenAiChatCompletionModel extends OpenAiModel {
             return model;
         }
 
-        var requestTaskSettings = OpenAiChatCompletionRequestTaskSettings.fromMap(taskSettings);
-        return new OpenAiChatCompletionModel(model, OpenAiChatCompletionTaskSettings.of(model.getTaskSettings(), requestTaskSettings));
+        return new OpenAiChatCompletionModel(model, model.getTaskSettings().updatedTaskSettings(taskSettings));
     }
 
     public static OpenAiChatCompletionModel of(OpenAiChatCompletionModel model, UnifiedCompletionRequest request) {
@@ -73,7 +72,7 @@ public class OpenAiChatCompletionModel extends OpenAiModel {
             taskType,
             service,
             OpenAiChatCompletionServiceSettings.fromMap(serviceSettings, context),
-            OpenAiChatCompletionTaskSettings.fromMap(taskSettings),
+            new OpenAiChatCompletionTaskSettings(taskSettings),
             DefaultSecretSettings.fromMap(secrets)
         );
     }

+ 0 - 57
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettings.java

@@ -1,57 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.inference.services.openai.completion;
-
-import org.elasticsearch.common.ValidationException;
-import org.elasticsearch.core.Nullable;
-import org.elasticsearch.inference.ModelConfigurations;
-
-import java.util.Map;
-
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMapRemoveNulls;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapStringValues;
-import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.HEADERS;
-import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.USER;
-
-/**
- * This class handles extracting OpenAI task settings from a request. The difference between this class and
- * {@link OpenAiChatCompletionTaskSettings} is that this class considers all fields as optional. It will not throw an error if a field
- * is missing. This allows overriding persistent task settings.
- * @param user a unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse
- * @param headers additional headers to include in the request to the OpenAI API
- */
-public record OpenAiChatCompletionRequestTaskSettings(@Nullable String user, @Nullable Map<String, String> headers) {
-
-    public static final OpenAiChatCompletionRequestTaskSettings EMPTY_SETTINGS = new OpenAiChatCompletionRequestTaskSettings(null, null);
-
-    /**
-     * Extracts the task settings from a map. All settings are considered optional and the absence of a setting
-     * does not throw an error.
-     *
-     * @param map the settings received from a request
-     * @return a {@link OpenAiChatCompletionRequestTaskSettings}
-     */
-    public static OpenAiChatCompletionRequestTaskSettings fromMap(Map<String, Object> map) {
-        if (map.isEmpty()) {
-            return OpenAiChatCompletionRequestTaskSettings.EMPTY_SETTINGS;
-        }
-
-        ValidationException validationException = new ValidationException();
-
-        String user = extractOptionalString(map, USER, ModelConfigurations.TASK_SETTINGS, validationException);
-        Map<String, Object> headers = extractOptionalMapRemoveNulls(map, HEADERS, validationException);
-        var stringHeaders = validateMapStringValues(headers, HEADERS, validationException, false, null);
-
-        if (validationException.validationErrors().isEmpty() == false) {
-            throw validationException;
-        }
-
-        return new OpenAiChatCompletionRequestTaskSettings(user, stringHeaders);
-    }
-}

+ 17 - 89
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionTaskSettings.java

@@ -9,100 +9,44 @@ package org.elasticsearch.xpack.inference.services.openai.completion;
 
 import org.elasticsearch.TransportVersion;
 import org.elasticsearch.TransportVersions;
-import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.core.Nullable;
-import org.elasticsearch.inference.ModelConfigurations;
-import org.elasticsearch.inference.TaskSettings;
-import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.inference.services.openai.OpenAiTaskSettings;
 
 import java.io.IOException;
-import java.util.HashMap;
 import java.util.Map;
-import java.util.Objects;
 
 import static org.elasticsearch.TransportVersions.INFERENCE_API_OPENAI_HEADERS;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMapRemoveNulls;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapStringValues;
-import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.HEADERS;
-import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.USER;
 
-public class OpenAiChatCompletionTaskSettings implements TaskSettings {
+public class OpenAiChatCompletionTaskSettings extends OpenAiTaskSettings<OpenAiChatCompletionTaskSettings> {
 
     public static final String NAME = "openai_completion_task_settings";
 
-    public static OpenAiChatCompletionTaskSettings fromMap(Map<String, Object> map) {
-        ValidationException validationException = new ValidationException();
-
-        String user = extractOptionalString(map, USER, ModelConfigurations.TASK_SETTINGS, validationException);
-        var headers = extractOptionalMapRemoveNulls(map, HEADERS, validationException);
-        var stringHeaders = validateMapStringValues(headers, HEADERS, validationException, false, null);
-
-        if (validationException.validationErrors().isEmpty() == false) {
-            throw validationException;
-        }
-
-        return new OpenAiChatCompletionTaskSettings(user, stringHeaders);
+    public OpenAiChatCompletionTaskSettings(Map<String, Object> map) {
+        super(map);
     }
 
-    private final String user;
-    @Nullable
-    private final Map<String, String> headers;
-
     public OpenAiChatCompletionTaskSettings(@Nullable String user, @Nullable Map<String, String> headers) {
-        this.user = user;
-        this.headers = headers;
+        super(user, headers);
     }
 
     public OpenAiChatCompletionTaskSettings(StreamInput in) throws IOException {
-        this.user = in.readOptionalString();
+        super(readTaskSettingsFromStream(in));
+    }
+
+    private static Settings readTaskSettingsFromStream(StreamInput in) throws IOException {
+        var user = in.readOptionalString();
+
+        Map<String, String> headers;
 
         if (in.getTransportVersion().onOrAfter(INFERENCE_API_OPENAI_HEADERS)) {
             headers = in.readOptionalImmutableMap(StreamInput::readString, StreamInput::readString);
         } else {
             headers = null;
         }
-    }
-
-    @Override
-    public boolean isEmpty() {
-        return user == null && (headers == null || headers.isEmpty());
-    }
-
-    public static OpenAiChatCompletionTaskSettings of(
-        OpenAiChatCompletionTaskSettings originalSettings,
-        OpenAiChatCompletionRequestTaskSettings requestSettings
-    ) {
-        var userToUse = requestSettings.user() == null ? originalSettings.user : requestSettings.user();
-        var headersToUse = requestSettings.headers() == null ? originalSettings.headers : requestSettings.headers();
-        return new OpenAiChatCompletionTaskSettings(userToUse, headersToUse);
-    }
-
-    public String user() {
-        return user;
-    }
 
-    public Map<String, String> headers() {
-        return headers;
-    }
-
-    @Override
-    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-        builder.startObject();
-
-        if (user != null) {
-            builder.field(USER, user);
-        }
-
-        if (headers != null && headers.isEmpty() == false) {
-            builder.field(HEADERS, headers);
-        }
-
-        builder.endObject();
-
-        return builder;
+        return createSettings(user, headers);
     }
 
     @Override
@@ -117,30 +61,14 @@ public class OpenAiChatCompletionTaskSettings implements TaskSettings {
 
     @Override
     public void writeTo(StreamOutput out) throws IOException {
-        out.writeOptionalString(user);
+        out.writeOptionalString(user());
         if (out.getTransportVersion().onOrAfter(INFERENCE_API_OPENAI_HEADERS)) {
-            out.writeOptionalMap(headers, StreamOutput::writeString, StreamOutput::writeString);
+            out.writeOptionalMap(headers(), StreamOutput::writeString, StreamOutput::writeString);
         }
     }
 
     @Override
-    public boolean equals(Object object) {
-        if (this == object) return true;
-        if (object == null || getClass() != object.getClass()) return false;
-        OpenAiChatCompletionTaskSettings that = (OpenAiChatCompletionTaskSettings) object;
-        return Objects.equals(user, that.user) && Objects.equals(headers, that.headers);
-    }
-
-    @Override
-    public int hashCode() {
-        return Objects.hash(user, headers);
-    }
-
-    @Override
-    public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
-        OpenAiChatCompletionRequestTaskSettings updatedSettings = OpenAiChatCompletionRequestTaskSettings.fromMap(
-            new HashMap<>(newSettings)
-        );
-        return of(this, updatedSettings);
+    protected OpenAiChatCompletionTaskSettings create(@Nullable String user, @Nullable Map<String, String> headers) {
+        return new OpenAiChatCompletionTaskSettings(user, headers);
     }
 }

+ 3 - 4
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModel.java

@@ -34,8 +34,7 @@ public class OpenAiEmbeddingsModel extends OpenAiModel {
             return model;
         }
 
-        var requestTaskSettings = OpenAiEmbeddingsRequestTaskSettings.fromMap(taskSettings);
-        return new OpenAiEmbeddingsModel(model, OpenAiEmbeddingsTaskSettings.of(model.getTaskSettings(), requestTaskSettings));
+        return new OpenAiEmbeddingsModel(model, model.getTaskSettings().updatedTaskSettings(taskSettings));
     }
 
     public OpenAiEmbeddingsModel(
@@ -53,14 +52,14 @@ public class OpenAiEmbeddingsModel extends OpenAiModel {
             taskType,
             service,
             OpenAiEmbeddingsServiceSettings.fromMap(serviceSettings, context),
-            OpenAiEmbeddingsTaskSettings.fromMap(taskSettings, context),
+            new OpenAiEmbeddingsTaskSettings(taskSettings),
             chunkingSettings,
             DefaultSecretSettings.fromMap(secrets)
         );
     }
 
     // Should only be used directly for testing
-    OpenAiEmbeddingsModel(
+    public OpenAiEmbeddingsModel(
         String inferenceEntityId,
         TaskType taskType,
         String service,

+ 30 - 71
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsTaskSettings.java

@@ -9,22 +9,13 @@ package org.elasticsearch.xpack.inference.services.openai.embeddings;
 
 import org.elasticsearch.TransportVersion;
 import org.elasticsearch.TransportVersions;
-import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.core.Nullable;
-import org.elasticsearch.inference.ModelConfigurations;
-import org.elasticsearch.inference.TaskSettings;
-import org.elasticsearch.xcontent.XContentBuilder;
-import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
+import org.elasticsearch.xpack.inference.services.openai.OpenAiTaskSettings;
 
 import java.io.IOException;
-import java.util.HashMap;
 import java.util.Map;
-import java.util.Objects;
-
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString;
-import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.USER;
 
 /**
  * Defines the task settings for the openai service.
@@ -32,68 +23,46 @@ import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFie
  * User is an optional unique identifier representing the end-user, which can help OpenAI to monitor and detect abuse
  *  <a href="https://platform.openai.com/docs/api-reference/embeddings/create">see the openai docs for more details</a>
  */
-public class OpenAiEmbeddingsTaskSettings implements TaskSettings {
+public class OpenAiEmbeddingsTaskSettings extends OpenAiTaskSettings<OpenAiEmbeddingsTaskSettings> {
 
     public static final String NAME = "openai_embeddings_task_settings";
 
-    public static OpenAiEmbeddingsTaskSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
-        ValidationException validationException = new ValidationException();
-
-        String user = extractOptionalString(map, USER, ModelConfigurations.TASK_SETTINGS, validationException);
-        if (validationException.validationErrors().isEmpty() == false) {
-            throw validationException;
-        }
+    // default for testing
+    static final TransportVersion INFERENCE_API_OPENAI_EMBEDDINGS_HEADERS = TransportVersion.fromName(
+        "inference_api_openai_embeddings_headers"
+    );
 
-        return new OpenAiEmbeddingsTaskSettings(user);
+    public OpenAiEmbeddingsTaskSettings(Map<String, Object> map) {
+        super(map);
     }
 
-    /**
-     * Creates a new {@link OpenAiEmbeddingsTaskSettings} object by overriding the values in originalSettings with the ones
-     * passed in via requestSettings if the fields are not null.
-     * @param originalSettings the original task settings from the inference entity configuration from storage
-     * @param requestSettings the task settings from the request
-     * @return a new {@link OpenAiEmbeddingsTaskSettings}
-     */
-    public static OpenAiEmbeddingsTaskSettings of(
-        OpenAiEmbeddingsTaskSettings originalSettings,
-        OpenAiEmbeddingsRequestTaskSettings requestSettings
-    ) {
-        var userToUse = requestSettings.user() == null ? originalSettings.user : requestSettings.user();
-        return new OpenAiEmbeddingsTaskSettings(userToUse);
+    public OpenAiEmbeddingsTaskSettings(@Nullable String user, @Nullable Map<String, String> headers) {
+        super(user, headers);
     }
 
-    private final String user;
-
-    public OpenAiEmbeddingsTaskSettings(@Nullable String user) {
-        this.user = user;
+    public OpenAiEmbeddingsTaskSettings(StreamInput in) throws IOException {
+        super(readTaskSettingsFromStream(in));
     }
 
-    @Override
-    public boolean isEmpty() {
-        return user == null;
-    }
+    private static Settings readTaskSettingsFromStream(StreamInput in) throws IOException {
+        String user;
 
-    public OpenAiEmbeddingsTaskSettings(StreamInput in) throws IOException {
         if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_13_0)) {
-            this.user = in.readOptionalString();
+            user = in.readOptionalString();
         } else {
             var discard = in.readString();
-            this.user = in.readOptionalString();
+            user = in.readOptionalString();
         }
-    }
 
-    @Override
-    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-        builder.startObject();
-        if (user != null) {
-            builder.field(USER, user);
+        Map<String, String> headers;
+
+        if (in.getTransportVersion().supports(INFERENCE_API_OPENAI_EMBEDDINGS_HEADERS)) {
+            headers = in.readOptionalImmutableMap(StreamInput::readString, StreamInput::readString);
+        } else {
+            headers = null;
         }
-        builder.endObject();
-        return builder;
-    }
 
-    public String user() {
-        return user;
+        return createSettings(user, headers);
     }
 
     @Override
@@ -109,29 +78,19 @@ public class OpenAiEmbeddingsTaskSettings implements TaskSettings {
     @Override
     public void writeTo(StreamOutput out) throws IOException {
         if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_13_0)) {
-            out.writeOptionalString(user);
+            out.writeOptionalString(user());
         } else {
             out.writeString("m"); // write any string
-            out.writeOptionalString(user);
+            out.writeOptionalString(user());
         }
-    }
 
-    @Override
-    public boolean equals(Object o) {
-        if (this == o) return true;
-        if (o == null || getClass() != o.getClass()) return false;
-        OpenAiEmbeddingsTaskSettings that = (OpenAiEmbeddingsTaskSettings) o;
-        return Objects.equals(user, that.user);
-    }
-
-    @Override
-    public int hashCode() {
-        return Objects.hash(user);
+        if (out.getTransportVersion().supports(INFERENCE_API_OPENAI_EMBEDDINGS_HEADERS)) {
+            out.writeOptionalMap(headers(), StreamOutput::writeString, StreamOutput::writeString);
+        }
     }
 
     @Override
-    public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
-        OpenAiEmbeddingsRequestTaskSettings requestSettings = OpenAiEmbeddingsRequestTaskSettings.fromMap(new HashMap<>(newSettings));
-        return of(this, requestSettings);
+    protected OpenAiEmbeddingsTaskSettings create(@Nullable String user, @Nullable Map<String, String> headers) {
+        return new OpenAiEmbeddingsTaskSettings(user, headers);
     }
 }

+ 6 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/request/OpenAiEmbeddingsRequest.java

@@ -60,6 +60,12 @@ public class OpenAiEmbeddingsRequest implements Request {
             httpPost.setHeader(createOrgHeader(org));
         }
 
+        if (model.getTaskSettings().headers() != null) {
+            for (var header : model.getTaskSettings().headers().entrySet()) {
+                httpPost.setHeader(header.getKey(), header.getValue());
+            }
+        }
+
         return new HttpRequest(httpPost, getInferenceEntityId());
     }
 

+ 48 - 32
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java

@@ -87,7 +87,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.
 import static org.elasticsearch.xpack.inference.services.openai.OpenAiUtils.ORGANIZATION_HEADER;
 import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createCompletionModel;
 import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettingsTests.getServiceSettingsMap;
-import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettingsTests.getTaskSettingsMap;
+import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettingsTests.getOpenAiTaskSettingsMap;
 import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
 import static org.hamcrest.CoreMatchers.is;
 import static org.hamcrest.Matchers.containsString;
@@ -140,7 +140,7 @@ public class OpenAiServiceTests extends InferenceServiceTestCase {
                 TaskType.TEXT_EMBEDDING,
                 getRequestConfigMap(
                     getServiceSettingsMap("model", "url", "org"),
-                    getTaskSettingsMap("user"),
+                    getOpenAiTaskSettingsMap("user"),
                     getSecretSettingsMap("secret")
                 ),
                 modelVerificationListener
@@ -174,7 +174,7 @@ public class OpenAiServiceTests extends InferenceServiceTestCase {
                 TaskType.COMPLETION,
                 getRequestConfigMap(
                     getServiceSettingsMap(model, url, organization),
-                    getTaskSettingsMap(user),
+                    getOpenAiTaskSettingsMap(user),
                     getSecretSettingsMap(secret)
                 ),
                 modelVerificationListener
@@ -197,7 +197,7 @@ public class OpenAiServiceTests extends InferenceServiceTestCase {
                 TaskType.SPARSE_EMBEDDING,
                 getRequestConfigMap(
                     getServiceSettingsMap("model", "url", "org"),
-                    getTaskSettingsMap("user"),
+                    getOpenAiTaskSettingsMap("user"),
                     getSecretSettingsMap("secret")
                 ),
                 modelVerificationListener
@@ -209,7 +209,7 @@ public class OpenAiServiceTests extends InferenceServiceTestCase {
         try (var service = createOpenAiService()) {
             var config = getRequestConfigMap(
                 getServiceSettingsMap("model", "url", "org"),
-                getTaskSettingsMap("user"),
+                getOpenAiTaskSettingsMap("user"),
                 getSecretSettingsMap("secret")
             );
             config.put("extra_key", "value");
@@ -234,7 +234,7 @@ public class OpenAiServiceTests extends InferenceServiceTestCase {
             var serviceSettings = getServiceSettingsMap("model", "url", "org");
             serviceSettings.put("extra_key", "value");
 
-            var config = getRequestConfigMap(serviceSettings, getTaskSettingsMap("user"), getSecretSettingsMap("secret"));
+            var config = getRequestConfigMap(serviceSettings, getOpenAiTaskSettingsMap("user"), getSecretSettingsMap("secret"));
 
             ActionListener<Model> modelVerificationListener = ActionListener.<Model>wrap((model) -> {
                 fail("Expected exception, but got model: " + model);
@@ -249,7 +249,7 @@ public class OpenAiServiceTests extends InferenceServiceTestCase {
 
     public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException {
         try (var service = createOpenAiService()) {
-            var taskSettingsMap = getTaskSettingsMap("user");
+            var taskSettingsMap = getOpenAiTaskSettingsMap("user");
             taskSettingsMap.put("extra_key", "value");
 
             var config = getRequestConfigMap(getServiceSettingsMap("model", "url", "org"), taskSettingsMap, getSecretSettingsMap("secret"));
@@ -270,7 +270,11 @@ public class OpenAiServiceTests extends InferenceServiceTestCase {
             var secretSettingsMap = getSecretSettingsMap("secret");
             secretSettingsMap.put("extra_key", "value");
 
-            var config = getRequestConfigMap(getServiceSettingsMap("model", "url", "org"), getTaskSettingsMap("user"), secretSettingsMap);
+            var config = getRequestConfigMap(
+                getServiceSettingsMap("model", "url", "org"),
+                getOpenAiTaskSettingsMap("user"),
+                secretSettingsMap
+            );
 
             ActionListener<Model> modelVerificationListener = ActionListener.<Model>wrap((model) -> {
                 fail("Expected exception, but got model: " + model);
@@ -299,7 +303,11 @@ public class OpenAiServiceTests extends InferenceServiceTestCase {
             service.parseRequestConfig(
                 "id",
                 TaskType.TEXT_EMBEDDING,
-                getRequestConfigMap(getServiceSettingsMap("model", null, null), getTaskSettingsMap(null), getSecretSettingsMap("secret")),
+                getRequestConfigMap(
+                    getServiceSettingsMap("model", null, null),
+                    getOpenAiTaskSettingsMap(null),
+                    getSecretSettingsMap("secret")
+                ),
                 modelVerificationListener
             );
         }
@@ -325,7 +333,7 @@ public class OpenAiServiceTests extends InferenceServiceTestCase {
             service.parseRequestConfig(
                 "id",
                 TaskType.COMPLETION,
-                getRequestConfigMap(getServiceSettingsMap(model, null, null), getTaskSettingsMap(null), getSecretSettingsMap(secret)),
+                getRequestConfigMap(getServiceSettingsMap(model, null, null), getOpenAiTaskSettingsMap(null), getSecretSettingsMap(secret)),
                 modelVerificationListener
             );
         }
@@ -349,7 +357,7 @@ public class OpenAiServiceTests extends InferenceServiceTestCase {
                 TaskType.TEXT_EMBEDDING,
                 getRequestConfigMap(
                     getServiceSettingsMap("model", "url", "org"),
-                    getTaskSettingsMap("user"),
+                    getOpenAiTaskSettingsMap("user"),
                     getSecretSettingsMap("secret")
                 ),
                 modelVerificationListener
@@ -376,7 +384,7 @@ public class OpenAiServiceTests extends InferenceServiceTestCase {
                 TaskType.TEXT_EMBEDDING,
                 getRequestConfigMap(
                     getServiceSettingsMap("model", null, null),
-                    getTaskSettingsMap(null),
+                    getOpenAiTaskSettingsMap(null),
                     createRandomChunkingSettingsMap(),
                     getSecretSettingsMap("secret")
                 ),
@@ -402,7 +410,11 @@ public class OpenAiServiceTests extends InferenceServiceTestCase {
             service.parseRequestConfig(
                 "id",
                 TaskType.TEXT_EMBEDDING,
-                getRequestConfigMap(getServiceSettingsMap("model", null, null), getTaskSettingsMap(null), getSecretSettingsMap("secret")),
+                getRequestConfigMap(
+                    getServiceSettingsMap("model", null, null),
+                    getOpenAiTaskSettingsMap(null),
+                    getSecretSettingsMap("secret")
+                ),
                 modelVerificationListener
             );
         }
@@ -412,7 +424,7 @@ public class OpenAiServiceTests extends InferenceServiceTestCase {
         try (var service = createOpenAiService()) {
             var persistedConfig = getPersistedConfigMap(
                 getServiceSettingsMap("model", "url", "org", 100, null, false),
-                getTaskSettingsMap("user"),
+                getOpenAiTaskSettingsMap("user"),
                 getSecretSettingsMap("secret")
             );
 
@@ -438,7 +450,7 @@ public class OpenAiServiceTests extends InferenceServiceTestCase {
         try (var service = createOpenAiService()) {
             var persistedConfig = getPersistedConfigMap(
                 getServiceSettingsMap("model", "url", "org"),
-                getTaskSettingsMap("user"),
+                getOpenAiTaskSettingsMap("user"),
                 getSecretSettingsMap("secret")
             );
 
@@ -463,7 +475,7 @@ public class OpenAiServiceTests extends InferenceServiceTestCase {
         try (var service = createOpenAiService()) {
             var persistedConfig = getPersistedConfigMap(
                 getServiceSettingsMap("model", null, null, null, null, true),
-                getTaskSettingsMap(null),
+                getOpenAiTaskSettingsMap(null),
                 getSecretSettingsMap("secret")
             );
 
@@ -489,7 +501,7 @@ public class OpenAiServiceTests extends InferenceServiceTestCase {
         try (var service = createOpenAiService()) {
             var persistedConfig = getPersistedConfigMap(
                 getServiceSettingsMap("model", null, null, null, null, true),
-                getTaskSettingsMap(null),
+                getOpenAiTaskSettingsMap(null),
                 createRandomChunkingSettingsMap(),
                 getSecretSettingsMap("secret")
             );
@@ -517,7 +529,7 @@ public class OpenAiServiceTests extends InferenceServiceTestCase {
         try (var service = createOpenAiService()) {
             var persistedConfig = getPersistedConfigMap(
                 getServiceSettingsMap("model", null, null, null, null, true),
-                getTaskSettingsMap(null),
+                getOpenAiTaskSettingsMap(null),
                 getSecretSettingsMap("secret")
             );
 
@@ -544,7 +556,7 @@ public class OpenAiServiceTests extends InferenceServiceTestCase {
         try (var service = createOpenAiService()) {
             var persistedConfig = getPersistedConfigMap(
                 getServiceSettingsMap("model", "url", "org", null, null, true),
-                getTaskSettingsMap("user"),
+                getOpenAiTaskSettingsMap("user"),
                 getSecretSettingsMap("secret")
             );
             persistedConfig.config().put("extra_key", "value");
@@ -575,7 +587,7 @@ public class OpenAiServiceTests extends InferenceServiceTestCase {
 
             var persistedConfig = getPersistedConfigMap(
                 getServiceSettingsMap("model", "url", "org", null, null, true),
-                getTaskSettingsMap("user"),
+                getOpenAiTaskSettingsMap("user"),
                 secretSettingsMap
             );
 
@@ -601,7 +613,7 @@ public class OpenAiServiceTests extends InferenceServiceTestCase {
         try (var service = createOpenAiService()) {
             var persistedConfig = getPersistedConfigMap(
                 getServiceSettingsMap("model", "url", "org", null, null, true),
-                getTaskSettingsMap("user"),
+                getOpenAiTaskSettingsMap("user"),
                 getSecretSettingsMap("secret")
             );
             persistedConfig.secrets().put("extra_key", "value");
@@ -630,7 +642,11 @@ public class OpenAiServiceTests extends InferenceServiceTestCase {
             var serviceSettingsMap = getServiceSettingsMap("model", "url", "org", null, null, true);
             serviceSettingsMap.put("extra_key", "value");
 
-            var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getTaskSettingsMap("user"), getSecretSettingsMap("secret"));
+            var persistedConfig = getPersistedConfigMap(
+                serviceSettingsMap,
+                getOpenAiTaskSettingsMap("user"),
+                getSecretSettingsMap("secret")
+            );
 
             var model = service.parsePersistedConfigWithSecrets(
                 "id",
@@ -652,7 +668,7 @@ public class OpenAiServiceTests extends InferenceServiceTestCase {
 
     public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException {
         try (var service = createOpenAiService()) {
-            var taskSettingsMap = getTaskSettingsMap("user");
+            var taskSettingsMap = getOpenAiTaskSettingsMap("user");
             taskSettingsMap.put("extra_key", "value");
 
             var persistedConfig = getPersistedConfigMap(
@@ -683,7 +699,7 @@ public class OpenAiServiceTests extends InferenceServiceTestCase {
         try (var service = createOpenAiService()) {
             var persistedConfig = getPersistedConfigMap(
                 getServiceSettingsMap("model", "url", "org", null, null, true),
-                getTaskSettingsMap("user")
+                getOpenAiTaskSettingsMap("user")
             );
 
             var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
@@ -701,7 +717,7 @@ public class OpenAiServiceTests extends InferenceServiceTestCase {
 
     public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() throws IOException {
         try (var service = createOpenAiService()) {
-            var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("model", "url", "org"), getTaskSettingsMap("user"));
+            var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("model", "url", "org"), getOpenAiTaskSettingsMap("user"));
 
             var thrownException = expectThrows(
                 ElasticsearchStatusException.class,
@@ -719,7 +735,7 @@ public class OpenAiServiceTests extends InferenceServiceTestCase {
         try (var service = createOpenAiService()) {
             var persistedConfig = getPersistedConfigMap(
                 getServiceSettingsMap("model", null, null, null, null, true),
-                getTaskSettingsMap(null)
+                getOpenAiTaskSettingsMap(null)
             );
 
             var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
@@ -739,7 +755,7 @@ public class OpenAiServiceTests extends InferenceServiceTestCase {
         try (var service = createOpenAiService()) {
             var persistedConfig = getPersistedConfigMap(
                 getServiceSettingsMap("model", null, null, null, null, true),
-                getTaskSettingsMap(null),
+                getOpenAiTaskSettingsMap(null),
                 createRandomChunkingSettingsMap()
             );
 
@@ -761,7 +777,7 @@ public class OpenAiServiceTests extends InferenceServiceTestCase {
         try (var service = createOpenAiService()) {
             var persistedConfig = getPersistedConfigMap(
                 getServiceSettingsMap("model", null, null, null, null, true),
-                getTaskSettingsMap(null)
+                getOpenAiTaskSettingsMap(null)
             );
 
             var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
@@ -782,7 +798,7 @@ public class OpenAiServiceTests extends InferenceServiceTestCase {
         try (var service = createOpenAiService()) {
             var persistedConfig = getPersistedConfigMap(
                 getServiceSettingsMap("model", "url", "org", null, null, true),
-                getTaskSettingsMap("user")
+                getOpenAiTaskSettingsMap("user")
             );
             persistedConfig.config().put("extra_key", "value");
 
@@ -804,7 +820,7 @@ public class OpenAiServiceTests extends InferenceServiceTestCase {
             var serviceSettingsMap = getServiceSettingsMap("model", "url", "org", null, null, true);
             serviceSettingsMap.put("extra_key", "value");
 
-            var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getTaskSettingsMap("user"));
+            var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getOpenAiTaskSettingsMap("user"));
 
             var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
 
@@ -821,7 +837,7 @@ public class OpenAiServiceTests extends InferenceServiceTestCase {
 
     public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException {
         try (var service = createOpenAiService()) {
-            var taskSettingsMap = getTaskSettingsMap("user");
+            var taskSettingsMap = getOpenAiTaskSettingsMap("user");
             taskSettingsMap.put("extra_key", "value");
 
             var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("model", "url", "org", null, null, true), taskSettingsMap);
@@ -1644,7 +1660,7 @@ public class OpenAiServiceTests extends InferenceServiceTestCase {
                                     "sensitive": false,
                                     "updatable": true,
                                     "type": "map",
-                                    "supported_task_types": ["completion", "chat_completion"]
+                                    "supported_task_types": ["text_embedding", "completion", "chat_completion"]
                                 }
                             }
                         }

+ 238 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiTaskSettingsTests.java

@@ -0,0 +1,238 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.openai;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.hamcrest.Matchers.anEmptyMap;
+import static org.hamcrest.Matchers.is;
+
+public abstract class OpenAiTaskSettingsTests<T extends OpenAiTaskSettings<T>> extends AbstractBWCWireSerializationTestCase<T> {
+
+    private enum HeadersDefinition {
+        NULL(null),
+        EMPTY(Map.of()),
+        DEFINED(Map.of(randomAlphaOfLength(15), randomAlphaOfLength(15)));
+
+        private final Map<String, String> headers;
+
+        HeadersDefinition(@Nullable Map<String, String> headers) {
+            this.headers = headers;
+        }
+    }
+
+    public T createRandom() {
+        var user = randomBoolean() ? null : randomAlphaOfLength(15);
+        var headers = randomFrom(HeadersDefinition.values()).headers;
+
+        return create(user, headers);
+    }
+
+    public void testIsEmpty() {
+        var bothNull = create(null, null);
+        assertTrue(bothNull.isEmpty());
+
+        var nullUserEmptyHeaders = create(null, Map.of());
+        assertTrue(nullUserEmptyHeaders.isEmpty());
+
+        var nullHeaders = create("user", null);
+        assertFalse(nullHeaders.isEmpty());
+
+        var nullUser = create(null, Map.of("K", "v"));
+        assertFalse(nullUser.isEmpty());
+
+        var neitherNull = create("user", Map.of("K", "v"));
+        assertFalse(neitherNull.isEmpty());
+    }
+
+    public void testUpdatedTaskSettings() {
+        var initialSettings = createRandom();
+        var newSettings = createRandom();
+
+        Map<String, Object> newSettingsMap = new HashMap<>();
+        if (newSettings.user() != null) {
+            newSettingsMap.put(OpenAiServiceFields.USER, newSettings.user());
+        }
+
+        if (newSettings.headers() != null) {
+            newSettingsMap.put(OpenAiServiceFields.HEADERS, newSettings.headers());
+        }
+
+        var updatedSettings = initialSettings.updatedTaskSettings(Collections.unmodifiableMap(newSettingsMap));
+
+        if (newSettings.user() == null) {
+            assertEquals(initialSettings.user(), updatedSettings.user());
+        } else {
+            assertEquals(newSettings.user(), updatedSettings.user());
+        }
+
+        if (newSettings.headers() == null) {
+            assertEquals(initialSettings.headers(), updatedSettings.headers());
+        } else {
+            assertEquals(newSettings.headers(), updatedSettings.headers());
+        }
+    }
+
+    public void testUpdatedTaskSettings_ApplyingEmptyHeaders() {
+        var user = "user";
+        var initialSettingsNullHeaders = create(user, null);
+        Map<String, Object> newSettingsMap = Map.of(OpenAiServiceFields.HEADERS, Map.of());
+
+        var updatedSettings = initialSettingsNullHeaders.updatedTaskSettings(newSettingsMap);
+        assertThat(updatedSettings, is(create(user, Map.of())));
+
+        var initialSettingsDefinedHeaders = create(user, Map.of("key", "value"));
+        updatedSettings = initialSettingsDefinedHeaders.updatedTaskSettings(newSettingsMap);
+        assertThat(updatedSettings, is(create(user, Map.of())));
+    }
+
+    public void testUpdatedTaskSettings_KeepsOriginalValuesWithOverridesAreNull() {
+        var taskSettings = createFromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user")));
+
+        assertThat(taskSettings.updatedTaskSettings(Map.of()), is(taskSettings));
+    }
+
+    public void testUpdatedTaskSettings_UsesOverriddenSettings() {
+        var taskSettings = createFromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user")));
+
+        assertThat(taskSettings.updatedTaskSettings(Map.of(OpenAiServiceFields.USER, "user2")), is(create("user2", null)));
+    }
+
+    public void testUpdatedTaskSettings_UsesOverriddenSettings_ForHeaders() {
+        var user = "user";
+        var taskSettings = createFromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, user)));
+
+        var headers = Map.of("key", "value");
+        assertThat(taskSettings.updatedTaskSettings(Map.of(OpenAiServiceFields.HEADERS, headers)), is(create(user, headers)));
+    }
+
+    public void testFromMap_WithUserAndHeaders() {
+        assertThat(
+            createFromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user", OpenAiServiceFields.HEADERS, Map.of("key", "value")))),
+            is(create("user", Map.of("key", "value")))
+        );
+    }
+
+    public void testFromMap_UserIsEmptyString() {
+        var thrownException = expectThrows(
+            ValidationException.class,
+            () -> createFromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "")))
+        );
+
+        assertThat(
+            thrownException.getMessage(),
+            is(Strings.format("Validation Failed: 1: [task_settings] Invalid value empty string. [user] must be a non-empty string;"))
+        );
+    }
+
+    public void testFromMap_MissingUser_DoesNotThrowException() {
+        var taskSettings = createFromMap(new HashMap<>(Map.of()));
+        assertNull(taskSettings.user());
+    }
+
+    public void testFromMap_ReturnsEmptySettings_WhenTheMapDoesNotContainTheFields() {
+        var settings = createFromMap(new HashMap<>(Map.of("key", "value")));
+        assertNull(settings.user());
+        assertNull(settings.headers());
+    }
+
+    public void testFromMap_ParsesCorrectly_WhenUserIsNull() {
+        var settings = createFromMap(new HashMap<>(Map.of(OpenAiServiceFields.HEADERS, new HashMap<>(Map.of("key", "value")))));
+
+        assertNull(settings.user());
+        assertThat(settings.headers(), is(Map.of("key", "value")));
+    }
+
+    public void testFromMap_ParsesCorrectly_WhenHeadersIsNull() {
+        var settings = createFromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user")));
+
+        assertThat(settings.user(), is("user"));
+        assertNull(settings.headers());
+    }
+
+    public void testFromMap_ParsesCorrectly_WhenHeadersIsEmptyMap() {
+        var settings = createFromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user", OpenAiServiceFields.HEADERS, Map.of())));
+
+        assertThat(settings.user(), is("user"));
+        assertThat(settings.headers(), anEmptyMap());
+    }
+
+    public void testFromMap_ParsesCorrectly_WhenHeadersMapOfNulls() {
+        var headersMap = new HashMap<String, Object>();
+        headersMap.put("key1", null);
+        headersMap.put("key2", null);
+        var settings = createFromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user", OpenAiServiceFields.HEADERS, headersMap)));
+
+        assertThat(settings.user(), is("user"));
+        assertThat(settings.headers(), anEmptyMap());
+    }
+
+    public void testFromMap_ParsesCorrectly_WhenHeadersContainsAnInteger() {
+        var exception = expectThrows(
+            ValidationException.class,
+            () -> createFromMap(
+                new HashMap<>(Map.of(OpenAiServiceFields.USER, "user", OpenAiServiceFields.HEADERS, new HashMap<>(Map.of("key", 1))))
+            )
+        );
+
+        assertThat(
+            exception.getMessage(),
+            is(
+                "Validation Failed: 1: Map field [headers] has an entry that is not valid, "
+                    + "[key => 1]. Value type of [1] is not one of [String].;"
+            )
+        );
+    }
+
+    @Override
+    protected T mutateInstance(T instance) throws IOException {
+        var setNull = randomBoolean();
+        var fieldToMutate = randomIntBetween(0, 1);
+
+        return switch (fieldToMutate) {
+            case 0 -> create(
+                instance.user() == null ? randomAlphaOfLength(15) : (setNull ? null : instance.user() + "modified"),
+                instance.headers()
+            );
+            case 1 -> {
+                if (instance.headers() == null) {
+                    yield create(instance.user(), Map.of(randomAlphaOfLength(15), randomAlphaOfLength(15)));
+                } else if (setNull) {
+                    yield create(instance.user(), null);
+                } else {
+                    var instanceHeaders = new HashMap<>(instance.headers() == null ? Map.of() : instance.headers());
+                    instanceHeaders.put(randomAlphaOfLength(15), randomAlphaOfLength(15));
+                    yield create(instance.user(), instanceHeaders);
+                }
+            }
+            default -> throw new IllegalStateException("Unexpected value: " + fieldToMutate);
+        };
+    }
+
+    protected abstract T create(@Nullable String user, @Nullable Map<String, String> headers);
+
+    protected abstract T createFromMap(Map<String, Object> map);
+
+    public static Map<String, Object> getOpenAiTaskSettingsMap(@Nullable String user) {
+        var map = new HashMap<String, Object>();
+
+        if (user != null) {
+            map.put(OpenAiServiceFields.USER, user);
+        }
+
+        return map;
+    }
+}

+ 5 - 5
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java

@@ -44,9 +44,9 @@ import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
 import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields;
 import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
 import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
+import static org.elasticsearch.xpack.inference.services.openai.OpenAiTaskSettingsTests.getOpenAiTaskSettingsMap;
 import static org.elasticsearch.xpack.inference.services.openai.OpenAiUtils.ORGANIZATION_HEADER;
 import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createCompletionModel;
-import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionRequestTaskSettingsTests.getChatCompletionRequestTaskSettingsMap;
 import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModelTests.createModel;
 import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsRequestTaskSettingsTests.createRequestTaskSettingsMap;
 import static org.hamcrest.Matchers.equalTo;
@@ -348,7 +348,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
 
             var model = createCompletionModel(getUrl(webServer), "org", "secret", "model", "user");
             var actionCreator = new OpenAiActionCreator(sender, createWithEmptySettings(threadPool));
-            var overriddenTaskSettings = getChatCompletionRequestTaskSettingsMap("overridden_user");
+            var overriddenTaskSettings = getOpenAiTaskSettingsMap("overridden_user");
             var action = actionCreator.create(model, overriddenTaskSettings);
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
@@ -412,7 +412,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
 
             var model = createCompletionModel(getUrl(webServer), "org", "secret", "model", null);
             var actionCreator = new OpenAiActionCreator(sender, createWithEmptySettings(threadPool));
-            var overriddenTaskSettings = getChatCompletionRequestTaskSettingsMap(null);
+            var overriddenTaskSettings = getOpenAiTaskSettingsMap(null);
             var action = actionCreator.create(model, overriddenTaskSettings);
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
@@ -475,7 +475,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
 
             var model = createCompletionModel(getUrl(webServer), null, "secret", "model", null);
             var actionCreator = new OpenAiActionCreator(sender, createWithEmptySettings(threadPool));
-            var overriddenTaskSettings = getChatCompletionRequestTaskSettingsMap("overridden_user");
+            var overriddenTaskSettings = getOpenAiTaskSettingsMap("overridden_user");
             var action = actionCreator.create(model, overriddenTaskSettings);
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
@@ -544,7 +544,7 @@ public class OpenAiActionCreatorTests extends ESTestCase {
 
             var model = createCompletionModel(getUrl(webServer), null, "secret", "model", null);
             var actionCreator = new OpenAiActionCreator(sender, createWithEmptySettings(threadPool));
-            var overriddenTaskSettings = getChatCompletionRequestTaskSettingsMap("overridden_user");
+            var overriddenTaskSettings = getOpenAiTaskSettingsMap("overridden_user");
             var action = actionCreator.create(model, overriddenTaskSettings);
 
             PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();

+ 2 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModelTests.java

@@ -17,7 +17,7 @@ import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings
 import java.util.List;
 import java.util.Map;
 
-import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionRequestTaskSettingsTests.getChatCompletionRequestTaskSettingsMap;
+import static org.elasticsearch.xpack.inference.services.openai.OpenAiTaskSettingsTests.getOpenAiTaskSettingsMap;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.sameInstance;
 
@@ -25,7 +25,7 @@ public class OpenAiChatCompletionModelTests extends ESTestCase {
 
     public void testOverrideWith_OverridesUser() {
         var model = createCompletionModel("url", "org", "api_key", "model_name", null);
-        var requestTaskSettingsMap = getChatCompletionRequestTaskSettingsMap("user_override");
+        var requestTaskSettingsMap = getOpenAiTaskSettingsMap("user_override");
 
         var overriddenModel = OpenAiChatCompletionModel.of(model, requestTaskSettingsMap);
 

+ 0 - 66
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettingsTests.java

@@ -1,66 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.inference.services.openai.completion;
-
-import org.elasticsearch.core.Nullable;
-import org.elasticsearch.test.ESTestCase;
-import org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields;
-
-import java.util.HashMap;
-import java.util.Map;
-
-import static org.hamcrest.Matchers.is;
-
-public class OpenAiChatCompletionRequestTaskSettingsTests extends ESTestCase {
-
-    public void testFromMap_ReturnsEmptySettings_WhenTheMapIsEmpty() {
-        var settings = OpenAiChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of()));
-        assertNull(settings.user());
-    }
-
-    public void testFromMap_ReturnsEmptySettings_WhenTheMapDoesNotContainTheFields() {
-        var settings = OpenAiChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of("key", "value")));
-        assertNull(settings.user());
-    }
-
-    public void testFromMap_ParsesCorrectly() {
-        var settings = OpenAiChatCompletionRequestTaskSettings.fromMap(
-            new HashMap<>(Map.of(OpenAiServiceFields.USER, "user", OpenAiServiceFields.HEADERS, new HashMap<>(Map.of("key", "value"))))
-        );
-
-        assertThat(settings.user(), is("user"));
-        assertThat(settings.headers(), is(Map.of("key", "value")));
-    }
-
-    public void testFromMap_ParsesCorrectly_WhenUserIsNull() {
-        var settings = OpenAiChatCompletionRequestTaskSettings.fromMap(
-            new HashMap<>(Map.of(OpenAiServiceFields.HEADERS, new HashMap<>(Map.of("key", "value"))))
-        );
-
-        assertNull(settings.user());
-        assertThat(settings.headers(), is(Map.of("key", "value")));
-    }
-
-    public void testFromMap_ParsesCorrectly_WhenHeadersIsNull() {
-        var settings = OpenAiChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user")));
-
-        assertThat(settings.user(), is("user"));
-        assertNull(settings.headers());
-    }
-
-    public static Map<String, Object> getChatCompletionRequestTaskSettingsMap(@Nullable String user) {
-        var map = new HashMap<String, Object>();
-
-        if (user != null) {
-            map.put(OpenAiServiceFields.USER, user);
-        }
-
-        return map;
-    }
-
-}

+ 15 - 138
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionTaskSettingsTests.java

@@ -8,124 +8,15 @@
 package org.elasticsearch.xpack.inference.services.openai.completion;
 
 import org.elasticsearch.TransportVersion;
-import org.elasticsearch.common.Strings;
-import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.io.stream.Writeable;
-import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
-import org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields;
+import org.elasticsearch.core.Nullable;
+import org.elasticsearch.xpack.inference.services.openai.OpenAiTaskSettingsTests;
 
-import java.io.IOException;
-import java.util.Collections;
-import java.util.HashMap;
 import java.util.Map;
 
 import static org.elasticsearch.TransportVersions.INFERENCE_API_OPENAI_HEADERS;
-import static org.hamcrest.Matchers.is;
 
-public class OpenAiChatCompletionTaskSettingsTests extends AbstractBWCWireSerializationTestCase<OpenAiChatCompletionTaskSettings> {
-
-    public static OpenAiChatCompletionTaskSettings createRandomWithUser() {
-        return new OpenAiChatCompletionTaskSettings(
-            randomBoolean() ? null : randomAlphaOfLength(15),
-            randomBoolean() ? null : Map.of(randomAlphaOfLength(15), randomAlphaOfLength(15))
-        );
-    }
-
-    public void testIsEmpty() {
-        var randomSettings = new OpenAiChatCompletionTaskSettings(
-            randomBoolean() ? null : "username",
-            randomBoolean() ? null : Map.of("key", "value")
-        );
-        var stringRep = Strings.toString(randomSettings);
-
-        assertEquals(stringRep, randomSettings.isEmpty(), stringRep.equals("{}"));
-    }
-
-    public void testUpdatedTaskSettings() {
-        var initialSettings = createRandomWithUser();
-        var newSettings = createRandomWithUser();
-
-        Map<String, Object> newSettingsMap = new HashMap<>();
-        if (newSettings.user() != null) {
-            newSettingsMap.put(OpenAiServiceFields.USER, newSettings.user());
-        }
-
-        if (newSettings.headers() != null && newSettings.headers().isEmpty() == false) {
-            newSettingsMap.put(OpenAiServiceFields.HEADERS, newSettings.headers());
-        }
-
-        OpenAiChatCompletionTaskSettings updatedSettings = (OpenAiChatCompletionTaskSettings) initialSettings.updatedTaskSettings(
-            Collections.unmodifiableMap(newSettingsMap)
-        );
-
-        if (newSettings.user() == null) {
-            assertEquals(initialSettings.user(), updatedSettings.user());
-        } else {
-            assertEquals(newSettings.user(), updatedSettings.user());
-        }
-
-        if (newSettings.headers() == null) {
-            assertEquals(initialSettings.headers(), updatedSettings.headers());
-        } else {
-            assertEquals(newSettings.headers(), updatedSettings.headers());
-        }
-    }
-
-    public void testFromMap_WithUser() {
-        assertEquals(
-            new OpenAiChatCompletionTaskSettings("user", null),
-            OpenAiChatCompletionTaskSettings.fromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user")))
-        );
-    }
-
-    public void testFromMap_UserIsEmptyString() {
-        var thrownException = expectThrows(
-            ValidationException.class,
-            () -> OpenAiChatCompletionTaskSettings.fromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "")))
-        );
-
-        assertThat(
-            thrownException.getMessage(),
-            is(Strings.format("Validation Failed: 1: [task_settings] Invalid value empty string. [user] must be a non-empty string;"))
-        );
-    }
-
-    public void testFromMap_MissingUser_DoesNotThrowException() {
-        var taskSettings = OpenAiChatCompletionTaskSettings.fromMap(new HashMap<>(Map.of()));
-        assertNull(taskSettings.user());
-    }
-
-    public void testOf_KeepsOriginalValuesWithOverridesAreNull() {
-        var taskSettings = OpenAiChatCompletionTaskSettings.fromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user")));
-
-        var overriddenTaskSettings = OpenAiChatCompletionTaskSettings.of(
-            taskSettings,
-            OpenAiChatCompletionRequestTaskSettings.EMPTY_SETTINGS
-        );
-        assertThat(overriddenTaskSettings, is(taskSettings));
-    }
-
-    public void testOf_UsesOverriddenSettings() {
-        var taskSettings = OpenAiChatCompletionTaskSettings.fromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user")));
-
-        var requestTaskSettings = OpenAiChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user2")));
-
-        var overriddenTaskSettings = OpenAiChatCompletionTaskSettings.of(taskSettings, requestTaskSettings);
-        assertThat(overriddenTaskSettings, is(new OpenAiChatCompletionTaskSettings("user2", null)));
-    }
-
-    public void testOf_UsesOverriddenSettings_ForHeaders() {
-        var user = "user";
-        var taskSettings = OpenAiChatCompletionTaskSettings.fromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, user)));
-
-        var headers = Map.of("key", "value");
-        var requestTaskSettings = OpenAiChatCompletionRequestTaskSettings.fromMap(
-            new HashMap<>(Map.of(OpenAiServiceFields.HEADERS, headers))
-        );
-
-        var overriddenTaskSettings = OpenAiChatCompletionTaskSettings.of(taskSettings, requestTaskSettings);
-        assertThat(overriddenTaskSettings, is(new OpenAiChatCompletionTaskSettings(user, headers)));
-    }
+public class OpenAiChatCompletionTaskSettingsTests extends OpenAiTaskSettingsTests<OpenAiChatCompletionTaskSettings> {
 
     @Override
     protected Writeable.Reader<OpenAiChatCompletionTaskSettings> instanceReader() {
@@ -134,31 +25,7 @@ public class OpenAiChatCompletionTaskSettingsTests extends AbstractBWCWireSerial
 
     @Override
     protected OpenAiChatCompletionTaskSettings createTestInstance() {
-        return createRandomWithUser();
-    }
-
-    @Override
-    protected OpenAiChatCompletionTaskSettings mutateInstance(OpenAiChatCompletionTaskSettings instance) throws IOException {
-        var setNull = randomBoolean();
-        var fieldToMutate = randomIntBetween(0, 1);
-        return switch (fieldToMutate) {
-            case 0 -> new OpenAiChatCompletionTaskSettings(
-                instance.user() == null ? randomAlphaOfLength(15) : (setNull ? null : instance.user() + "modified"),
-                instance.headers()
-            );
-            case 1 -> {
-                if (instance.headers() == null) {
-                    yield new OpenAiChatCompletionTaskSettings(instance.user(), Map.of(randomAlphaOfLength(15), randomAlphaOfLength(15)));
-                } else if (setNull) {
-                    yield new OpenAiChatCompletionTaskSettings(instance.user(), null);
-                } else {
-                    var instanceHeaders = new HashMap<>(instance.headers() == null ? Map.of() : instance.headers());
-                    instanceHeaders.put(randomAlphaOfLength(15), randomAlphaOfLength(15));
-                    yield new OpenAiChatCompletionTaskSettings(instance.user(), instanceHeaders);
-                }
-            }
-            default -> throw new IllegalStateException("Unexpected value: " + fieldToMutate);
-        };
+        return createRandom();
     }
 
     @Override
@@ -170,6 +37,16 @@ public class OpenAiChatCompletionTaskSettingsTests extends AbstractBWCWireSerial
             return instance;
         }
 
-        return new OpenAiChatCompletionTaskSettings(instance.user(), null);
+        return create(instance.user(), null);
+    }
+
+    @Override
+    protected OpenAiChatCompletionTaskSettings create(@Nullable String user, @Nullable Map<String, String> headers) {
+        return new OpenAiChatCompletionTaskSettings(user, headers);
+    }
+
+    @Override
+    protected OpenAiChatCompletionTaskSettings createFromMap(@Nullable Map<String, Object> map) {
+        return new OpenAiChatCompletionTaskSettings(map);
     }
 }

+ 6 - 6
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsModelTests.java

@@ -61,7 +61,7 @@ public class OpenAiEmbeddingsModelTests extends ESTestCase {
             TaskType.TEXT_EMBEDDING,
             "service",
             new OpenAiEmbeddingsServiceSettings(modelName, url, org, SimilarityMeasure.DOT_PRODUCT, 1536, null, false, null),
-            new OpenAiEmbeddingsTaskSettings(user),
+            new OpenAiEmbeddingsTaskSettings(user, null),
             null,
             new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
         );
@@ -80,7 +80,7 @@ public class OpenAiEmbeddingsModelTests extends ESTestCase {
             TaskType.TEXT_EMBEDDING,
             "service",
             new OpenAiEmbeddingsServiceSettings(modelName, url, org, SimilarityMeasure.DOT_PRODUCT, 1536, null, false, null),
-            new OpenAiEmbeddingsTaskSettings(user),
+            new OpenAiEmbeddingsTaskSettings(user, null),
             chunkingSettings,
             new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
         );
@@ -98,7 +98,7 @@ public class OpenAiEmbeddingsModelTests extends ESTestCase {
             TaskType.TEXT_EMBEDDING,
             "service",
             new OpenAiEmbeddingsServiceSettings(modelName, url, org, SimilarityMeasure.DOT_PRODUCT, 1536, null, false, null),
-            new OpenAiEmbeddingsTaskSettings(user),
+            new OpenAiEmbeddingsTaskSettings(user, null),
             null,
             new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
         );
@@ -117,7 +117,7 @@ public class OpenAiEmbeddingsModelTests extends ESTestCase {
             TaskType.TEXT_EMBEDDING,
             "service",
             new OpenAiEmbeddingsServiceSettings(modelName, url, org, SimilarityMeasure.DOT_PRODUCT, 1536, tokenLimit, false, null),
-            new OpenAiEmbeddingsTaskSettings(user),
+            new OpenAiEmbeddingsTaskSettings(user, null),
             null,
             new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
         );
@@ -137,7 +137,7 @@ public class OpenAiEmbeddingsModelTests extends ESTestCase {
             TaskType.TEXT_EMBEDDING,
             "service",
             new OpenAiEmbeddingsServiceSettings(modelName, url, org, SimilarityMeasure.DOT_PRODUCT, dimensions, tokenLimit, false, null),
-            new OpenAiEmbeddingsTaskSettings(user),
+            new OpenAiEmbeddingsTaskSettings(user, null),
             null,
             new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
         );
@@ -159,7 +159,7 @@ public class OpenAiEmbeddingsModelTests extends ESTestCase {
             TaskType.TEXT_EMBEDDING,
             "service",
             new OpenAiEmbeddingsServiceSettings(modelName, url, org, similarityMeasure, dimensions, tokenLimit, dimensionsSetByUser, null),
-            new OpenAiEmbeddingsTaskSettings(user),
+            new OpenAiEmbeddingsTaskSettings(user, null),
             null,
             new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
         );

+ 16 - 104
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsTaskSettingsTests.java

@@ -7,107 +7,15 @@
 
 package org.elasticsearch.xpack.inference.services.openai.embeddings;
 
-import org.elasticsearch.common.Strings;
-import org.elasticsearch.common.ValidationException;
+import org.elasticsearch.TransportVersion;
 import org.elasticsearch.common.io.stream.Writeable;
-import org.elasticsearch.core.Nullable;
-import org.elasticsearch.test.AbstractWireSerializingTestCase;
-import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
-import org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields;
-import org.hamcrest.MatcherAssert;
+import org.elasticsearch.xpack.inference.services.openai.OpenAiTaskSettingsTests;
 
-import java.io.IOException;
-import java.util.Collections;
-import java.util.HashMap;
 import java.util.Map;
 
-import static org.hamcrest.Matchers.is;
+import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettings.INFERENCE_API_OPENAI_EMBEDDINGS_HEADERS;
 
-public class OpenAiEmbeddingsTaskSettingsTests extends AbstractWireSerializingTestCase<OpenAiEmbeddingsTaskSettings> {
-
-    public static OpenAiEmbeddingsTaskSettings createRandomWithUser() {
-        return new OpenAiEmbeddingsTaskSettings(randomAlphaOfLength(15));
-    }
-
-    /**
-     * The created settings can have the user set to null.
-     */
-    public static OpenAiEmbeddingsTaskSettings createRandom() {
-        var user = randomBoolean() ? randomAlphaOfLength(15) : null;
-        return new OpenAiEmbeddingsTaskSettings(user);
-    }
-
-    public void testIsEmpty() {
-        var randomSettings = new OpenAiEmbeddingsTaskSettings(randomBoolean() ? null : "username");
-        var stringRep = Strings.toString(randomSettings);
-        assertEquals(stringRep, randomSettings.isEmpty(), stringRep.equals("{}"));
-    }
-
-    public void testUpdatedTaskSettings() {
-        var initialSettings = createRandom();
-        var newSettings = createRandom();
-        Map<String, Object> newSettingsMap = new HashMap<>();
-        if (newSettings.user() != null) {
-            newSettingsMap.put(OpenAiServiceFields.USER, newSettings.user());
-        }
-        OpenAiEmbeddingsTaskSettings updatedSettings = (OpenAiEmbeddingsTaskSettings) initialSettings.updatedTaskSettings(
-            Collections.unmodifiableMap(newSettingsMap)
-        );
-        if (newSettings.user() == null) {
-            assertEquals(initialSettings.user(), updatedSettings.user());
-        } else {
-            assertEquals(newSettings.user(), updatedSettings.user());
-        }
-    }
-
-    public void testFromMap_WithUser() {
-        assertEquals(
-            new OpenAiEmbeddingsTaskSettings("user"),
-            OpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user")), ConfigurationParseContext.REQUEST)
-        );
-    }
-
-    public void testFromMap_UserIsEmptyString() {
-        var thrownException = expectThrows(
-            ValidationException.class,
-            () -> OpenAiEmbeddingsTaskSettings.fromMap(
-                new HashMap<>(Map.of(OpenAiServiceFields.USER, "")),
-                ConfigurationParseContext.REQUEST
-            )
-        );
-
-        MatcherAssert.assertThat(
-            thrownException.getMessage(),
-            is(Strings.format("Validation Failed: 1: [task_settings] Invalid value empty string. [user] must be a non-empty string;"))
-        );
-    }
-
-    public void testFromMap_MissingUser_DoesNotThrowException() {
-        var taskSettings = OpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of()), ConfigurationParseContext.PERSISTENT);
-        assertNull(taskSettings.user());
-    }
-
-    public void testOverrideWith_KeepsOriginalValuesWithOverridesAreNull() {
-        var taskSettings = OpenAiEmbeddingsTaskSettings.fromMap(
-            new HashMap<>(Map.of(OpenAiServiceFields.USER, "user")),
-            ConfigurationParseContext.PERSISTENT
-        );
-
-        var overriddenTaskSettings = OpenAiEmbeddingsTaskSettings.of(taskSettings, OpenAiEmbeddingsRequestTaskSettings.EMPTY_SETTINGS);
-        MatcherAssert.assertThat(overriddenTaskSettings, is(taskSettings));
-    }
-
-    public void testOverrideWith_UsesOverriddenSettings() {
-        var taskSettings = OpenAiEmbeddingsTaskSettings.fromMap(
-            new HashMap<>(Map.of(OpenAiServiceFields.USER, "user")),
-            ConfigurationParseContext.PERSISTENT
-        );
-
-        var requestTaskSettings = OpenAiEmbeddingsRequestTaskSettings.fromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user2")));
-
-        var overriddenTaskSettings = OpenAiEmbeddingsTaskSettings.of(taskSettings, requestTaskSettings);
-        MatcherAssert.assertThat(overriddenTaskSettings, is(new OpenAiEmbeddingsTaskSettings("user2")));
-    }
+public class OpenAiEmbeddingsTaskSettingsTests extends OpenAiTaskSettingsTests<OpenAiEmbeddingsTaskSettings> {
 
     @Override
     protected Writeable.Reader<OpenAiEmbeddingsTaskSettings> instanceReader() {
@@ -116,21 +24,25 @@ public class OpenAiEmbeddingsTaskSettingsTests extends AbstractWireSerializingTe
 
     @Override
     protected OpenAiEmbeddingsTaskSettings createTestInstance() {
-        return createRandomWithUser();
+        return createRandom();
     }
 
     @Override
-    protected OpenAiEmbeddingsTaskSettings mutateInstance(OpenAiEmbeddingsTaskSettings instance) throws IOException {
-        return randomValueOtherThan(instance, OpenAiEmbeddingsTaskSettingsTests::createRandomWithUser);
+    protected OpenAiEmbeddingsTaskSettings create(String user, Map<String, String> headers) {
+        return new OpenAiEmbeddingsTaskSettings(user, headers);
     }
 
-    public static Map<String, Object> getTaskSettingsMap(@Nullable String user) {
-        var map = new HashMap<String, Object>();
+    @Override
+    protected OpenAiEmbeddingsTaskSettings createFromMap(Map<String, Object> map) {
+        return new OpenAiEmbeddingsTaskSettings(map);
+    }
 
-        if (user != null) {
-            map.put(OpenAiServiceFields.USER, user);
+    @Override
+    protected OpenAiEmbeddingsTaskSettings mutateInstanceForVersion(OpenAiEmbeddingsTaskSettings instance, TransportVersion version) {
+        if (version.supports(INFERENCE_API_OPENAI_EMBEDDINGS_HEADERS)) {
+            return instance;
         }
 
-        return map;
+        return create(instance.user(), null);
     }
 }

+ 31 - 3
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/request/OpenAiEmbeddingsRequestTests.java

@@ -9,16 +9,24 @@ package org.elasticsearch.xpack.inference.services.openai.request;
 
 import org.apache.http.HttpHeaders;
 import org.apache.http.client.methods.HttpPost;
+import org.elasticsearch.common.settings.SecureString;
 import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xpack.inference.common.Truncator;
 import org.elasticsearch.xpack.inference.common.TruncatorTests;
+import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModel;
 import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModelTests;
+import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings;
+import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsTaskSettings;
+import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
 
 import java.io.IOException;
+import java.net.URI;
 import java.net.URISyntaxException;
 import java.util.List;
+import java.util.Map;
 
 import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
 import static org.elasticsearch.xpack.inference.services.openai.OpenAiUtils.ORGANIZATION_HEADER;
@@ -28,17 +36,37 @@ import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.is;
 
 public class OpenAiEmbeddingsRequestTests extends ESTestCase {
-    public void testCreateRequest_WithUrlOrganizationUserDefined() throws URISyntaxException, IOException {
-        var request = createRequest("www.google.com", "org", "secret", "abc", "model", "user");
+    public void testCreateRequest_WithUrlOrganizationUser_AndCustomHeadersDefined() throws IOException {
+
+        var headerKey = "key";
+        var headerValue = "value";
+
+        var model = new OpenAiEmbeddingsModel(
+            "id",
+            TaskType.TEXT_EMBEDDING,
+            "service",
+            new OpenAiEmbeddingsServiceSettings("model", URI.create("www.elastic.co"), "org", null, null, null, false, null),
+            new OpenAiEmbeddingsTaskSettings("user", Map.of(headerKey, headerValue)),
+            null,
+            new DefaultSecretSettings(new SecureString("secret".toCharArray()))
+        );
+
+        var request = new OpenAiEmbeddingsRequest(
+            TruncatorTests.createTruncator(),
+            new Truncator.TruncationResult(List.of("abc"), new boolean[] { false }),
+            model
+        );
+
         var httpRequest = request.createHttpRequest();
 
         assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
         var httpPost = (HttpPost) httpRequest.httpRequestBase();
 
-        assertThat(httpPost.getURI().toString(), is("www.google.com"));
+        assertThat(httpPost.getURI().toString(), is("www.elastic.co"));
         assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType()));
         assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret"));
         assertThat(httpPost.getLastHeader(ORGANIZATION_HEADER).getValue(), is("org"));
+        assertThat(httpPost.getLastHeader(headerKey).getValue(), is(headerValue));
 
         var requestMap = entityAsMap(httpPost.getEntity().getContent());
         assertThat(requestMap, aMapWithSize(3));