Browse Source

[ML] Missing chunkers for AzureOpenAIService, AzureAIStudioService and HuggingFace embeddings (#109875)

Adds chunking for AzureOpenAIService, AzureAIStudioService 
and the HuggingFace text embedding service
David Kyle 1 year ago
parent
commit
b60d77e074
13 changed files with 322 additions and 153 deletions
  1. 0 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/EmbeddingRequestChunker.java
  2. 11 21
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java
  3. 14 5
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java
  4. 7 49
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java
  5. 37 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java
  6. 56 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java
  7. 0 18
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java
  8. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceFields.java
  9. 27 19
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java
  10. 28 19
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java
  11. 1 1
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java
  12. 125 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java
  13. 15 19
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java

+ 0 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/EmbeddingRequestChunker.java

@@ -191,7 +191,6 @@ public class EmbeddingRequestChunker {
                 case FLOAT -> handleFloatResults(inferenceServiceResults);
                 case BYTE -> handleByteResults(inferenceServiceResults);
             }
-            ;
         }
 
         private void handleFloatResults(InferenceServiceResults inferenceServiceResults) {

+ 11 - 21
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java

@@ -24,10 +24,7 @@ import org.elasticsearch.inference.ModelSecrets;
 import org.elasticsearch.inference.SimilarityMeasure;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.rest.RestStatus;
-import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
-import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults;
-import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
-import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
+import org.elasticsearch.xpack.inference.common.EmbeddingRequestChunker;
 import org.elasticsearch.xpack.inference.external.action.azureaistudio.AzureAiStudioActionCreator;
 import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
@@ -44,7 +41,6 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 
-import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
@@ -53,6 +49,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNot
 import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProviderCapabilities.providerAllowsEndpointTypeForTask;
 import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioProviderCapabilities.providerAllowsTaskType;
 import static org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionTaskSettings.DEFAULT_MAX_NEW_TOKENS;
+import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.EMBEDDING_MAX_BATCH_SIZE;
 
 public class AzureAiStudioService extends SenderService {
 
@@ -105,23 +102,16 @@ public class AzureAiStudioService extends SenderService {
         TimeValue timeout,
         ActionListener<List<ChunkedInferenceServiceResults>> listener
     ) {
-        ActionListener<InferenceServiceResults> inferListener = listener.delegateFailureAndWrap(
-            (delegate, response) -> delegate.onResponse(translateToChunkedResults(input, response))
-        );
-
-        doInfer(model, input, taskSettings, inputType, timeout, inferListener);
-    }
-
-    private static List<ChunkedInferenceServiceResults> translateToChunkedResults(
-        List<String> inputs,
-        InferenceServiceResults inferenceResults
-    ) {
-        if (inferenceResults instanceof InferenceTextEmbeddingFloatResults textEmbeddingResults) {
-            return InferenceChunkedTextEmbeddingFloatResults.listOf(inputs, textEmbeddingResults);
-        } else if (inferenceResults instanceof ErrorInferenceResults error) {
-            return List.of(new ErrorChunkedInferenceResults(error.getException()));
+        if (model instanceof AzureAiStudioModel baseAzureAiStudioModel) {
+            var actionCreator = new AzureAiStudioActionCreator(getSender(), getServiceComponents());
+            var batchedRequests = new EmbeddingRequestChunker(input, EMBEDDING_MAX_BATCH_SIZE, EmbeddingRequestChunker.EmbeddingType.FLOAT)
+                .batchRequestsWithListeners(listener);
+            for (var request : batchedRequests) {
+                var action = baseAzureAiStudioModel.accept(actionCreator, taskSettings);
+                action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());
+            }
         } else {
-            throw createInvalidChunkedResultException(InferenceTextEmbeddingFloatResults.NAME, inferenceResults.getWriteableName());
+            listener.onFailure(createInvalidModelException(model));
         }
     }
 

+ 14 - 5
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java

@@ -28,6 +28,7 @@ import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResul
 import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults;
 import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
 import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
+import org.elasticsearch.xpack.inference.common.EmbeddingRequestChunker;
 import org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionCreator;
 import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
@@ -49,6 +50,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersi
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
+import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.EMBEDDING_MAX_BATCH_SIZE;
 
 public class AzureOpenAiService extends SenderService {
     public static final String NAME = "azureopenai";
@@ -230,11 +232,18 @@ public class AzureOpenAiService extends SenderService {
         TimeValue timeout,
         ActionListener<List<ChunkedInferenceServiceResults>> listener
     ) {
-        ActionListener<InferenceServiceResults> inferListener = listener.delegateFailureAndWrap(
-            (delegate, response) -> delegate.onResponse(translateToChunkedResults(input, response))
-        );
-
-        doInfer(model, input, taskSettings, inputType, timeout, inferListener);
+        if (model instanceof AzureOpenAiModel == false) {
+            listener.onFailure(createInvalidModelException(model));
+            return;
+        }
+        AzureOpenAiModel azureOpenAiModel = (AzureOpenAiModel) model;
+        var actionCreator = new AzureOpenAiActionCreator(getSender(), getServiceComponents());
+        var batchedRequests = new EmbeddingRequestChunker(input, EMBEDDING_MAX_BATCH_SIZE, EmbeddingRequestChunker.EmbeddingType.FLOAT)
+            .batchRequestsWithListeners(listener);
+        for (var request : batchedRequests) {
+            var action = azureOpenAiModel.accept(actionCreator, taskSettings);
+            action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());
+        }
     }
 
     private static List<ChunkedInferenceServiceResults> translateToChunkedResults(

+ 7 - 49
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java

@@ -8,23 +8,13 @@
 package org.elasticsearch.xpack.inference.services.huggingface;
 
 import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.common.Strings;
-import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.TimeValue;
-import org.elasticsearch.inference.ChunkedInferenceServiceResults;
-import org.elasticsearch.inference.ChunkingOptions;
 import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.inference.InputType;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.ModelConfigurations;
 import org.elasticsearch.inference.ModelSecrets;
 import org.elasticsearch.inference.TaskType;
-import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
-import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults;
-import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults;
-import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
-import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
-import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
 import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionCreator;
 import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
@@ -36,7 +26,6 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 
-import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
@@ -44,6 +33,13 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNot
 
 public abstract class HuggingFaceBaseService extends SenderService {
 
+    /**
+     * The optimal batch size depends on the hardware the model is deployed on.
+     * For HuggingFace use a conservatively small max batch size as it is
+     * unknown how the model is deployed
+     */
+    static final int EMBEDDING_MAX_BATCH_SIZE = 20;
+
     public HuggingFaceBaseService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
         super(factory, serviceComponents);
     }
@@ -153,42 +149,4 @@ public abstract class HuggingFaceBaseService extends SenderService {
     ) {
         throw new UnsupportedOperationException("Hugging Face service does not support inference with query input");
     }
-
-    @Override
-    protected void doChunkedInfer(
-        Model model,
-        @Nullable String query,
-        List<String> input,
-        Map<String, Object> taskSettings,
-        InputType inputType,
-        ChunkingOptions chunkingOptions,
-        TimeValue timeout,
-        ActionListener<List<ChunkedInferenceServiceResults>> listener
-    ) {
-        ActionListener<InferenceServiceResults> inferListener = listener.delegateFailureAndWrap(
-            (delegate, response) -> delegate.onResponse(translateToChunkedResults(input, response))
-        );
-
-        doInfer(model, input, taskSettings, inputType, timeout, inferListener);
-    }
-
-    private static List<ChunkedInferenceServiceResults> translateToChunkedResults(
-        List<String> inputs,
-        InferenceServiceResults inferenceResults
-    ) {
-        if (inferenceResults instanceof InferenceTextEmbeddingFloatResults textEmbeddingResults) {
-            return InferenceChunkedTextEmbeddingFloatResults.listOf(inputs, textEmbeddingResults);
-        } else if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) {
-            return InferenceChunkedSparseEmbeddingResults.listOf(inputs, sparseEmbeddingResults);
-        } else if (inferenceResults instanceof ErrorInferenceResults error) {
-            return List.of(new ErrorChunkedInferenceResults(error.getException()));
-        } else {
-            String expectedClasses = Strings.format(
-                "One of [%s,%s]",
-                InferenceTextEmbeddingFloatResults.class.getSimpleName(),
-                SparseEmbeddingResults.class.getSimpleName()
-            );
-            throw createInvalidChunkedResultException(expectedClasses, inferenceResults.getWriteableName());
-        }
-    }
 }

+ 37 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java

@@ -12,9 +12,16 @@ import org.elasticsearch.TransportVersion;
 import org.elasticsearch.TransportVersions;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.core.Nullable;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.ChunkedInferenceServiceResults;
+import org.elasticsearch.inference.ChunkingOptions;
+import org.elasticsearch.inference.InputType;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.xpack.inference.common.EmbeddingRequestChunker;
+import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionCreator;
+import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
 import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
 import org.elasticsearch.xpack.inference.services.ServiceComponents;
@@ -22,8 +29,11 @@ import org.elasticsearch.xpack.inference.services.ServiceUtils;
 import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModel;
 import org.elasticsearch.xpack.inference.services.huggingface.embeddings.HuggingFaceEmbeddingsModel;
 
+import java.util.List;
 import java.util.Map;
 
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
+
 public class HuggingFaceService extends HuggingFaceBaseService {
     public static final String NAME = "hugging_face";
 
@@ -79,6 +89,33 @@ public class HuggingFaceService extends HuggingFaceBaseService {
         return new HuggingFaceEmbeddingsModel(model, serviceSettings);
     }
 
+    @Override
+    protected void doChunkedInfer(
+        Model model,
+        @Nullable String query,
+        List<String> input,
+        Map<String, Object> taskSettings,
+        InputType inputType,
+        ChunkingOptions chunkingOptions,
+        TimeValue timeout,
+        ActionListener<List<ChunkedInferenceServiceResults>> listener
+    ) {
+        if (model instanceof HuggingFaceModel == false) {
+            listener.onFailure(createInvalidModelException(model));
+            return;
+        }
+
+        var huggingFaceModel = (HuggingFaceModel) model;
+        var actionCreator = new HuggingFaceActionCreator(getSender(), getServiceComponents());
+
+        var batchedRequests = new EmbeddingRequestChunker(input, EMBEDDING_MAX_BATCH_SIZE, EmbeddingRequestChunker.EmbeddingType.FLOAT)
+            .batchRequestsWithListeners(listener);
+        for (var request : batchedRequests) {
+            var action = huggingFaceModel.accept(actionCreator);
+            action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());
+        }
+    }
+
     @Override
     public String name() {
         return NAME;

+ 56 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java

@@ -10,17 +10,34 @@ package org.elasticsearch.xpack.inference.services.huggingface.elser;
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.TransportVersion;
 import org.elasticsearch.TransportVersions;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.common.Strings;
 import org.elasticsearch.core.Nullable;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.ChunkedInferenceServiceResults;
+import org.elasticsearch.inference.ChunkingOptions;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
+import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults;
+import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
 import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
 import org.elasticsearch.xpack.inference.services.ServiceComponents;
 import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceBaseService;
 import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel;
 
+import java.util.List;
 import java.util.Map;
 
+import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException;
+
 public class HuggingFaceElserService extends HuggingFaceBaseService {
     public static final String NAME = "hugging_face_elser";
 
@@ -48,6 +65,45 @@ public class HuggingFaceElserService extends HuggingFaceBaseService {
         };
     }
 
+    @Override
+    protected void doChunkedInfer(
+        Model model,
+        @Nullable String query,
+        List<String> input,
+        Map<String, Object> taskSettings,
+        InputType inputType,
+        ChunkingOptions chunkingOptions,
+        TimeValue timeout,
+        ActionListener<List<ChunkedInferenceServiceResults>> listener
+    ) {
+        ActionListener<InferenceServiceResults> inferListener = listener.delegateFailureAndWrap(
+            (delegate, response) -> delegate.onResponse(translateToChunkedResults(input, response))
+        );
+
+        // TODO chunking sparse embeddings not implemented
+        doInfer(model, input, taskSettings, inputType, timeout, inferListener);
+    }
+
+    private static List<ChunkedInferenceServiceResults> translateToChunkedResults(
+        List<String> inputs,
+        InferenceServiceResults inferenceResults
+    ) {
+        if (inferenceResults instanceof InferenceTextEmbeddingFloatResults textEmbeddingResults) {
+            return InferenceChunkedTextEmbeddingFloatResults.listOf(inputs, textEmbeddingResults);
+        } else if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) {
+            return InferenceChunkedSparseEmbeddingResults.listOf(inputs, sparseEmbeddingResults);
+        } else if (inferenceResults instanceof ErrorInferenceResults error) {
+            return List.of(new ErrorChunkedInferenceResults(error.getException()));
+        } else {
+            String expectedClasses = Strings.format(
+                "One of [%s,%s]",
+                InferenceTextEmbeddingFloatResults.class.getSimpleName(),
+                SparseEmbeddingResults.class.getSimpleName()
+            );
+            throw createInvalidChunkedResultException(expectedClasses, inferenceResults.getWriteableName());
+        }
+    }
+
     @Override
     public TransportVersion getMinimalSupportedVersion() {
         return TransportVersions.V_8_12_0;

+ 0 - 18
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java

@@ -22,10 +22,6 @@ import org.elasticsearch.inference.ModelSecrets;
 import org.elasticsearch.inference.SimilarityMeasure;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.rest.RestStatus;
-import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
-import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults;
-import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
-import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
 import org.elasticsearch.xpack.inference.common.EmbeddingRequestChunker;
 import org.elasticsearch.xpack.inference.external.action.mistral.MistralActionCreator;
 import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
@@ -42,7 +38,6 @@ import java.util.Map;
 import java.util.Set;
 
 import static org.elasticsearch.TransportVersions.ADD_MISTRAL_EMBEDDINGS_INFERENCE;
-import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
@@ -117,19 +112,6 @@ public class MistralService extends SenderService {
         }
     }
 
-    private static List<ChunkedInferenceServiceResults> translateToChunkedResults(
-        List<String> inputs,
-        InferenceServiceResults inferenceResults
-    ) {
-        if (inferenceResults instanceof InferenceTextEmbeddingFloatResults textEmbeddingResults) {
-            return InferenceChunkedTextEmbeddingFloatResults.listOf(inputs, textEmbeddingResults);
-        } else if (inferenceResults instanceof ErrorInferenceResults error) {
-            return List.of(new ErrorChunkedInferenceResults(error.getException()));
-        } else {
-            throw createInvalidChunkedResultException(InferenceChunkedTextEmbeddingFloatResults.NAME, inferenceResults.getWriteableName());
-        }
-    }
-
     @Override
     public String name() {
         return NAME;

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceFields.java

@@ -16,6 +16,6 @@ public class OpenAiServiceFields {
     /**
      * Taken from https://platform.openai.com/docs/api-reference/embeddings/create
      */
-    static final int EMBEDDING_MAX_BATCH_SIZE = 2048;
+    public static final int EMBEDDING_MAX_BATCH_SIZE = 2048;
 
 }

+ 27 - 19
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java

@@ -33,7 +33,6 @@ import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xpack.core.inference.action.InferenceAction;
 import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
 import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults;
-import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults;
 import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
@@ -55,14 +54,12 @@ import org.junit.After;
 import org.junit.Before;
 
 import java.io.IOException;
-import java.net.URISyntaxException;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.TimeUnit;
 
-import static org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResultsTests.asMapWithListsInsteadOfArrays;
 import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
 import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
 import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
@@ -849,7 +846,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
         verifyNoMoreInteractions(sender);
     }
 
-    public void testChunkedInfer_Embeddings_CallsInfer_ConvertsFloatResponse() throws IOException, URISyntaxException {
+    public void testChunkedInfer() throws IOException {
         var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
 
         try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
@@ -865,6 +862,14 @@ public class AzureAiStudioServiceTests extends ESTestCase {
                 0.0123,
                 -0.0123
                 ]
+                },
+                {
+                "object": "embedding",
+                "index": 1,
+                "embedding": [
+                1.0123,
+                -1.0123
+                ]
                 }
                 ],
                 "model": "text-embedding-ada-002-v2",
@@ -892,7 +897,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
             PlainActionFuture<List<ChunkedInferenceServiceResults>> listener = new PlainActionFuture<>();
             service.chunkedInfer(
                 model,
-                List.of("abc"),
+                List.of("foo", "bar"),
                 new HashMap<>(),
                 InputType.INGEST,
                 new ChunkingOptions(null, null),
@@ -900,20 +905,23 @@ public class AzureAiStudioServiceTests extends ESTestCase {
                 listener
             );
 
-            var result = listener.actionGet(TIMEOUT).get(0);
-            assertThat(result, CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class));
+            var results = listener.actionGet(TIMEOUT);
+            assertThat(results, hasSize(2));
+            {
+                assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class));
+                var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0);
+                assertThat(floatResult.chunks(), hasSize(1));
+                assertEquals("foo", floatResult.chunks().get(0).matchedText());
+                assertArrayEquals(new float[] { 0.0123f, -0.0123f }, floatResult.chunks().get(0).embedding(), 0.0f);
+            }
+            {
+                assertThat(results.get(1), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class));
+                var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1);
+                assertThat(floatResult.chunks(), hasSize(1));
+                assertEquals("bar", floatResult.chunks().get(0).matchedText());
+                assertArrayEquals(new float[] { 1.0123f, -1.0123f }, floatResult.chunks().get(0).embedding(), 0.0f);
+            }
 
