Browse Source

[ML] Add status and increased estimate to memory usage (#58588)

Adds parsing of `status` and `increased_memory_estimate_bytes`
to data frame analytics `memory_usage`. When the training surpasses
the model memory limit, the status will be set to `hard_limit` and
`increased_memory_estimate_bytes` can be used to update the job's
limit in order to restart the job.
Dimitris Athanasiou 5 years ago
parent
commit
0994005c2e

+ 49 - 4
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/stats/common/MemoryUsage.java

@@ -26,18 +26,22 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
 import org.elasticsearch.common.xcontent.ObjectParser;
 import org.elasticsearch.common.xcontent.ToXContentObject;
 import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
 
 import java.io.IOException;
 import java.time.Instant;
+import java.util.Locale;
 import java.util.Objects;
 
 public class MemoryUsage implements ToXContentObject {
 
     static final ParseField TIMESTAMP = new ParseField("timestamp");
     static final ParseField PEAK_USAGE_BYTES = new ParseField("peak_usage_bytes");
+    static final ParseField STATUS = new ParseField("status");
+    static final ParseField INCREASED_MEMORY_ESTIMATE_BYTES = new ParseField("increased_memory_estimate_bytes");
 
     public static final ConstructingObjectParser<MemoryUsage, Void> PARSER = new ConstructingObjectParser<>("analytics_memory_usage",
-        true, a -> new MemoryUsage((Instant) a[0], (long) a[1]));
+        true, a -> new MemoryUsage((Instant) a[0], (long) a[1], (Status) a[2], (Long) a[3]));
 
     static {
         PARSER.declareField(ConstructingObjectParser.optionalConstructorArg(),
@@ -45,15 +49,26 @@ public class MemoryUsage implements ToXContentObject {
             TIMESTAMP,
             ObjectParser.ValueType.VALUE);
         PARSER.declareLong(ConstructingObjectParser.constructorArg(), PEAK_USAGE_BYTES);
+        PARSER.declareField(ConstructingObjectParser.optionalConstructorArg(), p -> {
+            if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
+                return Status.fromString(p.text());
+            }
+            throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]");
+        }, STATUS, ObjectParser.ValueType.STRING);
+        PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), INCREASED_MEMORY_ESTIMATE_BYTES);
     }
 
     @Nullable
     private final Instant timestamp;
     private final long peakUsageBytes;
+    private final Status status;
+    private final Long increasedMemoryEstimateBytes;
 
-    public MemoryUsage(@Nullable Instant timestamp, long peakUsageBytes) {
+    public MemoryUsage(@Nullable Instant timestamp, long peakUsageBytes, Status status, @Nullable Long increasedMemoryEstimateBytes) {
         this.timestamp = timestamp == null ? null : Instant.ofEpochMilli(Objects.requireNonNull(timestamp).toEpochMilli());
         this.peakUsageBytes = peakUsageBytes;
+        this.status = status;
+        this.increasedMemoryEstimateBytes = increasedMemoryEstimateBytes;
     }
 
     @Nullable
@@ -65,6 +80,14 @@ public class MemoryUsage implements ToXContentObject {
         return peakUsageBytes;
     }
 
+    public Status getStatus() {
+        return status;
+    }
+
+    public Long getIncreasedMemoryEstimateBytes() {
+        return increasedMemoryEstimateBytes;
+    }
+
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
@@ -72,6 +95,10 @@ public class MemoryUsage implements ToXContentObject {
             builder.timeField(TIMESTAMP.getPreferredName(), TIMESTAMP.getPreferredName() + "_string", timestamp.toEpochMilli());
         }
         builder.field(PEAK_USAGE_BYTES.getPreferredName(), peakUsageBytes);
+        builder.field(STATUS.getPreferredName(), status);
+        if (increasedMemoryEstimateBytes != null) {
+            builder.field(INCREASED_MEMORY_ESTIMATE_BYTES.getPreferredName(), increasedMemoryEstimateBytes);
+        }
         builder.endObject();
         return builder;
     }
@@ -83,12 +110,14 @@ public class MemoryUsage implements ToXContentObject {
 
         MemoryUsage other = (MemoryUsage) o;
         return Objects.equals(timestamp, other.timestamp)
-            && peakUsageBytes == other.peakUsageBytes;
+            && peakUsageBytes == other.peakUsageBytes
+            && Objects.equals(status, other.status)
+            && Objects.equals(increasedMemoryEstimateBytes, other.increasedMemoryEstimateBytes);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(timestamp, peakUsageBytes);
+        return Objects.hash(timestamp, peakUsageBytes, status, increasedMemoryEstimateBytes);
     }
 
     @Override
