Browse Source

Fix incorrect accounting of semantic text indexing memory pressure (#130221)

Mike Pellegrini 3 months ago
parent
commit
52495aa5fc

+ 5 - 0
docs/changelog/130221.yaml

@@ -0,0 +1,5 @@
+pr: 130221
+summary: Fix incorrect accounting of semantic text indexing memory pressure
+area: Distributed
+type: bug
+issues: []

+ 4 - 1
server/src/main/java/org/elasticsearch/common/bytes/BytesReference.java

@@ -159,7 +159,10 @@ public interface BytesReference extends Comparable<BytesReference>, ToXContentFr
     BytesReference slice(int from, int length);
 
     /**
-     * The amount of memory used by this BytesReference
+     * The amount of memory used by this BytesReference.
+     * <p>
+     * Note that this is not always the same as length and can vary by implementation.
+     * </p>
      */
     long ramBytesUsed();
 

+ 3 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java

@@ -631,7 +631,7 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
             if (indexRequest.isIndexingPressureIncremented() == false) {
                 try {
                     // Track operation count as one operation per document source update
-                    coordinatingIndexingPressure.increment(1, indexRequest.getIndexRequest().source().ramBytesUsed());
+                    coordinatingIndexingPressure.increment(1, indexRequest.getIndexRequest().source().length());
                     indexRequest.setIndexingPressureIncremented();
                 } catch (EsRejectedExecutionException e) {
                     addInferenceResponseFailure(
@@ -737,13 +737,13 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
                     indexRequest.source(builder);
                 }
             }
-            long modifiedSourceSize = indexRequest.source().ramBytesUsed();
+            long modifiedSourceSize = indexRequest.source().length();
 
             // Add the indexing pressure from the source modifications.
             // Don't increment operation count because we count one source update as one operation, and we already accounted for those
             // in addFieldInferenceRequests.
             try {
-                coordinatingIndexingPressure.increment(0, modifiedSourceSize - originalSource.ramBytesUsed());
+                coordinatingIndexingPressure.increment(0, modifiedSourceSize - originalSource.length());
             } catch (EsRejectedExecutionException e) {
                 indexRequest.source(originalSource, indexRequest.getContentType());
                 item.abort(

+ 16 - 16
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java

@@ -616,14 +616,14 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
 
                 IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
                 assertThat(coordinatingIndexingPressure, notNullValue());
-                verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc0Source));
-                verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1Source));
-                verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc2Source));
-                verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc3Source));
-                verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc4Source));
-                verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc0UpdateSource));
+                verify(coordinatingIndexingPressure).increment(1, length(doc0Source));
+                verify(coordinatingIndexingPressure).increment(1, length(doc1Source));
+                verify(coordinatingIndexingPressure).increment(1, length(doc2Source));
+                verify(coordinatingIndexingPressure).increment(1, length(doc3Source));
+                verify(coordinatingIndexingPressure).increment(1, length(doc4Source));
+                verify(coordinatingIndexingPressure).increment(1, length(doc0UpdateSource));
                 if (useLegacyFormat == false) {
-                    verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1UpdateSource));
+                    verify(coordinatingIndexingPressure).increment(1, length(doc1UpdateSource));
                 }
 
                 verify(coordinatingIndexingPressure, times(useLegacyFormat ? 6 : 7)).increment(eq(0), longThat(l -> l > 0));
@@ -720,7 +720,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
 
                 IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
                 assertThat(coordinatingIndexingPressure, notNullValue());
-                verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1Source));
+                verify(coordinatingIndexingPressure).increment(1, length(doc1Source));
                 verify(coordinatingIndexingPressure, times(1)).increment(anyInt(), anyLong());
 
                 // Verify that the coordinating indexing pressure is maintained through downstream action filters
@@ -759,7 +759,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
     public void testIndexingPressureTripsOnInferenceResponseHandling() throws Exception {
         final XContentBuilder doc1Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "bar");
         final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure(
-            Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), (bytesUsed(doc1Source) + 1) + "b").build()
+            Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), (length(doc1Source) + 1) + "b").build()
         );
 
         final InferenceStats inferenceStats = new InferenceStats(mock(), mock());
@@ -802,7 +802,7 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
 
                 IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
                 assertThat(coordinatingIndexingPressure, notNullValue());
-                verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1Source));
+                verify(coordinatingIndexingPressure).increment(1, length(doc1Source));
                 verify(coordinatingIndexingPressure).increment(eq(0), longThat(l -> l > 0));
                 verify(coordinatingIndexingPressure, times(2)).increment(anyInt(), anyLong());
 
@@ -862,14 +862,14 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
             );
             XContentBuilder builder = XContentFactory.jsonBuilder();
             semanticTextField.toXContent(builder, EMPTY_PARAMS);
-            return bytesUsed(builder);
+            return length(builder);
         };
 
         final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure(
             Settings.builder()
                 .put(
                     MAX_COORDINATING_BYTES.getKey(),
-                    (bytesUsed(doc1Source) + bytesUsed(doc2Source) + estimateInferenceResultsBytes.apply(List.of("bar"), barEmbedding)
+                    (length(doc1Source) + length(doc2Source) + estimateInferenceResultsBytes.apply(List.of("bar"), barEmbedding)
                         + (estimateInferenceResultsBytes.apply(List.of("bazzz"), bazzzEmbedding) / 2)) + "b"
                 )
                 .build()
@@ -913,8 +913,8 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
 
                 IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
                 assertThat(coordinatingIndexingPressure, notNullValue());
-                verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1Source));
-                verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc2Source));
+                verify(coordinatingIndexingPressure).increment(1, length(doc1Source));
+                verify(coordinatingIndexingPressure).increment(1, length(doc2Source));
                 verify(coordinatingIndexingPressure, times(2)).increment(eq(0), longThat(l -> l > 0));
                 verify(coordinatingIndexingPressure, times(4)).increment(anyInt(), anyLong());
 
@@ -1124,8 +1124,8 @@ public class ShardBulkInferenceActionFilterTests extends ESTestCase {
             new BulkItemRequest(requestId, new IndexRequest("index").source(expectedDocMap, requestContentType)) };
     }
 
-    private static long bytesUsed(XContentBuilder builder) {
-        return BytesReference.bytes(builder).ramBytesUsed();
+    private static long length(XContentBuilder builder) {
+        return BytesReference.bytes(builder).length();
     }
 
     @SuppressWarnings({ "unchecked" })