|
@@ -17,6 +17,7 @@ import org.elasticsearch.common.util.LazyInitializable;
|
|
|
import org.elasticsearch.core.Nullable;
|
|
|
import org.elasticsearch.core.TimeValue;
|
|
|
import org.elasticsearch.inference.ChunkedInference;
|
|
|
+import org.elasticsearch.inference.ChunkingSettings;
|
|
|
import org.elasticsearch.inference.EmptySecretSettings;
|
|
|
import org.elasticsearch.inference.EmptyTaskSettings;
|
|
|
import org.elasticsearch.inference.InferenceServiceConfiguration;
|
|
@@ -36,6 +37,8 @@ import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
|
|
|
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
|
|
|
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
|
|
|
+import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
|
|
|
+import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
|
|
|
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
|
|
|
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
|
|
|
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
|
@@ -71,6 +74,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT
|
|
|
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
|
|
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
|
|
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
|
|
|
+import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
|
|
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
|
|
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
|
|
|
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
|
|
@@ -80,6 +84,7 @@ public class ElasticInferenceService extends SenderService {
|
|
|
|
|
|
public static final String NAME = "elastic";
|
|
|
public static final String ELASTIC_INFERENCE_SERVICE_IDENTIFIER = "Elastic Inference Service";
|
|
|
+ public static final int SPARSE_TEXT_EMBEDDING_MAX_BATCH_SIZE = 512;
|
|
|
|
|
|
private static final EnumSet<TaskType> IMPLEMENTED_TASK_TYPES = EnumSet.of(
|
|
|
TaskType.SPARSE_EMBEDDING,
|
|
@@ -161,7 +166,8 @@ public class ElasticInferenceService extends SenderService {
|
|
|
new ElasticInferenceServiceSparseEmbeddingsServiceSettings(DEFAULT_ELSER_MODEL_ID_V2, null, null),
|
|
|
EmptyTaskSettings.INSTANCE,
|
|
|
EmptySecretSettings.INSTANCE,
|
|
|
- elasticInferenceServiceComponents
|
|
|
+ elasticInferenceServiceComponents,
|
|
|
+ ChunkingSettingsBuilder.DEFAULT_SETTINGS
|
|
|
),
|
|
|
MinimalServiceSettings.sparseEmbedding(NAME)
|
|
|
),
|
|
@@ -304,12 +310,25 @@ public class ElasticInferenceService extends SenderService {
|
|
|
TimeValue timeout,
|
|
|
ActionListener<List<ChunkedInference>> listener
|
|
|
) {
|
|
|
- // Pass-through without actually performing chunking (result will have a single chunk per input)
|
|
|
- ActionListener<InferenceServiceResults> inferListener = listener.delegateFailureAndWrap(
|
|
|
- (delegate, response) -> delegate.onResponse(translateToChunkedResults(inputs, response))
|
|
|
- );
|
|
|
+ if (model instanceof ElasticInferenceServiceSparseEmbeddingsModel sparseTextEmbeddingsModel) {
|
|
|
+ var actionCreator = new ElasticInferenceServiceActionCreator(getSender(), getServiceComponents(), getCurrentTraceInfo());
|
|
|
+
|
|
|
+ List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
|
|
|
+ inputs.getInputs(),
|
|
|
+ SPARSE_TEXT_EMBEDDING_MAX_BATCH_SIZE,
|
|
|
+ model.getConfigurations().getChunkingSettings()
|
|
|
+ ).batchRequestsWithListeners(listener);
|
|
|
+
|
|
|
+ for (var request : batchedRequests) {
|
|
|
+ var action = sparseTextEmbeddingsModel.accept(actionCreator, taskSettings);
|
|
|
+ action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
|
|
|
+ }
|
|
|
+
|
|
|
+ return;
|
|
|
+ }
|
|
|
|
|
|
- doInfer(model, inputs, taskSettings, timeout, inferListener);
|
|
|
+ // Model cannot perform chunked inference
|
|
|
+ listener.onFailure(createInvalidModelException(model));
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -328,6 +347,13 @@ public class ElasticInferenceService extends SenderService {
|
|
|
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
|
|
|
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
|
|
|
|
|
|
+ ChunkingSettings chunkingSettings = null;
|
|
|
+ if (TaskType.SPARSE_EMBEDDING.equals(taskType)) {
|
|
|
+ chunkingSettings = ChunkingSettingsBuilder.fromMap(
|
|
|
+ removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
ElasticInferenceServiceModel model = createModel(
|
|
|
inferenceEntityId,
|
|
|
taskType,
|
|
@@ -336,7 +362,8 @@ public class ElasticInferenceService extends SenderService {
|
|
|
serviceSettingsMap,
|
|
|
elasticInferenceServiceComponents,
|
|
|
TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
|
|
|
- ConfigurationParseContext.REQUEST
|
|
|
+ ConfigurationParseContext.REQUEST,
|
|
|
+ chunkingSettings
|
|
|
);
|
|
|
|
|
|
throwIfNotEmptyMap(config, NAME);
|
|
@@ -372,7 +399,8 @@ public class ElasticInferenceService extends SenderService {
|
|
|
@Nullable Map<String, Object> secretSettings,
|
|
|
ElasticInferenceServiceComponents elasticInferenceServiceComponents,
|
|
|
String failureMessage,
|
|
|
- ConfigurationParseContext context
|
|
|
+ ConfigurationParseContext context,
|
|
|
+ ChunkingSettings chunkingSettings
|
|
|
) {
|
|
|
return switch (taskType) {
|
|
|
case SPARSE_EMBEDDING -> new ElasticInferenceServiceSparseEmbeddingsModel(
|
|
@@ -383,7 +411,8 @@ public class ElasticInferenceService extends SenderService {
|
|
|
taskSettings,
|
|
|
secretSettings,
|
|
|
elasticInferenceServiceComponents,
|
|
|
- context
|
|
|
+ context,
|
|
|
+ chunkingSettings
|
|
|
);
|
|
|
case CHAT_COMPLETION -> new ElasticInferenceServiceCompletionModel(
|
|
|
inferenceEntityId,
|
|
@@ -420,13 +449,19 @@ public class ElasticInferenceService extends SenderService {
|
|
|
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
|
|
|
Map<String, Object> secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS);
|
|
|
|
|
|
+ ChunkingSettings chunkingSettings = null;
|
|
|
+ if (TaskType.SPARSE_EMBEDDING.equals(taskType)) {
|
|
|
+ chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
|
|
|
+ }
|
|
|
+
|
|
|
return createModelFromPersistent(
|
|
|
inferenceEntityId,
|
|
|
taskType,
|
|
|
serviceSettingsMap,
|
|
|
taskSettingsMap,
|
|
|
secretSettingsMap,
|
|
|
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
|
|
|
+ parsePersistedConfigErrorMsg(inferenceEntityId, NAME),
|
|
|
+ chunkingSettings
|
|
|
);
|
|
|
}
|
|
|
|
|
@@ -435,13 +470,19 @@ public class ElasticInferenceService extends SenderService {
|
|
|
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
|
|
|
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
|
|
|
|
|
|
+ ChunkingSettings chunkingSettings = null;
|
|
|
+ if (TaskType.SPARSE_EMBEDDING.equals(taskType)) {
|
|
|
+ chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
|
|
|
+ }
|
|
|
+
|
|
|
return createModelFromPersistent(
|
|
|
inferenceEntityId,
|
|
|
taskType,
|
|
|
serviceSettingsMap,
|
|
|
taskSettingsMap,
|
|
|
null,
|
|
|
- parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
|
|
|
+ parsePersistedConfigErrorMsg(inferenceEntityId, NAME),
|
|
|
+ chunkingSettings
|
|
|
);
|
|
|
}
|
|
|
|
|
@@ -456,7 +497,8 @@ public class ElasticInferenceService extends SenderService {
|
|
|
Map<String, Object> serviceSettings,
|
|
|
Map<String, Object> taskSettings,
|
|
|
@Nullable Map<String, Object> secretSettings,
|
|
|
- String failureMessage
|
|
|
+ String failureMessage,
|
|
|
+ ChunkingSettings chunkingSettings
|
|
|
) {
|
|
|
return createModel(
|
|
|
inferenceEntityId,
|
|
@@ -466,7 +508,8 @@ public class ElasticInferenceService extends SenderService {
|
|
|
secretSettings,
|
|
|
elasticInferenceServiceComponents,
|
|
|
failureMessage,
|
|
|
- ConfigurationParseContext.PERSISTENT
|
|
|
+ ConfigurationParseContext.PERSISTENT,
|
|
|
+ chunkingSettings
|
|
|
);
|
|
|
}
|
|
|
|