@@ -96,6 +125,22 @@ public class MemoryUsage implements ToXContentObject {
         return new ToStringBuilder(getClass())
             .add(TIMESTAMP.getPreferredName(), timestamp == null ? null : timestamp.getEpochSecond())
             .add(PEAK_USAGE_BYTES.getPreferredName(), peakUsageBytes)
+            .add(STATUS.getPreferredName(), status)
+            .add(INCREASED_MEMORY_ESTIMATE_BYTES.getPreferredName(), increasedMemoryEstimateBytes)
             .toString();
     }
+
+    public enum Status {
+        OK,
+        HARD_LIMIT;
+
+        public static Status fromString(String value) {
+            return valueOf(value.toUpperCase(Locale.ROOT));
+        }
+
+        @Override
+        public String toString() {
+            return name().toLowerCase(Locale.ROOT);
+        }
+    }
 }

+ 3 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

@@ -150,6 +150,7 @@ import org.elasticsearch.client.ml.dataframe.evaluation.softclassification.Recal
 import org.elasticsearch.client.ml.dataframe.explain.FieldSelection;
 import org.elasticsearch.client.ml.dataframe.explain.MemoryEstimation;
 import org.elasticsearch.client.ml.dataframe.stats.common.DataCounts;
+import org.elasticsearch.client.ml.dataframe.stats.common.MemoryUsage;
 import org.elasticsearch.client.ml.filestructurefinder.FileStructure;
 import org.elasticsearch.client.ml.inference.InferenceToXContentCompressor;
 import org.elasticsearch.client.ml.inference.MlInferenceNamedXContentProvider;
@@ -1537,6 +1538,8 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
         assertThat(progress.get(2), equalTo(new PhaseProgress("computing_outliers", 0)));
         assertThat(progress.get(3), equalTo(new PhaseProgress("writing_results", 0)));
         assertThat(stats.getMemoryUsage().getPeakUsageBytes(), equalTo(0L));
+        assertThat(stats.getMemoryUsage().getStatus(), equalTo(MemoryUsage.Status.OK));
+        assertThat(stats.getMemoryUsage().getIncreasedMemoryEstimateBytes(), is(nullValue()));
         assertThat(stats.getDataCounts(), equalTo(new DataCounts(0, 0, 0)));
     }
 

+ 9 - 3
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/stats/common/MemoryUsageTests.java

@@ -34,7 +34,12 @@ public class MemoryUsageTests extends AbstractXContentTestCase<MemoryUsage> {
     }
 
     public static MemoryUsage createRandom() {
-        return new MemoryUsage(randomBoolean() ? null : Instant.now(), randomNonNegativeLong());
+        return new MemoryUsage(
+            randomBoolean() ? null : Instant.now(),
+            randomNonNegativeLong(),
+            randomFrom(MemoryUsage.Status.values()),
+            randomBoolean() ? null : randomNonNegativeLong()
+        );
     }
 
     @Override
@@ -48,7 +53,8 @@ public class MemoryUsageTests extends AbstractXContentTestCase<MemoryUsage> {
     }
 
     public void testToString_GivenNullTimestamp() {
-        MemoryUsage memoryUsage = new MemoryUsage(null, 42L);
-        assertThat(memoryUsage.toString(), equalTo("MemoryUsage[timestamp=null, peak_usage_bytes=42]"));
+        MemoryUsage memoryUsage = new MemoryUsage(null, 42L, MemoryUsage.Status.OK, null);
+        assertThat(memoryUsage.toString(), equalTo(
+            "MemoryUsage[timestamp=null, peak_usage_bytes=42, status=ok, increased_memory_estimate_bytes=null]"));
     }
 }

+ 9 - 0
docs/reference/ml/df-analytics/apis/get-dfanalytics-stats.asciidoc

@@ -435,6 +435,15 @@ job is started and memory usage is reported.
 (long)
 The number of bytes used at the highest peak of memory usage.
 
+`status`::::
+(string)
+The memory usage status. May have one of the following values:
++
+--
+* `ok`: usage stayed below the limit.
+* `hard_limit`: usage surpassed the configured memory limit.
+--
+
 `timestamp`::::
 (date)
 The timestamp when memory usage was calculated.

+ 67 - 5
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/stats/common/MemoryUsage.java

@@ -5,6 +5,8 @@
  */
 package org.elasticsearch.xpack.core.ml.dataframe.stats.common;
 
+import org.elasticsearch.Version;
+import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.io.stream.StreamInput;
@@ -14,6 +16,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser;
 import org.elasticsearch.common.xcontent.ObjectParser;
 import org.elasticsearch.common.xcontent.ToXContentObject;
 import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.xpack.core.common.time.TimeUtils;
 import org.elasticsearch.xpack.core.ml.dataframe.stats.Fields;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@@ -21,6 +24,7 @@ import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
 
 import java.io.IOException;
 import java.time.Instant;