-            assertThat(
-                asMapWithListsInsteadOfArrays((InferenceChunkedTextEmbeddingFloatResults) result),
-                Matchers.is(
-                    Map.of(
-                        InferenceChunkedTextEmbeddingFloatResults.FIELD_NAME,
-                        List.of(
-                            Map.of(ChunkedNlpInferenceResults.TEXT, "abc", ChunkedNlpInferenceResults.INFERENCE, List.of(0.0123f, -0.0123f))
-                        )
-                    )
-                )
-            );
             assertThat(webServer.requests(), hasSize(1));
             assertNull(webServer.requests().get(0).getUri().getQuery());
             assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
@@ -921,7 +929,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
 
             var requestMap = entityAsMap(webServer.requests().get(0).getBody());
             assertThat(requestMap.size(), Matchers.is(2));
-            assertThat(requestMap.get("input"), Matchers.is(List.of("abc")));
+            assertThat(requestMap.get("input"), Matchers.is(List.of("foo", "bar")));
             assertThat(requestMap.get("user"), Matchers.is("user"));
         }
     }

+ 28 - 19
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java

@@ -32,7 +32,6 @@ import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xpack.core.inference.action.InferenceAction;
 import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults;
-import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults;
 import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
@@ -55,7 +54,6 @@ import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.TimeUnit;
 
