Browse Source

[ML] Write Chat Completion JSON (#128592)

Most providers write the UnifiedCompletionRequest JSON as we received
it, with some exception:
- the modelId can be null and/or overwritten from various locations
- `max_completion_tokens` repalced `max_tokens`, but some providers
  still use the deprecated field name
We will handle the variations using Params, otherwise all of the
XContent building code has moved into UnifiedCompletionRequest so it can
be reused across providers.
Pat Whelan 4 months ago
parent
commit
bf0dc6e7f2

+ 196 - 10
server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java

@@ -19,6 +19,10 @@ import org.elasticsearch.core.Nullable;
 import org.elasticsearch.xcontent.ConstructingObjectParser;
 import org.elasticsearch.xcontent.ObjectParser;
 import org.elasticsearch.xcontent.ParseField;
+import org.elasticsearch.xcontent.ToXContent;
+import org.elasticsearch.xcontent.ToXContentFragment;
+import org.elasticsearch.xcontent.ToXContentObject;
+import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentParseException;
 import org.elasticsearch.xcontent.XContentParser;
 
@@ -38,9 +42,68 @@ public record UnifiedCompletionRequest(
     @Nullable ToolChoice toolChoice,
     @Nullable List<Tool> tools,
     @Nullable Float topP
-) implements Writeable {
+) implements Writeable, ToXContentFragment {
+
+    public static final String NAME_FIELD = "name";
+    public static final String TOOL_CALL_ID_FIELD = "tool_call_id";
+    public static final String TOOL_CALLS_FIELD = "tool_calls";
+    public static final String ID_FIELD = "id";
+    public static final String FUNCTION_FIELD = "function";
+    public static final String ARGUMENTS_FIELD = "arguments";
+    public static final String DESCRIPTION_FIELD = "description";
+    public static final String PARAMETERS_FIELD = "parameters";
+    public static final String STRICT_FIELD = "strict";
+    public static final String TOP_P_FIELD = "top_p";
+    public static final String MESSAGES_FIELD = "messages";
+    private static final String ROLE_FIELD = "role";
+    private static final String CONTENT_FIELD = "content";
+    private static final String STOP_FIELD = "stop";
+    private static final String TEMPERATURE_FIELD = "temperature";
+    private static final String TOOL_CHOICE_FIELD = "tool_choice";
+    private static final String TOOL_FIELD = "tools";
+    private static final String TEXT_FIELD = "text";
+    private static final String TYPE_FIELD = "type";
+    private static final String MODEL_FIELD = "model";
+    private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens";
+    private static final String MAX_TOKENS_FIELD = "max_tokens";
+
+    /**
+     * We currently allow providers to override the model id that is written to JSON.
+     * Rather than use {@link #model()}, providers are expected to pass in the modelId via
+     * {@link org.elasticsearch.xcontent.ToXContent.Params}.
+     */
+    private static final String MODEL_ID_PARAM = "model_id_value";
+    /**
+     * Some providers only support the now-deprecated {@link #MAX_TOKENS_FIELD}, others have migrated to
+     * {@link #MAX_COMPLETION_TOKENS_FIELD}. Providers are expected to pass in their supported field name.
+     */
+    private static final String MAX_TOKENS_PARAM = "max_tokens_field";
+
+    /**
+     * Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
+     * - Key: {@link #MODEL_FIELD}, Value: modelId
+     * - Key: {@link #MAX_TOKENS_FIELD}, Value: {@link #maxCompletionTokens()}
+     */
+    public static Params withMaxTokens(String modelId, Params params) {
+        return new DelegatingMapParams(
+            Map.ofEntries(Map.entry(MODEL_ID_PARAM, modelId), Map.entry(MAX_TOKENS_PARAM, MAX_TOKENS_FIELD)),
+            params
+        );
+    }
 
-    public sealed interface Content extends NamedWriteable permits ContentObjects, ContentString {}
+    /**
+     * Creates a {@link org.elasticsearch.xcontent.ToXContent.Params} that causes ToXContent to include the key values:
+     * - Key: {@link #MODEL_FIELD}, Value: modelId
+     * - Key: {@link #MAX_COMPLETION_TOKENS_FIELD}, Value: {@link #maxCompletionTokens()}
+     */
+    public static Params withMaxCompletionTokensTokens(String modelId, Params params) {
+        return new DelegatingMapParams(
+            Map.ofEntries(Map.entry(MODEL_ID_PARAM, modelId), Map.entry(MAX_TOKENS_PARAM, MAX_COMPLETION_TOKENS_FIELD)),
+            params
+        );
+    }
+
+    public sealed interface Content extends NamedWriteable, ToXContent permits ContentObjects, ContentString {}
 
     @SuppressWarnings("unchecked")
     public static final ConstructingObjectParser<UnifiedCompletionRequest, Void> PARSER = new ConstructingObjectParser<>(
@@ -111,9 +174,40 @@ public record UnifiedCompletionRequest(
         out.writeOptionalFloat(topP);
     }
 
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.field(MESSAGES_FIELD, messages);
+        if (stop != null && (stop.isEmpty() == false)) {
+            builder.field(STOP_FIELD, stop);
+        }
+        if (temperature != null) {
+            builder.field(TEMPERATURE_FIELD, temperature);
+        }
+        if (toolChoice != null) {
+            toolChoice.toXContent(builder, params);
+        }
+        if (tools != null && (tools.isEmpty() == false)) {
+            builder.field(TOOL_FIELD, tools);
+        }
+        if (topP != null) {
+            builder.field(TOP_P_FIELD, topP);
+        }
+        // some providers only support the now-deprecated max_tokens, others have migrated to max_completion_tokens
+        if (maxCompletionTokens != null && params.param(MAX_TOKENS_PARAM) != null) {
+            builder.field(params.param(MAX_TOKENS_PARAM), maxCompletionTokens);
+        }
+        // some implementations handle modelId differently, for example OpenAI has a default in the server settings and override it there
+        // so we allow implementations to pass in the model id via the params
+        if (params.param(MODEL_ID_PARAM) != null) {
+            builder.field(MODEL_FIELD, params.param(MODEL_ID_PARAM));
+        }
+        return builder;
+    }
+
     public record Message(Content content, String role, @Nullable String toolCallId, @Nullable List<ToolCall> toolCalls)
         implements
-            Writeable {
+            Writeable,
+            ToXContentObject {
 
         @SuppressWarnings("unchecked")
         static final ConstructingObjectParser<Message, Void> PARSER = new ConstructingObjectParser<>(
@@ -161,6 +255,24 @@ public record UnifiedCompletionRequest(
             out.writeOptionalString(toolCallId);
             out.writeOptionalCollection(toolCalls);
         }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+
+            if (content != null) {
+                content.toXContent(builder, params);
+            }
+            builder.field(ROLE_FIELD, role);
+            if (toolCallId != null) {
+                builder.field(TOOL_CALL_ID_FIELD, toolCallId);
+            }
+            if (toolCalls != null) {
+                builder.field(TOOL_CALLS_FIELD, toolCalls);
+            }
+
+            return builder.endObject();
+        }
     }
 
     public record ContentObjects(List<ContentObject> contentObjects) implements Content, NamedWriteable {
@@ -180,9 +292,14 @@ public record UnifiedCompletionRequest(
         public String getWriteableName() {
             return NAME;
         }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            return builder.field(CONTENT_FIELD, contentObjects);
+        }
     }
 
-    public record ContentObject(String text, String type) implements Writeable {
+    public record ContentObject(String text, String type) implements Writeable, ToXContentObject {
         static final ConstructingObjectParser<ContentObject, Void> PARSER = new ConstructingObjectParser<>(
             ContentObject.class.getSimpleName(),
             args -> new ContentObject((String) args[0], (String) args[1])
@@ -207,6 +324,13 @@ public record UnifiedCompletionRequest(
             return text + ":" + type;
         }
 
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            builder.field(TEXT_FIELD, text);
+            builder.field(TYPE_FIELD, type);
+            return builder.endObject();
+        }
     }
 
     public record ContentString(String content) implements Content, NamedWriteable {
@@ -234,9 +358,14 @@ public record UnifiedCompletionRequest(
         public String toString() {
             return content;
         }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            return builder.field(CONTENT_FIELD, content);
+        }
     }
 
-    public record ToolCall(String id, FunctionField function, String type) implements Writeable {
+    public record ToolCall(String id, FunctionField function, String type) implements Writeable, ToXContentObject {
 
         static final ConstructingObjectParser<ToolCall, Void> PARSER = new ConstructingObjectParser<>(
             ToolCall.class.getSimpleName(),
@@ -260,7 +389,16 @@ public record UnifiedCompletionRequest(
             out.writeString(type);
         }
 
-        public record FunctionField(String arguments, String name) implements Writeable {
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            builder.field(ID_FIELD, id);
+            builder.field(FUNCTION_FIELD, function);
+            builder.field(TYPE_FIELD, type);
+            return builder.endObject();
+        }
+
+        public record FunctionField(String arguments, String name) implements Writeable, ToXContentObject {
             static final ConstructingObjectParser<FunctionField, Void> PARSER = new ConstructingObjectParser<>(
                 "tool_call_function_field",
                 args -> new FunctionField((String) args[0], (String) args[1])
@@ -280,6 +418,14 @@ public record UnifiedCompletionRequest(
                 out.writeString(arguments);
                 out.writeString(name);
             }
+
+            @Override
+            public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+                builder.startObject();
+                builder.field(ARGUMENTS_FIELD, arguments);
+                builder.field(NAME_FIELD, name);
+                return builder.endObject();
+            }
         }
     }
 
@@ -294,7 +440,7 @@ public record UnifiedCompletionRequest(
         throw new XContentParseException("Unsupported token [" + token + "]");
     }
 
-    public sealed interface ToolChoice extends NamedWriteable permits ToolChoiceObject, ToolChoiceString {}
+    public sealed interface ToolChoice extends NamedWriteable, ToXContent permits ToolChoiceObject, ToolChoiceString {}
 
     public record ToolChoiceObject(String type, FunctionField function) implements ToolChoice, NamedWriteable {
 
@@ -325,7 +471,15 @@ public record UnifiedCompletionRequest(
             return NAME;
         }
 
-        public record FunctionField(String name) implements Writeable {
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject(TOOL_CHOICE_FIELD);
+            builder.field(TYPE_FIELD, type);
+            builder.field(FUNCTION_FIELD, function);
+            return builder.endObject();
+        }
+
+        public record FunctionField(String name) implements Writeable, ToXContentObject {
             static final ConstructingObjectParser<FunctionField, Void> PARSER = new ConstructingObjectParser<>(
                 "tool_choice_function_field",
                 args -> new FunctionField((String) args[0])
@@ -343,6 +497,11 @@ public record UnifiedCompletionRequest(
             public void writeTo(StreamOutput out) throws IOException {
                 out.writeString(name);
             }
+
+            @Override
+            public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+                return builder.startObject().field(NAME_FIELD, name).endObject();
+            }
         }
     }
 
@@ -367,9 +526,14 @@ public record UnifiedCompletionRequest(
         public String getWriteableName() {
             return NAME;
         }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            return builder.field(TOOL_CHOICE_FIELD, value);
+        }
     }
 
-    public record Tool(String type, FunctionField function) implements Writeable {
+    public record Tool(String type, FunctionField function) implements Writeable, ToXContentObject {
 
         static final ConstructingObjectParser<Tool, Void> PARSER = new ConstructingObjectParser<>(
             Tool.class.getSimpleName(),
@@ -391,12 +555,22 @@ public record UnifiedCompletionRequest(
             function.writeTo(out);
         }
 
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+
+            builder.field(TYPE_FIELD, type);
+            builder.field(FUNCTION_FIELD, function);
+
+            return builder.endObject();
+        }
+
         public record FunctionField(
             @Nullable String description,
             String name,
             @Nullable Map<String, Object> parameters,
             @Nullable Boolean strict
-        ) implements Writeable {
+        ) implements Writeable, ToXContentObject {
 
             @SuppressWarnings("unchecked")
             static final ConstructingObjectParser<FunctionField, Void> PARSER = new ConstructingObjectParser<>(
@@ -422,6 +596,18 @@ public record UnifiedCompletionRequest(
                 out.writeGenericMap(parameters);
                 out.writeOptionalBoolean(strict);
             }
+
+            @Override
+            public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+                builder.startObject();
+                builder.field(DESCRIPTION_FIELD, description);
+                builder.field(NAME_FIELD, name);
+                builder.field(PARAMETERS_FIELD, parameters);
+                if (strict != null) {
+                    builder.field(STRICT_FIELD, strict);
+                }
+                return builder.endObject();
+            }
         }
     }
 }

+ 20 - 13
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java

@@ -8,9 +8,12 @@
 package org.elasticsearch.xpack.core.inference.action;
 
 import org.elasticsearch.TransportVersion;
+import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
 import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.XContentHelper;
 import org.elasticsearch.inference.UnifiedCompletionRequest;
+import org.elasticsearch.xcontent.ToXContent;
 import org.elasticsearch.xcontent.json.JsonXContent;
 import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
 
@@ -25,51 +28,51 @@ public class UnifiedCompletionRequestTests extends AbstractBWCWireSerializationT
     public void testParseAllFields() throws IOException {
         String requestJson = """
             {
-                "model": "gpt-4o",
                 "messages": [
                   {
-                    "role": "user",
                     "content": [
                         {
                           "text": "some text",
                           "type": "string"
                         }
                     ],
+                    "role": "user",
                     "tool_call_id": "100",
                     "tool_calls": [
                         {
                             "id": "call_62136354",
-                            "type": "function",
                             "function": {
                                 "arguments": "{'order_id': 'order_12345'}",
                                 "name": "get_delivery_date"
-                            }
+                            },
+                            "type": "function"
                         }
                     ]
                   }
                 ],
-                "max_completion_tokens": 100,
                 "stop": ["stop"],
                 "temperature": 0.1,
+                "tool_choice": {
+                  "type": "function",
+                  "function": {
+                    "name": "some function"
+                  }
+                },
                 "tools": [
                   {
                     "type": "function",
                     "function": {
-                      "name": "get_current_weather",
                       "description": "Get the current weather in a given location",
+                      "name": "get_current_weather",
                       "parameters": {
                         "type": "object"
                       }
                     }
                   }
                 ],
-                "tool_choice": {
-                  "type": "function",
-                  "function": {
-                    "name": "some function"
-                  }
-                },
-                "top_p": 0.2
+                "top_p": 0.2,
+                "max_completion_tokens": 100,
+                "model": "gpt-4o"
             }
             """;
 
@@ -115,6 +118,10 @@ public class UnifiedCompletionRequestTests extends AbstractBWCWireSerializationT
             );
 
             assertThat(request, is(expected));
+            assertThat(
+                Strings.toString(request, UnifiedCompletionRequest.withMaxCompletionTokensTokens("gpt-4o", ToXContent.EMPTY_PARAMS)),
+                is(XContentHelper.stripWhitespace(requestJson))
+            );
         }
     }
 

+ 1 - 120
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/unified/UnifiedChatCompletionRequestEntity.java

@@ -17,27 +17,8 @@ import java.util.Objects;
 
 public class UnifiedChatCompletionRequestEntity implements ToXContentFragment {
 
-    public static final String NAME_FIELD = "name";
-    public static final String TOOL_CALL_ID_FIELD = "tool_call_id";
-    public static final String TOOL_CALLS_FIELD = "tool_calls";
-    public static final String ID_FIELD = "id";
-    public static final String FUNCTION_FIELD = "function";
-    public static final String ARGUMENTS_FIELD = "arguments";
-    public static final String DESCRIPTION_FIELD = "description";
-    public static final String PARAMETERS_FIELD = "parameters";
-    public static final String STRICT_FIELD = "strict";
-    public static final String TOP_P_FIELD = "top_p";
     public static final String STREAM_FIELD = "stream";
     private static final String NUMBER_OF_RETURNED_CHOICES_FIELD = "n";
-    public static final String MESSAGES_FIELD = "messages";
-    private static final String ROLE_FIELD = "role";
-    private static final String CONTENT_FIELD = "content";
-    private static final String STOP_FIELD = "stop";
-    private static final String TEMPERATURE_FIELD = "temperature";
-    private static final String TOOL_CHOICE_FIELD = "tool_choice";
-    private static final String TOOL_FIELD = "tools";
-    private static final String TEXT_FIELD = "text";
-    private static final String TYPE_FIELD = "type";
     private static final String STREAM_OPTIONS_FIELD = "stream_options";
     private static final String INCLUDE_USAGE_FIELD = "include_usage";
 
@@ -55,111 +36,11 @@ public class UnifiedChatCompletionRequestEntity implements ToXContentFragment {
 
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
-        builder.startArray(MESSAGES_FIELD);
-        {
-            for (UnifiedCompletionRequest.Message message : unifiedRequest.messages()) {
-                builder.startObject();
-                {
-                    switch (message.content()) {
-                        case UnifiedCompletionRequest.ContentString contentString -> builder.field(CONTENT_FIELD, contentString.content());
-                        case UnifiedCompletionRequest.ContentObjects contentObjects -> {
-                            builder.startArray(CONTENT_FIELD);
-                            for (UnifiedCompletionRequest.ContentObject contentObject : contentObjects.contentObjects()) {
-                                builder.startObject();
-                                builder.field(TEXT_FIELD, contentObject.text());
-                                builder.field(TYPE_FIELD, contentObject.type());
-                                builder.endObject();
-                            }
-                            builder.endArray();
-                        }
-                        case null -> {
-                            // do nothing because content is optional
-                        }
-                    }
-
-                    builder.field(ROLE_FIELD, message.role());
-                    if (message.toolCallId() != null) {
-                        builder.field(TOOL_CALL_ID_FIELD, message.toolCallId());
-                    }
-                    if (message.toolCalls() != null) {
-                        builder.startArray(TOOL_CALLS_FIELD);
-                        for (UnifiedCompletionRequest.ToolCall toolCall : message.toolCalls()) {
-                            builder.startObject();
-                            {
-                                builder.field(ID_FIELD, toolCall.id());
-                                builder.startObject(FUNCTION_FIELD);
-                                {
-                                    builder.field(ARGUMENTS_FIELD, toolCall.function().arguments());
-                                    builder.field(NAME_FIELD, toolCall.function().name());
-                                }
-                                builder.endObject();
-                                builder.field(TYPE_FIELD, toolCall.type());
-                            }
-                            builder.endObject();
-                        }
-                        builder.endArray();
-                    }
-                }
-                builder.endObject();
-            }
-        }
-        builder.endArray();
+        unifiedRequest.toXContent(builder, params);
 
         // Underlying providers expect OpenAI to only return 1 possible choice.
         builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1);
 
-        if (unifiedRequest.stop() != null && unifiedRequest.stop().isEmpty() == false) {
-            builder.field(STOP_FIELD, unifiedRequest.stop());
-        }
-        if (unifiedRequest.temperature() != null) {
-            builder.field(TEMPERATURE_FIELD, unifiedRequest.temperature());
-        }
-        if (unifiedRequest.toolChoice() != null) {
-            if (unifiedRequest.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceString) {
-                builder.field(TOOL_CHOICE_FIELD, ((UnifiedCompletionRequest.ToolChoiceString) unifiedRequest.toolChoice()).value());
-            } else if (unifiedRequest.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceObject) {
-                builder.startObject(TOOL_CHOICE_FIELD);
-                {
-                    builder.field(TYPE_FIELD, ((UnifiedCompletionRequest.ToolChoiceObject) unifiedRequest.toolChoice()).type());
-                    builder.startObject(FUNCTION_FIELD);
-                    {
-                        builder.field(
-                            NAME_FIELD,
-                            ((UnifiedCompletionRequest.ToolChoiceObject) unifiedRequest.toolChoice()).function().name()
-                        );
-                    }
-                    builder.endObject();
-                }
-                builder.endObject();
-            }
-        }
-        boolean usesTools = unifiedRequest.tools() != null && unifiedRequest.tools().isEmpty() == false;
-
-        if (usesTools) {
-            builder.startArray(TOOL_FIELD);
-            for (UnifiedCompletionRequest.Tool tool : unifiedRequest.tools()) {
-                builder.startObject();
-                {
-                    builder.field(TYPE_FIELD, tool.type());
-                    builder.startObject(FUNCTION_FIELD);
-                    {
-                        builder.field(DESCRIPTION_FIELD, tool.function().description());
-                        builder.field(NAME_FIELD, tool.function().name());
-                        builder.field(PARAMETERS_FIELD, tool.function().parameters());
-                        if (tool.function().strict() != null) {
-                            builder.field(STRICT_FIELD, tool.function().strict());
-                        }
-                    }
-                    builder.endObject();
-                }
-                builder.endObject();
-            }
-            builder.endArray();
-        }
-        if (unifiedRequest.topP() != null) {
-            builder.field(TOP_P_FIELD, unifiedRequest.topP());
-        }
-
         builder.field(STREAM_FIELD, stream);
         if (stream) {
             builder.startObject(STREAM_OPTIONS_FIELD);

+ 5 - 7
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/deepseek/request/DeepSeekChatCompletionRequest.java

@@ -14,6 +14,7 @@ import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.common.Strings;
+import org.elasticsearch.inference.UnifiedCompletionRequest;
 import org.elasticsearch.xcontent.ToXContent;
 import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xcontent.json.JsonXContent;
@@ -63,13 +64,10 @@ public class DeepSeekChatCompletionRequest implements Request {
         var modelId = Objects.requireNonNullElseGet(unifiedChatInput.getRequest().model(), model::model);
         try (var builder = JsonXContent.contentBuilder()) {
             builder.startObject();
-            new UnifiedChatCompletionRequestEntity(unifiedChatInput).toXContent(builder, ToXContent.EMPTY_PARAMS);
-            builder.field(MODEL_FIELD, modelId);
-
-            if (unifiedChatInput.getRequest().maxCompletionTokens() != null) {
-                builder.field(MAX_TOKENS, unifiedChatInput.getRequest().maxCompletionTokens());
-            }
-
+            new UnifiedChatCompletionRequestEntity(unifiedChatInput).toXContent(
+                builder,
+                UnifiedCompletionRequest.withMaxTokens(modelId, ToXContent.EMPTY_PARAMS)
+            );
             builder.endObject();
             return new ByteArrayEntity(Strings.toString(builder).getBytes(StandardCharsets.UTF_8));
         } catch (IOException e) {

+ 2 - 12
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/request/ElasticInferenceServiceUnifiedChatCompletionRequestEntity.java

@@ -7,6 +7,7 @@
 
 package org.elasticsearch.xpack.inference.services.elastic.request;
 
+import org.elasticsearch.inference.UnifiedCompletionRequest;
 import org.elasticsearch.xcontent.ToXContentObject;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
@@ -16,15 +17,10 @@ import java.io.IOException;
 import java.util.Objects;
 
 public class ElasticInferenceServiceUnifiedChatCompletionRequestEntity implements ToXContentObject {
-    private static final String MODEL_FIELD = "model";
-    private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens";
-
-    private final UnifiedChatInput unifiedChatInput;
     private final UnifiedChatCompletionRequestEntity unifiedRequestEntity;
     private final String modelId;
 
     public ElasticInferenceServiceUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, String modelId) {
-        this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
         this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput);
         this.modelId = Objects.requireNonNull(modelId);
     }
@@ -32,13 +28,7 @@ public class ElasticInferenceServiceUnifiedChatCompletionRequestEntity implement
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
-        unifiedRequestEntity.toXContent(builder, params);
-        builder.field(MODEL_FIELD, modelId);
-
-        if (unifiedChatInput.getRequest().maxCompletionTokens() != null) {
-            builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedChatInput.getRequest().maxCompletionTokens());
-        }
-
+        unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxCompletionTokensTokens(modelId, params));
         builder.endObject();
 
         return builder;

+ 2 - 15
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/request/completion/HuggingFaceUnifiedChatCompletionRequestEntity.java

@@ -7,6 +7,7 @@
 
 package org.elasticsearch.xpack.inference.services.huggingface.request.completion;
 
+import org.elasticsearch.inference.UnifiedCompletionRequest;
 import org.elasticsearch.xcontent.ToXContentObject;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
@@ -18,15 +19,10 @@ import java.util.Objects;
 
 public class HuggingFaceUnifiedChatCompletionRequestEntity implements ToXContentObject {
 
-    private static final String MODEL_FIELD = "model";
-    private static final String MAX_TOKENS_FIELD = "max_tokens";
-
-    private final UnifiedChatInput unifiedChatInput;
     private final HuggingFaceChatCompletionModel model;
     private final UnifiedChatCompletionRequestEntity unifiedRequestEntity;
 
     public HuggingFaceUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, HuggingFaceChatCompletionModel model) {
-        this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput);
         this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput);
         this.model = Objects.requireNonNull(model);
     }
@@ -34,16 +30,7 @@ public class HuggingFaceUnifiedChatCompletionRequestEntity implements ToXContent
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
-        unifiedRequestEntity.toXContent(builder, params);
-
-        if (model.getServiceSettings().modelId() != null) {
-            builder.field(MODEL_FIELD, model.getServiceSettings().modelId());
-        }
-
-        if (unifiedChatInput.getRequest().maxCompletionTokens() != null) {
-            builder.field(MAX_TOKENS_FIELD, unifiedChatInput.getRequest().maxCompletionTokens());
-        }
-
+        unifiedRequestEntity.toXContent(builder, UnifiedCompletionRequest.withMaxTokens(model.getServiceSettings().modelId(), params));
         builder.endObject();
 
         return builder;

+ 5 - 7
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/request/OpenAiUnifiedChatCompletionRequestEntity.java

@@ -8,6 +8,7 @@
 package org.elasticsearch.xpack.inference.services.openai.request;
 
 import org.elasticsearch.common.Strings;
+import org.elasticsearch.inference.UnifiedCompletionRequest;
 import org.elasticsearch.xcontent.ToXContentObject;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
@@ -36,18 +37,15 @@ public class OpenAiUnifiedChatCompletionRequestEntity implements ToXContentObjec
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
-        unifiedRequestEntity.toXContent(builder, params);
-
-        builder.field(MODEL_FIELD, model.getServiceSettings().modelId());
+        unifiedRequestEntity.toXContent(
+            builder,
+            UnifiedCompletionRequest.withMaxCompletionTokensTokens(model.getServiceSettings().modelId(), params)
+        );
 
         if (Strings.isNullOrEmpty(model.getTaskSettings().user()) == false) {
             builder.field(USER_FIELD, model.getTaskSettings().user());
         }
 
-        if (unifiedChatInput.getRequest().maxCompletionTokens() != null) {
-            builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedChatInput.getRequest().maxCompletionTokens());
-        }
-
         builder.endObject();
 
         return builder;

+ 2 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java

@@ -50,7 +50,7 @@ import static org.elasticsearch.test.ESTestCase.randomFrom;
 import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;
 import static org.hamcrest.CoreMatchers.is;
 import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 import static org.mockito.Mockito.mock;
@@ -232,7 +232,7 @@ public final class Utils {
             var actualParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, actual);
             var expectedParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, expected);
         ) {
-            assertThat(actualParser.mapOrdered(), equalTo(expectedParser.mapOrdered()));
+            assertThat(actualParser.map().entrySet(), containsInAnyOrder(expectedParser.map().entrySet().toArray()));
         }
     }
 }