Browse Source

[EIS] Implement chunked & batched inference for sparse text embeddings (#129922)

Tim Grein 3 months ago
parent
commit
870d581084

+ 56 - 13
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

@@ -17,6 +17,7 @@ import org.elasticsearch.common.util.LazyInitializable;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.inference.ChunkedInference;
+import org.elasticsearch.inference.ChunkingSettings;
 import org.elasticsearch.inference.EmptySecretSettings;
 import org.elasticsearch.inference.EmptyTaskSettings;
 import org.elasticsearch.inference.InferenceServiceConfiguration;
@@ -36,6 +37,8 @@ import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
 import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
 import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
 import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
+import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
+import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
 import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
 import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
@@ -71,6 +74,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT
 import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
@@ -80,6 +84,7 @@ public class ElasticInferenceService extends SenderService {
 
     public static final String NAME = "elastic";
     public static final String ELASTIC_INFERENCE_SERVICE_IDENTIFIER = "Elastic Inference Service";
+    public static final int SPARSE_TEXT_EMBEDDING_MAX_BATCH_SIZE = 512;
 
     private static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(
         TaskType.SPARSE_EMBEDDING,
@@ -161,7 +166,8 @@ public class ElasticInferenceService extends SenderService {
                     new ElasticInferenceServiceSparseEmbeddingsServiceSettings(DEFAULT_ELSER_MODEL_ID_V2, null, null),
                     EmptyTaskSettings.INSTANCE,
                     EmptySecretSettings.INSTANCE,
-                    elasticInferenceServiceComponents
+                    elasticInferenceServiceComponents,
+                    ChunkingSettingsBuilder.DEFAULT_SETTINGS
                 ),
                 MinimalServiceSettings.sparseEmbedding(NAME)
             ),
@@ -304,12 +310,25 @@ public class ElasticInferenceService extends SenderService {
         TimeValue timeout,
         ActionListener<List<ChunkedInference>> listener
     ) {
-        // Pass-through without actually performing chunking (result will have a single chunk per input)
-        ActionListener<InferenceServiceResults> inferListener = listener.delegateFailureAndWrap(
-            (delegate, response) -> delegate.onResponse(translateToChunkedResults(inputs, response))
-        );
+        if (model instanceof ElasticInferenceServiceSparseEmbeddingsModel sparseTextEmbeddingsModel) {
+            var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), getCurrentTraceInfo());
+
+            List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
+                inputs.getInputs(),
+                SPARSE_TEXT_EMBEDDING_MAX_BATCH_SIZE,
+                model.getConfigurations().getChunkingSettings()
+            ).batchRequestsWithListeners(listener);
+
+            for (var request : batchedRequests) {
+                var action = sparseTextEmbeddingsModel.accept(actionCreator, taskSettings);
+                action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
+            }
+
+            return;
+        }
 
-        doInfer(model, inputs, taskSettings, timeout, inferListener);
+        // Model cannot perform chunked inference
+        listener.onFailure(createInvalidModelException(model));
     }
 
     @Override
