Browse Source

[ML] Handle Errors and pre-streaming exceptions (#115868) (#115967)

If we fail to establish a connection to bedrock, the error is returned in the client's CompletableFuture. We will forward it to the listener via the stream processor. Any Errors are thrown on another thread.
Pat Whelan 11 months ago
parent
commit
ff464185ae

+ 5 - 0
docs/changelog/115868.yaml

@@ -0,0 +1,5 @@
+pr: 115868
+summary: Forward bedrock connection errors to user
+area: Machine Learning
+type: bug
+issues: []

+ 6 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClient.java

@@ -23,6 +23,7 @@ import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest;
 import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse;
 import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse;
 
 
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.ExceptionsHelper;
 import org.elasticsearch.SpecialPermission;
 import org.elasticsearch.SpecialPermission;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.common.xcontent.ChunkedToXContent;
 import org.elasticsearch.common.xcontent.ChunkedToXContent;
@@ -93,11 +94,15 @@ public class AmazonBedrockInferenceClient extends AmazonBedrockBaseClient {
         internalClient.converseStream(
         internalClient.converseStream(
             request,
             request,
             ConverseStreamResponseHandler.builder().subscriber(() -> FlowAdapters.toSubscriber(awsResponseProcessor)).build()
             ConverseStreamResponseHandler.builder().subscriber(() -> FlowAdapters.toSubscriber(awsResponseProcessor)).build()
-        );
+        ).exceptionally(e -> {
+            awsResponseProcessor.onError(e);
+            return null; // Void
+        });
         return awsResponseProcessor;
         return awsResponseProcessor;
     }
     }
 
 
     private void onFailure(ActionListener<?> listener, Throwable t, String method) {
     private void onFailure(ActionListener<?> listener, Throwable t, String method) {
+        ExceptionsHelper.maybeDieOnAnotherThread(t);
         var unwrappedException = t;
         var unwrappedException = t;
         if (t instanceof CompletionException || t instanceof ExecutionException) {
         if (t instanceof CompletionException || t instanceof ExecutionException) {
             unwrappedException = t.getCause() != null ? t.getCause() : t;
             unwrappedException = t.getCause() != null ? t.getCause() : t;

+ 3 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockStreamingChatProcessor.java

@@ -12,6 +12,7 @@ import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput
 import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler;
 import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler;
 
 
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.ElasticsearchException;
+import org.elasticsearch.ExceptionsHelper;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
 import org.elasticsearch.core.Strings;
 import org.elasticsearch.core.Strings;
 import org.elasticsearch.logging.LogManager;
 import org.elasticsearch.logging.LogManager;
@@ -89,6 +90,7 @@ class AmazonBedrockStreamingChatProcessor implements Flow.Processor<ConverseStre
 
 
     @Override
     @Override
     public void onError(Throwable amazonBedrockRuntimeException) {
     public void onError(Throwable amazonBedrockRuntimeException) {
+        ExceptionsHelper.maybeDieOnAnotherThread(amazonBedrockRuntimeException);
         error.set(
         error.set(
             new ElasticsearchException(
             new ElasticsearchException(
                 Strings.format("AmazonBedrock StreamingChatProcessor failure: [%s]", amazonBedrockRuntimeException.getMessage()),
                 Strings.format("AmazonBedrock StreamingChatProcessor failure: [%s]", amazonBedrockRuntimeException.getMessage()),
@@ -96,7 +98,7 @@ class AmazonBedrockStreamingChatProcessor implements Flow.Processor<ConverseStre
             )
             )
         );
         );
         if (isDone.compareAndSet(false, true) && checkAndResetDemand() && onErrorCalled.compareAndSet(false, true)) {
         if (isDone.compareAndSet(false, true) && checkAndResetDemand() && onErrorCalled.compareAndSet(false, true)) {
-            downstream.onError(error.get());
+            runOnUtilityThreadPool(() -> downstream.onError(amazonBedrockRuntimeException));
         }
         }
     }
     }
 
 

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/ServerSentEventsRestActionListener.java

@@ -224,7 +224,7 @@ public class ServerSentEventsRestActionListener implements ActionListener<Infere
         @Override
         @Override
         public void onError(Throwable throwable) {
         public void onError(Throwable throwable) {
             if (isLastPart.compareAndSet(false, true)) {
             if (isLastPart.compareAndSet(false, true)) {
-                logger.error("A failure occurred in ElasticSearch while streaming the response.", throwable);
+                logger.warn("A failure occurred in ElasticSearch while streaming the response.", throwable);
                 nextBodyPartListener().onResponse(new ServerSentEventResponseBodyPart(ServerSentEvents.ERROR, errorChunk(throwable)));
                 nextBodyPartListener().onResponse(new ServerSentEventResponseBodyPart(ServerSentEvents.ERROR, errorChunk(throwable)));
             }
             }
         }
         }