+import java.util.Locale;
 import java.util.Objects;
 
 public class MemoryUsage implements Writeable, ToXContentObject {
@@ -28,13 +32,15 @@ public class MemoryUsage implements Writeable, ToXContentObject {
     public static final String TYPE_VALUE = "analytics_memory_usage";
 
     public static final ParseField PEAK_USAGE_BYTES = new ParseField("peak_usage_bytes");
+    public static final ParseField STATUS = new ParseField("status");
+    public static final ParseField INCREASED_MEMORY_ESTIMATE_BYTES = new ParseField("increased_memory_estimate_bytes");
 
     public static final ConstructingObjectParser<MemoryUsage, Void> STRICT_PARSER = createParser(false);
     public static final ConstructingObjectParser<MemoryUsage, Void> LENIENT_PARSER = createParser(true);
 
     private static ConstructingObjectParser<MemoryUsage, Void> createParser(boolean ignoreUnknownFields) {
         ConstructingObjectParser<MemoryUsage, Void> parser = new ConstructingObjectParser<>(TYPE_VALUE,
-            ignoreUnknownFields, a -> new MemoryUsage((String) a[0], (Instant) a[1], (long) a[2]));
+            ignoreUnknownFields, a -> new MemoryUsage((String) a[0], (Instant) a[1], (long) a[2], (Status) a[3], (Long) a[4]));
 
         parser.declareString((bucket, s) -> {}, Fields.TYPE);
         parser.declareString(ConstructingObjectParser.constructorArg(), Fields.JOB_ID);
@@ -43,6 +49,13 @@ public class MemoryUsage implements Writeable, ToXContentObject {
             Fields.TIMESTAMP,
             ObjectParser.ValueType.VALUE);
         parser.declareLong(ConstructingObjectParser.constructorArg(), PEAK_USAGE_BYTES);
+        parser.declareField(ConstructingObjectParser.optionalConstructorArg(), p -> {
+            if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
+                return Status.fromString(p.text());
+            }
+            throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]");
+        }, STATUS, ObjectParser.ValueType.STRING);
+        parser.declareLong(ConstructingObjectParser.optionalConstructorArg(), INCREASED_MEMORY_ESTIMATE_BYTES);
         return parser;
     }
 
@@ -52,27 +65,43 @@ public class MemoryUsage implements Writeable, ToXContentObject {
      */
     private final Instant timestamp;
     private final long peakUsageBytes;
+    private final Status status;
+    @Nullable private final Long increasedMemoryEstimateBytes;
 
     /**
      * Creates a zero usage object
      */
     public MemoryUsage(String jobId) {
-        this(jobId, null, 0);
+        this(jobId, null, 0, null, null);
     }
 
-    public MemoryUsage(String jobId, Instant timestamp, long peakUsageBytes) {
+    public MemoryUsage(String jobId, Instant timestamp, long peakUsageBytes, @Nullable Status status,
+                       @Nullable Long increasedMemoryEstimateBytes) {
         this.jobId = Objects.requireNonNull(jobId);
         // We intend to store this timestamp in millis granularity. Thus we're rounding here to ensure
         // internal representation matches toXContent
         this.timestamp = timestamp == null ? null : Instant.ofEpochMilli(
             ExceptionsHelper.requireNonNull(timestamp, Fields.TIMESTAMP).toEpochMilli());
         this.peakUsageBytes = peakUsageBytes;
+        this.status = status == null ? Status.OK : status;
+        this.increasedMemoryEstimateBytes = increasedMemoryEstimateBytes;
     }
 
     public MemoryUsage(StreamInput in) throws IOException {
         jobId = in.readString();
         timestamp = in.readOptionalInstant();
         peakUsageBytes = in.readVLong();
+        if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
+            status = Status.readFromStream(in);
+            increasedMemoryEstimateBytes = in.readOptionalVLong();
+        } else {
+            status = Status.OK;
+            increasedMemoryEstimateBytes = null;
+        }
+    }
+
+    public Status getStatus() {
+        return status;
     }
 
     @Override
@@ -80,6 +109,10 @@ public class MemoryUsage implements Writeable, ToXContentObject {
         out.writeString(jobId);
         out.writeOptionalInstant(timestamp);
         out.writeVLong(peakUsageBytes);
+        if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
+            status.writeTo(out);
+            out.writeOptionalVLong(increasedMemoryEstimateBytes);
+        }
     }
 
     @Override
