|
@@ -25,7 +25,11 @@ import org.elasticsearch.action.update.UpdateRequest;
|
|
|
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
|
|
|
import org.elasticsearch.cluster.metadata.ProjectMetadata;
|
|
|
import org.elasticsearch.cluster.service.ClusterService;
|
|
|
+import org.elasticsearch.common.bytes.BytesReference;
|
|
|
+import org.elasticsearch.common.settings.Setting;
|
|
|
+import org.elasticsearch.common.unit.ByteSizeValue;
|
|
|
import org.elasticsearch.common.util.concurrent.AtomicArray;
|
|
|
+import org.elasticsearch.common.xcontent.XContentHelper;
|
|
|
import org.elasticsearch.common.xcontent.support.XContentMapValues;
|
|
|
import org.elasticsearch.core.Nullable;
|
|
|
import org.elasticsearch.core.Releasable;
|
|
@@ -43,6 +47,10 @@ import org.elasticsearch.license.XPackLicenseState;
|
|
|
import org.elasticsearch.rest.RestStatus;
|
|
|
import org.elasticsearch.tasks.Task;
|
|
|
import org.elasticsearch.xcontent.XContent;
|
|
|
+import org.elasticsearch.xcontent.XContentBuilder;
|
|
|
+import org.elasticsearch.xcontent.XContentParser;
|
|
|
+import org.elasticsearch.xcontent.XContentParserConfiguration;
|
|
|
+import org.elasticsearch.xcontent.XContentType;
|
|
|
import org.elasticsearch.xpack.core.XPackField;
|
|
|
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError;
|
|
|
import org.elasticsearch.xpack.inference.InferenceException;
|
|
@@ -63,6 +71,8 @@ import java.util.Map;
|
|
|
import java.util.stream.Collectors;
|
|
|
|
|
|
import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
|
|
|
+import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunks;
|
|
|
+import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.toSemanticTextFieldChunksLegacy;
|
|
|
|
|
|
/**
|
|
|
* A {@link MappedActionFilter} that intercepts {@link BulkShardRequest} to apply inference on fields specified
|
|
@@ -72,10 +82,23 @@ import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FE
|
|
|
* This transformation happens on the bulk coordinator node, and the {@link SemanticTextFieldMapper} parses the
|
|
|
* results during indexing on the shard.
|
|
|
*
|
|
|
- * TODO: batchSize should be configurable via a cluster setting
|
|
|
*/
|
|
|
public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|
|
- protected static final int DEFAULT_BATCH_SIZE = 512;
|
|
|
+ private static final ByteSizeValue DEFAULT_BATCH_SIZE = ByteSizeValue.ofMb(1);
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Defines the cumulative size limit of input data before triggering a batch inference call.
|
|
|
+ * This setting controls how much data can be accumulated before an inference request is sent in batch.
|
|
|
+ */
|
|
|
+ public static Setting<ByteSizeValue> INDICES_INFERENCE_BATCH_SIZE = Setting.byteSizeSetting(
|
|
|
+ "indices.inference.batch_size",
|
|
|
+ DEFAULT_BATCH_SIZE,
|
|
|
+ ByteSizeValue.ONE,
|
|
|
+ ByteSizeValue.ofMb(100),
|
|
|
+ Setting.Property.NodeScope,
|
|
|
+ Setting.Property.OperatorDynamic
|
|
|
+ );
|
|
|
+
|
|
|
private static final Object EXPLICIT_NULL = new Object();
|
|
|
private static final ChunkedInference EMPTY_CHUNKED_INFERENCE = new EmptyChunkedInference();
|
|
|
|
|
@@ -83,29 +106,24 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|
|
private final InferenceServiceRegistry inferenceServiceRegistry;
|
|
|
private final ModelRegistry modelRegistry;
|
|
|
private final XPackLicenseState licenseState;
|
|
|
- private final int batchSize;
|
|
|
+ private volatile long batchSizeInBytes;
|
|
|
|
|
|
public ShardBulkInferenceActionFilter(
|
|
|
ClusterService clusterService,
|
|
|
InferenceServiceRegistry inferenceServiceRegistry,
|
|
|
ModelRegistry modelRegistry,
|
|
|
XPackLicenseState licenseState
|
|
|
- ) {
|
|
|
- this(clusterService, inferenceServiceRegistry, modelRegistry, licenseState, DEFAULT_BATCH_SIZE);
|
|
|
- }
|
|
|
-
|
|
|
- public ShardBulkInferenceActionFilter(
|
|
|
- ClusterService clusterService,
|
|
|
- InferenceServiceRegistry inferenceServiceRegistry,
|
|
|
- ModelRegistry modelRegistry,
|
|
|
- XPackLicenseState licenseState,
|
|
|
- int batchSize
|
|
|
) {
|
|
|
this.clusterService = clusterService;
|
|
|
this.inferenceServiceRegistry = inferenceServiceRegistry;
|
|
|
this.modelRegistry = modelRegistry;
|
|
|
this.licenseState = licenseState;
|
|
|
- this.batchSize = batchSize;
|
|
|
+ this.batchSizeInBytes = INDICES_INFERENCE_BATCH_SIZE.get(clusterService.getSettings()).getBytes();
|
|
|
+ clusterService.getClusterSettings().addSettingsUpdateConsumer(INDICES_INFERENCE_BATCH_SIZE, this::setBatchSize);
|
|
|
+ }
|
|
|
+
|
|
|
+ private void setBatchSize(ByteSizeValue newBatchSize) {
|
|
|
+ batchSizeInBytes = newBatchSize.getBytes();
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -148,14 +166,21 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|
|
|
|
|
/**
|
|
|
* A field inference request on a single input.
|
|
|
- * @param index The index of the request in the original bulk request.
|
|
|
+ * @param bulkItemIndex The index of the item in the original bulk request.
|
|
|
* @param field The target field.
|
|
|
* @param sourceField The source field.
|
|
|
* @param input The input to run inference on.
|
|
|
* @param inputOrder The original order of the input.
|
|
|
* @param offsetAdjustment The adjustment to apply to the chunk text offsets.
|
|
|
*/
|
|
|
- private record FieldInferenceRequest(int index, String field, String sourceField, String input, int inputOrder, int offsetAdjustment) {}
|
|
|
+ private record FieldInferenceRequest(
|
|
|
+ int bulkItemIndex,
|
|
|
+ String field,
|
|
|
+ String sourceField,
|
|
|
+ String input,
|
|
|
+ int inputOrder,
|
|
|
+ int offsetAdjustment
|
|
|
+ ) {}
|
|
|
|
|
|
/**
|
|
|
* The field inference response.
|
|
@@ -218,29 +243,54 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|
|
|
|
|
@Override
|
|
|
public void run() {
|
|
|
- Map<String, List<FieldInferenceRequest>> inferenceRequests = createFieldInferenceRequests(bulkShardRequest);
|
|
|
+ executeNext(0);
|
|
|
+ }
|
|
|
+
|
|
|
+ private void executeNext(int itemOffset) {
|
|
|
+ if (itemOffset >= bulkShardRequest.items().length) {
|
|
|
+ onCompletion.run();
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ var items = bulkShardRequest.items();
|
|
|
+ Map<String, List<FieldInferenceRequest>> fieldRequestsMap = new HashMap<>();
|
|
|
+ long totalInputLength = 0;
|
|
|
+ int itemIndex = itemOffset;
|
|
|
+ while (itemIndex < items.length && totalInputLength < batchSizeInBytes) {
|
|
|
+ var item = items[itemIndex];
|
|
|
+ totalInputLength += addFieldInferenceRequests(item, itemIndex, fieldRequestsMap);
|
|
|
+ itemIndex += 1;
|
|
|
+ }
|
|
|
+ int nextItemOffset = itemIndex;
|
|
|
Runnable onInferenceCompletion = () -> {
|
|
|
try {
|
|
|
- for (var inferenceResponse : inferenceResults.asList()) {
|
|
|
- var request = bulkShardRequest.items()[inferenceResponse.id];
|
|
|
+ for (int i = itemOffset; i < nextItemOffset; i++) {
|
|
|
+ var result = inferenceResults.get(i);
|
|
|
+ if (result == null) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ var item = items[i];
|
|
|
try {
|
|
|
- applyInferenceResponses(request, inferenceResponse);
|
|
|
+ applyInferenceResponses(item, result);
|
|
|
} catch (Exception exc) {
|
|
|
- request.abort(bulkShardRequest.index(), exc);
|
|
|
+ item.abort(bulkShardRequest.index(), exc);
|
|
|
}
|
|
|
+ // we don't need to keep the inference results around
|
|
|
+ inferenceResults.set(i, null);
|
|
|
}
|
|
|
} finally {
|
|
|
- onCompletion.run();
|
|
|
+ executeNext(nextItemOffset);
|
|
|
}
|
|
|
};
|
|
|
+
|
|
|
try (var releaseOnFinish = new RefCountingRunnable(onInferenceCompletion)) {
|
|
|
- for (var entry : inferenceRequests.entrySet()) {
|
|
|
- executeShardBulkInferenceAsync(entry.getKey(), null, entry.getValue(), releaseOnFinish.acquire());
|
|
|
+ for (var entry : fieldRequestsMap.entrySet()) {
|
|
|
+ executeChunkedInferenceAsync(entry.getKey(), null, entry.getValue(), releaseOnFinish.acquire());
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- private void executeShardBulkInferenceAsync(
|
|
|
+ private void executeChunkedInferenceAsync(
|
|
|
final String inferenceId,
|
|
|
@Nullable InferenceProvider inferenceProvider,
|
|
|
final List<FieldInferenceRequest> requests,
|
|
@@ -262,11 +312,11 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|
|
unparsedModel.secrets()
|
|
|
)
|
|
|
);
|
|
|
- executeShardBulkInferenceAsync(inferenceId, provider, requests, onFinish);
|
|
|
+ executeChunkedInferenceAsync(inferenceId, provider, requests, onFinish);
|
|
|
} else {
|
|
|
try (onFinish) {
|
|
|
for (FieldInferenceRequest request : requests) {
|
|
|
- inferenceResults.get(request.index).failures.add(
|
|
|
+ inferenceResults.get(request.bulkItemIndex).failures.add(
|
|
|
new ResourceNotFoundException(
|
|
|
"Inference service [{}] not found for field [{}]",
|
|
|
unparsedModel.service(),
|
|
@@ -297,7 +347,7 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|
|
request.field
|
|
|
);
|
|
|
}
|
|
|
- inferenceResults.get(request.index).failures.add(failure);
|
|
|
+ inferenceResults.get(request.bulkItemIndex).failures.add(failure);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -305,18 +355,15 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|
|
modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener);
|
|
|
return;
|
|
|
}
|
|
|
- int currentBatchSize = Math.min(requests.size(), batchSize);
|
|
|
- final List<FieldInferenceRequest> currentBatch = requests.subList(0, currentBatchSize);
|
|
|
- final List<FieldInferenceRequest> nextBatch = requests.subList(currentBatchSize, requests.size());
|
|
|
- final List<String> inputs = currentBatch.stream().map(FieldInferenceRequest::input).collect(Collectors.toList());
|
|
|
+ final List<String> inputs = requests.stream().map(FieldInferenceRequest::input).collect(Collectors.toList());
|
|
|
ActionListener<List<ChunkedInference>> completionListener = new ActionListener<>() {
|
|
|
@Override
|
|
|
public void onResponse(List<ChunkedInference> results) {
|
|
|
- try {
|
|
|
+ try (onFinish) {
|
|
|
var requestsIterator = requests.iterator();
|
|
|
for (ChunkedInference result : results) {
|
|
|
var request = requestsIterator.next();
|
|
|
- var acc = inferenceResults.get(request.index);
|
|
|
+ var acc = inferenceResults.get(request.bulkItemIndex);
|
|
|
if (result instanceof ChunkedInferenceError error) {
|
|
|
acc.addFailure(
|
|
|
new InferenceException(
|
|
@@ -331,7 +378,7 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|
|
new FieldInferenceResponse(
|
|
|
request.field(),
|
|
|
request.sourceField(),
|
|
|
- request.input(),
|
|
|
+ useLegacyFormat ? request.input() : null,
|
|
|
request.inputOrder(),
|
|
|
request.offsetAdjustment(),
|
|
|
inferenceProvider.model,
|
|
@@ -340,17 +387,15 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|
|
);
|
|
|
}
|
|
|
}
|
|
|
- } finally {
|
|
|
- onFinish();
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
public void onFailure(Exception exc) {
|
|
|
- try {
|
|
|
+ try (onFinish) {
|
|
|
for (FieldInferenceRequest request : requests) {
|
|
|
addInferenceResponseFailure(
|
|
|
- request.index,
|
|
|
+ request.bulkItemIndex,
|
|
|
new InferenceException(
|
|
|
"Exception when running inference id [{}] on field [{}]",
|
|
|
exc,
|
|
@@ -359,16 +404,6 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|
|
)
|
|
|
);
|
|
|
}
|
|
|
- } finally {
|
|
|
- onFinish();
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- private void onFinish() {
|
|
|
- if (nextBatch.isEmpty()) {
|
|
|
- onFinish.close();
|
|
|
- } else {
|
|
|
- executeShardBulkInferenceAsync(inferenceId, inferenceProvider, nextBatch, onFinish);
|
|
|
}
|
|
|
}
|
|
|
};
|
|
@@ -376,6 +411,132 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|
|
.chunkedInfer(inferenceProvider.model(), null, inputs, Map.of(), InputType.INGEST, TimeValue.MAX_VALUE, completionListener);
|
|
|
}
|
|
|
|
|
|
+ /**
|
|
|
+ * Adds all inference requests associated with their respective inference IDs to the given {@code requestsMap}
|
|
|
+ * for the specified {@code item}.
|
|
|
+ *
|
|
|
+ * @param item The bulk request item to process.
|
|
|
+ * @param itemIndex The position of the item within the original bulk request.
|
|
|
+ * @param requestsMap A map storing inference requests, where each key is an inference ID,
|
|
|
+ * and the value is a list of associated {@link FieldInferenceRequest} objects.
|
|
|
+ * @return The total content length of all newly added requests, or {@code 0} if no requests were added.
|
|
|
+ */
|
|
|
+ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<String, List<FieldInferenceRequest>> requestsMap) {
|
|
|
+ boolean isUpdateRequest = false;
|
|
|
+ final IndexRequest indexRequest;
|
|
|
+ if (item.request() instanceof IndexRequest ir) {
|
|
|
+ indexRequest = ir;
|
|
|
+ } else if (item.request() instanceof UpdateRequest updateRequest) {
|
|
|
+ isUpdateRequest = true;
|
|
|
+ if (updateRequest.script() != null) {
|
|
|
+ addInferenceResponseFailure(
|
|
|
+ itemIndex,
|
|
|
+ new ElasticsearchStatusException(
|
|
|
+ "Cannot apply update with a script on indices that contain [{}] field(s)",
|
|
|
+ RestStatus.BAD_REQUEST,
|
|
|
+ SemanticTextFieldMapper.CONTENT_TYPE
|
|
|
+ )
|
|
|
+ );
|
|
|
+ return 0;
|
|
|
+ }
|
|
|
+ indexRequest = updateRequest.doc();
|
|
|
+ } else {
|
|
|
+ // ignore delete request
|
|
|
+ return 0;
|
|
|
+ }
|
|
|
+
|
|
|
+ final Map<String, Object> docMap = indexRequest.sourceAsMap();
|
|
|
+ long inputLength = 0;
|
|
|
+ for (var entry : fieldInferenceMap.values()) {
|
|
|
+ String field = entry.getName();
|
|
|
+ String inferenceId = entry.getInferenceId();
|
|
|
+
|
|
|
+ if (useLegacyFormat) {
|
|
|
+ var originalFieldValue = XContentMapValues.extractValue(field, docMap);
|
|
|
+ if (originalFieldValue instanceof Map || (originalFieldValue == null && entry.getSourceFields().length == 1)) {
|
|
|
+ // Inference has already been computed, or there is no inference required.
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ var inferenceMetadataFieldsValue = XContentMapValues.extractValue(
|
|
|
+ InferenceMetadataFieldsMapper.NAME + "." + field,
|
|
|
+ docMap,
|
|
|
+ EXPLICIT_NULL
|
|
|
+ );
|
|
|
+ if (inferenceMetadataFieldsValue != null) {
|
|
|
+ // Inference has already been computed
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ int order = 0;
|
|
|
+ for (var sourceField : entry.getSourceFields()) {
|
|
|
+ var valueObj = XContentMapValues.extractValue(sourceField, docMap, EXPLICIT_NULL);
|
|
|
+ if (useLegacyFormat == false && isUpdateRequest && valueObj == EXPLICIT_NULL) {
|
|
|
+ /**
|
|
|
+ * It's an update request, and the source field is explicitly set to null,
|
|
|
+ * so we need to propagate this information to the inference fields metadata
|
|
|
+ * to overwrite any inference previously computed on the field.
|
|
|
+ * This ensures that the field is treated as intentionally cleared,
|
|
|
+ * preventing any unintended carryover of prior inference results.
|
|
|
+ */
|
|
|
+ var slot = ensureResponseAccumulatorSlot(itemIndex);
|
|
|
+ slot.addOrUpdateResponse(
|
|
|
+ new FieldInferenceResponse(field, sourceField, null, order++, 0, null, EMPTY_CHUNKED_INFERENCE)
|
|
|
+ );
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ if (valueObj == null || valueObj == EXPLICIT_NULL) {
|
|
|
+ if (isUpdateRequest && useLegacyFormat) {
|
|
|
+ addInferenceResponseFailure(
|
|
|
+ itemIndex,
|
|
|
+ new ElasticsearchStatusException(
|
|
|
+ "Field [{}] must be specified on an update request to calculate inference for field [{}]",
|
|
|
+ RestStatus.BAD_REQUEST,
|
|
|
+ sourceField,
|
|
|
+ field
|
|
|
+ )
|
|
|
+ );
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ var slot = ensureResponseAccumulatorSlot(itemIndex);
|
|
|
+ final List<String> values;
|
|
|
+ try {
|
|
|
+ values = SemanticTextUtils.nodeStringValues(field, valueObj);
|
|
|
+ } catch (Exception exc) {
|
|
|
+ addInferenceResponseFailure(itemIndex, exc);
|
|
|
+ break;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (INFERENCE_API_FEATURE.check(licenseState) == false) {
|
|
|
+ addInferenceResponseFailure(itemIndex, LicenseUtils.newComplianceException(XPackField.INFERENCE));
|
|
|
+ break;
|
|
|
+ }
|
|
|
+
|
|
|
+ List<FieldInferenceRequest> requests = requestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>());
|
|
|
+ int offsetAdjustment = 0;
|
|
|
+ for (String v : values) {
|
|
|
+ inputLength += v.length();
|
|
|
+ if (v.isBlank()) {
|
|
|
+ slot.addOrUpdateResponse(
|
|
|
+ new FieldInferenceResponse(field, sourceField, v, order++, 0, null, EMPTY_CHUNKED_INFERENCE)
|
|
|
+ );
|
|
|
+ } else {
|
|
|
+ requests.add(new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment));
|
|
|
+ }
|
|
|
+
|
|
|
+ // When using the inference metadata fields format, all the input values are concatenated so that the
|
|
|
+ // chunk text offsets are expressed in the context of a single string. Calculate the offset adjustment
|
|
|
+ // to apply to account for this.
|
|
|
+ offsetAdjustment += v.length() + 1; // Add one for separator char length
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return inputLength;
|
|
|
+ }
|
|
|
+
|
|
|
private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) {
|
|
|
FieldInferenceResponseAccumulator acc = inferenceResults.get(id);
|
|
|
if (acc == null) {
|
|
@@ -404,7 +565,6 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|
|
}
|
|
|
|
|
|
final IndexRequest indexRequest = getIndexRequestOrNull(item.request());
|
|
|
- var newDocMap = indexRequest.sourceAsMap();
|
|
|
Map<String, Object> inferenceFieldsMap = new HashMap<>();
|
|
|
for (var entry : response.responses.entrySet()) {
|
|
|
var fieldName = entry.getKey();
|
|
@@ -426,28 +586,22 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|
|
}
|
|
|
|
|
|
var lst = chunkMap.computeIfAbsent(resp.sourceField, k -> new ArrayList<>());
|
|
|
- lst.addAll(
|
|
|
- SemanticTextField.toSemanticTextFieldChunks(
|
|
|
- resp.input,
|
|
|
- resp.offsetAdjustment,
|
|
|
- resp.chunkedResults,
|
|
|
- indexRequest.getContentType(),
|
|
|
- useLegacyFormat
|
|
|
- )
|
|
|
- );
|
|
|
+ var chunks = useLegacyFormat
|
|
|
+ ? toSemanticTextFieldChunksLegacy(resp.input, resp.chunkedResults, indexRequest.getContentType())
|
|
|
+ : toSemanticTextFieldChunks(resp.offsetAdjustment, resp.chunkedResults, indexRequest.getContentType());
|
|
|
+ lst.addAll(chunks);
|
|
|
}
|
|
|
|
|
|
- List<String> inputs = responses.stream()
|
|
|
- .filter(r -> r.sourceField().equals(fieldName))
|
|
|
- .map(r -> r.input)
|
|
|
- .collect(Collectors.toList());
|
|
|
+ List<String> inputs = useLegacyFormat
|
|
|
+ ? responses.stream().filter(r -> r.sourceField().equals(fieldName)).map(r -> r.input).collect(Collectors.toList())
|
|
|
+ : null;
|
|
|
|
|
|
// The model can be null if we are only processing update requests that clear inference results. This is ok because we will
|
|
|
// merge in the field's existing model settings on the data node.
|
|
|
var result = new SemanticTextField(
|
|
|
useLegacyFormat,
|
|
|
fieldName,
|
|
|
- useLegacyFormat ? inputs : null,
|
|
|
+ inputs,
|
|
|
new SemanticTextField.InferenceResult(
|
|
|
inferenceFieldMetadata.getInferenceId(),
|
|
|
model != null ? new MinimalServiceSettings(model) : null,
|
|
@@ -455,149 +609,52 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
|
|
|
),
|
|
|
indexRequest.getContentType()
|
|
|
);
|
|
|
-
|
|
|
- if (useLegacyFormat) {
|
|
|
- SemanticTextUtils.insertValue(fieldName, newDocMap, result);
|
|
|
- } else {
|
|
|
- inferenceFieldsMap.put(fieldName, result);
|
|
|
- }
|
|
|
- }
|
|
|
- if (useLegacyFormat == false) {
|
|
|
- newDocMap.put(InferenceMetadataFieldsMapper.NAME, inferenceFieldsMap);
|
|
|
+ inferenceFieldsMap.put(fieldName, result);
|
|
|
}
|
|
|
- indexRequest.source(newDocMap, indexRequest.getContentType());
|
|
|
- }
|
|
|
|
|
|
- /**
|
|
|
- * Register a {@link FieldInferenceRequest} for every non-empty field referencing an inference ID in the index.
|
|
|
- * If results are already populated for fields in the original index request, the inference request for this specific
|
|
|
- * field is skipped, and the existing results remain unchanged.
|
|
|
- * Validation of inference ID and model settings occurs in the {@link SemanticTextFieldMapper} during field indexing,
|
|
|
- * where an error will be thrown if they mismatch or if the content is malformed.
|
|
|
- * <p>
|
|
|
- * TODO: We should validate the settings for pre-existing results here and apply the inference only if they differ?
|
|
|
- */
|
|
|
- private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(BulkShardRequest bulkShardRequest) {
|
|
|
- Map<String, List<FieldInferenceRequest>> fieldRequestsMap = new LinkedHashMap<>();
|
|
|
- for (int itemIndex = 0; itemIndex < bulkShardRequest.items().length; itemIndex++) {
|
|
|
- var item = bulkShardRequest.items()[itemIndex];
|
|
|
- if (item.getPrimaryResponse() != null) {
|
|
|
- // item was already aborted/processed by a filter in the chain upstream (e.g. security)
|
|
|
- continue;
|
|
|
+ if (useLegacyFormat) {
|
|
|
+ var newDocMap = indexRequest.sourceAsMap();
|
|
|
+ for (var entry : inferenceFieldsMap.entrySet()) {
|
|
|
+ SemanticTextUtils.insertValue(entry.getKey(), newDocMap, entry.getValue());
|
|
|
}
|
|
|
- boolean isUpdateRequest = false;
|
|
|
- final IndexRequest indexRequest;
|
|
|
- if (item.request() instanceof IndexRequest ir) {
|
|
|
- indexRequest = ir;
|
|
|
- } else if (item.request() instanceof UpdateRequest updateRequest) {
|
|
|
- isUpdateRequest = true;
|
|
|
- if (updateRequest.script() != null) {
|
|
|
- addInferenceResponseFailure(
|
|
|
- itemIndex,
|
|
|
- new ElasticsearchStatusException(
|
|
|
- "Cannot apply update with a script on indices that contain [{}] field(s)",
|
|
|
- RestStatus.BAD_REQUEST,
|
|
|
- SemanticTextFieldMapper.CONTENT_TYPE
|
|
|
- )
|
|
|
- );
|
|
|
- continue;
|
|
|
- }
|
|
|
- indexRequest = updateRequest.doc();
|
|
|
- } else {
|
|
|
- // ignore delete request
|
|
|
- continue;
|
|
|
+ indexRequest.source(newDocMap, indexRequest.getContentType());
|
|
|
+ } else {
|
|
|
+ try (XContentBuilder builder = XContentBuilder.builder(indexRequest.getContentType().xContent())) {
|
|
|
+ appendSourceAndInferenceMetadata(builder, indexRequest.source(), indexRequest.getContentType(), inferenceFieldsMap);
|
|
|
+ indexRequest.source(builder);
|
|
|
}
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
|
|
|
- final Map<String, Object> docMap = indexRequest.sourceAsMap();
|
|
|
- for (var entry : fieldInferenceMap.values()) {
|
|
|
- String field = entry.getName();
|
|
|
- String inferenceId = entry.getInferenceId();
|
|
|
-
|
|
|
- if (useLegacyFormat) {
|
|
|
- var originalFieldValue = XContentMapValues.extractValue(field, docMap);
|
|
|
- if (originalFieldValue instanceof Map || (originalFieldValue == null && entry.getSourceFields().length == 1)) {
|
|
|
- // Inference has already been computed, or there is no inference required.
|
|
|
- continue;
|
|
|
- }
|
|
|
- } else {
|
|
|
- var inferenceMetadataFieldsValue = XContentMapValues.extractValue(
|
|
|
- InferenceMetadataFieldsMapper.NAME + "." + field,
|
|
|
- docMap,
|
|
|
- EXPLICIT_NULL
|
|
|
- );
|
|
|
- if (inferenceMetadataFieldsValue != null) {
|
|
|
- // Inference has already been computed
|
|
|
- continue;
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- int order = 0;
|
|
|
- for (var sourceField : entry.getSourceFields()) {
|
|
|
- var valueObj = XContentMapValues.extractValue(sourceField, docMap, EXPLICIT_NULL);
|
|
|
- if (useLegacyFormat == false && isUpdateRequest && valueObj == EXPLICIT_NULL) {
|
|
|
- /**
|
|
|
- * It's an update request, and the source field is explicitly set to null,
|
|
|
- * so we need to propagate this information to the inference fields metadata
|
|
|
- * to overwrite any inference previously computed on the field.
|
|
|
- * This ensures that the field is treated as intentionally cleared,
|
|
|
- * preventing any unintended carryover of prior inference results.
|
|
|
- */
|
|
|
- var slot = ensureResponseAccumulatorSlot(itemIndex);
|
|
|
- slot.addOrUpdateResponse(
|
|
|
- new FieldInferenceResponse(field, sourceField, null, order++, 0, null, EMPTY_CHUNKED_INFERENCE)
|
|
|
- );
|
|
|
- continue;
|
|
|
- }
|
|
|
- if (valueObj == null || valueObj == EXPLICIT_NULL) {
|
|
|
- if (isUpdateRequest && useLegacyFormat) {
|
|
|
- addInferenceResponseFailure(
|
|
|
- itemIndex,
|
|
|
- new ElasticsearchStatusException(
|
|
|
- "Field [{}] must be specified on an update request to calculate inference for field [{}]",
|
|
|
- RestStatus.BAD_REQUEST,
|
|
|
- sourceField,
|
|
|
- field
|
|
|
- )
|
|
|
- );
|
|
|
- break;
|
|
|
- }
|
|
|
- continue;
|
|
|
- }
|
|
|
- var slot = ensureResponseAccumulatorSlot(itemIndex);
|
|
|
- final List<String> values;
|
|
|
- try {
|
|
|
- values = SemanticTextUtils.nodeStringValues(field, valueObj);
|
|
|
- } catch (Exception exc) {
|
|
|
- addInferenceResponseFailure(itemIndex, exc);
|
|
|
- break;
|
|
|
- }
|
|
|
-
|
|
|
- if (INFERENCE_API_FEATURE.check(licenseState) == false) {
|
|
|
- addInferenceResponseFailure(itemIndex, LicenseUtils.newComplianceException(XPackField.INFERENCE));
|
|
|
- break;
|
|
|
- }
|
|
|
-
|
|
|
- List<FieldInferenceRequest> fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>());
|
|
|
- int offsetAdjustment = 0;
|
|
|
- for (String v : values) {
|
|
|
- if (v.isBlank()) {
|
|
|
- slot.addOrUpdateResponse(
|
|
|
- new FieldInferenceResponse(field, sourceField, v, order++, 0, null, EMPTY_CHUNKED_INFERENCE)
|
|
|
- );
|
|
|
- } else {
|
|
|
- fieldRequests.add(new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment));
|
|
|
- }
|
|
|
-
|
|
|
- // When using the inference metadata fields format, all the input values are concatenated so that the
|
|
|
- // chunk text offsets are expressed in the context of a single string. Calculate the offset adjustment
|
|
|
- // to apply to account for this.
|
|
|
- offsetAdjustment += v.length() + 1; // Add one for separator char length
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
+ /**
|
|
|
+ * Appends the original source and the new inference metadata field directly to the provided
|
|
|
+ * {@link XContentBuilder}, avoiding the need to materialize the original source as a {@link Map}.
|
|
|
+ */
|
|
|
+ private static void appendSourceAndInferenceMetadata(
|
|
|
+ XContentBuilder builder,
|
|
|
+ BytesReference source,
|
|
|
+ XContentType xContentType,
|
|
|
+ Map<String, Object> inferenceFieldsMap
|
|
|
+ ) throws IOException {
|
|
|
+ builder.startObject();
|
|
|
+
|
|
|
+ // append the original source
|
|
|
+ try (XContentParser parser = XContentHelper.createParserNotCompressed(XContentParserConfiguration.EMPTY, source, xContentType)) {
|
|
|
+ // skip start object
|
|
|
+ parser.nextToken();
|
|
|
+ while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
|
|
|
+ builder.copyCurrentStructure(parser);
|
|
|
}
|
|
|
- return fieldRequestsMap;
|
|
|
}
|
|
|
+
|
|
|
+ // add the inference metadata field
|
|
|
+ builder.field(InferenceMetadataFieldsMapper.NAME);
|
|
|
+ try (XContentParser parser = XContentHelper.mapToXContentParser(XContentParserConfiguration.EMPTY, inferenceFieldsMap)) {
|
|
|
+ builder.copyCurrentStructure(parser);
|
|
|
+ }
|
|
|
+
|
|
|
+ builder.endObject();
|
|
|
}
|
|
|
|
|
|
static IndexRequest getIndexRequestOrNull(DocWriteRequest<?> docWriteRequest) {
|