Browse Source

[8.19] [EIS] Implement chunked & batched inference for sparse text embeddings (#129922) (#129958)

Tim Grein 3 months ago
parent
commit
2a57e4e9ba

+ 56 - 13
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

@@ -17,6 +17,7 @@ import org.elasticsearch.common.util.LazyInitializable;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.inference.ChunkedInference;
+import org.elasticsearch.inference.ChunkingSettings;
 import org.elasticsearch.inference.EmptySecretSettings;
 import org.elasticsearch.inference.EmptyTaskSettings;
 import org.elasticsearch.inference.InferenceServiceConfiguration;
@@ -36,6 +37,8 @@ import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
 import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
 import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
 import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
+import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
+import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
 import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
 import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
@@ -68,6 +71,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT
 import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
@@ -77,6 +81,7 @@ public class ElasticInferenceService extends SenderService {
 
     public static final String NAME = "elastic";
     public static final String ELASTIC_INFERENCE_SERVICE_IDENTIFIER = "Elastic Inference Service";
+    public static final int SPARSE_TEXT_EMBEDDING_MAX_BATCH_SIZE = 512;
 
     private static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(
         TaskType.SPARSE_EMBEDDING,
@@ -154,7 +159,8 @@ public class ElasticInferenceService extends SenderService {
                     new ElasticInferenceServiceSparseEmbeddingsServiceSettings(DEFAULT_ELSER_MODEL_ID_V2, null, null),
                     EmptyTaskSettings.INSTANCE,
                     EmptySecretSettings.INSTANCE,
-                    elasticInferenceServiceComponents
+                    elasticInferenceServiceComponents,
+                    ChunkingSettingsBuilder.DEFAULT_SETTINGS
                 ),
                 MinimalServiceSettings.sparseEmbedding(NAME)
             )
@@ -284,12 +290,25 @@ public class ElasticInferenceService extends SenderService {
         TimeValue timeout,
         ActionListener<List<ChunkedInference>> listener
     ) {
-        // Pass-through without actually performing chunking (result will have a single chunk per input)
-        ActionListener<InferenceServiceResults> inferListener = listener.delegateFailureAndWrap(
-            (delegate, response) -> delegate.onResponse(translateToChunkedResults(inputs, response))
-        );
+        if (model instanceof ElasticInferenceServiceSparseEmbeddingsModel sparseTextEmbeddingsModel) {
+            var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), getCurrentTraceInfo());
+
+            List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
+                inputs.getInputs(),
+                SPARSE_TEXT_EMBEDDING_MAX_BATCH_SIZE,
+                model.getConfigurations().getChunkingSettings()
+            ).batchRequestsWithListeners(listener);
+
+            for (var request : batchedRequests) {
+                var action = sparseTextEmbeddingsModel.accept(actionCreator, taskSettings);
+                action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
+            }
+
+            return;
+        }
 
-        doInfer(model, inputs, taskSettings, timeout, inferListener);
+        // Model cannot perform chunked inference
+        listener.onFailure(createInvalidModelException(model));
     }
 
     @Override