@@ -94,6 +127,10 @@ public class MemoryUsage implements Writeable, ToXContentObject {
                 timestamp.toEpochMilli());
         }
         builder.field(PEAK_USAGE_BYTES.getPreferredName(), peakUsageBytes);
+        builder.field(STATUS.getPreferredName(), status);
+        if (increasedMemoryEstimateBytes != null) {
+            builder.field(INCREASED_MEMORY_ESTIMATE_BYTES.getPreferredName(), increasedMemoryEstimateBytes);
+        }
         builder.endObject();
         return builder;
     }
@@ -106,12 +143,14 @@ public class MemoryUsage implements Writeable, ToXContentObject {
         MemoryUsage other = (MemoryUsage) o;
         return Objects.equals(jobId, other.jobId)
             && Objects.equals(timestamp, other.timestamp)
-            && peakUsageBytes == other.peakUsageBytes;
+            && peakUsageBytes == other.peakUsageBytes
+            && Objects.equals(status, other.status)
+            && Objects.equals(increasedMemoryEstimateBytes, other.increasedMemoryEstimateBytes);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(jobId, timestamp, peakUsageBytes);
+        return Objects.hash(jobId, timestamp, peakUsageBytes, status, increasedMemoryEstimateBytes);
     }
 
     @Override
@@ -127,4 +166,27 @@ public class MemoryUsage implements Writeable, ToXContentObject {
     public static String documentIdPrefix(String jobId) {
         return TYPE_VALUE + "_" + jobId + "_";
     }
+
+    public enum Status implements Writeable  {
+        OK,
+        HARD_LIMIT;
+
+        public static Status fromString(String value) {
+            return valueOf(value.toUpperCase(Locale.ROOT));
+        }
+
+        public static Status readFromStream(StreamInput in) throws IOException {
+            return in.readEnum(Status.class);
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeEnum(this);
+        }
+
+        @Override
+        public String toString() {
+            return name().toLowerCase(Locale.ROOT);
+        }
+    }
 }

+ 8 - 2
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/stats/common/MemoryUsageTests.java

@@ -44,7 +44,13 @@ public class MemoryUsageTests extends AbstractSerializingTestCase<MemoryUsage> {
     }
 
     public static MemoryUsage createRandom() {
-        return new MemoryUsage(randomAlphaOfLength(10), Instant.now(), randomNonNegativeLong());
+        return new MemoryUsage(
+            randomAlphaOfLength(10),
+            Instant.now(),
+            randomNonNegativeLong(),
+            randomBoolean() ? null : randomFrom(MemoryUsage.Status.values()),
+            randomBoolean() ? null : randomNonNegativeLong()
+        );
     }
 
     @Override
@@ -60,6 +66,6 @@ public class MemoryUsageTests extends AbstractSerializingTestCase<MemoryUsage> {
     public void testZeroUsage() {
         MemoryUsage memoryUsage = new MemoryUsage("zero_usage_job");
         String asJson = Strings.toString(memoryUsage);
-        assertThat(asJson, equalTo("{\"peak_usage_bytes\":0}"));
+        assertThat(asJson, equalTo("{\"peak_usage_bytes\":0,\"status\":\"ok\"}"));
     }
 }

+ 6 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java

@@ -172,8 +172,7 @@ public class AnalyticsResultProcessor {
         }
         MemoryUsage memoryUsage = result.getMemoryUsage();
         if (memoryUsage != null) {
-            statsHolder.setMemoryUsage(memoryUsage);
-            statsPersister.persistWithRetry(memoryUsage, memoryUsage::documentId);
+            processMemoryUsage(memoryUsage);
         }
         OutlierDetectionStats outlierDetectionStats = result.getOutlierDetectionStats();
         if (outlierDetectionStats != null) {
@@ -273,4 +272,9 @@ public class AnalyticsResultProcessor {
         failure = "error processing results; " + e.getMessage();
         auditor.error(analytics.getId(), "Error processing results; " + e.getMessage());
     }
+
+    private void processMemoryUsage(MemoryUsage memoryUsage) {
+        statsHolder.setMemoryUsage(memoryUsage);
+        statsPersister.persistWithRetry(memoryUsage, memoryUsage::documentId);
+    }
 }

+ 1 - 0
x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml

@@ -931,6 +931,7 @@ setup:
   - match: { data_frame_analytics.0.data_counts.test_docs_count: 0 }
   - match: { data_frame_analytics.0.data_counts.skipped_docs_count: 0 }
   - match: { data_frame_analytics.0.memory_usage.peak_usage_bytes: 0 }
+  - match: { data_frame_analytics.0.memory_usage.status: "ok" }
 
 ---
 "Test delete given stopped config":