Browse Source

[ML] Make Streaming Results Writeable (#122527)

Make streaming elements extend Writeable and create StreamInput
constructors so we can publish elements across nodes using the transport
layer.

Additional notes:
- Moved optional methods into the InferenceServiceResults interface and
  default them
- StreamingUnifiedChatCompletionResults elements are now all records
Pat Whelan 8 months ago
parent
commit
e74ef2d325
20 changed files with 591 additions and 276 deletions
  1. 31 4
      server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java
  2. 53 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/DequeUtils.java
  3. 3 4
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java
  4. 45 34
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java
  5. 119 122
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java
  6. 52 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/DequeUtilsTests.java
  7. 41 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResultsTests.java
  8. 66 2
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java
  9. 58 21
      x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java
  10. 20 8
      x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java
  11. 16 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
  12. 7 5
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java
  13. 2 3
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java
  14. 3 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockClient.java
  15. 3 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClient.java
  16. 2 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java
  17. 3 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java
  18. 4 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionRequest.java
  19. 2 2
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java
  20. 61 61
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessorTests.java

+ 31 - 4
server/src/main/java/org/elasticsearch/inference/InferenceServiceResults.java

@@ -10,8 +10,12 @@
 package org.elasticsearch.inference;
 
 import org.elasticsearch.common.io.stream.NamedWriteable;
+import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.xcontent.ChunkedToXContent;
+import org.elasticsearch.xcontent.ToXContent;
 
+import java.io.IOException;
+import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.Flow;
@@ -27,18 +31,39 @@ public interface InferenceServiceResults extends NamedWriteable, ChunkedToXConte
      *
      * <p>For other results like SparseEmbeddingResults, this method can be a pass through to the transformToLegacyFormat.</p>
      */
-    List<? extends InferenceResults> transformToCoordinationFormat();
+    default List<? extends InferenceResults> transformToCoordinationFormat() {
+        throw new UnsupportedOperationException("transformToCoordinationFormat() is not implemented");
+    }
 
     /**
      * Transform the result to match the format required for versions prior to
      * {@link org.elasticsearch.TransportVersions#V_8_12_0}
      */
-    List<? extends InferenceResults> transformToLegacyFormat();
+    default List<? extends InferenceResults> transformToLegacyFormat() {
+        throw new UnsupportedOperationException("transformToLegacyFormat() is not implemented");
+    }
 
     /**
      * Convert the result to a map to aid with test assertions
      */
-    Map<String, Object> asMap();
+    default Map<String, Object> asMap() {
+        throw new UnsupportedOperationException("asMap() is not implemented");
+    }
+
+    default String getWriteableName() {
+        assert isStreaming() : "This must be implemented when isStreaming() == false";
+        throw new UnsupportedOperationException("This must be implemented when isStreaming() == false");
+    }
+
+    default void writeTo(StreamOutput out) throws IOException {
+        assert isStreaming() : "This must be implemented when isStreaming() == false";
+        throw new UnsupportedOperationException("This must be implemented when isStreaming() == false");
+    }
+
+    default Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
+        assert isStreaming() : "This must be implemented when isStreaming() == false";
+        throw new UnsupportedOperationException("This must be implemented when isStreaming() == false");
+    }
 
     /**
      * Returns {@code true} if these results are streamed as chunks, or {@code false} if these results contain the entire payload.
@@ -52,8 +77,10 @@ public interface InferenceServiceResults extends NamedWriteable, ChunkedToXConte
      * When {@link #isStreaming()} is {@code true}, the InferenceAction.Results will subscribe to this publisher.
      * Implementations should follow the {@link java.util.concurrent.Flow.Publisher} spec to stream the chunks.
      */
-    default Flow.Publisher<? extends ChunkedToXContent> publisher() {
+    default Flow.Publisher<? extends Result> publisher() {
         assert isStreaming() == false : "This must be implemented when isStreaming() == true";
         throw new UnsupportedOperationException("This must be implemented when isStreaming() == true");
     }
+
+    interface Result extends NamedWriteable, ChunkedToXContent {}
 }

+ 53 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/DequeUtils.java

@@ -0,0 +1,53 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.inference;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.Writeable;
+
+import java.io.IOException;
+import java.util.ArrayDeque;
+import java.util.Deque;
+
+public final class DequeUtils {
+
+    private DequeUtils() {
+        // util functions only
+    }
+
+    public static <T> Deque<T> readDeque(StreamInput in, Writeable.Reader<T> reader) throws IOException {
+        return in.readCollection(ArrayDeque::new, ((stream, deque) -> deque.offer(reader.read(in))));
+    }
+
+    public static boolean dequeEquals(Deque<?> thisDeque, Deque<?> otherDeque) {
+        if (thisDeque.size() != otherDeque.size()) {
+            return false;
+        }
+        var thisIter = thisDeque.iterator();
+        var otherIter = otherDeque.iterator();
+        while (thisIter.hasNext() && otherIter.hasNext()) {
+            if (thisIter.next().equals(otherIter.next()) == false) {
+                return false;
+            }
+        }
+        return true;
+    }
+
+    public static int dequeHashCode(Deque<?> deque) {
+        if (deque == null) {
+            return 0;
+        }
+        return deque.stream().reduce(1, (hashCode, chunk) -> 31 * hashCode + (chunk == null ? 0 : chunk.hashCode()), Integer::sum);
+    }
+
+    public static <T> Deque<T> of(T elem) {
+        var deque = new ArrayDeque<T>(1);
+        deque.offer(elem);
+        return deque;
+    }
+}

+ 3 - 4
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java

@@ -16,7 +16,6 @@ import org.elasticsearch.action.ActionType;
 import org.elasticsearch.common.collect.Iterators;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.common.xcontent.ChunkedToXContent;
 import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
 import org.elasticsearch.common.xcontent.ChunkedToXContentObject;
 import org.elasticsearch.core.TimeValue;
@@ -342,7 +341,7 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
 
         private final InferenceServiceResults results;
         private final boolean isStreaming;
-        private final Flow.Publisher<ChunkedToXContent> publisher;
+        private final Flow.Publisher<InferenceServiceResults.Result> publisher;
 
         public Response(InferenceServiceResults results) {
             this.results = results;
@@ -350,7 +349,7 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
             this.publisher = null;
         }
 