@@ -308,6 +327,13 @@ public class ElasticInferenceService extends SenderService {
             Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
             Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
 
+            ChunkingSettings chunkingSettings = null;
+            if (TaskType.SPARSE_EMBEDDING.equals(taskType)) {
+                chunkingSettings = ChunkingSettingsBuilder.fromMap(
+                    removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)
+                );
+            }
+
             ElasticInferenceServiceModel model = createModel(
                 inferenceEntityId,
                 taskType,
@@ -316,7 +342,8 @@ public class ElasticInferenceService extends SenderService {
                 serviceSettingsMap,
                 elasticInferenceServiceComponents,
                 TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
-                ConfigurationParseContext.REQUEST
+                ConfigurationParseContext.REQUEST,
+                chunkingSettings
             );
 
             throwIfNotEmptyMap(config, NAME);
@@ -352,7 +379,8 @@ public class ElasticInferenceService extends SenderService {
         @Nullable Map<String, Object> secretSettings,
         ElasticInferenceServiceComponents elasticInferenceServiceComponents,
         String failureMessage,
-        ConfigurationParseContext context
+        ConfigurationParseContext context,
+        ChunkingSettings chunkingSettings
     ) {
         return switch (taskType) {
             case SPARSE_EMBEDDING -> new ElasticInferenceServiceSparseEmbeddingsModel(
@@ -363,7 +391,8 @@ public class ElasticInferenceService extends SenderService {
                 taskSettings,
                 secretSettings,
                 elasticInferenceServiceComponents,
-                context
+                context,
+                chunkingSettings
             );
             case CHAT_COMPLETION -> new ElasticInferenceServiceCompletionModel(
                 inferenceEntityId,
@@ -400,13 +429,19 @@ public class ElasticInferenceService extends SenderService {
         Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
         Map<String, Object> secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS);
 
+        ChunkingSettings chunkingSettings = null;
+        if (TaskType.SPARSE_EMBEDDING.equals(taskType)) {
+            chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
+        }
+
         return createModelFromPersistent(
             inferenceEntityId,
             taskType,
             serviceSettingsMap,
             taskSettingsMap,
             secretSettingsMap,
-            parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
+            parsePersistedConfigErrorMsg(inferenceEntityId, NAME),
+            chunkingSettings
         );
     }
 
@@ -415,13 +450,19 @@ public class ElasticInferenceService extends SenderService {
         Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
         Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
 
+        ChunkingSettings chunkingSettings = null;
+        if (TaskType.SPARSE_EMBEDDING.equals(taskType)) {
+            chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
+        }
+
         return createModelFromPersistent(
             inferenceEntityId,
             taskType,
             serviceSettingsMap,
             taskSettingsMap,
             null,
-            parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
+            parsePersistedConfigErrorMsg(inferenceEntityId, NAME),
+            chunkingSettings
         );
     }
 
@@ -436,7 +477,8 @@ public class ElasticInferenceService extends SenderService {
         Map<String, Object> serviceSettings,
         Map<String, Object> taskSettings,
         @Nullable Map<String, Object> secretSettings,
-        String failureMessage
+        String failureMessage,
+        ChunkingSettings chunkingSettings
     ) {
         return createModel(
             inferenceEntityId,
@@ -446,7 +488,8 @@ public class ElasticInferenceService extends SenderService {
             secretSettings,
             elasticInferenceServiceComponents,
             failureMessage,
-            ConfigurationParseContext.PERSISTENT
+            ConfigurationParseContext.PERSISTENT,
+            chunkingSettings
         );
     }
 

+ 8 - 4
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModel.java

@@ -9,6 +9,7 @@ package org.elasticsearch.xpack.inference.services.elastic;
 
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.ChunkingSettings;
 import org.elasticsearch.inference.EmptySecretSettings;
 import org.elasticsearch.inference.EmptyTaskSettings;
 import org.elasticsearch.inference.ModelConfigurations;
@@ -37,7 +38,8 @@ public class ElasticInferenceServiceSparseEmbeddingsModel extends ElasticInferen
         Map<String, Object> taskSettings,
         Map<String, Object> secrets,
         ElasticInferenceServiceComponents elasticInferenceServiceComponents,
-        ConfigurationParseContext context
+        ConfigurationParseContext context,
+        ChunkingSettings chunkingSettings
     ) {
         this(
             inferenceEntityId,
@@ -46,7 +48,8 @@ public class ElasticInferenceServiceSparseEmbeddingsModel extends ElasticInferen
             ElasticInferenceServiceSparseEmbeddingsServiceSettings.fromMap(serviceSettings, context),
             EmptyTaskSettings.INSTANCE,
             EmptySecretSettings.INSTANCE,
-            elasticInferenceServiceComponents
+            elasticInferenceServiceComponents,
+            chunkingSettings
         );
     }
 
@@ -65,10 +68,11 @@ public class ElasticInferenceServiceSparseEmbeddingsModel extends ElasticInferen
         ElasticInferenceServiceSparseEmbeddingsServiceSettings serviceSettings,
         @Nullable TaskSettings taskSettings,
         @Nullable SecretSettings secretSettings,
-        ElasticInferenceServiceComponents elasticInferenceServiceComponents
+        ElasticInferenceServiceComponents elasticInferenceServiceComponents,
+        ChunkingSettings chunkingSettings
     ) {
         super(
-            new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings),
+            new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings),
             new ModelSecrets(secretSettings),
             serviceSettings,
             elasticInferenceServiceComponents

+ 3 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java

@@ -11,6 +11,7 @@ import org.elasticsearch.inference.EmptySecretSettings;
 import org.elasticsearch.inference.EmptyTaskSettings;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
 
 public class ElasticInferenceServiceSparseEmbeddingsModelTests extends ESTestCase {
 
@@ -26,7 +27,8 @@ public class ElasticInferenceServiceSparseEmbeddingsModelTests extends ESTestCas
             new ElasticInferenceServiceSparseEmbeddingsServiceSettings(modelId, maxInputTokens, null),
             EmptyTaskSettings.INSTANCE,
             EmptySecretSettings.INSTANCE,
-            ElasticInferenceServiceComponents.of(url)
+            ElasticInferenceServiceComponents.of(url),
+            ChunkingSettingsBuilder.DEFAULT_SETTINGS
         );
     }
 }

+ 1 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java

@@ -834,7 +834,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
         }
     }
 
-    public void testChunkedInfer_PassesThrough() throws IOException {
+    public void testChunkedInfer() throws IOException {
         var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
         var elasticInferenceServiceURL = getUrl(webServer);
 

+ 3 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java

@@ -21,6 +21,7 @@ import org.elasticsearch.plugins.Plugin;
 import org.elasticsearch.test.ESSingleNodeTestCase;
 import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
 import org.elasticsearch.xpack.inference.Utils;
+import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
 import org.elasticsearch.xpack.inference.external.http.sender.Sender;
 import org.elasticsearch.xpack.inference.registry.ModelRegistry;
 import org.elasticsearch.xpack.inference.services.elastic.DefaultModelConfig;
@@ -196,7 +197,8 @@ public class ElasticInferenceServiceAuthorizationHandlerTests extends ESSingleNo
                     new ElasticInferenceServiceSparseEmbeddingsServiceSettings("elser-v2", null, null),
                     EmptyTaskSettings.INSTANCE,
                     EmptySecretSettings.INSTANCE,
-                    ElasticInferenceServiceComponents.EMPTY_INSTANCE
+                    ElasticInferenceServiceComponents.EMPTY_INSTANCE,
+                    ChunkingSettingsBuilder.DEFAULT_SETTINGS
                 ),
                 MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME)
             )