-import static org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResultsTests.asMapWithListsInsteadOfArrays;
 import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
 import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
 import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
@@ -1079,8 +1077,16 @@ public class AzureOpenAiServiceTests extends ESTestCase {
                 "object": "embedding",
                 "index": 0,
                 "embedding": [
-                0.0123,
-                -0.0123
+                0.123,
+                -0.123
+                ]
+                },
+                {
+                "object": "embedding",
+                "index": 1,
+                "embedding": [
+                1.123,
+                -1.123
                 ]
                 }
                 ],
@@ -1098,7 +1104,7 @@ public class AzureOpenAiServiceTests extends ESTestCase {
             PlainActionFuture<List<ChunkedInferenceServiceResults>> listener = new PlainActionFuture<>();
             service.chunkedInfer(
                 model,
-                List.of("abc"),
+                List.of("foo", "bar"),
                 new HashMap<>(),
                 InputType.INGEST,
                 new ChunkingOptions(null, null),
@@ -1106,20 +1112,23 @@ public class AzureOpenAiServiceTests extends ESTestCase {
                 listener
             );
 
-            var result = listener.actionGet(TIMEOUT).get(0);
-            assertThat(result, CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class));
+            var results = listener.actionGet(TIMEOUT);
+            assertThat(results, hasSize(2));
+            {
+                assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class));
+                var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0);
+                assertThat(floatResult.chunks(), hasSize(1));
+                assertEquals("foo", floatResult.chunks().get(0).matchedText());
+                assertArrayEquals(new float[] { 0.123f, -0.123f }, floatResult.chunks().get(0).embedding(), 0.0f);
+            }
+            {
+                assertThat(results.get(1), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class));
+                var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1);
+                assertThat(floatResult.chunks(), hasSize(1));
+                assertEquals("bar", floatResult.chunks().get(0).matchedText());
+                assertArrayEquals(new float[] { 1.123f, -1.123f }, floatResult.chunks().get(0).embedding(), 0.0f);
+            }
 
