|
@@ -23,10 +23,6 @@ import org.elasticsearch.inference.ModelConfigurations;
|
|
|
import org.elasticsearch.inference.ModelSecrets;
|
|
|
import org.elasticsearch.inference.TaskType;
|
|
|
import org.elasticsearch.rest.RestStatus;
|
|
|
-import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
|
|
|
-import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults;
|
|
|
-import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
|
|
|
-import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
|
|
|
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
|
|
|
import org.elasticsearch.xpack.inference.external.action.amazonbedrock.AmazonBedrockActionCreator;
|
|
|
import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockRequestSender;
|
|
@@ -47,7 +43,6 @@ import java.util.Map;
|
|
|
import java.util.Set;
|
|
|
|
|
|
import static org.elasticsearch.TransportVersions.ML_INFERENCE_AMAZON_BEDROCK_ADDED;
|
|
|
-import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException;
|
|
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
|
|
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
|
|
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
|
|
@@ -115,10 +110,6 @@ public class AmazonBedrockService extends SenderService {
|
|
|
TimeValue timeout,
|
|
|
ActionListener<List<ChunkedInferenceServiceResults>> listener
|
|
|
) {
|
|
|
- ActionListener<InferenceServiceResults> inferListener = listener.delegateFailureAndWrap(
|
|
|
- (delegate, response) -> delegate.onResponse(translateToChunkedResults(input, response))
|
|
|
- );
|
|
|
-
|
|
|
var actionCreator = new AmazonBedrockActionCreator(amazonBedrockSender, this.getServiceComponents(), timeout);
|
|
|
if (model instanceof AmazonBedrockModel baseAmazonBedrockModel) {
|
|
|
var maxBatchSize = getEmbeddingsMaxBatchSize(baseAmazonBedrockModel.provider());
|
|
@@ -126,26 +117,13 @@ public class AmazonBedrockService extends SenderService {
|
|
|
.batchRequestsWithListeners(listener);
|
|
|
for (var request : batchedRequests) {
|
|
|
var action = baseAmazonBedrockModel.accept(actionCreator, taskSettings);
|
|
|
- action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, inferListener);
|
|
|
+ action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());
|
|
|
}
|
|
|
} else {
|
|
|
listener.onFailure(createInvalidModelException(model));
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- private static List<ChunkedInferenceServiceResults> translateToChunkedResults(
|
|
|
- List<String> inputs,
|
|
|
- InferenceServiceResults inferenceResults
|
|
|
- ) {
|
|
|
- if (inferenceResults instanceof InferenceTextEmbeddingFloatResults textEmbeddingResults) {
|
|
|
- return InferenceChunkedTextEmbeddingFloatResults.listOf(inputs, textEmbeddingResults);
|
|
|
- } else if (inferenceResults instanceof ErrorInferenceResults error) {
|
|
|
- return List.of(new ErrorChunkedInferenceResults(error.getException()));
|
|
|
- } else {
|
|
|
- throw createInvalidChunkedResultException(InferenceTextEmbeddingFloatResults.NAME, inferenceResults.getWriteableName());
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
@Override
|
|
|
public String name() {
|
|
|
return NAME;
|