Explorar el Código

[ML] Stream Anthropic Completion (#114321) (#114499)

Enable chat completion streaming responses for Anthropic's server sent
events.

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
Pat Whelan hace 1 año
padre
commit
c6d59e8890
Se han modificado 12 ficheros con 423 adiciones y 13 borrados
  1. 5 0
      docs/changelog/114321.yaml
  2. 24 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/anthropic/AnthropicResponseHandler.java
  3. 125 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/anthropic/AnthropicStreamingProcessor.java
  4. 5 4
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicCompletionRequestManager.java
  5. 9 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/anthropic/AnthropicChatCompletionRequest.java
  6. 9 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/anthropic/AnthropicChatCompletionRequestEntity.java
  7. 6 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java
  8. 1 1
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/anthropic/AnthropicResponseHandlerTests.java
  9. 170 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/anthropic/AnthropicStreamingProcessorTests.java
  10. 4 2
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/anthropic/AnthropicChatCompletionRequestEntityTests.java
  11. 2 2
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/anthropic/AnthropicChatCompletionRequestTests.java
  12. 63 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java

+ 5 - 0
docs/changelog/114321.yaml

@@ -0,0 +1,5 @@
+pr: 114321
+summary: Stream Anthropic Completion
+area: Machine Learning
+type: enhancement
+issues: []

+ 24 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/anthropic/AnthropicResponseHandler.java

@@ -9,14 +9,20 @@ package org.elasticsearch.xpack.inference.external.anthropic;
 
 import org.apache.logging.log4j.Logger;
 import org.elasticsearch.common.Strings;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
 import org.elasticsearch.xpack.inference.external.http.HttpResult;
 import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
 import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
 import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
 import org.elasticsearch.xpack.inference.external.request.Request;
 import org.elasticsearch.xpack.inference.external.response.ErrorMessageResponseEntity;
+import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser;
+import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor;
 import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
 
+import java.util.concurrent.Flow;
+
 import static org.elasticsearch.xpack.inference.external.http.HttpUtils.checkForEmptyBody;
 import static org.elasticsearch.xpack.inference.external.http.retry.ResponseHandlerUtils.getFirstHeaderOrUnknown;
 
@@ -41,8 +47,11 @@ public class AnthropicResponseHandler extends BaseResponseHandler {
 
     static final String SERVER_BUSY = "Received an Anthropic server is temporarily overloaded status code";
 
-    public AnthropicResponseHandler(String requestType, ResponseParser parseFunction) {
+    private final boolean canHandleStreamingResponses;
+
+    public AnthropicResponseHandler(String requestType, ResponseParser parseFunction, boolean canHandleStreamingResponses) {
         super(requestType, parseFunction, ErrorMessageResponseEntity::fromResponse);
+        this.canHandleStreamingResponses = canHandleStreamingResponses;
     }
 
     @Override
@@ -52,6 +61,20 @@ public class AnthropicResponseHandler extends BaseResponseHandler {
         checkForEmptyBody(throttlerManager, logger, request, result);
     }
 
+    @Override
+    public boolean canHandleStreamingResponses() {
+        return canHandleStreamingResponses;
+    }
+
+    @Override
+    public InferenceServiceResults parseResult(Request request, Flow.Publisher<HttpResult> flow) {
+        var sseProcessor = new ServerSentEventProcessor(new ServerSentEventParser());
+        var anthropicProcessor = new AnthropicStreamingProcessor();
+        sseProcessor.subscribe(anthropicProcessor);
+        flow.subscribe(sseProcessor);
+        return new StreamingChatCompletionResults(anthropicProcessor);
+    }
+
     /**
      * Validates the status code throws an RetryException if not in the range [200, 300).
      *

+ 125 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/anthropic/AnthropicStreamingProcessor.java

@@ -0,0 +1,125 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.anthropic;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xcontent.XContentParserConfiguration;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
+import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
+import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
+import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField;
+
+import java.io.IOException;
+import java.util.ArrayDeque;
+import java.util.Deque;
+import java.util.Optional;
+
+import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
+import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken;
+import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField;
+
+public class AnthropicStreamingProcessor extends DelegatingProcessor<Deque<ServerSentEvent>, StreamingChatCompletionResults.Results> {
+    private static final Logger log = LogManager.getLogger(AnthropicStreamingProcessor.class);
+    private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in Anthropic chat completions response";
+
+    @Override
+    protected void next(Deque<ServerSentEvent> item) throws Exception {
+        if (item.isEmpty()) {
+            upstream().request(1);
+            return;
+        }
+
+        var results = new ArrayDeque<StreamingChatCompletionResults.Result>(item.size());
+        for (var event : item) {
+            if (event.name() == ServerSentEventField.DATA && event.hasValue()) {
+                try (var parser = parser(event.value())) {
+                    var eventType = eventType(parser);
+                    switch (eventType) {
+                        case "error" -> {
+                            onError(parseError(parser));
+                            return;
+                        }
+                        case "content_block_start" -> {
+                            parseStartBlock(parser).ifPresent(results::offer);
+                        }
+                        case "content_block_delta" -> {
+                            parseMessage(parser).ifPresent(results::offer);
+                        }
+                        case "message_start", "message_stop", "message_delta", "content_block_stop", "ping" -> {
+                            log.debug("Skipping event type [{}] for line [{}].", eventType, item);
+                        }
+                        default -> {
+                            // "handle unknown events gracefully" https://docs.anthropic.com/en/api/messages-streaming#other-events
+                            // we'll ignore unknown events
+                            log.debug("Unknown event type [{}] for line [{}].", eventType, item);
+                        }
+                    }
+                } catch (Exception e) {
+                    log.warn("Failed to parse line {}", event);
+                    throw e;
+                }
+            }
+        }
+
+        if (results.isEmpty()) {
+            upstream().request(1);
+        } else {
+            downstream().onNext(new StreamingChatCompletionResults.Results(results));
+        }
+    }
+
+    private Throwable parseError(XContentParser parser) throws IOException {
+        positionParserAtTokenAfterField(parser, "error", FAILED_TO_FIND_FIELD_TEMPLATE);
+        var type = parseString(parser, "type");
+        var message = parseString(parser, "message");
+        var statusCode = switch (type) {
+            case "invalid_request_error" -> RestStatus.BAD_REQUEST;
+            case "authentication_error" -> RestStatus.UNAUTHORIZED;
+            case "permission_error" -> RestStatus.FORBIDDEN;
+            case "not_found_error" -> RestStatus.NOT_FOUND;
+            case "request_too_large" -> RestStatus.REQUEST_ENTITY_TOO_LARGE;
+            case "rate_limit_error" -> RestStatus.TOO_MANY_REQUESTS;
+            default -> RestStatus.INTERNAL_SERVER_ERROR;
+        };
+        return new ElasticsearchStatusException(message, statusCode);
+    }
+
+    private Optional<StreamingChatCompletionResults.Result> parseStartBlock(XContentParser parser) throws IOException {
+        positionParserAtTokenAfterField(parser, "content_block", FAILED_TO_FIND_FIELD_TEMPLATE);
+        var text = parseString(parser, "text");
+        return text.isBlank() ? Optional.empty() : Optional.of(new StreamingChatCompletionResults.Result(text));
+    }
+
+    private Optional<StreamingChatCompletionResults.Result> parseMessage(XContentParser parser) throws IOException {
+        positionParserAtTokenAfterField(parser, "delta", FAILED_TO_FIND_FIELD_TEMPLATE);
+        var text = parseString(parser, "text");
+        return text.isBlank() ? Optional.empty() : Optional.of(new StreamingChatCompletionResults.Result(text));
+    }
+
+    private static XContentParser parser(String line) throws IOException {
+        return XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, line);
+    }
+
+    private static String eventType(XContentParser parser) throws IOException {
+        moveToFirstToken(parser);
+        ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
+        return parseString(parser, "type");
+    }
+
+    private static String parseString(XContentParser parser, String fieldName) throws IOException {
+        positionParserAtTokenAfterField(parser, fieldName, FAILED_TO_FIND_FIELD_TEMPLATE);
+        ensureExpectedToken(XContentParser.Token.VALUE_STRING, parser.currentToken(), parser);
+        return parser.text();
+    }
+}

+ 5 - 4
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicCompletionRequestManager.java

@@ -19,7 +19,6 @@ import org.elasticsearch.xpack.inference.external.request.anthropic.AnthropicCha
 import org.elasticsearch.xpack.inference.external.response.anthropic.AnthropicChatCompletionResponseEntity;
 import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionModel;
 
-import java.util.List;
 import java.util.Objects;
 import java.util.function.Supplier;
 
@@ -47,13 +46,15 @@ public class AnthropicCompletionRequestManager extends AnthropicRequestManager {
         Supplier<Boolean> hasRequestCompletedFunction,
         ActionListener<InferenceServiceResults> listener
     ) {
-        List<String> docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs();
-        AnthropicChatCompletionRequest request = new AnthropicChatCompletionRequest(docsInput, model);
+        var docsOnly = DocumentsOnlyInput.of(inferenceInputs);
+        var docsInput = docsOnly.getInputs();
+        var stream = docsOnly.stream();
+        AnthropicChatCompletionRequest request = new AnthropicChatCompletionRequest(docsInput, model, stream);
 
         execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
     }
 
     private static ResponseHandler createCompletionHandler() {
-        return new AnthropicResponseHandler("anthropic completions", AnthropicChatCompletionResponseEntity::fromResponse);
+        return new AnthropicResponseHandler("anthropic completions", AnthropicChatCompletionResponseEntity::fromResponse, true);
     }
 }

+ 9 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/anthropic/AnthropicChatCompletionRequest.java

@@ -29,11 +29,13 @@ public class AnthropicChatCompletionRequest implements Request {
     private final AnthropicAccount account;
     private final List<String> input;
     private final AnthropicChatCompletionModel model;
+    private final boolean stream;
 
-    public AnthropicChatCompletionRequest(List<String> input, AnthropicChatCompletionModel model) {
+    public AnthropicChatCompletionRequest(List<String> input, AnthropicChatCompletionModel model, boolean stream) {
         this.account = AnthropicAccount.of(model);
         this.input = Objects.requireNonNull(input);
         this.model = Objects.requireNonNull(model);
+        this.stream = stream;
     }
 
     @Override
@@ -41,7 +43,7 @@ public class AnthropicChatCompletionRequest implements Request {
         HttpPost httpPost = new HttpPost(account.uri());
 
         ByteArrayEntity byteEntity = new ByteArrayEntity(
-            Strings.toString(new AnthropicChatCompletionRequestEntity(input, model.getServiceSettings(), model.getTaskSettings()))
+            Strings.toString(new AnthropicChatCompletionRequestEntity(input, model.getServiceSettings(), model.getTaskSettings(), stream))
                 .getBytes(StandardCharsets.UTF_8)
         );
         httpPost.setEntity(byteEntity);
@@ -75,4 +77,9 @@ public class AnthropicChatCompletionRequest implements Request {
         return model.getInferenceEntityId();
     }
 
+    @Override
+    public boolean isStreaming() {
+        return stream;
+    }
+
 }

+ 9 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/anthropic/AnthropicChatCompletionRequestEntity.java

@@ -28,19 +28,23 @@ public class AnthropicChatCompletionRequestEntity implements ToXContentObject {
     private static final String TEMPERATURE_FIELD = "temperature";
     private static final String TOP_P_FIELD = "top_p";
     private static final String TOP_K_FIELD = "top_k";
+    private static final String STREAM = "stream";
 
     private final List<String> messages;
     private final AnthropicChatCompletionServiceSettings serviceSettings;
     private final AnthropicChatCompletionTaskSettings taskSettings;
+    private final boolean stream;
 
     public AnthropicChatCompletionRequestEntity(
         List<String> messages,
         AnthropicChatCompletionServiceSettings serviceSettings,
-        AnthropicChatCompletionTaskSettings taskSettings
+        AnthropicChatCompletionTaskSettings taskSettings,
+        boolean stream
     ) {
         this.messages = Objects.requireNonNull(messages);
         this.serviceSettings = Objects.requireNonNull(serviceSettings);
         this.taskSettings = Objects.requireNonNull(taskSettings);
+        this.stream = stream;
     }
 
     @Override
@@ -77,6 +81,10 @@ public class AnthropicChatCompletionRequestEntity implements ToXContentObject {
             builder.field(TOP_K_FIELD, taskSettings.topK());
         }
 
+        if (stream) {
+            builder.field(STREAM, true);
+        }
+
         builder.endObject();
 
         return builder;

+ 6 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java

@@ -33,6 +33,7 @@ import org.elasticsearch.xpack.inference.services.anthropic.completion.Anthropic
 
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
@@ -199,4 +200,9 @@ public class AnthropicService extends SenderService {
     public TransportVersion getMinimalSupportedVersion() {
         return TransportVersions.ML_ANTHROPIC_INTEGRATION_ADDED;
     }
+
+    @Override
+    public Set<TaskType> supportedStreamingTasks() {
+        return COMPLETION_ONLY;
+    }
 }

+ 1 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/anthropic/AnthropicResponseHandlerTests.java

@@ -160,7 +160,7 @@ public class AnthropicResponseHandlerTests extends ESTestCase {
         var mockRequest = mock(Request.class);
         when(mockRequest.getInferenceEntityId()).thenReturn(inferenceEntityId);
         var httpResult = new HttpResult(httpResponse, new byte[] {});
-        var handler = new AnthropicResponseHandler("", (request, result) -> null);
+        var handler = new AnthropicResponseHandler("", (request, result) -> null, false);
 
         handler.checkForFailureStatusCode(mockRequest, httpResult);
     }

+ 170 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/anthropic/AnthropicStreamingProcessorTests.java

@@ -0,0 +1,170 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.external.anthropic;
+
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.common.xcontent.ChunkedToXContent;
+import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
+import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent;
+import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField;
+import org.hamcrest.Matcher;
+import org.hamcrest.Matchers;
+
+import java.util.ArrayDeque;
+import java.util.Arrays;
+import java.util.Deque;
+import java.util.Map;
+import java.util.concurrent.Flow;
+import java.util.concurrent.atomic.AtomicReference;
+
+import static org.elasticsearch.xpack.inference.common.DelegatingProcessorTests.onNext;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.isA;
+import static org.hamcrest.Matchers.notNullValue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+
+public class AnthropicStreamingProcessorTests extends ESTestCase {
+
+    public void testParseSuccess() {
+        var item = events("""
+            {
+                "type": "message_start",
+                "message": {
+                    "id": "a cool id",
+                    "type": "message",
+                    "role": "assistant",
+                    "content": [],
+                    "model": "claude, probably",
+                    "stop_reason": null,
+                    "stop_sequence": null,
+                    "usage": {
+                        "input_tokens": 25,
+                        "output_tokens": 1
+                    }
+                }
+            }""", """
+            {
+                "type": "content_block_start",
+                "index": 0,
+                "content_block": {
+                    "type": "text",
+                    "text": ""
+                }
+            }""", """
+            {"type": "ping"}""", """
+            {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello"}}""", """
+            {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": ", World"}}""", """
+            {"type": "content_block_stop", "index": 0}""", """
+            {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence":null}, "usage": {"output_tokens": 4}}""", """
+            {"type": "message_stop"}""");
+
+        var response = onNext(new AnthropicStreamingProcessor(), item);
+        assertThat(response.results().size(), equalTo(2));
+        assertThat(response.results(), containsResults("Hello", ", World"));
+    }
+
+    public void testParseWithError() {
+        var item = events("""
+            {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello"}}""", """
+            {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": ", World"}}""", """
+            {"type": "error", "error": {"type": "rate_limit_error", "message": "You're going too fast, ahhhh!"}}""");
+
+        var statusException = onError(item);
+        assertThat(statusException.status(), equalTo(RestStatus.TOO_MANY_REQUESTS));
+        assertThat(statusException.getMessage(), equalTo("You're going too fast, ahhhh!"));
+    }
+
+    public void testErrors() {
+        var errors = Map.of("""
+            {"type": "error", "error": {"type": "invalid_request_error", "message": "blah"}}""", RestStatus.BAD_REQUEST, """
+            {"type": "error", "error": {"type": "authentication_error", "message": "blah"}}""", RestStatus.UNAUTHORIZED, """
+            {"type": "error", "error": {"type": "permission_error", "message": "blah"}}""", RestStatus.FORBIDDEN, """
+            {"type": "error", "error": {"type": "not_found_error", "message": "blah"}}""", RestStatus.NOT_FOUND, """
+            {"type": "error", "error": {"type": "request_too_large", "message": "blah"}}""", RestStatus.REQUEST_ENTITY_TOO_LARGE, """
+            {"type": "error", "error": {"type": "rate_limit_error", "message": "blah"}}""", RestStatus.TOO_MANY_REQUESTS, """
+            {"type": "error", "error": {"type": "overloaded_error", "message": "blah"}}""", RestStatus.INTERNAL_SERVER_ERROR, """
+            {"type": "error", "error": {"type": "some_cool_new_error", "message": "blah"}}""", RestStatus.INTERNAL_SERVER_ERROR);
+        errors.forEach((json, expectedStatus) -> { assertThat(onError(events(json)).status(), equalTo(expectedStatus)); });
+    }
+
+    public void testEmptyResultsRequestsMoreData() throws Exception {
+        var emptyDeque = new ArrayDeque<ServerSentEvent>();
+
+        var processor = new AnthropicStreamingProcessor();
+
+        Flow.Subscriber<ChunkedToXContent> downstream = mock();
+        processor.subscribe(downstream);
+
+        Flow.Subscription upstream = mock();
+        processor.onSubscribe(upstream);
+
+        processor.next(emptyDeque);
+
+        verify(upstream, times(1)).request(1);
+        verify(downstream, times(0)).onNext(any());
+    }
+
+    public void testDroppedEventsRequestsMoreData() throws Exception {
+        var item = events("""
+            {"type": "ping"}""");
+
+        var processor = new AnthropicStreamingProcessor();
+
+        Flow.Subscriber<ChunkedToXContent> downstream = mock();
+        processor.subscribe(downstream);
+
+        Flow.Subscription upstream = mock();
+        processor.onSubscribe(upstream);
+
+        processor.next(item);
+
+        verify(upstream, times(1)).request(1);
+        verify(downstream, times(0)).onNext(any());
+    }
+
+    private Deque<ServerSentEvent> events(String... data) {
+        var item = new ArrayDeque<ServerSentEvent>();
+        Arrays.stream(data).map(datum -> new ServerSentEvent(ServerSentEventField.DATA, datum)).forEach(item::offer);
+        return item;
+    }
+
+    @SuppressWarnings("unchecked")
+    private Matcher<Iterable<? extends StreamingChatCompletionResults.Result>> containsResults(String... results) {
+        Matcher<StreamingChatCompletionResults.Result>[] resultMatcher = Arrays.stream(results)
+            .map(StreamingChatCompletionResults.Result::new)
+            .map(Matchers::equalTo)
+            .toArray(Matcher[]::new);
+        return Matchers.contains(resultMatcher);
+    }
+
+    private static ElasticsearchStatusException onError(Deque<ServerSentEvent> item) {
+        var processor = new AnthropicStreamingProcessor();
+        var response = new AtomicReference<Throwable>();
+
+        Flow.Subscription upstream = mock();
+        processor.onSubscribe(upstream);
+
+        Flow.Subscriber<ChunkedToXContent> downstream = mock();
+        doAnswer(ans -> {
+            response.set(ans.getArgument(0));
+            return null;
+        }).when(downstream).onError(any());
+        processor.subscribe(downstream);
+
+        processor.onNext(item);
+        assertThat("Error from processor was null", response.get(), notNullValue());
+        assertThat(response.get(), isA(ElasticsearchStatusException.class));
+        return (ElasticsearchStatusException) response.get();
+    }
+}

+ 4 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/anthropic/AnthropicChatCompletionRequestEntityTests.java

@@ -26,7 +26,8 @@ public class AnthropicChatCompletionRequestEntityTests extends ESTestCase {
         var entity = new AnthropicChatCompletionRequestEntity(
             List.of("abc"),
             new AnthropicChatCompletionServiceSettings("model", null),
-            new AnthropicChatCompletionTaskSettings(1, -1.0, 1.2, 3)
+            new AnthropicChatCompletionTaskSettings(1, -1.0, 1.2, 3),
+            false
         );
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
@@ -42,7 +43,8 @@ public class AnthropicChatCompletionRequestEntityTests extends ESTestCase {
         var entity = new AnthropicChatCompletionRequestEntity(
             List.of("abc"),
             new AnthropicChatCompletionServiceSettings("model", null),
-            new AnthropicChatCompletionTaskSettings(1, null, 1.2, 3)
+            new AnthropicChatCompletionTaskSettings(1, null, 1.2, 3),
+            false
         );
 
         XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);

+ 2 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/anthropic/AnthropicChatCompletionRequestTests.java

@@ -95,12 +95,12 @@ public class AnthropicChatCompletionRequestTests extends ESTestCase {
 
     public static AnthropicChatCompletionRequest createRequest(String apiKey, String input, String model, int maxTokens) {
         var chatCompletionModel = AnthropicChatCompletionModelTests.createChatCompletionModel(apiKey, model, maxTokens);
-        return new AnthropicChatCompletionRequest(List.of(input), chatCompletionModel);
+        return new AnthropicChatCompletionRequest(List.of(input), chatCompletionModel, false);
     }
 
     public static AnthropicChatCompletionRequest createRequest(String url, String apiKey, String input, String model, int maxTokens) {
         var chatCompletionModel = AnthropicChatCompletionModelTests.createChatCompletionModel(url, apiKey, model, maxTokens);
-        return new AnthropicChatCompletionRequest(List.of(input), chatCompletionModel);
+        return new AnthropicChatCompletionRequest(List.of(input), chatCompletionModel, false);
     }
 
     private static String buildAnthropicUri() {

+ 63 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java

@@ -17,6 +17,7 @@ import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.inference.InputType;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.http.MockResponse;
 import org.elasticsearch.test.http.MockWebServer;
@@ -29,6 +30,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderT
 import org.elasticsearch.xpack.inference.external.http.sender.Sender;
 import org.elasticsearch.xpack.inference.external.request.anthropic.AnthropicRequestUtils;
 import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion;
 import org.elasticsearch.xpack.inference.services.ServiceFields;
 import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionModel;
 import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionModelTests;
@@ -530,6 +532,67 @@ public class AnthropicServiceTests extends ESTestCase {
         }
     }
 
+    public void testInfer_StreamRequest() throws Exception {
+        String responseJson = """
+            data: {"type": "message_start", "message": {"model": "claude, probably"}}
+            data: {"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}}
+            data: {"type": "ping"}
+            data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello"}}
+            data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": ", World"}}
+            data: {"type": "content_block_stop", "index": 0}
+            data: {"type": "message_delta", "delta": {"stop_reason": "end_turn", "stop_sequence":null}, "usage": {"output_tokens": 4}}
+            data: {"type": "message_stop"}
+
+            """;
+        webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+        var result = streamChatCompletion();
+
+        InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent("""
+            {"completion":[{"delta":"Hello"},{"delta":", World"}]}""");
+    }
+
+    private InferenceServiceResults streamChatCompletion() throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+        try (var service = new AnthropicService(senderFactory, createWithEmptySettings(threadPool))) {
+            var model = AnthropicChatCompletionModelTests.createChatCompletionModel(
+                getUrl(webServer),
+                "secret",
+                "model",
+                Integer.MAX_VALUE
+            );
+            var listener = new PlainActionFuture<InferenceServiceResults>();
+            service.infer(
+                model,
+                null,
+                List.of("abc"),
+                true,
+                new HashMap<>(),
+                InputType.INGEST,
+                InferenceAction.Request.DEFAULT_TIMEOUT,
+                listener
+            );
+
+            return listener.actionGet(TIMEOUT);
+        }
+    }
+
+    public void testInfer_StreamRequest_ErrorResponse() throws Exception {
+        String responseJson = """
+            data: {"type": "error", "error": {"type": "request_too_large", "message": "blah"}}
+
+            """;
+        webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+        var result = streamChatCompletion();
+
+        InferenceEventsAssertion.assertThat(result)
+            .hasFinishedStream()
+            .hasNoEvents()
+            .hasErrorWithStatusCode(RestStatus.REQUEST_ENTITY_TOO_LARGE.getStatus())
+            .hasErrorContaining("blah");
+    }
+
     private AnthropicService createServiceWithMockSender() {
         return new AnthropicService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool));
     }