-            assertThat(
-                asMapWithListsInsteadOfArrays((InferenceChunkedTextEmbeddingFloatResults) result),
-                Matchers.is(
-                    Map.of(
-                        InferenceChunkedTextEmbeddingFloatResults.FIELD_NAME,
-                        List.of(
-                            Map.of(ChunkedNlpInferenceResults.TEXT, "abc", ChunkedNlpInferenceResults.INFERENCE, List.of(0.0123f, -0.0123f))
-                        )
-                    )
-                )
-            );
             assertThat(webServer.requests(), hasSize(1));
             assertNull(webServer.requests().get(0).getUri().getQuery());
             assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
@@ -1127,7 +1136,7 @@ public class AzureOpenAiServiceTests extends ESTestCase {
 
             var requestMap = entityAsMap(webServer.requests().get(0).getBody());
             assertThat(requestMap.size(), Matchers.is(2));
-            assertThat(requestMap.get("input"), Matchers.is(List.of("abc")));
+            assertThat(requestMap.get("input"), Matchers.is(List.of("foo", "bar")));
             assertThat(requestMap.get("user"), Matchers.is("user"));
         }
     }

+ 1 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java

@@ -90,7 +90,7 @@ public class HuggingFaceBaseServiceTests extends ESTestCase {
         verifyNoMoreInteractions(sender);
     }
 