-        public Response(InferenceServiceResults results, Flow.Publisher<ChunkedToXContent> publisher) {
+        public Response(InferenceServiceResults results, Flow.Publisher<InferenceServiceResults.Result> publisher) {
             this.results = results;
             this.isStreaming = true;
             this.publisher = publisher;
@@ -434,7 +433,7 @@ public class InferenceAction extends ActionType<InferenceAction.Response> {
          * When the RestResponse is finished with the current chunk, it will request the next chunk using the subscription.
          * If the RestResponse is closed, it will cancel the subscription.
          */
-        public Flow.Publisher<ChunkedToXContent> publisher() {
+        public Flow.Publisher<InferenceServiceResults.Result> publisher() {
             assert isStreaming() : "this should only be called after isStreaming() verifies this object is non-null";
             return publisher;
         }

+ 45 - 34
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResults.java

@@ -8,63 +8,43 @@
 package org.elasticsearch.xpack.core.inference.results;
 
 import org.elasticsearch.common.collect.Iterators;
+import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.xcontent.ChunkedToXContent;
 import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
-import org.elasticsearch.inference.InferenceResults;
 import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.xcontent.ToXContent;
 
 import java.io.IOException;
 import java.util.Deque;
 import java.util.Iterator;
-import java.util.List;
-import java.util.Map;
 import java.util.concurrent.Flow;
 
+import static org.elasticsearch.xpack.core.inference.DequeUtils.dequeEquals;
+import static org.elasticsearch.xpack.core.inference.DequeUtils.dequeHashCode;
+import static org.elasticsearch.xpack.core.inference.DequeUtils.readDeque;
 import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResults.COMPLETION;
 
 /**
  * Chat Completion results that only contain a Flow.Publisher.
  */
-public record StreamingChatCompletionResults(Flow.Publisher<? extends ChunkedToXContent> publisher) implements InferenceServiceResults {
+public record StreamingChatCompletionResults(Flow.Publisher<? extends InferenceServiceResults.Result> publisher)
+    implements
+        InferenceServiceResults {
 
     @Override
     public boolean isStreaming() {
         return true;
     }
 
-    @Override
-    public List<? extends InferenceResults> transformToCoordinationFormat() {
-        throw new UnsupportedOperationException("Not implemented");
-    }
-
-    @Override
-    public List<? extends InferenceResults> transformToLegacyFormat() {
-        throw new UnsupportedOperationException("Not implemented");
-    }
+    public record Results(Deque<Result> results) implements InferenceServiceResults.Result {
+        public static final String NAME = "streaming_chat_completion_results";
 
-    @Override
-    public Map<String, Object> asMap() {
-        throw new UnsupportedOperationException("Not implemented");
-    }
-
-    @Override
-    public String getWriteableName() {
-        throw new UnsupportedOperationException("Not implemented");
-    }
-
-    @Override
-    public void writeTo(StreamOutput out) throws IOException {
-        throw new UnsupportedOperationException("Not implemented");
-    }
-
-    @Override
-    public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
-        throw new UnsupportedOperationException("Not implemented");
-    }
+        public Results(StreamInput in) throws IOException {
+            this(readDeque(in, Result::new));
+        }
 
-    public record Results(Deque<Result> results) implements ChunkedToXContent {
         @Override
         public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
             return Iterators.concat(
@@ -75,14 +55,45 @@ public record StreamingChatCompletionResults(Flow.Publisher<? extends ChunkedToX
                 ChunkedToXContentHelper.endObject()
             );
         }
+
+        @Override
+        public String getWriteableName() {
+            return NAME;
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeCollection(results);
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (o == null || getClass() != o.getClass()) return false;
+            Results other = (Results) o;
+            return dequeEquals(this.results, other.results());
+        }
+
+        @Override
+        public int hashCode() {
+            return dequeHashCode(results);
+        }
     }
 
-    public record Result(String delta) implements ChunkedToXContent {
+    public record Result(String delta) implements ChunkedToXContent, Writeable {
         private static final String RESULT = "delta";
 
+        private Result(StreamInput in) throws IOException {
+            this(in.readString());
+        }
+
         @Override
         public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
             return ChunkedToXContentHelper.chunk((b, p) -> b.startObject().field(RESULT, delta).endObject());
         }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeString(delta);
+        }
     }
 }

+ 119 - 122
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java

@@ -8,11 +8,12 @@
 package org.elasticsearch.xpack.core.inference.results;
 
 import org.elasticsearch.common.collect.Iterators;
+import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.xcontent.ChunkedToXContent;
 import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
 import org.elasticsearch.common.xcontent.ChunkedToXContentObject;
-import org.elasticsearch.inference.InferenceResults;
 import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.xcontent.ToXContent;
 
@@ -21,15 +22,17 @@ import java.util.Collections;
 import java.util.Deque;
 import java.util.Iterator;
 import java.util.List;
-import java.util.Map;
 import java.util.concurrent.Flow;
 
 import static org.elasticsearch.common.xcontent.ChunkedToXContentHelper.chunk;
+import static org.elasticsearch.xpack.core.inference.DequeUtils.dequeEquals;
+import static org.elasticsearch.xpack.core.inference.DequeUtils.dequeHashCode;
+import static org.elasticsearch.xpack.core.inference.DequeUtils.readDeque;
 
 /**
  * Chat Completion results that only contain a Flow.Publisher.
  */
-public record StreamingUnifiedChatCompletionResults(Flow.Publisher<? extends ChunkedToXContent> publisher)
+public record StreamingUnifiedChatCompletionResults(Flow.Publisher<? extends InferenceServiceResults.Result> publisher)
     implements
         InferenceServiceResults {
 
@@ -60,76 +63,58 @@ public record StreamingUnifiedChatCompletionResults(Flow.Publisher<? extends Chu
     }
 
     @Override
-    public List<? extends InferenceResults> transformToCoordinationFormat() {
-        throw new UnsupportedOperationException("Not implemented");
-    }
-
-    @Override
-    public List<? extends InferenceResults> transformToLegacyFormat() {
-        throw new UnsupportedOperationException("Not implemented");
-    }
-
-    @Override
-    public Map<String, Object> asMap() {
-        throw new UnsupportedOperationException("Not implemented");
-    }
-
-    @Override
-    public String getWriteableName() {
+    public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
         throw new UnsupportedOperationException("Not implemented");
     }
 
-    @Override
-    public void writeTo(StreamOutput out) throws IOException {
-        throw new UnsupportedOperationException("Not implemented");
-    }
+    public record Results(Deque<ChatCompletionChunk> chunks) implements InferenceServiceResults.Result {
+        public static String NAME = "streaming_unified_chat_completion_results";
 
-    @Override
-    public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
-        throw new UnsupportedOperationException("Not implemented");
-    }
+        public Results(StreamInput in) throws IOException {
+            this(readDeque(in, ChatCompletionChunk::new));
+        }
 
-    public record Results(Deque<ChatCompletionChunk> chunks) implements ChunkedToXContent {
         @Override
         public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
             return Iterators.concat(Iterators.flatMap(chunks.iterator(), c -> c.toXContentChunked(params)));
         }
-    }
-
-    public static class ChatCompletionChunk implements ChunkedToXContent {
-        private final String id;
-
-        public String getId() {
-            return id;
-        }
 
-        public List<Choice> getChoices() {
-            return choices;
+        @Override
+        public String getWriteableName() {
+            return NAME;
         }
 
-        public String getModel() {
-            return model;
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeCollection(chunks, StreamOutput::writeWriteable);
         }
 
-        public String getObject() {
-            return object;
+        @Override
+        public boolean equals(Object o) {
+            if (o == null || getClass() != o.getClass()) return false;
+            Results results = (Results) o;
+            return dequeEquals(chunks, results.chunks());
         }
 
-        public Usage getUsage() {
-            return usage;
+        @Override
+        public int hashCode() {
+            return dequeHashCode(chunks);
         }
+    }
 
-        private final List<Choice> choices;
-        private final String model;
-        private final String object;
-        private final ChatCompletionChunk.Usage usage;
-
-        public ChatCompletionChunk(String id, List<Choice> choices, String model, String object, ChatCompletionChunk.Usage usage) {
-            this.id = id;
-            this.choices = choices;
-            this.model = model;
-            this.object = object;
-            this.usage = usage;
+    public record ChatCompletionChunk(String id, List<Choice> choices, String model, String object, ChatCompletionChunk.Usage usage)
+        implements
+            ChunkedToXContent,
+            Writeable {
+
+        private ChatCompletionChunk(StreamInput in) throws IOException {
+            this(
+                in.readString(),
+                in.readOptionalCollectionAsList(Choice::new),
+                in.readString(),
+                in.readString(),
+                in.readOptional(Usage::new)
+            );
         }
 
         @Override
@@ -152,7 +137,23 @@ public record StreamingUnifiedChatCompletionResults(Flow.Publisher<? extends Chu
             );
         }
 
-        public record Choice(ChatCompletionChunk.Choice.Delta delta, String finishReason, int index) implements ChunkedToXContentObject {
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeString(id);
+            out.writeOptionalCollection(choices);
+            out.writeString(model);
+            out.writeString(object);
+            out.writeOptionalWriteable(usage);
+        }
+
+        public record Choice(ChatCompletionChunk.Choice.Delta delta, String finishReason, int index)
+            implements
+                ChunkedToXContentObject,
+                Writeable {
+
+            private Choice(StreamInput in) throws IOException {
+                this(new Delta(in), in.readOptionalString(), in.readInt());
+            }
 
             /*
               choices: Array<{
@@ -172,17 +173,22 @@ public record StreamingUnifiedChatCompletionResults(Flow.Publisher<? extends Chu
                 );
             }
 
-            public static class Delta {
-                private final String content;
-                private final String refusal;
-                private final String role;
-                private List<ToolCall> toolCalls;
-
-                public Delta(String content, String refusal, String role, List<ToolCall> toolCalls) {
-                    this.content = content;
-                    this.refusal = refusal;
-                    this.role = role;
-                    this.toolCalls = toolCalls;
+            @Override
+            public void writeTo(StreamOutput out) throws IOException {
+                out.writeWriteable(delta);
+                out.writeOptionalString(finishReason);
+                out.writeInt(index);
+            }
+
+            public record Delta(String content, String refusal, String role, List<ToolCall> toolCalls) implements Writeable {
+
+                private Delta(StreamInput in) throws IOException {
+                    this(
+                        in.readOptionalString(),
+                        in.readOptionalString(),
+                        in.readOptionalString(),
+                        in.readOptionalCollectionAsList(ToolCall::new)
+                    );
                 }
 
                 /*
@@ -214,49 +220,26 @@ public record StreamingUnifiedChatCompletionResults(Flow.Publisher<? extends Chu
 
                 }
 
-                public String getContent() {
-                    return content;
-                }
-
-                public String getRefusal() {
-                    return refusal;
-                }
-
-                public String getRole() {
-                    return role;
-                }
-
-                public List<ToolCall> getToolCalls() {
-                    return toolCalls;
+                @Override
+                public void writeTo(StreamOutput out) throws IOException {
+                    out.writeOptionalString(content);
+                    out.writeOptionalString(refusal);
+                    out.writeOptionalString(role);
+                    out.writeOptionalCollection(toolCalls);
                 }
 
-                public static class ToolCall implements ChunkedToXContentObject {
-                    private final int index;
-                    private final String id;
-                    public ChatCompletionChunk.Choice.Delta.ToolCall.Function function;
-                    private final String type;
-
-                    public ToolCall(int index, String id, ChatCompletionChunk.Choice.Delta.ToolCall.Function function, String type) {
-                        this.index = index;
-                        this.id = id;
-                        this.function = function;
-                        this.type = type;
-                    }
-
-                    public int getIndex() {
-                        return index;
-                    }
-
-                    public String getId() {
-                        return id;
-                    }
-
-                    public ChatCompletionChunk.Choice.Delta.ToolCall.Function getFunction() {
-                        return function;
-                    }
-
-                    public String getType() {
-                        return type;
+                public record ToolCall(int index, String id, ChatCompletionChunk.Choice.Delta.ToolCall.Function function, String type)
+                    implements
+                        ChunkedToXContentObject,
+                        Writeable {
+
+                    private ToolCall(StreamInput in) throws IOException {
+                        this(
+                            in.readInt(),
+                            in.readOptionalString(),
+                            in.readOptional(ChatCompletionChunk.Choice.Delta.ToolCall.Function::new),
+                            in.readOptionalString()
+                        );
                     }
 
                     /*
@@ -280,8 +263,8 @@ public record StreamingUnifiedChatCompletionResults(Flow.Publisher<? extends Chu
                             content = Iterators.concat(
                                 content,
                                 ChunkedToXContentHelper.startObject(FUNCTION_FIELD),
-                                optionalField(FUNCTION_ARGUMENTS_FIELD, function.getArguments()),
-                                optionalField(FUNCTION_NAME_FIELD, function.getName()),
+                                optionalField(FUNCTION_ARGUMENTS_FIELD, function.arguments()),
+                                optionalField(FUNCTION_NAME_FIELD, function.name()),
                                 ChunkedToXContentHelper.endObject()
                             );
                         }
@@ -294,28 +277,42 @@ public record StreamingUnifiedChatCompletionResults(Flow.Publisher<? extends Chu
                         return content;
                     }
 
-                    public static class Function {
-                        private final String arguments;
-                        private final String name;
+                    @Override
+                    public void writeTo(StreamOutput out) throws IOException {
+                        out.writeInt(index);
+                        out.writeOptionalString(id);
+                        out.writeOptionalWriteable(function);
+                        out.writeOptionalString(type);
+                    }
 
-                        public Function(String arguments, String name) {
-                            this.arguments = arguments;
-                            this.name = name;
-                        }
+                    public record Function(String arguments, String name) implements Writeable {
 
-                        public String getArguments() {
-                            return arguments;
+                        private Function(StreamInput in) throws IOException {
+                            this(in.readOptionalString(), in.readOptionalString());
                         }
 
-                        public String getName() {
-                            return name;
+                        @Override
+                        public void writeTo(StreamOutput out) throws IOException {
+                            out.writeOptionalString(arguments);
+                            out.writeOptionalString(name);
                         }
                     }
                 }
             }
         }
 
-        public record Usage(int completionTokens, int promptTokens, int totalTokens) {}
+        public record Usage(int completionTokens, int promptTokens, int totalTokens) implements Writeable {
+            private Usage(StreamInput in) throws IOException {
+                this(in.readInt(), in.readInt(), in.readInt());
+            }
+
+            @Override
+            public void writeTo(StreamOutput out) throws IOException {
+                out.writeInt(completionTokens);
+                out.writeInt(promptTokens);
+                out.writeInt(totalTokens);
+            }
+        }
 
         private static Iterator<ToXContent> optionalField(String name, String value) {
             if (value == null) {

+ 52 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/DequeUtilsTests.java

@@ -0,0 +1,52 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.inference;
+
+import org.elasticsearch.common.io.stream.ByteArrayStreamInput;
+import org.elasticsearch.common.io.stream.BytesStreamOutput;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.test.ESTestCase;
+
+import java.io.IOException;
+
+import static org.mockito.Mockito.mock;
+
+public class DequeUtilsTests extends ESTestCase {
+
+    public void testEqualsAndHashCodeWithSameObject() {
+        var someObject = mock();
+        var dequeOne = DequeUtils.of(someObject);
+        var dequeTwo = DequeUtils.of(someObject);
+        assertTrue(DequeUtils.dequeEquals(dequeOne, dequeTwo));
+        assertEquals(DequeUtils.dequeHashCode(dequeOne), DequeUtils.dequeHashCode(dequeTwo));
+    }
+
+    public void testEqualsAndHashCodeWithEqualsObject() {
+        var dequeOne = DequeUtils.of("the same string");
+        var dequeTwo = DequeUtils.of("the same string");
+        assertTrue(DequeUtils.dequeEquals(dequeOne, dequeTwo));
+        assertEquals(DequeUtils.dequeHashCode(dequeOne), DequeUtils.dequeHashCode(dequeTwo));
+    }
+
+    public void testNotEqualsAndHashCode() {
+        var dequeOne = DequeUtils.of(mock());
+        var dequeTwo = DequeUtils.of(mock());
+        assertFalse(DequeUtils.dequeEquals(dequeOne, dequeTwo));
+        assertNotEquals(DequeUtils.dequeHashCode(dequeOne), DequeUtils.dequeHashCode(dequeTwo));
+    }
+
+    public void testReadFromStream() throws IOException {
+        var dequeOne = DequeUtils.of("this is a string");
+        var out = new BytesStreamOutput();
+        out.writeStringCollection(dequeOne);
+        var in = new ByteArrayStreamInput(out.bytes().array());
+        var dequeTwo = DequeUtils.readDeque(in, StreamInput::readString);
+        assertTrue(DequeUtils.dequeEquals(dequeOne, dequeTwo));
+        assertEquals(DequeUtils.dequeHashCode(dequeOne), DequeUtils.dequeHashCode(dequeTwo));
+    }
+}

+ 41 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingChatCompletionResultsTests.java

@@ -0,0 +1,41 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.inference.results;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+
+import java.io.IOException;
+import java.util.ArrayDeque;
+
+public class StreamingChatCompletionResultsTests extends AbstractWireSerializingTestCase<StreamingChatCompletionResults.Results> {
+    @Override
+    protected Writeable.Reader<StreamingChatCompletionResults.Results> instanceReader() {
+        return StreamingChatCompletionResults.Results::new;
+    }
+
+    @Override
+    protected StreamingChatCompletionResults.Results createTestInstance() {
+        var results = new ArrayDeque<StreamingChatCompletionResults.Result>();
+        for (int i = 0; i < randomIntBetween(1, 10); i++) {
+            results.offer(new StreamingChatCompletionResults.Result(randomAlphanumericOfLength(5)));
+        }
+        return new StreamingChatCompletionResults.Results(results);
+    }
+
+    @Override
+    protected StreamingChatCompletionResults.Results mutateInstance(StreamingChatCompletionResults.Results instance) throws IOException {
+        var results = new ArrayDeque<>(instance.results());
+        if (randomBoolean()) {
+            results.pop();
+        } else {
+            results.offer(new StreamingChatCompletionResults.Result(randomAlphanumericOfLength(5)));
+        }
+        return new StreamingChatCompletionResults.Results(results);
+    }
+}

+ 66 - 2
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java

@@ -10,7 +10,8 @@
 package org.elasticsearch.xpack.core.inference.results;
 
 import org.elasticsearch.common.Strings;
-import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.json.JsonXContent;
 
@@ -18,8 +19,10 @@ import java.io.IOException;
 import java.util.ArrayDeque;
 import java.util.Deque;
 import java.util.List;
+import java.util.function.Supplier;
 
-public class StreamingUnifiedChatCompletionResultsTests extends ESTestCase {
+public class StreamingUnifiedChatCompletionResultsTests extends AbstractWireSerializingTestCase<
+    StreamingUnifiedChatCompletionResults.Results> {
 
     public void testResults_toXContentChunked() throws IOException {
         String expected = """
@@ -195,4 +198,65 @@ public class StreamingUnifiedChatCompletionResultsTests extends ESTestCase {
         assertEquals(expected.replaceAll("\\s+", ""), Strings.toString(builder.prettyPrint()).trim());
     }
 
+    @Override
+    protected Writeable.Reader<StreamingUnifiedChatCompletionResults.Results> instanceReader() {
+        return StreamingUnifiedChatCompletionResults.Results::new;
+    }
+
+    @Override
+    protected StreamingUnifiedChatCompletionResults.Results createTestInstance() {
+        var results = new ArrayDeque<StreamingUnifiedChatCompletionResults.ChatCompletionChunk>();
+        for (int i = 0; i < randomIntBetween(1, 3); i++) {
+            results.offer(randomChatCompletionChunk());
+        }
+        return new StreamingUnifiedChatCompletionResults.Results(results);
+    }
+
+    private static StreamingUnifiedChatCompletionResults.ChatCompletionChunk randomChatCompletionChunk() {
+        Supplier<String> randomOptionalString = () -> randomBoolean() ? null : randomAlphanumericOfLength(5);
+        return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk(
+            randomAlphanumericOfLength(5),
+            randomBoolean() ? null : randomList(randomInt(5), () -> {
+                return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice(
+                    new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta(
+                        randomOptionalString.get(),
+                        randomOptionalString.get(),
+                        randomOptionalString.get(),
+                        randomBoolean() ? null : randomList(randomInt(5), () -> {
+                            return new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall(
+                                randomInt(5),
+                                randomOptionalString.get(),
+                                randomBoolean()
+                                    ? null
+                                    : new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function(
+                                        randomOptionalString.get(),
+                                        randomOptionalString.get()
+                                    ),
+                                randomOptionalString.get()
+                            );
+                        })
+                    ),
+                    randomOptionalString.get(),
+                    randomInt(5)
+                );
+            }),
+            randomAlphanumericOfLength(5),
+            randomAlphanumericOfLength(5),
+            randomBoolean()
+                ? null
+                : new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage(randomInt(5), randomInt(5), randomInt(5))
+        );
+    }
+
+    @Override
+    protected StreamingUnifiedChatCompletionResults.Results mutateInstance(StreamingUnifiedChatCompletionResults.Results instance)
+        throws IOException {
+        var results = new ArrayDeque<>(instance.chunks());
+        if (randomBoolean()) {
+            results.pop();
+        } else {
+            results.add(randomChatCompletionChunk());
+        }
+        return new StreamingUnifiedChatCompletionResults.Results(results); // immutable
+    }
 }

+ 58 - 21
x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java

@@ -14,7 +14,6 @@ import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.util.LazyInitializable;
-import org.elasticsearch.common.xcontent.ChunkedToXContent;
 import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.inference.ChunkedInference;
@@ -30,6 +29,7 @@ import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.inference.UnifiedCompletionRequest;
 import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
 import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.xcontent.ToXContent;
 import org.elasticsearch.xcontent.ToXContentObject;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
@@ -38,6 +38,7 @@ import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatComple
 import java.io.IOException;
 import java.util.EnumSet;
 import java.util.HashMap;
+import java.util.Iterator;
 import java.util.List;
 import java.util.Locale;
 import java.util.Map;
@@ -157,10 +158,31 @@ public class TestStreamingCompletionServiceExtension implements InferenceService
             });
         }
 
-        private ChunkedToXContent completionChunk(String delta) {
-            return params -> ChunkedToXContentHelper.chunk(
-                (b, p) -> b.startObject().startArray(COMPLETION).startObject().field("delta", delta).endObject().endArray().endObject()
-            );
+        private InferenceServiceResults.Result completionChunk(String delta) {
+            return new InferenceServiceResults.Result() {
+                @Override
+                public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
+                    return ChunkedToXContentHelper.chunk(
+                        (b, p) -> b.startObject()
+                            .startArray(COMPLETION)
+                            .startObject()
+                            .field("delta", delta)
+                            .endObject()
+                            .endArray()
+                            .endObject()
+                    );
+                }
+
+                @Override
+                public void writeTo(StreamOutput out) throws IOException {
+                    out.writeString(delta);
+                }
+
+                @Override
+                public String getWriteableName() {
+                    return "test_completionChunk";
+                }
+            };
         }
 
         private StreamingUnifiedChatCompletionResults makeUnifiedResults(UnifiedCompletionRequest request) {
@@ -198,22 +220,37 @@ public class TestStreamingCompletionServiceExtension implements InferenceService
           "object": "chat.completion.chunk"
         }
          */
-        private ChunkedToXContent unifiedCompletionChunk(String delta) {
-            return params -> ChunkedToXContentHelper.chunk(
-                (b, p) -> b.startObject()
-                    .field("id", "id")
-                    .startArray("choices")
-                    .startObject()
-                    .startObject("delta")
-                    .field("content", delta)
-                    .endObject()
-                    .field("index", 0)
-                    .endObject()
-                    .endArray()
-                    .field("model", "gpt-4o-2024-08-06")
-                    .field("object", "chat.completion.chunk")
-                    .endObject()
-            );
+        private InferenceServiceResults.Result unifiedCompletionChunk(String delta) {
+            return new InferenceServiceResults.Result() {
+                @Override
+                public String getWriteableName() {
+                    return "test_unifiedCompletionChunk";
+                }
+
+                @Override
+                public void writeTo(StreamOutput out) throws IOException {
+                    out.writeString(delta);
+                }
+
+                @Override
+                public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
+                    return ChunkedToXContentHelper.chunk(
+                        (b, p) -> b.startObject()
+                            .field("id", "id")
+                            .startArray("choices")
+                            .startObject()
+                            .startObject("delta")
+                            .field("content", delta)
+                            .endObject()
+                            .field("index", 0)
+                            .endObject()
+                            .endArray()
+                            .field("model", "gpt-4o-2024-08-06")
+                            .field("object", "chat.completion.chunk")
+                            .endObject()
+                    );
+                }
+            };
         }
 
         @Override

+ 20 - 8
x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListenerTests.java

@@ -33,7 +33,6 @@ import org.elasticsearch.common.settings.IndexScopedSettings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.settings.SettingsFilter;
 import org.elasticsearch.common.util.CollectionUtils;
-import org.elasticsearch.common.xcontent.ChunkedToXContent;
 import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
 import org.elasticsearch.features.NodeFeature;
 import org.elasticsearch.inference.InferenceResults;
@@ -179,14 +178,14 @@ public class ServerSentEventsRestActionListenerTests extends ESIntegTestCase {
     }
 
     private static class StreamingInferenceServiceResults implements InferenceServiceResults {
-        private final Flow.Publisher<ChunkedToXContent> publisher;
+        private final Flow.Publisher<InferenceServiceResults.Result> publisher;
 
-        private StreamingInferenceServiceResults(Flow.Publisher<ChunkedToXContent> publisher) {
+        private StreamingInferenceServiceResults(Flow.Publisher<InferenceServiceResults.Result> publisher) {
             this.publisher = publisher;
         }
 
         @Override
-        public Flow.Publisher<ChunkedToXContent> publisher() {
+        public Flow.Publisher<InferenceServiceResults.Result> publisher() {
             return publisher;
         }
 
@@ -224,7 +223,7 @@ public class ServerSentEventsRestActionListenerTests extends ESIntegTestCase {
         }
     }
 
-    private static class RandomPublisher implements Flow.Publisher<ChunkedToXContent> {
+    private static class RandomPublisher implements Flow.Publisher<InferenceServiceResults.Result> {
         private final int requestCount;
         private final boolean withError;
 
@@ -234,7 +233,7 @@ public class ServerSentEventsRestActionListenerTests extends ESIntegTestCase {
         }
 
         @Override
-        public void subscribe(Flow.Subscriber<? super ChunkedToXContent> subscriber) {
+        public void subscribe(Flow.Subscriber<? super InferenceServiceResults.Result> subscriber) {
             var resultCount = new AtomicInteger(requestCount);
             subscriber.onSubscribe(new Flow.Subscription() {
                 @Override
@@ -256,12 +255,25 @@ public class ServerSentEventsRestActionListenerTests extends ESIntegTestCase {
         }
     }
 
-    private static class RandomString implements ChunkedToXContent {
+    private record RandomString(String randomString) implements InferenceServiceResults.Result {
+        RandomString() {
+            this(randomUnicodeOfLengthBetween(2, 20));
+        }
+
         @Override
         public Iterator<? extends ToXContent> toXContentChunked(ToXContent.Params params) {
-            var randomString = randomUnicodeOfLengthBetween(2, 20);
             return ChunkedToXContentHelper.chunk((b, p) -> b.startObject().field("delta", randomString).endObject());
         }
+
+        @Override
+        public String getWriteableName() {
+            return "test_RandomString";
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeString(randomString);
+        }
     }
 
     private static class SingleInferenceServiceResults implements InferenceServiceResults {

+ 16 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

@@ -23,6 +23,8 @@ import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloa
 import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults;
 import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
 import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
+import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
+import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
 import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
 import org.elasticsearch.xpack.inference.chunking.SentenceBoundaryChunkingSettings;
 import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings;
@@ -484,6 +486,20 @@ public class InferenceNamedWriteablesProvider {
         namedWriteables.add(
             new NamedWriteableRegistry.Entry(InferenceServiceResults.class, RankedDocsResults.NAME, RankedDocsResults::new)
         );
+        namedWriteables.add(
+            new NamedWriteableRegistry.Entry(
+                StreamingChatCompletionResults.Results.class,
+                StreamingChatCompletionResults.Results.NAME,
+                StreamingChatCompletionResults.Results::new
+            )
+        );
+        namedWriteables.add(
+            new NamedWriteableRegistry.Entry(
+                StreamingUnifiedChatCompletionResults.Results.class,
+                StreamingUnifiedChatCompletionResults.Results.NAME,
+                StreamingUnifiedChatCompletionResults.Results::new
+            )
+        );
     }
 
     private static void addCustomElandWriteables(final List<NamedWriteableRegistry.Entry> namedWriteables) {

+ 7 - 5
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java

@@ -19,7 +19,6 @@ import org.elasticsearch.common.Randomness;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
-import org.elasticsearch.common.xcontent.ChunkedToXContent;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.inference.InferenceService;
 import org.elasticsearch.inference.InferenceServiceRegistry;
@@ -276,7 +275,10 @@ public abstract class BaseTransportInferenceAction<Request extends BaseInference
         inferenceStats.requestCount().incrementBy(1, modelAttributes(model));
         inferOnService(model, request, service, ActionListener.wrap(inferenceResults -> {
             if (request.isStreaming()) {
-                var taskProcessor = streamingTaskManager.<ChunkedToXContent>create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION);
+                var taskProcessor = streamingTaskManager.<InferenceServiceResults.Result>create(
+                    STREAMING_INFERENCE_TASK_TYPE,
+                    STREAMING_TASK_ACTION
+                );
                 inferenceResults.publisher().subscribe(taskProcessor);
 
                 var instrumentedStream = new PublisherWithMetrics(timer, model);
@@ -295,7 +297,7 @@ public abstract class BaseTransportInferenceAction<Request extends BaseInference
         }));
     }
 
-    protected Flow.Publisher<ChunkedToXContent> streamErrorHandler(Flow.Processor<ChunkedToXContent, ChunkedToXContent> upstream) {
+    protected <T> Flow.Publisher<T> streamErrorHandler(Flow.Processor<T, T> upstream) {
         return upstream;
     }
 
@@ -349,7 +351,7 @@ public abstract class BaseTransportInferenceAction<Request extends BaseInference
         );
     }
 
-    private class PublisherWithMetrics extends DelegatingProcessor<ChunkedToXContent, ChunkedToXContent> {
+    private class PublisherWithMetrics extends DelegatingProcessor<InferenceServiceResults.Result, InferenceServiceResults.Result> {
 
         private final InferenceTimer timer;
         private final Model model;
@@ -360,7 +362,7 @@ public abstract class BaseTransportInferenceAction<Request extends BaseInference
         }
 
         @Override
-        protected void next(ChunkedToXContent item) {
+        protected void next(InferenceServiceResults.Result item) {
             downstream().onNext(item);
         }
 

+ 2 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java

@@ -11,7 +11,6 @@ import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.ActionFilters;
 import org.elasticsearch.client.internal.node.NodeClient;
-import org.elasticsearch.common.xcontent.ChunkedToXContent;
 import org.elasticsearch.inference.InferenceService;
 import org.elasticsearch.inference.InferenceServiceRegistry;
 import org.elasticsearch.inference.InferenceServiceResults;
@@ -103,7 +102,7 @@ public class TransportUnifiedCompletionInferenceAction extends BaseTransportInfe
      * as {@link UnifiedChatCompletionException}.
      */
     @Override
-    protected Flow.Publisher<ChunkedToXContent> streamErrorHandler(Flow.Processor<ChunkedToXContent, ChunkedToXContent> upstream) {
+    protected <T> Flow.Publisher<T> streamErrorHandler(Flow.Processor<T, T> upstream) {
         return downstream -> {
             upstream.subscribe(new Flow.Subscriber<>() {
                 @Override
@@ -112,7 +111,7 @@ public class TransportUnifiedCompletionInferenceAction extends BaseTransportInfe
                 }
 
                 @Override
-                public void onNext(ChunkedToXContent item) {
+                public void onNext(T item) {
                     downstream.onNext(item);
                 }
 

+ 3 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockClient.java

@@ -15,7 +15,7 @@ import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse;
 
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.common.xcontent.ChunkedToXContent;
+import org.elasticsearch.inference.InferenceServiceResults;
 
 import java.time.Instant;
 import java.util.concurrent.Flow;
@@ -23,7 +23,8 @@ import java.util.concurrent.Flow;
 public interface AmazonBedrockClient {
     void converse(ConverseRequest converseRequest, ActionListener<ConverseResponse> responseListener) throws ElasticsearchException;
 
-    Flow.Publisher<? extends ChunkedToXContent> converseStream(ConverseStreamRequest converseStreamRequest) throws ElasticsearchException;
+    Flow.Publisher<? extends InferenceServiceResults.Result> converseStream(ConverseStreamRequest converseStreamRequest)
+        throws ElasticsearchException;
 
     void invokeModel(InvokeModelRequest invokeModelRequest, ActionListener<InvokeModelResponse> responseListener)
         throws ElasticsearchException;

+ 3 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClient.java

@@ -26,10 +26,10 @@ import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.ExceptionsHelper;
 import org.elasticsearch.SpecialPermission;
 import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.common.xcontent.ChunkedToXContent;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.Strings;
 import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel;
 import org.reactivestreams.FlowAdapters;
@@ -88,7 +88,8 @@ public class AmazonBedrockInferenceClient extends AmazonBedrockBaseClient {
     }
 
     @Override
-    public Flow.Publisher<? extends ChunkedToXContent> converseStream(ConverseStreamRequest request) throws ElasticsearchException {
+    public Flow.Publisher<? extends InferenceServiceResults.Result> converseStream(ConverseStreamRequest request)
+        throws ElasticsearchException {
         var awsResponseProcessor = new AmazonBedrockStreamingChatProcessor(threadPool);
         internalClient.converseStream(
             request,

+ 2 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java

@@ -9,8 +9,8 @@ package org.elasticsearch.xpack.inference.external.openai;
 
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
-import org.elasticsearch.common.xcontent.ChunkedToXContent;
 import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
+import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.xcontent.XContentFactory;
 import org.elasticsearch.xcontent.XContentParser;
 import org.elasticsearch.xcontent.XContentParserConfiguration;
@@ -101,7 +101,7 @@ import static org.elasticsearch.xpack.inference.external.response.XContentUtils.
  *     </code>
  * </pre>
  */
-public class OpenAiStreamingProcessor extends DelegatingProcessor<Deque<ServerSentEvent>, ChunkedToXContent> {
+public class OpenAiStreamingProcessor extends DelegatingProcessor<Deque<ServerSentEvent>, InferenceServiceResults.Result> {
     private static final Logger log = LogManager.getLogger(OpenAiStreamingProcessor.class);
     private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in OpenAI chat completions response";
 

+ 3 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java

@@ -9,7 +9,6 @@ package org.elasticsearch.xpack.inference.external.openai;
 
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
-import org.elasticsearch.common.xcontent.ChunkedToXContent;
 import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
 import org.elasticsearch.xcontent.ConstructingObjectParser;
 import org.elasticsearch.xcontent.ParseField;
@@ -34,7 +33,9 @@ import java.util.function.BiFunction;
 import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
 import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
 
-public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor<Deque<ServerSentEvent>, ChunkedToXContent> {
+public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor<
+    Deque<ServerSentEvent>,
+    StreamingUnifiedChatCompletionResults.Results> {
     public static final String FUNCTION_FIELD = "function";
     private static final Logger logger = LogManager.getLogger(OpenAiUnifiedStreamingProcessor.class);
 

+ 4 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionRequest.java

@@ -10,9 +10,9 @@ package org.elasticsearch.xpack.inference.external.request.amazonbedrock.complet
 import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest;
 import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest;
 
-import org.elasticsearch.common.xcontent.ChunkedToXContent;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.xpack.core.common.socket.SocketAccess;
 import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockBaseClient;
@@ -82,7 +82,9 @@ public class AmazonBedrockChatCompletionRequest extends AmazonBedrockRequest {
         this.executeRequest(awsBedrockClient);
     }
 
-    public Flow.Publisher<? extends ChunkedToXContent> executeStreamChatCompletionRequest(AmazonBedrockBaseClient awsBedrockClient) {
+    public Flow.Publisher<? extends InferenceServiceResults.Result> executeStreamChatCompletionRequest(
+        AmazonBedrockBaseClient awsBedrockClient
+    ) {
         var converseStreamRequest = ConverseStreamRequest.builder()
             .modelId(amazonBedrockModel.model())
             .messages(getConverseMessageList(requestEntity.messages()));

+ 2 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java

@@ -308,7 +308,7 @@ public abstract class BaseTransportInferenceActionTestCase<Request extends BaseI
             }
 
             @Override
-            public void onNext(ChunkedToXContent item) {
+            public void onNext(InferenceServiceResults.Result item) {
 
             }
 
@@ -332,7 +332,7 @@ public abstract class BaseTransportInferenceActionTestCase<Request extends BaseI
         }));
     }
 
-    protected Flow.Publisher<ChunkedToXContent> mockStreamResponse(Consumer<Flow.Processor<?, ?>> action) {
+    protected Flow.Publisher<InferenceServiceResults.Result> mockStreamResponse(Consumer<Flow.Processor<?, ?>> action) {
         mockService(true, Set.of(), listener -> {
             Flow.Processor<ChunkedToXContent, ChunkedToXContent> taskProcessor = mock();
             doAnswer(innerAns -> {

+ 61 - 61
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessorTests.java

@@ -65,31 +65,31 @@ public class OpenAiUnifiedStreamingProcessorTests extends ESTestCase {
                 .parse(parser);
 
             // Assertions to verify the parsed object
-            assertEquals("example_id", chunk.getId());
-            assertEquals("example_model", chunk.getModel());
-            assertEquals("chat.completion.chunk", chunk.getObject());
-            assertNotNull(chunk.getUsage());
-            assertEquals(50, chunk.getUsage().completionTokens());
-            assertEquals(20, chunk.getUsage().promptTokens());
-            assertEquals(70, chunk.getUsage().totalTokens());
+            assertEquals("example_id", chunk.id());
+            assertEquals("example_model", chunk.model());
+            assertEquals("chat.completion.chunk", chunk.object());
+            assertNotNull(chunk.usage());
+            assertEquals(50, chunk.usage().completionTokens());
+            assertEquals(20, chunk.usage().promptTokens());
+            assertEquals(70, chunk.usage().totalTokens());
 
-            List<StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice> choices = chunk.getChoices();
+            List<StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice> choices = chunk.choices();
             assertEquals(1, choices.size());
             StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice choice = choices.get(0);
-            assertEquals("example_content", choice.delta().getContent());
-            assertNull(choice.delta().getRefusal());
-            assertEquals("assistant", choice.delta().getRole());
+            assertEquals("example_content", choice.delta().content());
+            assertNull(choice.delta().refusal());
+            assertEquals("assistant", choice.delta().role());
             assertEquals("stop", choice.finishReason());
             assertEquals(0, choice.index());
 
-            List<StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall> toolCalls = choice.delta().getToolCalls();
+            List<StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall> toolCalls = choice.delta().toolCalls();
             assertEquals(1, toolCalls.size());
             StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall = toolCalls.get(0);
-            assertEquals(1, toolCall.getIndex());
-            assertEquals("tool_call_id", toolCall.getId());
-            assertEquals("example_function_name", toolCall.getFunction().getName());
-            assertEquals("example_arguments", toolCall.getFunction().getArguments());
-            assertEquals("function", toolCall.getType());
+            assertEquals(1, toolCall.index());
+            assertEquals("tool_call_id", toolCall.id());
+            assertEquals("example_function_name", toolCall.function().name());
+            assertEquals("example_arguments", toolCall.function().arguments());
+            assertEquals("function", toolCall.type());
         } catch (IOException e) {
             fail();
         }
@@ -143,40 +143,40 @@ public class OpenAiUnifiedStreamingProcessorTests extends ESTestCase {
                 .parse(parser);
 
             // Assertions to verify the parsed object
-            assertEquals("example_id", chunk.getId());
-            assertEquals("example_model", chunk.getModel());
-            assertEquals("chat.completion.chunk", chunk.getObject());
-            assertNull(chunk.getUsage());
+            assertEquals("example_id", chunk.id());
+            assertEquals("example_model", chunk.model());
+            assertEquals("chat.completion.chunk", chunk.object());
+            assertNull(chunk.usage());
 
-            List<StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice> choices = chunk.getChoices();
+            List<StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice> choices = chunk.choices();
             assertEquals(2, choices.size());
 
             // First choice assertions
             StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice firstChoice = choices.get(0);
-            assertNull(firstChoice.delta().getContent());
-            assertNull(firstChoice.delta().getRefusal());
-            assertEquals("assistant", firstChoice.delta().getRole());
-            assertTrue(firstChoice.delta().getToolCalls().isEmpty());
+            assertNull(firstChoice.delta().content());
+            assertNull(firstChoice.delta().refusal());
+            assertEquals("assistant", firstChoice.delta().role());
+            assertTrue(firstChoice.delta().toolCalls().isEmpty());
             assertNull(firstChoice.finishReason());
             assertEquals(0, firstChoice.index());
 
             // Second choice assertions
             StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice secondChoice = choices.get(1);
-            assertEquals("example_content", secondChoice.delta().getContent());
-            assertEquals("example_refusal", secondChoice.delta().getRefusal());
-            assertEquals("user", secondChoice.delta().getRole());
+            assertEquals("example_content", secondChoice.delta().content());
+            assertEquals("example_refusal", secondChoice.delta().refusal());
+            assertEquals("user", secondChoice.delta().role());
             assertEquals("stop", secondChoice.finishReason());
             assertEquals(1, secondChoice.index());
 
             List<StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall> toolCalls = secondChoice.delta()
-                .getToolCalls();
+                .toolCalls();
             assertEquals(1, toolCalls.size());
             StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall = toolCalls.get(0);
-            assertEquals(1, toolCall.getIndex());
-            assertNull(toolCall.getId());
-            assertEquals("example_function_name", toolCall.getFunction().getName());
-            assertNull(toolCall.getFunction().getArguments());
-            assertEquals("function", toolCall.getType());
+            assertEquals(1, toolCall.index());
+            assertNull(toolCall.id());
+            assertEquals("example_function_name", toolCall.function().name());
+            assertNull(toolCall.function().arguments());
+            assertEquals("function", toolCall.type());
         } catch (IOException e) {
             fail();
         }
@@ -221,31 +221,31 @@ public class OpenAiUnifiedStreamingProcessorTests extends ESTestCase {
                 .parse(parser);
 
             // Assertions to verify the parsed object
-            assertEquals(chatCompletionChunkId, chunk.getId());
-            assertEquals(chatCompletionChunkModel, chunk.getModel());
-            assertEquals("chat.completion.chunk", chunk.getObject());
-            assertNotNull(chunk.getUsage());
-            assertEquals(usageCompletionTokens, chunk.getUsage().completionTokens());
-            assertEquals(usagePromptTokens, chunk.getUsage().promptTokens());
-            assertEquals(usageTotalTokens, chunk.getUsage().totalTokens());
+            assertEquals(chatCompletionChunkId, chunk.id());
+            assertEquals(chatCompletionChunkModel, chunk.model());
+            assertEquals("chat.completion.chunk", chunk.object());
+            assertNotNull(chunk.usage());
+            assertEquals(usageCompletionTokens, chunk.usage().completionTokens());
+            assertEquals(usagePromptTokens, chunk.usage().promptTokens());
+            assertEquals(usageTotalTokens, chunk.usage().totalTokens());
 
-            List<StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice> choices = chunk.getChoices();
+            List<StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice> choices = chunk.choices();
             assertEquals(1, choices.size());
             StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice choice = choices.get(0);
-            assertEquals(choiceContent, choice.delta().getContent());
-            assertNull(choice.delta().getRefusal());
-            assertEquals(choiceRole, choice.delta().getRole());
+            assertEquals(choiceContent, choice.delta().content());
+            assertNull(choice.delta().refusal());
+            assertEquals(choiceRole, choice.delta().role());
             assertEquals(choiceFinishReason, choice.finishReason());
             assertEquals(choiceIndex, choice.index());
 
-            List<StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall> toolCalls = choice.delta().getToolCalls();
+            List<StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall> toolCalls = choice.delta().toolCalls();
             assertEquals(1, toolCalls.size());
             StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall = toolCalls.get(0);
-            assertEquals(toolCallIndex, toolCall.getIndex());
-            assertEquals(toolCallId, toolCall.getId());
-            assertEquals(toolCallFunctionName, toolCall.getFunction().getName());
-            assertEquals(toolCallFunctionArguments, toolCall.getFunction().getArguments());
-            assertEquals(toolCallType, toolCall.getType());
+            assertEquals(toolCallIndex, toolCall.index());
+            assertEquals(toolCallId, toolCall.id());
+            assertEquals(toolCallFunctionName, toolCall.function().name());
+            assertEquals(toolCallFunctionArguments, toolCall.function().arguments());
+            assertEquals(toolCallType, toolCall.type());
         }
     }
 
@@ -273,20 +273,20 @@ public class OpenAiUnifiedStreamingProcessorTests extends ESTestCase {
                 .parse(parser);
 
             // Assertions to verify the parsed object
-            assertEquals(chatCompletionChunkId, chunk.getId());
-            assertEquals(chatCompletionChunkModel, chunk.getModel());
-            assertEquals("chat.completion.chunk", chunk.getObject());
-            assertNull(chunk.getUsage());
+            assertEquals(chatCompletionChunkId, chunk.id());
+            assertEquals(chatCompletionChunkModel, chunk.model());
+            assertEquals("chat.completion.chunk", chunk.object());
+            assertNull(chunk.usage());
 
-            List<StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice> choices = chunk.getChoices();
+            List<StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice> choices = chunk.choices();
             assertEquals(1, choices.size());
             StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice choice = choices.get(0);
-            assertNull(choice.delta().getContent());
-            assertNull(choice.delta().getRefusal());
-            assertNull(choice.delta().getRole());
+            assertNull(choice.delta().content());
+            assertNull(choice.delta().refusal());
+            assertNull(choice.delta().role());
             assertNull(choice.finishReason());
             assertEquals(choiceIndex, choice.index());
-            assertTrue(choice.delta().getToolCalls().isEmpty());
+            assertTrue(choice.delta().toolCalls().isEmpty());
         }
     }