Jelajahi Sumber

[ML] Fixes processing chunked results in AWS Bedrock service (#110592)

Fixes error using the Amazon Bedrock service with a large 
input that was chunked.
David Kyle 1 tahun lalu
induk
melakukan
b01949c6aa

+ 1 - 23
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java

@@ -23,10 +23,6 @@ import org.elasticsearch.inference.ModelConfigurations;
 import org.elasticsearch.inference.ModelSecrets;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.rest.RestStatus;
-import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
-import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults;
-import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
-import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
 import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
 import org.elasticsearch.xpack.inference.external.action.amazonbedrock.AmazonBedrockActionCreator;
 import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockRequestSender;
@@ -47,7 +43,6 @@ import java.util.Map;
 import java.util.Set;
 
 import static org.elasticsearch.TransportVersions.ML_INFERENCE_AMAZON_BEDROCK_ADDED;
-import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
@@ -115,10 +110,6 @@ public class AmazonBedrockService extends SenderService {
         TimeValue timeout,
         ActionListener<List<ChunkedInferenceServiceResults>> listener
     ) {
-        ActionListener<InferenceServiceResults> inferListener = listener.delegateFailureAndWrap(
-            (delegate, response) -> delegate.onResponse(translateToChunkedResults(input, response))
-        );
-
         var actionCreator = new AmazonBedrockActionCreator(amazonBedrockSender, this.getServiceComponents(), timeout);
         if (model instanceof AmazonBedrockModel baseAmazonBedrockModel) {
             var maxBatchSize = getEmbeddingsMaxBatchSize(baseAmazonBedrockModel.provider());
@@ -126,26 +117,13 @@ public class AmazonBedrockService extends SenderService {
                 .batchRequestsWithListeners(listener);
             for (var request : batchedRequests) {
                 var action = baseAmazonBedrockModel.accept(actionCreator, taskSettings);
-                action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, inferListener);
+                action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());
             }
         } else {
             listener.onFailure(createInvalidModelException(model));
         }
     }
 
-    private static List<ChunkedInferenceServiceResults> translateToChunkedResults(
-        List<String> inputs,
-        InferenceServiceResults inferenceResults
-    ) {
-        if (inferenceResults instanceof InferenceTextEmbeddingFloatResults textEmbeddingResults) {
-            return InferenceChunkedTextEmbeddingFloatResults.listOf(inputs, textEmbeddingResults);
-        } else if (inferenceResults instanceof ErrorInferenceResults error) {
-            return List.of(new ErrorChunkedInferenceResults(error.getException()));
-        } else {
-            throw createInvalidChunkedResultException(InferenceTextEmbeddingFloatResults.NAME, inferenceResults.getWriteableName());
-        }
-    }
-
     @Override
     public String name() {
         return NAME;

+ 0 - 18
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java

@@ -24,10 +24,6 @@ import org.elasticsearch.inference.ModelSecrets;
 import org.elasticsearch.inference.SimilarityMeasure;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.rest.RestStatus;
-import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
-import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults;
-import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
-import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
 import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
 import org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionCreator;
 import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
@@ -44,7 +40,6 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 
-import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
@@ -246,19 +241,6 @@ public class AzureOpenAiService extends SenderService {
         }
     }
 
-    private static List<ChunkedInferenceServiceResults> translateToChunkedResults(
-        List<String> inputs,
-        InferenceServiceResults inferenceResults
-    ) {
-        if (inferenceResults instanceof InferenceTextEmbeddingFloatResults textEmbeddingResults) {
-            return InferenceChunkedTextEmbeddingFloatResults.listOf(inputs, textEmbeddingResults);
-        } else if (inferenceResults instanceof ErrorInferenceResults error) {
-            return List.of(new ErrorChunkedInferenceResults(error.getException()));
-        } else {
-            throw createInvalidChunkedResultException(InferenceTextEmbeddingFloatResults.NAME, inferenceResults.getWriteableName());
-        }
-    }
-
     /**
      * For text embedding models get the embedding size and
      * update the service settings.

+ 13 - 8
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java

@@ -1048,13 +1048,18 @@ public class AmazonBedrockServiceTests extends ESTestCase {
 
         try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) {
             try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) {
-                var mockResults = new InferenceTextEmbeddingFloatResults(
-                    List.of(
-                        new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.123F, 0.678F }),
-                        new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.456F, 0.987F })
-                    )
-                );
-                requestSender.enqueue(mockResults);
+                {
+                    var mockResults1 = new InferenceTextEmbeddingFloatResults(
+                        List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.123F, 0.678F }))
+                    );
+                    requestSender.enqueue(mockResults1);
+                }
+                {
+                    var mockResults2 = new InferenceTextEmbeddingFloatResults(
+                        List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.223F, 0.278F }))
+                    );
+                    requestSender.enqueue(mockResults2);
+                }
 
                 var model = AmazonBedrockEmbeddingsModelTests.createModel(
                     "id",
@@ -1089,7 +1094,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
                     var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1);
                     assertThat(floatResult.chunks(), hasSize(1));
                     assertEquals("xyz", floatResult.chunks().get(0).matchedText());
-                    assertArrayEquals(new float[] { 0.456F, 0.987F }, floatResult.chunks().get(0).embedding(), 0.0f);
+                    assertArrayEquals(new float[] { 0.223F, 0.278F }, floatResult.chunks().get(0).embedding(), 0.0f);
                 }
             }
         }