-    private static final class TestService extends HuggingFaceBaseService {
+    private static final class TestService extends HuggingFaceService {
         TestService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
             super(factory, serviceComponents);
         }

+ 125 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java

@@ -0,0 +1,125 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.services.huggingface;
+
+import org.apache.http.HttpHeaders;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.inference.ChunkedInferenceServiceResults;
+import org.elasticsearch.inference.ChunkingOptions;
+import org.elasticsearch.inference.InputType;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.test.http.MockResponse;
+import org.elasticsearch.test.http.MockWebServer;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults;
+import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
+import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
+import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
+import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModelTests;
+import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserService;
+import org.hamcrest.MatcherAssert;
+import org.hamcrest.Matchers;
+import org.junit.After;
+import org.junit.Before;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+
+import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
+import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
+import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
+import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
+import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.hasSize;
+import static org.mockito.Mockito.mock;
+
+public class HuggingFaceElserServiceTests extends ESTestCase {
+
+    private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
+
+    private final MockWebServer webServer = new MockWebServer();
+    private ThreadPool threadPool;
+    private HttpClientManager clientManager;
+
+    @Before
+    public void init() throws Exception {
+        webServer.start();
+        threadPool = createThreadPool(inferenceUtilityPool());
+        clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
+    }
+
+    @After
+    public void shutdown() throws IOException {
+        clientManager.close();
+        terminate(threadPool);
+        webServer.close();
+    }
+
+    public void testChunkedInfer_CallsInfer_Elser_ConvertsFloatResponse() throws IOException {
+        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
+
+        try (var service = new HuggingFaceElserService(senderFactory, createWithEmptySettings(threadPool))) {
+
+            String responseJson = """
+                [
+                    {
+                        ".": 0.133155956864357
+                    }
+                ]
+                """;
+            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
+
+            var model = HuggingFaceElserModelTests.createModel(getUrl(webServer), "secret");
+            PlainActionFuture<List<ChunkedInferenceServiceResults>> listener = new PlainActionFuture<>();
+            service.chunkedInfer(
+                model,
+                List.of("abc"),
+                new HashMap<>(),
+                InputType.INGEST,
+                new ChunkingOptions(null, null),
+                InferenceAction.Request.DEFAULT_TIMEOUT,
+                listener
+            );
+
+            var result = listener.actionGet(TIMEOUT).get(0);
+
+            MatcherAssert.assertThat(
+                result.asMap(),
+                Matchers.is(
+                    Map.of(
+                        InferenceChunkedSparseEmbeddingResults.FIELD_NAME,
+                        List.of(
+                            Map.of(ChunkedNlpInferenceResults.TEXT, "abc", ChunkedNlpInferenceResults.INFERENCE, Map.of(".", 0.13315596f))
+                        )
+                    )
+                )
+            );
+
+            assertThat(webServer.requests(), hasSize(1));
+            assertNull(webServer.requests().get(0).getUri().getQuery());
+            assertThat(
+                webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE),
+                equalTo(XContentType.JSON.mediaTypeWithoutParameters())
+            );
+            assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
+
+            var requestMap = entityAsMap(webServer.requests().get(0).getBody());
+            assertThat(requestMap.size(), Matchers.is(1));
+            assertThat(requestMap.get("inputs"), Matchers.is(List.of("abc")));
+        }
+    }
+}