@@ -328,6 +347,13 @@ public class ElasticInferenceService extends SenderService {
             Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
             Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
 
+            ChunkingSettings chunkingSettings = null;
+            if (TaskType.SPARSE_EMBEDDING.equals(taskType)) {
+                chunkingSettings = ChunkingSettingsBuilder.fromMap(
+                    removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)
+                );
+            }
+
             ElasticInferenceServiceModel model = createModel(
                 inferenceEntityId,
                 taskType,
@@ -336,7 +362,8 @@ public class ElasticInferenceService extends SenderService {
                 serviceSettingsMap,
                 elasticInferenceServiceComponents,
                 TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
-                ConfigurationParseContext.REQUEST
+                ConfigurationParseContext.REQUEST,
+                chunkingSettings
             );
 
             throwIfNotEmptyMap(config, NAME);
@@ -372,7 +399,8 @@ public class ElasticInferenceService extends SenderService {
         @Nullable Map<String, Object> secretSettings,
         ElasticInferenceServiceComponents elasticInferenceServiceComponents,
         String failureMessage,
-        ConfigurationParseContext context
+        ConfigurationParseContext context,
+        ChunkingSettings chunkingSettings
     ) {
         return switch (taskType) {
             case SPARSE_EMBEDDING -> new ElasticInferenceServiceSparseEmbeddingsModel(
@@ -383,7 +411,8 @@ public class ElasticInferenceService extends SenderService {
                 taskSettings,
                 secretSettings,
                 elasticInferenceServiceComponents,
-                context
+                context,
+                chunkingSettings
             );
             case CHAT_COMPLETION -> new ElasticInferenceServiceCompletionModel(
                 inferenceEntityId,
@@ -420,13 +449,19 @@ public class ElasticInferenceService extends SenderService {
         Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
         Map<String, Object> secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS);
 
+        ChunkingSettings chunkingSettings = null;
+        if (TaskType.SPARSE_EMBEDDING.equals(taskType)) {
+            chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
+        }
+
         return createModelFromPersistent(
             inferenceEntityId,
             taskType,
             serviceSettingsMap,
             taskSettingsMap,
             secretSettingsMap,
-            parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
+            parsePersistedConfigErrorMsg(inferenceEntityId, NAME),
+            chunkingSettings
         );
     }
 
@@ -435,13 +470,19 @@ public class ElasticInferenceService extends SenderService {
         Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
         Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
 
+        ChunkingSettings chunkingSettings = null;
+        if (TaskType.SPARSE_EMBEDDING.equals(taskType)) {
+            chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
+        }
+
         return createModelFromPersistent(
             inferenceEntityId,
             taskType,
             serviceSettingsMap,
             taskSettingsMap,
             null,
-            parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
+            parsePersistedConfigErrorMsg(inferenceEntityId, NAME),
+            chunkingSettings
         );
     }
 
@@ -456,7 +497,8 @@ public class ElasticInferenceService extends SenderService {
         Map<String, Object> serviceSettings,
         Map<String, Object> taskSettings,
         @Nullable Map<String, Object> secretSettings,
-        String failureMessage
+        String failureMessage,
+        ChunkingSettings chunkingSettings
     ) {
         return createModel(
             inferenceEntityId,
@@ -466,7 +508,8 @@ public class ElasticInferenceService extends SenderService {
             secretSettings,
             elasticInferenceServiceComponents,
             failureMessage,
-            ConfigurationParseContext.PERSISTENT
+            ConfigurationParseContext.PERSISTENT,
+            chunkingSettings
         );
     }
 

+ 8 - 4
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/sparseembeddings/ElasticInferenceServiceSparseEmbeddingsModel.java

@@ -9,6 +9,7 @@ package org.elasticsearch.xpack.inference.services.elastic.sparseembeddings;
 
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.ChunkingSettings;
 import org.elasticsearch.inference.EmptySecretSettings;
 import org.elasticsearch.inference.EmptyTaskSettings;
 import org.elasticsearch.inference.ModelConfigurations;
@@ -39,7 +40,8 @@ public class ElasticInferenceServiceSparseEmbeddingsModel extends ElasticInferen
         Map<String, Object> taskSettings,
         Map<String, Object> secrets,
         ElasticInferenceServiceComponents elasticInferenceServiceComponents,
-        ConfigurationParseContext context
+        ConfigurationParseContext context,
+        ChunkingSettings chunkingSettings
     ) {
         this(
             inferenceEntityId,
@@ -48,7 +50,8 @@ public class ElasticInferenceServiceSparseEmbeddingsModel extends ElasticInferen
             ElasticInferenceServiceSparseEmbeddingsServiceSettings.fromMap(serviceSettings, context),
             EmptyTaskSettings.INSTANCE,
             EmptySecretSettings.INSTANCE,
-            elasticInferenceServiceComponents
+            elasticInferenceServiceComponents,
+            chunkingSettings
         );
     }
 
@@ -67,10 +70,11 @@ public class ElasticInferenceServiceSparseEmbeddingsModel extends ElasticInferen
         ElasticInferenceServiceSparseEmbeddingsServiceSettings serviceSettings,
         @Nullable TaskSettings taskSettings,
         @Nullable SecretSettings secretSettings,
-        ElasticInferenceServiceComponents elasticInferenceServiceComponents
+        ElasticInferenceServiceComponents elasticInferenceServiceComponents,
+        ChunkingSettings chunkingSettings
     ) {
         super(
-            new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings),
+            new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings),
             new ModelSecrets(secretSettings),
             serviceSettings,
             elasticInferenceServiceComponents

+ 3 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java

@@ -11,6 +11,7 @@ import org.elasticsearch.inference.EmptySecretSettings;
 import org.elasticsearch.inference.EmptyTaskSettings;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
 import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel;
 import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
 
@@ -28,7 +29,8 @@ public class ElasticInferenceServiceSparseEmbeddingsModelTests extends ESTestCas
             new ElasticInferenceServiceSparseEmbeddingsServiceSettings(modelId, maxInputTokens, null),
             EmptyTaskSettings.INSTANCE,
             EmptySecretSettings.INSTANCE,
-            ElasticInferenceServiceComponents.of(url)
+            ElasticInferenceServiceComponents.of(url),
+            ChunkingSettingsBuilder.DEFAULT_SETTINGS
         );
     }
 }

+ 1 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java

@@ -835,7 +835,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
         }
     }
 
-    public void testChunkedInfer_PassesThrough() throws IOException {
+    public void testChunkedInfer() throws IOException {
         var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
         var elasticInferenceServiceURL = getUrl(webServer);
 

+ 3 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java

@@ -21,6 +21,7 @@ import org.elasticsearch.plugins.Plugin;
 import org.elasticsearch.test.ESSingleNodeTestCase;
 import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
 import org.elasticsearch.xpack.inference.Utils;
+import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
 import org.elasticsearch.xpack.inference.external.http.sender.Sender;
 import org.elasticsearch.xpack.inference.registry.ModelRegistry;
 import org.elasticsearch.xpack.inference.services.elastic.DefaultModelConfig;
@@ -196,7 +197,8 @@ public class ElasticInferenceServiceAuthorizationHandlerTests extends ESSingleNo
                     new ElasticInferenceServiceSparseEmbeddingsServiceSettings("elser-v2", null, null),
                     EmptyTaskSettings.INSTANCE,
                     EmptySecretSettings.INSTANCE,
-                    ElasticInferenceServiceComponents.EMPTY_INSTANCE
+                    ElasticInferenceServiceComponents.EMPTY_INSTANCE,
+                    ChunkingSettingsBuilder.DEFAULT_SETTINGS
                 ),
                 MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME)
             )