|
@@ -15,6 +15,7 @@ import org.elasticsearch.action.ActionListener;
|
|
|
import org.elasticsearch.action.bulk.BulkItemRequest;
|
|
|
import org.elasticsearch.action.bulk.BulkItemResponse;
|
|
|
import org.elasticsearch.action.bulk.BulkShardRequest;
|
|
|
+import org.elasticsearch.action.bulk.BulkShardResponse;
|
|
|
import org.elasticsearch.action.bulk.TransportShardBulkAction;
|
|
|
import org.elasticsearch.action.index.IndexRequest;
|
|
|
import org.elasticsearch.action.support.ActionFilterChain;
|
|
@@ -26,19 +27,24 @@ import org.elasticsearch.cluster.metadata.IndexMetadata;
|
|
|
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
|
|
|
import org.elasticsearch.cluster.metadata.Metadata;
|
|
|
import org.elasticsearch.cluster.service.ClusterService;
|
|
|
+import org.elasticsearch.common.CheckedBiFunction;
|
|
|
import org.elasticsearch.common.Strings;
|
|
|
+import org.elasticsearch.common.bytes.BytesReference;
|
|
|
import org.elasticsearch.common.settings.ClusterSettings;
|
|
|
import org.elasticsearch.common.settings.Settings;
|
|
|
import org.elasticsearch.common.unit.ByteSizeValue;
|
|
|
+import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
|
|
|
import org.elasticsearch.common.xcontent.XContentHelper;
|
|
|
import org.elasticsearch.common.xcontent.support.XContentMapValues;
|
|
|
import org.elasticsearch.index.IndexVersion;
|
|
|
+import org.elasticsearch.index.IndexingPressure;
|
|
|
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
|
|
|
import org.elasticsearch.index.shard.ShardId;
|
|
|
import org.elasticsearch.inference.ChunkInferenceInput;
|
|
|
import org.elasticsearch.inference.ChunkedInference;
|
|
|
import org.elasticsearch.inference.InferenceService;
|
|
|
import org.elasticsearch.inference.InferenceServiceRegistry;
|
|
|
+import org.elasticsearch.inference.MinimalServiceSettings;
|
|
|
import org.elasticsearch.inference.Model;
|
|
|
import org.elasticsearch.inference.TaskType;
|
|
|
import org.elasticsearch.inference.UnparsedModel;
|
|
@@ -48,6 +54,8 @@ import org.elasticsearch.tasks.Task;
|
|
|
import org.elasticsearch.test.ESTestCase;
|
|
|
import org.elasticsearch.threadpool.TestThreadPool;
|
|
|
import org.elasticsearch.threadpool.ThreadPool;
|
|
|
+import org.elasticsearch.xcontent.XContentBuilder;
|
|
|
+import org.elasticsearch.xcontent.XContentFactory;
|
|
|
import org.elasticsearch.xcontent.XContentType;
|
|
|
import org.elasticsearch.xcontent.json.JsonXContent;
|
|
|
import org.elasticsearch.xpack.core.XPackField;
|
|
@@ -72,13 +80,16 @@ import java.util.Set;
|
|
|
import java.util.concurrent.CountDownLatch;
|
|
|
import java.util.concurrent.TimeUnit;
|
|
|
|
|
|
+import static org.elasticsearch.index.IndexingPressure.MAX_COORDINATING_BYTES;
|
|
|
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent;
|
|
|
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch;
|
|
|
+import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS;
|
|
|
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE;
|
|
|
import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.getIndexRequestOrNull;
|
|
|
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getChunksFieldName;
|
|
|
import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getOriginalTextFieldName;
|
|
|
-import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomChunkedInferenceEmbeddingSparse;
|
|
|
+import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapperTests.addSemanticTextInferenceResults;
|
|
|
+import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomChunkedInferenceEmbedding;
|
|
|
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText;
|
|
|
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput;
|
|
|
import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.semanticTextFieldFromChunkedInferenceResults;
|
|
@@ -87,13 +98,23 @@ import static org.hamcrest.Matchers.containsString;
|
|
|
import static org.hamcrest.Matchers.equalTo;
|
|
|
import static org.hamcrest.Matchers.instanceOf;
|
|
|
import static org.hamcrest.Matchers.is;
|
|
|
+import static org.hamcrest.Matchers.notNullValue;
|
|
|
+import static org.mockito.ArgumentMatchers.anyInt;
|
|
|
+import static org.mockito.ArgumentMatchers.anyLong;
|
|
|
+import static org.mockito.ArgumentMatchers.eq;
|
|
|
+import static org.mockito.ArgumentMatchers.longThat;
|
|
|
import static org.mockito.Mockito.any;
|
|
|
import static org.mockito.Mockito.doAnswer;
|
|
|
import static org.mockito.Mockito.mock;
|
|
|
+import static org.mockito.Mockito.never;
|
|
|
+import static org.mockito.Mockito.spy;
|
|
|
+import static org.mockito.Mockito.times;
|
|
|
+import static org.mockito.Mockito.verify;
|
|
|
import static org.mockito.Mockito.when;
|
|
|
|
|
|
public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|
|
private static final Object EXPLICIT_NULL = new Object();
|
|
|
+ private static final IndexingPressure NOOP_INDEXING_PRESSURE = new NoopIndexingPressure();
|
|
|
|
|
|
private final boolean useLegacyFormat;
|
|
|
private ThreadPool threadPool;
|
|
@@ -119,7 +140,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|
|
|
|
|
@SuppressWarnings({ "unchecked", "rawtypes" })
|
|
|
public void testFilterNoop() throws Exception {
|
|
|
- ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), useLegacyFormat, true);
|
|
|
+ ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), NOOP_INDEXING_PRESSURE, useLegacyFormat, true);
|
|
|
CountDownLatch chainExecuted = new CountDownLatch(1);
|
|
|
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
|
|
|
try {
|
|
@@ -145,7 +166,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|
|
@SuppressWarnings({ "unchecked", "rawtypes" })
|
|
|
public void testLicenseInvalidForInference() throws InterruptedException {
|
|
|
StaticModel model = StaticModel.createRandomInstance();
|
|
|
- ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), useLegacyFormat, false);
|
|
|
+ ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), NOOP_INDEXING_PRESSURE, useLegacyFormat, false);
|
|
|
CountDownLatch chainExecuted = new CountDownLatch(1);
|
|
|
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
|
|
|
try {
|
|
@@ -186,6 +207,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|
|
ShardBulkInferenceActionFilter filter = createFilter(
|
|
|
threadPool,
|
|
|
Map.of(model.getInferenceEntityId(), model),
|
|
|
+ NOOP_INDEXING_PRESSURE,
|
|
|
useLegacyFormat,
|
|
|
true
|
|
|
);
|
|
@@ -227,16 +249,17 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|
|
|
|
|
@SuppressWarnings({ "unchecked", "rawtypes" })
|
|
|
public void testItemFailures() throws Exception {
|
|
|
- StaticModel model = StaticModel.createRandomInstance();
|
|
|
+ StaticModel model = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
|
|
|
|
|
|
ShardBulkInferenceActionFilter filter = createFilter(
|
|
|
threadPool,
|
|
|
Map.of(model.getInferenceEntityId(), model),
|
|
|
+ NOOP_INDEXING_PRESSURE,
|
|
|
useLegacyFormat,
|
|
|
true
|
|
|
);
|
|
|
model.putResult("I am a failure", new ChunkedInferenceError(new IllegalArgumentException("boom")));
|
|
|
- model.putResult("I am a success", randomChunkedInferenceEmbeddingSparse(List.of("I am a success")));
|
|
|
+ model.putResult("I am a success", randomChunkedInferenceEmbedding(model, List.of("I am a success")));
|
|
|
CountDownLatch chainExecuted = new CountDownLatch(1);
|
|
|
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
|
|
|
try {
|
|
@@ -295,13 +318,14 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|
|
|
|
|
@SuppressWarnings({ "unchecked", "rawtypes" })
|
|
|
public void testExplicitNull() throws Exception {
|
|
|
- StaticModel model = StaticModel.createRandomInstance();
|
|
|
+ StaticModel model = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
|
|
|
model.putResult("I am a failure", new ChunkedInferenceError(new IllegalArgumentException("boom")));
|
|
|
- model.putResult("I am a success", randomChunkedInferenceEmbeddingSparse(List.of("I am a success")));
|
|
|
+ model.putResult("I am a success", randomChunkedInferenceEmbedding(model, List.of("I am a success")));
|
|
|
|
|
|
ShardBulkInferenceActionFilter filter = createFilter(
|
|
|
threadPool,
|
|
|
Map.of(model.getInferenceEntityId(), model),
|
|
|
+ NOOP_INDEXING_PRESSURE,
|
|
|
useLegacyFormat,
|
|
|
true
|
|
|
);
|
|
@@ -372,6 +396,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|
|
ShardBulkInferenceActionFilter filter = createFilter(
|
|
|
threadPool,
|
|
|
Map.of(model.getInferenceEntityId(), model),
|
|
|
+ NOOP_INDEXING_PRESSURE,
|
|
|
useLegacyFormat,
|
|
|
true
|
|
|
);
|
|
@@ -444,7 +469,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|
|
modifiedRequests[id] = res[1];
|
|
|
}
|
|
|
|
|
|
- ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap, useLegacyFormat, true);
|
|
|
+ ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap, NOOP_INDEXING_PRESSURE, useLegacyFormat, true);
|
|
|
CountDownLatch chainExecuted = new CountDownLatch(1);
|
|
|
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
|
|
|
try {
|
|
@@ -474,10 +499,397 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|
|
awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
|
|
|
}
|
|
|
|
|
|
+ @SuppressWarnings({ "unchecked", "rawtypes" })
|
|
|
+ public void testIndexingPressure() throws Exception {
|
|
|
+ final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure(Settings.EMPTY);
|
|
|
+ final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
|
|
|
+ final StaticModel denseModel = StaticModel.createRandomInstance(TaskType.TEXT_EMBEDDING);
|
|
|
+ final ShardBulkInferenceActionFilter filter = createFilter(
|
|
|
+ threadPool,
|
|
|
+ Map.of(sparseModel.getInferenceEntityId(), sparseModel, denseModel.getInferenceEntityId(), denseModel),
|
|
|
+ indexingPressure,
|
|
|
+ useLegacyFormat,
|
|
|
+ true
|
|
|
+ );
|
|
|
+
|
|
|
+ XContentBuilder doc0Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "a test value");
|
|
|
+ XContentBuilder doc1Source = IndexRequest.getXContentBuilder(XContentType.JSON, "dense_field", "another test value");
|
|
|
+ XContentBuilder doc2Source = IndexRequest.getXContentBuilder(
|
|
|
+ XContentType.JSON,
|
|
|
+ "sparse_field",
|
|
|
+ "a test value",
|
|
|
+ "dense_field",
|
|
|
+ "another test value"
|
|
|
+ );
|
|
|
+ XContentBuilder doc3Source = IndexRequest.getXContentBuilder(
|
|
|
+ XContentType.JSON,
|
|
|
+ "dense_field",
|
|
|
+ List.of("value one", " ", "value two")
|
|
|
+ );
|
|
|
+ XContentBuilder doc4Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", " ");
|
|
|
+ XContentBuilder doc5Source = XContentFactory.contentBuilder(XContentType.JSON);
|
|
|
+ {
|
|
|
+ doc5Source.startObject();
|
|
|
+ if (useLegacyFormat == false) {
|
|
|
+ doc5Source.field("sparse_field", "a test value");
|
|
|
+ }
|
|
|
+ addSemanticTextInferenceResults(
|
|
|
+ useLegacyFormat,
|
|
|
+ doc5Source,
|
|
|
+ List.of(randomSemanticText(useLegacyFormat, "sparse_field", sparseModel, null, List.of("a test value"), XContentType.JSON))
|
|
|
+ );
|
|
|
+ doc5Source.endObject();
|
|
|
+ }
|
|
|
+ XContentBuilder doc0UpdateSource = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "an updated value");
|
|
|
+ XContentBuilder doc1UpdateSource = IndexRequest.getXContentBuilder(XContentType.JSON, "dense_field", null);
|
|
|
+
|
|
|
+ CountDownLatch chainExecuted = new CountDownLatch(1);
|
|
|
+ ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
|
|
|
+ try {
|
|
|
+ BulkShardRequest bulkShardRequest = (BulkShardRequest) request;
|
|
|
+ assertNull(bulkShardRequest.getInferenceFieldMap());
|
|
|
+ assertThat(bulkShardRequest.items().length, equalTo(10));
|
|
|
+
|
|
|
+ for (BulkItemRequest item : bulkShardRequest.items()) {
|
|
|
+ assertNull(item.getPrimaryResponse());
|
|
|
+ }
|
|
|
+
|
|
|
+ IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
|
|
|
+ assertThat(coordinatingIndexingPressure, notNullValue());
|
|
|
+ verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc0Source));
|
|
|
+ verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1Source));
|
|
|
+ verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc2Source));
|
|
|
+ verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc3Source));
|
|
|
+ verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc4Source));
|
|
|
+ verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc0UpdateSource));
|
|
|
+ if (useLegacyFormat == false) {
|
|
|
+ verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1UpdateSource));
|
|
|
+ }
|
|
|
+
|
|
|
+ verify(coordinatingIndexingPressure, times(useLegacyFormat ? 6 : 7)).increment(eq(0), longThat(l -> l > 0));
|
|
|
+
|
|
|
+ // Verify that the only times that increment is called are the times verified above
|
|
|
+ verify(coordinatingIndexingPressure, times(useLegacyFormat ? 12 : 14)).increment(anyInt(), anyLong());
|
|
|
+
|
|
|
+ // Verify that the coordinating indexing pressure is maintained through downstream action filters
|
|
|
+ verify(coordinatingIndexingPressure, never()).close();
|
|
|
+
|
|
|
+ // Call the listener once the request is successfully processed, like is done in the production code path
|
|
|
+ listener.onResponse(null);
|
|
|
+ } finally {
|
|
|
+ chainExecuted.countDown();
|
|
|
+ }
|
|
|
+ };
|
|
|
+ ActionListener actionListener = mock(ActionListener.class);
|
|
|
+ Task task = mock(Task.class);
|
|
|
+
|
|
|
+ Map<String, InferenceFieldMetadata> inferenceFieldMap = Map.of(
|
|
|
+ "sparse_field",
|
|
|
+ new InferenceFieldMetadata("sparse_field", sparseModel.getInferenceEntityId(), new String[] { "sparse_field" }, null),
|
|
|
+ "dense_field",
|
|
|
+ new InferenceFieldMetadata("dense_field", denseModel.getInferenceEntityId(), new String[] { "dense_field" }, null)
|
|
|
+ );
|
|
|
+
|
|
|
+ BulkItemRequest[] items = new BulkItemRequest[10];
|
|
|
+ items[0] = new BulkItemRequest(0, new IndexRequest("index").id("doc_0").source(doc0Source));
|
|
|
+ items[1] = new BulkItemRequest(1, new IndexRequest("index").id("doc_1").source(doc1Source));
|
|
|
+ items[2] = new BulkItemRequest(2, new IndexRequest("index").id("doc_2").source(doc2Source));
|
|
|
+ items[3] = new BulkItemRequest(3, new IndexRequest("index").id("doc_3").source(doc3Source));
|
|
|
+ items[4] = new BulkItemRequest(4, new IndexRequest("index").id("doc_4").source(doc4Source));
|
|
|
+ items[5] = new BulkItemRequest(5, new IndexRequest("index").id("doc_5").source(doc5Source));
|
|
|
+ items[6] = new BulkItemRequest(6, new IndexRequest("index").id("doc_6").source("non_inference_field", "yet another test value"));
|
|
|
+ items[7] = new BulkItemRequest(7, new UpdateRequest().doc(new IndexRequest("index").id("doc_0").source(doc0UpdateSource)));
|
|
|
+ items[8] = new BulkItemRequest(8, new UpdateRequest().doc(new IndexRequest("index").id("doc_1").source(doc1UpdateSource)));
|
|
|
+ items[9] = new BulkItemRequest(
|
|
|
+ 9,
|
|
|
+ new UpdateRequest().doc(new IndexRequest("index").id("doc_3").source("non_inference_field", "yet another updated value"))
|
|
|
+ );
|
|
|
+
|
|
|
+ BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items);
|
|
|
+ request.setInferenceFieldMap(inferenceFieldMap);
|
|
|
+ filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain);
|
|
|
+ awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
|
|
|
+
|
|
|
+ IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
|
|
|
+ assertThat(coordinatingIndexingPressure, notNullValue());
|
|
|
+ verify(coordinatingIndexingPressure).close();
|
|
|
+ }
|
|
|
+
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
+ public void testIndexingPressureTripsOnInferenceRequestGeneration() throws Exception {
|
|
|
+ final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure(
|
|
|
+ Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), "1b").build()
|
|
|
+ );
|
|
|
+ final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
|
|
|
+ final ShardBulkInferenceActionFilter filter = createFilter(
|
|
|
+ threadPool,
|
|
|
+ Map.of(sparseModel.getInferenceEntityId(), sparseModel),
|
|
|
+ indexingPressure,
|
|
|
+ useLegacyFormat,
|
|
|
+ true
|
|
|
+ );
|
|
|
+
|
|
|
+ XContentBuilder doc1Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "bar");
|
|
|
+
|
|
|
+ CountDownLatch chainExecuted = new CountDownLatch(1);
|
|
|
+ ActionFilterChain<BulkShardRequest, BulkShardResponse> actionFilterChain = (task, action, request, listener) -> {
|
|
|
+ try {
|
|
|
+ assertNull(request.getInferenceFieldMap());
|
|
|
+ assertThat(request.items().length, equalTo(3));
|
|
|
+
|
|
|
+ assertNull(request.items()[0].getPrimaryResponse());
|
|
|
+ assertNull(request.items()[2].getPrimaryResponse());
|
|
|
+
|
|
|
+ BulkItemRequest doc1Request = request.items()[1];
|
|
|
+ BulkItemResponse doc1Response = doc1Request.getPrimaryResponse();
|
|
|
+ assertNotNull(doc1Response);
|
|
|
+ assertTrue(doc1Response.isFailed());
|
|
|
+ BulkItemResponse.Failure doc1Failure = doc1Response.getFailure();
|
|
|
+ assertThat(
|
|
|
+ doc1Failure.getCause().getMessage(),
|
|
|
+ containsString("Insufficient memory available to update source on document [doc_1]")
|
|
|
+ );
|
|
|
+ assertThat(doc1Failure.getCause().getCause(), instanceOf(EsRejectedExecutionException.class));
|
|
|
+ assertThat(doc1Failure.getStatus(), is(RestStatus.TOO_MANY_REQUESTS));
|
|
|
+
|
|
|
+ IndexRequest doc1IndexRequest = getIndexRequestOrNull(doc1Request.request());
|
|
|
+ assertThat(doc1IndexRequest, notNullValue());
|
|
|
+ assertThat(doc1IndexRequest.source(), equalTo(BytesReference.bytes(doc1Source)));
|
|
|
+
|
|
|
+ IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
|
|
|
+ assertThat(coordinatingIndexingPressure, notNullValue());
|
|
|
+ verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1Source));
|
|
|
+ verify(coordinatingIndexingPressure, times(1)).increment(anyInt(), anyLong());
|
|
|
+
|
|
|
+ // Verify that the coordinating indexing pressure is maintained through downstream action filters
|
|
|
+ verify(coordinatingIndexingPressure, never()).close();
|
|
|
+
|
|
|
+ // Call the listener once the request is successfully processed, like is done in the production code path
|
|
|
+ listener.onResponse(null);
|
|
|
+ } finally {
|
|
|
+ chainExecuted.countDown();
|
|
|
+ }
|
|
|
+ };
|
|
|
+ ActionListener<BulkShardResponse> actionListener = (ActionListener<BulkShardResponse>) mock(ActionListener.class);
|
|
|
+ Task task = mock(Task.class);
|
|
|
+
|
|
|
+ Map<String, InferenceFieldMetadata> inferenceFieldMap = Map.of(
|
|
|
+ "sparse_field",
|
|
|
+ new InferenceFieldMetadata("sparse_field", sparseModel.getInferenceEntityId(), new String[] { "sparse_field" }, null)
|
|
|
+ );
|
|
|
+
|
|
|
+ BulkItemRequest[] items = new BulkItemRequest[3];
|
|
|
+ items[0] = new BulkItemRequest(0, new IndexRequest("index").id("doc_0").source("non_inference_field", "foo"));
|
|
|
+ items[1] = new BulkItemRequest(1, new IndexRequest("index").id("doc_1").source(doc1Source));
|
|
|
+ items[2] = new BulkItemRequest(2, new IndexRequest("index").id("doc_2").source("non_inference_field", "baz"));
|
|
|
+
|
|
|
+ BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items);
|
|
|
+ request.setInferenceFieldMap(inferenceFieldMap);
|
|
|
+ filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain);
|
|
|
+ awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
|
|
|
+
|
|
|
+ IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
|
|
|
+ assertThat(coordinatingIndexingPressure, notNullValue());
|
|
|
+ verify(coordinatingIndexingPressure).close();
|
|
|
+ }
|
|
|
+
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
+ public void testIndexingPressureTripsOnInferenceResponseHandling() throws Exception {
|
|
|
+ final XContentBuilder doc1Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "bar");
|
|
|
+ final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure(
|
|
|
+ Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), (bytesUsed(doc1Source) + 1) + "b").build()
|
|
|
+ );
|
|
|
+
|
|
|
+ final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
|
|
|
+ sparseModel.putResult("bar", randomChunkedInferenceEmbedding(sparseModel, List.of("bar")));
|
|
|
+
|
|
|
+ final ShardBulkInferenceActionFilter filter = createFilter(
|
|
|
+ threadPool,
|
|
|
+ Map.of(sparseModel.getInferenceEntityId(), sparseModel),
|
|
|
+ indexingPressure,
|
|
|
+ useLegacyFormat,
|
|
|
+ true
|
|
|
+ );
|
|
|
+
|
|
|
+ CountDownLatch chainExecuted = new CountDownLatch(1);
|
|
|
+ ActionFilterChain<BulkShardRequest, BulkShardResponse> actionFilterChain = (task, action, request, listener) -> {
|
|
|
+ try {
|
|
|
+ assertNull(request.getInferenceFieldMap());
|
|
|
+ assertThat(request.items().length, equalTo(3));
|
|
|
+
|
|
|
+ assertNull(request.items()[0].getPrimaryResponse());
|
|
|
+ assertNull(request.items()[2].getPrimaryResponse());
|
|
|
+
|
|
|
+ BulkItemRequest doc1Request = request.items()[1];
|
|
|
+ BulkItemResponse doc1Response = doc1Request.getPrimaryResponse();
|
|
|
+ assertNotNull(doc1Response);
|
|
|
+ assertTrue(doc1Response.isFailed());
|
|
|
+ BulkItemResponse.Failure doc1Failure = doc1Response.getFailure();
|
|
|
+ assertThat(
|
|
|
+ doc1Failure.getCause().getMessage(),
|
|
|
+ containsString("Insufficient memory available to insert inference results into document [doc_1]")
|
|
|
+ );
|
|
|
+ assertThat(doc1Failure.getCause().getCause(), instanceOf(EsRejectedExecutionException.class));
|
|
|
+ assertThat(doc1Failure.getStatus(), is(RestStatus.TOO_MANY_REQUESTS));
|
|
|
+
|
|
|
+ IndexRequest doc1IndexRequest = getIndexRequestOrNull(doc1Request.request());
|
|
|
+ assertThat(doc1IndexRequest, notNullValue());
|
|
|
+ assertThat(doc1IndexRequest.source(), equalTo(BytesReference.bytes(doc1Source)));
|
|
|
+
|
|
|
+ IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
|
|
|
+ assertThat(coordinatingIndexingPressure, notNullValue());
|
|
|
+ verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1Source));
|
|
|
+ verify(coordinatingIndexingPressure).increment(eq(0), longThat(l -> l > 0));
|
|
|
+ verify(coordinatingIndexingPressure, times(2)).increment(anyInt(), anyLong());
|
|
|
+
|
|
|
+ // Verify that the coordinating indexing pressure is maintained through downstream action filters
|
|
|
+ verify(coordinatingIndexingPressure, never()).close();
|
|
|
+
|
|
|
+ // Call the listener once the request is successfully processed, like is done in the production code path
|
|
|
+ listener.onResponse(null);
|
|
|
+ } finally {
|
|
|
+ chainExecuted.countDown();
|
|
|
+ }
|
|
|
+ };
|
|
|
+ ActionListener<BulkShardResponse> actionListener = (ActionListener<BulkShardResponse>) mock(ActionListener.class);
|
|
|
+ Task task = mock(Task.class);
|
|
|
+
|
|
|
+ Map<String, InferenceFieldMetadata> inferenceFieldMap = Map.of(
|
|
|
+ "sparse_field",
|
|
|
+ new InferenceFieldMetadata("sparse_field", sparseModel.getInferenceEntityId(), new String[] { "sparse_field" }, null)
|
|
|
+ );
|
|
|
+
|
|
|
+ BulkItemRequest[] items = new BulkItemRequest[3];
|
|
|
+ items[0] = new BulkItemRequest(0, new IndexRequest("index").id("doc_0").source("non_inference_field", "foo"));
|
|
|
+ items[1] = new BulkItemRequest(1, new IndexRequest("index").id("doc_1").source(doc1Source));
|
|
|
+ items[2] = new BulkItemRequest(2, new IndexRequest("index").id("doc_2").source("non_inference_field", "baz"));
|
|
|
+
|
|
|
+ BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items);
|
|
|
+ request.setInferenceFieldMap(inferenceFieldMap);
|
|
|
+ filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain);
|
|
|
+ awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
|
|
|
+
|
|
|
+ IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
|
|
|
+ assertThat(coordinatingIndexingPressure, notNullValue());
|
|
|
+ verify(coordinatingIndexingPressure).close();
|
|
|
+ }
|
|
|
+
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
+ public void testIndexingPressurePartialFailure() throws Exception {
|
|
|
+ // Use different length strings so that doc 1 and doc 2 sources are different sizes
|
|
|
+ final XContentBuilder doc1Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "bar");
|
|
|
+ final XContentBuilder doc2Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "bazzz");
|
|
|
+
|
|
|
+ final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
|
|
|
+ final ChunkedInferenceEmbedding barEmbedding = randomChunkedInferenceEmbedding(sparseModel, List.of("bar"));
|
|
|
+ final ChunkedInferenceEmbedding bazzzEmbedding = randomChunkedInferenceEmbedding(sparseModel, List.of("bazzz"));
|
|
|
+ sparseModel.putResult("bar", barEmbedding);
|
|
|
+ sparseModel.putResult("bazzz", bazzzEmbedding);
|
|
|
+
|
|
|
+ CheckedBiFunction<List<String>, ChunkedInference, Long, IOException> estimateInferenceResultsBytes = (inputs, inference) -> {
|
|
|
+ SemanticTextField semanticTextField = semanticTextFieldFromChunkedInferenceResults(
|
|
|
+ useLegacyFormat,
|
|
|
+ "sparse_field",
|
|
|
+ sparseModel,
|
|
|
+ null,
|
|
|
+ inputs,
|
|
|
+ inference,
|
|
|
+ XContentType.JSON
|
|
|
+ );
|
|
|
+ XContentBuilder builder = XContentFactory.jsonBuilder();
|
|
|
+ semanticTextField.toXContent(builder, EMPTY_PARAMS);
|
|
|
+ return bytesUsed(builder);
|
|
|
+ };
|
|
|
+
|
|
|
+ final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure(
|
|
|
+ Settings.builder()
|
|
|
+ .put(
|
|
|
+ MAX_COORDINATING_BYTES.getKey(),
|
|
|
+ (bytesUsed(doc1Source) + bytesUsed(doc2Source) + estimateInferenceResultsBytes.apply(List.of("bar"), barEmbedding)
|
|
|
+ + (estimateInferenceResultsBytes.apply(List.of("bazzz"), bazzzEmbedding) / 2)) + "b"
|
|
|
+ )
|
|
|
+ .build()
|
|
|
+ );
|
|
|
+
|
|
|
+ final ShardBulkInferenceActionFilter filter = createFilter(
|
|
|
+ threadPool,
|
|
|
+ Map.of(sparseModel.getInferenceEntityId(), sparseModel),
|
|
|
+ indexingPressure,
|
|
|
+ useLegacyFormat,
|
|
|
+ true
|
|
|
+ );
|
|
|
+
|
|
|
+ CountDownLatch chainExecuted = new CountDownLatch(1);
|
|
|
+ ActionFilterChain<BulkShardRequest, BulkShardResponse> actionFilterChain = (task, action, request, listener) -> {
|
|
|
+ try {
|
|
|
+ assertNull(request.getInferenceFieldMap());
|
|
|
+ assertThat(request.items().length, equalTo(4));
|
|
|
+
|
|
|
+ assertNull(request.items()[0].getPrimaryResponse());
|
|
|
+ assertNull(request.items()[1].getPrimaryResponse());
|
|
|
+ assertNull(request.items()[3].getPrimaryResponse());
|
|
|
+
|
|
|
+ BulkItemRequest doc2Request = request.items()[2];
|
|
|
+ BulkItemResponse doc2Response = doc2Request.getPrimaryResponse();
|
|
|
+ assertNotNull(doc2Response);
|
|
|
+ assertTrue(doc2Response.isFailed());
|
|
|
+ BulkItemResponse.Failure doc2Failure = doc2Response.getFailure();
|
|
|
+ assertThat(
|
|
|
+ doc2Failure.getCause().getMessage(),
|
|
|
+ containsString("Insufficient memory available to insert inference results into document [doc_2]")
|
|
|
+ );
|
|
|
+ assertThat(doc2Failure.getCause().getCause(), instanceOf(EsRejectedExecutionException.class));
|
|
|
+ assertThat(doc2Failure.getStatus(), is(RestStatus.TOO_MANY_REQUESTS));
|
|
|
+
|
|
|
+ IndexRequest doc2IndexRequest = getIndexRequestOrNull(doc2Request.request());
|
|
|
+ assertThat(doc2IndexRequest, notNullValue());
|
|
|
+ assertThat(doc2IndexRequest.source(), equalTo(BytesReference.bytes(doc2Source)));
|
|
|
+
|
|
|
+ IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
|
|
|
+ assertThat(coordinatingIndexingPressure, notNullValue());
|
|
|
+ verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1Source));
|
|
|
+ verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc2Source));
|
|
|
+ verify(coordinatingIndexingPressure, times(2)).increment(eq(0), longThat(l -> l > 0));
|
|
|
+ verify(coordinatingIndexingPressure, times(4)).increment(anyInt(), anyLong());
|
|
|
+
|
|
|
+ // Verify that the coordinating indexing pressure is maintained through downstream action filters
|
|
|
+ verify(coordinatingIndexingPressure, never()).close();
|
|
|
+
|
|
|
+ // Call the listener once the request is successfully processed, like is done in the production code path
|
|
|
+ listener.onResponse(null);
|
|
|
+ } finally {
|
|
|
+ chainExecuted.countDown();
|
|
|
+ }
|
|
|
+ };
|
|
|
+ ActionListener<BulkShardResponse> actionListener = (ActionListener<BulkShardResponse>) mock(ActionListener.class);
|
|
|
+ Task task = mock(Task.class);
|
|
|
+
|
|
|
+ Map<String, InferenceFieldMetadata> inferenceFieldMap = Map.of(
|
|
|
+ "sparse_field",
|
|
|
+ new InferenceFieldMetadata("sparse_field", sparseModel.getInferenceEntityId(), new String[] { "sparse_field" }, null)
|
|
|
+ );
|
|
|
+
|
|
|
+ BulkItemRequest[] items = new BulkItemRequest[4];
|
|
|
+ items[0] = new BulkItemRequest(0, new IndexRequest("index").id("doc_0").source("non_inference_field", "foo"));
|
|
|
+ items[1] = new BulkItemRequest(1, new IndexRequest("index").id("doc_1").source(doc1Source));
|
|
|
+ items[2] = new BulkItemRequest(2, new IndexRequest("index").id("doc_2").source(doc2Source));
|
|
|
+ items[3] = new BulkItemRequest(3, new IndexRequest("index").id("doc_3").source("non_inference_field", "baz"));
|
|
|
+
|
|
|
+ BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items);
|
|
|
+ request.setInferenceFieldMap(inferenceFieldMap);
|
|
|
+ filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain);
|
|
|
+ awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
|
|
|
+
|
|
|
+ IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
|
|
|
+ assertThat(coordinatingIndexingPressure, notNullValue());
|
|
|
+ verify(coordinatingIndexingPressure).close();
|
|
|
+ }
|
|
|
+
|
|
|
@SuppressWarnings("unchecked")
|
|
|
private static ShardBulkInferenceActionFilter createFilter(
|
|
|
ThreadPool threadPool,
|
|
|
Map<String, StaticModel> modelMap,
|
|
|
+ IndexingPressure indexingPressure,
|
|
|
boolean useLegacyFormat,
|
|
|
boolean isLicenseValidForInference
|
|
|
) {
|
|
@@ -503,6 +915,17 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|
|
};
|
|
|
doAnswer(unparsedModelAnswer).when(modelRegistry).getModelWithSecrets(any(), any());
|
|
|
|
|
|
+ Answer<MinimalServiceSettings> minimalServiceSettingsAnswer = invocationOnMock -> {
|
|
|
+ String inferenceId = (String) invocationOnMock.getArguments()[0];
|
|
|
+ var model = modelMap.get(inferenceId);
|
|
|
+ if (model == null) {
|
|
|
+ throw new ResourceNotFoundException("model id [{}] not found", inferenceId);
|
|
|
+ }
|
|
|
+
|
|
|
+ return new MinimalServiceSettings(model);
|
|
|
+ };
|
|
|
+ doAnswer(minimalServiceSettingsAnswer).when(modelRegistry).getMinimalServiceSettings(any());
|
|
|
+
|
|
|
InferenceService inferenceService = mock(InferenceService.class);
|
|
|
Answer<?> chunkedInferAnswer = invocationOnMock -> {
|
|
|
StaticModel model = (StaticModel) invocationOnMock.getArguments()[0];
|
|
@@ -544,7 +967,8 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|
|
createClusterService(useLegacyFormat),
|
|
|
inferenceServiceRegistry,
|
|
|
modelRegistry,
|
|
|
- licenseState
|
|
|
+ licenseState,
|
|
|
+ indexingPressure
|
|
|
);
|
|
|
}
|
|
|
|
|
@@ -629,6 +1053,10 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|
|
new BulkItemRequest(requestId, new IndexRequest("index").source(expectedDocMap, requestContentType)) };
|
|
|
}
|
|
|
|
|
|
+ private static long bytesUsed(XContentBuilder builder) {
|
|
|
+ return BytesReference.bytes(builder).ramBytesUsed();
|
|
|
+ }
|
|
|
+
|
|
|
@SuppressWarnings({ "unchecked" })
|
|
|
private static void assertInferenceResults(
|
|
|
boolean useLegacyFormat,
|
|
@@ -693,7 +1121,11 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|
|
}
|
|
|
|
|
|
public static StaticModel createRandomInstance() {
|
|
|
- TestModel testModel = TestModel.createRandomInstance();
|
|
|
+ return createRandomInstance(randomFrom(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING));
|
|
|
+ }
|
|
|
+
|
|
|
+ public static StaticModel createRandomInstance(TaskType taskType) {
|
|
|
+ TestModel testModel = TestModel.createRandomInstance(taskType);
|
|
|
return new StaticModel(
|
|
|
testModel.getInferenceEntityId(),
|
|
|
testModel.getTaskType(),
|
|
@@ -716,4 +1148,42 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
|
|
|
return resultMap.containsKey(text);
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ private static class InstrumentedIndexingPressure extends IndexingPressure {
|
|
|
+ private Coordinating coordinating = null;
|
|
|
+
|
|
|
+ private InstrumentedIndexingPressure(Settings settings) {
|
|
|
+ super(settings);
|
|
|
+ }
|
|
|
+
|
|
|
+ private Coordinating getCoordinating() {
|
|
|
+ return coordinating;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public Coordinating createCoordinatingOperation(boolean forceExecution) {
|
|
|
+ coordinating = spy(super.createCoordinatingOperation(forceExecution));
|
|
|
+ return coordinating;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private static class NoopIndexingPressure extends IndexingPressure {
|
|
|
+ private NoopIndexingPressure() {
|
|
|
+ super(Settings.EMPTY);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public Coordinating createCoordinatingOperation(boolean forceExecution) {
|
|
|
+ return new NoopCoordinating(forceExecution);
|
|
|
+ }
|
|
|
+
|
|
|
+ private class NoopCoordinating extends Coordinating {
|
|
|
+ private NoopCoordinating(boolean forceExecution) {
|
|
|
+ super(forceExecution);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void increment(int operations, long bytes) {}
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|