Browse Source

Removing error object check (#132502)

Jonathan Buttner 2 months ago
parent
commit
de4245cd57
16 changed files with 35 additions and 130 deletions
  1. 1 26
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java
  2. 0 13
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ChatCompletionErrorResponseHandler.java
  3. 1 4
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ResponseHandler.java
  4. 2 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java
  5. 2 7
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/AmazonBedrockResponseHandler.java
  6. 2 7
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/response/AzureMistralOpenAiExternalResponseHandler.java
  7. 0 5
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java
  8. 2 7
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/AlwaysRetryingResponseHandler.java
  9. 14 36
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandlerTests.java
  10. 5 11
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSenderTests.java
  11. 1 2
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/completion/Ai21ChatCompletionResponseHandlerTests.java
  12. 1 2
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandlerTests.java
  13. 1 2
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandlerTests.java
  14. 1 2
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionResponseHandlerTests.java
  15. 1 2
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandlerTests.java
  16. 1 2
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandlerTests.java

+ 1 - 26
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java

@@ -76,38 +76,13 @@ public abstract class BaseResponseHandler implements ResponseHandler {
     }
 
     @Override
-    public void validateResponse(
-        ThrottlerManager throttlerManager,
-        Logger logger,
-        Request request,
-        HttpResult result,
-        boolean checkForErrorObject
-    ) {
+    public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result) {
         checkForFailureStatusCode(request, result);
         checkForEmptyBody(throttlerManager, logger, request, result);
-
-        if (checkForErrorObject) {
-            // When the response is streamed the status code could be 200 but the error object will be set
-            // so we need to check for that specifically
-            checkForErrorObject(request, result);
-        }
     }
 
     protected abstract void checkForFailureStatusCode(Request request, HttpResult result);
 
-    protected void checkForErrorObject(Request request, HttpResult result) {
-        var errorEntity = errorParseFunction.apply(result);
-
-        if (errorEntity.errorStructureFound()) {
-            // We don't really know what happened because the status code was 200 so we'll return a failure and let the
-            // client retry if necessary
-            // If we did want to retry here, we'll need to determine if this was a streaming request, if it was
-            // we shouldn't retry because that would replay the entire streaming request and the client would get
-            // duplicate chunks back
-            throw new RetryException(false, buildError(SERVER_ERROR_OBJECT, request, result, errorEntity));
-        }
-    }
-
     protected Exception buildError(String message, Request request, HttpResult result) {
         var errorEntityMsg = errorParseFunction.apply(result);
         return buildError(message, request, result, errorEntityMsg);

+ 0 - 13
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ChatCompletionErrorResponseHandler.java

@@ -28,19 +28,6 @@ public class ChatCompletionErrorResponseHandler {
         this.unifiedChatCompletionErrorParser = Objects.requireNonNull(errorParser);
     }
 
-    public void checkForErrorObject(Request request, HttpResult result) {
-        var errorEntity = unifiedChatCompletionErrorParser.parse(result);
-
-        if (errorEntity.errorStructureFound()) {
-            // We don't really know what happened because the status code was 200 so we'll return a failure and let the
-            // client retry if necessary
-            // If we did want to retry here, we'll need to determine if this was a streaming request, if it was
-            // we shouldn't retry because that would replay the entire streaming request and the client would get
-            // duplicate chunks back
-            throw new RetryException(false, buildChatCompletionErrorInternal(SERVER_ERROR_OBJECT, request, result, errorEntity));
-        }
-    }
-
     public UnifiedChatCompletionException buildChatCompletionError(String message, Request request, HttpResult result) {
         var errorResponse = unifiedChatCompletionErrorParser.parse(result);
         return buildChatCompletionErrorInternal(message, request, result, errorResponse);

+ 1 - 4
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ResponseHandler.java

@@ -29,12 +29,9 @@ public interface ResponseHandler {
      * @param logger the logger to use for logging
      * @param request the original request
      * @param result the response from the server
-     * @param checkForErrorObject if true, the validation function should check for the presence of an error object even if the status code
-     *                            indicates a success
      * @throws RetryException if the response is invalid
      */
-    void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result, boolean checkForErrorObject)
-        throws RetryException;
+    void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result) throws RetryException;
 
     /**
      * A method for parsing the response from the server.

+ 2 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java

@@ -121,7 +121,7 @@ public class RetryingHttpSender implements RequestSender {
                         } else {
                             r.readFullResponse(l.delegateFailureAndWrap((ll, httpResult) -> {
                                 try {
-                                    responseHandler.validateResponse(throttlerManager, logger, request, httpResult, true);
+                                    responseHandler.validateResponse(throttlerManager, logger, request, httpResult);
                                     InferenceServiceResults inferenceResults = responseHandler.parseResult(request, httpResult);
                                     ll.onResponse(inferenceResults);
                                 } catch (Exception e) {
@@ -134,7 +134,7 @@ public class RetryingHttpSender implements RequestSender {
                 } else {
                     httpClient.send(request.createHttpRequest(), context, retryableListener.delegateFailure((l, r) -> {
                         try {
-                            responseHandler.validateResponse(throttlerManager, logger, request, r, false);
+                            responseHandler.validateResponse(throttlerManager, logger, request, r);
                             InferenceServiceResults inferenceResults = responseHandler.parseResult(request, r);
 
                             l.onResponse(inferenceResults);

+ 2 - 7
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/AmazonBedrockResponseHandler.java

@@ -22,13 +22,8 @@ public abstract class AmazonBedrockResponseHandler implements ResponseHandler {
     }
 
     @Override
-    public final void validateResponse(
-        ThrottlerManager throttlerManager,
-        Logger logger,
-        Request request,
-        HttpResult result,
-        boolean checkForErrorObject
-    ) throws RetryException {
+    public final void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result)
+        throws RetryException {
         // do nothing as the AWS SDK will take care of validation for us
     }
 }

+ 2 - 7
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/response/AzureMistralOpenAiExternalResponseHandler.java

@@ -63,13 +63,8 @@ public class AzureMistralOpenAiExternalResponseHandler extends BaseResponseHandl
     }
 
     @Override
-    public void validateResponse(
-        ThrottlerManager throttlerManager,
-        Logger logger,
-        Request request,
-        HttpResult result,
-        boolean checkForErrorObject
-    ) throws RetryException {
+    public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result)
+        throws RetryException {
         checkForFailureStatusCode(request, result);
         checkForEmptyBody(throttlerManager, logger, request, result);
     }

+ 0 - 5
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandler.java

@@ -65,11 +65,6 @@ public class GoogleVertexAiUnifiedChatCompletionResponseHandler extends GoogleVe
         return chatCompletionErrorResponseHandler.buildChatCompletionError(message, request, result);
     }
 
-    @Override
-    protected void checkForErrorObject(Request request, HttpResult result) {
-        chatCompletionErrorResponseHandler.checkForErrorObject(request, result);
-    }
-
     private static class GoogleVertexAiErrorParser implements UnifiedChatCompletionErrorParser {
 
         @Override

+ 2 - 7
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/AlwaysRetryingResponseHandler.java

@@ -36,13 +36,8 @@ public class AlwaysRetryingResponseHandler implements ResponseHandler {
     }
 
     @Override
-    public void validateResponse(
-        ThrottlerManager throttlerManager,
-        Logger logger,
-        Request request,
-        HttpResult result,
-        boolean checkForErrorObject
-    ) throws RetryException {
+    public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result)
+        throws RetryException {
         try {
             checkForFailureStatusCode(throttlerManager, logger, request, result);
             checkForEmptyBody(throttlerManager, logger, request, result);

+ 14 - 36
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandlerTests.java

@@ -59,12 +59,11 @@ public class BaseResponseHandlerTests extends ESTestCase {
             mock(ThrottlerManager.class),
             mock(Logger.class),
             request,
-            new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8)),
-            true
+            new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8))
         );
     }
 
-    public void testValidateResponse_ThrowsErrorWhenMalformedErrorObjectExists() {
+    public void testValidateResponse_DoesNotThrowError_WhenStatus200_AndMalformedErrorObject() {
         var handler = getBaseResponseHandler();
 
         String responseJson = """
@@ -80,25 +79,15 @@ public class BaseResponseHandlerTests extends ESTestCase {
         var request = mock(Request.class);
         when(request.getInferenceEntityId()).thenReturn("abc");
 
-        var exception = expectThrows(
-            RetryException.class,
-            () -> handler.validateResponse(
-                mock(ThrottlerManager.class),
-                mock(Logger.class),
-                request,
-                new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8)),
-                true
-            )
-        );
-
-        assertFalse(exception.shouldRetry());
-        assertThat(
-            exception.getCause().getMessage(),
-            is("Received an error response for request from inference entity id [abc] status [200]")
+        handler.validateResponse(
+            mock(ThrottlerManager.class),
+            mock(Logger.class),
+            request,
+            new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8))
         );
     }
 
-    public void testValidateResponse_ThrowsErrorWhenWellFormedErrorObjectExists() {
+    public void testValidateResponse_DoesNotThrow_WhenStatus200_AndWellFormedErrorObjectExists() {
         var handler = getBaseResponseHandler();
 
         String responseJson = """
@@ -115,21 +104,11 @@ public class BaseResponseHandlerTests extends ESTestCase {
         var request = mock(Request.class);
         when(request.getInferenceEntityId()).thenReturn("abc");
 
-        var exception = expectThrows(
-            RetryException.class,
-            () -> handler.validateResponse(
-                mock(ThrottlerManager.class),
-                mock(Logger.class),
-                request,
-                new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8)),
-                true
-            )
-        );
-
-        assertFalse(exception.shouldRetry());
-        assertThat(
-            exception.getCause().getMessage(),
-            is("Received an error response for request from inference entity id [abc] status [200]. Error message: [a message]")
+        handler.validateResponse(
+            mock(ThrottlerManager.class),
+            mock(Logger.class),
+            request,
+            new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8))
         );
     }
 
@@ -154,8 +133,7 @@ public class BaseResponseHandlerTests extends ESTestCase {
             mock(ThrottlerManager.class),
             mock(Logger.class),
             request,
-            new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8)),
-            false
+            new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8))
         );
     }
 

+ 5 - 11
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSenderTests.java

@@ -42,7 +42,6 @@ import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.sameInstance;
 import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.anyBoolean;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.mock;
@@ -77,7 +76,7 @@ public class RetryingHttpSenderTests extends ESTestCase {
         Answer<InferenceServiceResults> answer = (invocation) -> inferenceResults;
 
         var handler = mock(ResponseHandler.class);
-        doThrow(new RetryException(true, "failed")).doNothing().when(handler).validateResponse(any(), any(), any(), any(), anyBoolean());
+        doThrow(new RetryException(true, "failed")).doNothing().when(handler).validateResponse(any(), any(), any(), any());
         // Mockito.thenReturn() does not compile when returning a
         // bounded wild card list, thenAnswer must be used instead.
         when(handler.parseResult(any(Request.class), any(HttpResult.class))).thenAnswer(answer);
@@ -352,7 +351,7 @@ public class RetryingHttpSenderTests extends ESTestCase {
         var handler = mock(ResponseHandler.class);
         doThrow(new RetryException(true, "failed")).doThrow(new IllegalStateException("failed again"))
             .when(handler)
-            .validateResponse(any(), any(), any(), any(), anyBoolean());
+            .validateResponse(any(), any(), any(), any());
         when(handler.parseResult(any(Request.class), any(HttpResult.class))).thenAnswer(answer);
 
         var retrier = createRetrier(sender);
@@ -389,7 +388,7 @@ public class RetryingHttpSenderTests extends ESTestCase {
         var handler = mock(ResponseHandler.class);
         doThrow(new RetryException(true, "failed")).doThrow(new RetryException(false, "failed again"))
             .when(handler)
-            .validateResponse(any(), any(), any(), any(), anyBoolean());
+            .validateResponse(any(), any(), any(), any());
         when(handler.parseResult(any(Request.class), any(HttpResult.class))).thenAnswer(answer);
 
         var retrier = createRetrier(httpClient);
@@ -702,13 +701,8 @@ public class RetryingHttpSenderTests extends ESTestCase {
         // testing failed requests
         return new ResponseHandler() {
             @Override
-            public void validateResponse(
-                ThrottlerManager throttlerManager,
-                Logger logger,
-                Request request,
-                HttpResult result,
-                boolean checkForErrorObject
-            ) throws RetryException {
+            public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result)
+                throws RetryException {
                 throw new RetryException(true, new IOException("response handler validate failed as designed"));
             }
 

+ 1 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ai21/completion/Ai21ChatCompletionResponseHandlerTests.java

@@ -112,8 +112,7 @@ public class Ai21ChatCompletionResponseHandlerTests extends ESTestCase {
                 mock(),
                 mock(),
                 mockRequest(),
-                new HttpResult(mockErrorResponse(statusCode), responseJson.getBytes(StandardCharsets.UTF_8)),
-                true
+                new HttpResult(mockErrorResponse(statusCode), responseJson.getBytes(StandardCharsets.UTF_8))
             )
         );
     }

+ 1 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiUnifiedChatCompletionResponseHandlerTests.java

@@ -120,8 +120,7 @@ public class GoogleVertexAiUnifiedChatCompletionResponseHandlerTests extends EST
                 mock(),
                 mock(),
                 mockRequest(),
-                new HttpResult(mockHttpResponse(500), responseJson.getBytes(StandardCharsets.UTF_8)),
-                true
+                new HttpResult(mockHttpResponse(500), responseJson.getBytes(StandardCharsets.UTF_8))
             )
         );
     }

+ 1 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceChatCompletionResponseHandlerTests.java

@@ -92,8 +92,7 @@ public class HuggingFaceChatCompletionResponseHandlerTests extends ESTestCase {
                 mock(),
                 mock(),
                 mockRequest(),
-                new HttpResult(mock500Response(), responseJson.getBytes(StandardCharsets.UTF_8)),
-                true
+                new HttpResult(mock500Response(), responseJson.getBytes(StandardCharsets.UTF_8))
             )
         );
     }

+ 1 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/llama/completion/LlamaChatCompletionResponseHandlerTests.java

@@ -123,8 +123,7 @@ public class LlamaChatCompletionResponseHandlerTests extends ESTestCase {
                 mock(),
                 mock(),
                 mockRequest(),
-                new HttpResult(mockErrorResponse(statusCode), responseJson.getBytes(StandardCharsets.UTF_8)),
-                true
+                new HttpResult(mockErrorResponse(statusCode), responseJson.getBytes(StandardCharsets.UTF_8))
             )
         );
     }

+ 1 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralUnifiedChatCompletionResponseHandlerTests.java

@@ -116,8 +116,7 @@ public class MistralUnifiedChatCompletionResponseHandlerTests extends ESTestCase
                 mock(),
                 mock(),
                 mockRequest(),
-                new HttpResult(mockErrorResponse(statusCode), responseJson.getBytes(StandardCharsets.UTF_8)),
-                true
+                new HttpResult(mockErrorResponse(statusCode), responseJson.getBytes(StandardCharsets.UTF_8))
             )
         );
     }

+ 1 - 2
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandlerTests.java

@@ -96,8 +96,7 @@ public class OpenAiUnifiedChatCompletionResponseHandlerTests extends ESTestCase
                 mock(),
                 mock(),
                 mockRequest(),
-                new HttpResult(mock500Response(), responseJson.getBytes(StandardCharsets.UTF_8)),
-                true
+                new HttpResult(mock500Response(), responseJson.getBytes(StandardCharsets.UTF_8))
             )
         );
     }