+ 15 - 19
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java

@@ -31,7 +31,6 @@ import org.elasticsearch.test.http.MockWebServer;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xpack.core.inference.action.InferenceAction;
-import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults;
 import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults;
 import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults;
 import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
@@ -649,21 +648,22 @@ public class HuggingFaceServiceTests extends ESTestCase {
         }
     }
 
-    public void testChunkedInfer_CallsInfer_Elser_ConvertsFloatResponse() throws IOException {
+    public void testChunkedInfer() throws IOException {
         var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
 
         try (var service = new HuggingFaceService(senderFactory, createWithEmptySettings(threadPool))) {
 
             String responseJson = """
                 [
-                    {
-                        ".": 0.133155956864357
-                    }
+                      [
+                          0.123,
+                          -0.123
+                      ]
                 ]
                 """;
             webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
 
-            var model = HuggingFaceElserModelTests.createModel(getUrl(webServer), "secret");
+            var model = HuggingFaceEmbeddingsModelTests.createModel(getUrl(webServer), "secret");
             PlainActionFuture<List<ChunkedInferenceServiceResults>> listener = new PlainActionFuture<>();
             service.chunkedInfer(
                 model,
@@ -675,19 +675,15 @@ public class HuggingFaceServiceTests extends ESTestCase {
                 listener
             );
 
-            var result = listener.actionGet(TIMEOUT).get(0);
-
-            MatcherAssert.assertThat(
-                result.asMap(),
-                Matchers.is(
-                    Map.of(
-                        InferenceChunkedSparseEmbeddingResults.FIELD_NAME,
-                        List.of(
-                            Map.of(ChunkedNlpInferenceResults.TEXT, "abc", ChunkedNlpInferenceResults.INFERENCE, Map.of(".", 0.13315596f))
-                        )
-                    )
-                )
-            );
+            var results = listener.actionGet(TIMEOUT);
+            assertThat(results, hasSize(1));
+            {
+                assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class));
+                var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0);
+                assertThat(floatResult.chunks(), hasSize(1));
+                assertEquals("abc", floatResult.chunks().get(0).matchedText());
+                assertArrayEquals(new float[] { 0.123f, -0.123f }, floatResult.chunks().get(0).embedding(), 0.0f);
+            }
 
             assertThat(webServer.requests(), hasSize(1));
             assertNull(webServer.requests().get(0).getUri().getQuery());