Kaynağa Gözat

[ML] Adding assignment_memory_basis to model_size_stats (#65561)

At present the Java code makes a decision on whether to
use current model memory or model memory limit to calculate
how much memory a job requires to be assigned.

The plan is to move this decision to the C++ code, which will
report it via a new field in the model size stats.  An
additional change will be that once we have made the switch
from using model memory limit to using current model memory
we will never switch back, as this causes large fluctuations
up and down in memory requirement which will be much more
noticeable when autoscaling is in use.

Although the only two options at present are model memory
limit and current model memory, the new enum includes a
third possibility, peak model memory.  To switch to this
now would be tricky, as there have been two bugs in the
implementation of peak model memory which render its value
unreliable in 7.x.  However, in 8.x it might make sense to
switch to using peak model memory instead of current model
memory and it's much easier from a BWC perspective if the
enum contains all the values from the start.

Relates #63163
David Roberts 4 yıl önce
ebeveyn
işleme
49e492f313

+ 53 - 5
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/job/process/ModelSizeStats.java

@@ -21,6 +21,7 @@ package org.elasticsearch.client.ml.job.process;
 import org.elasticsearch.client.common.TimeUtil;
 import org.elasticsearch.client.ml.job.config.Job;
 import org.elasticsearch.client.ml.job.results.Result;
+import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.xcontent.ConstructingObjectParser;
 import org.elasticsearch.common.xcontent.ObjectParser.ValueType;
@@ -55,6 +56,7 @@ public class ModelSizeStats implements ToXContentObject {
     public static final ParseField TOTAL_PARTITION_FIELD_COUNT_FIELD = new ParseField("total_partition_field_count");
     public static final ParseField BUCKET_ALLOCATION_FAILURES_COUNT_FIELD = new ParseField("bucket_allocation_failures_count");
     public static final ParseField MEMORY_STATUS_FIELD = new ParseField("memory_status");
+    public static final ParseField ASSIGNMENT_MEMORY_BASIS_FIELD = new ParseField("assignment_memory_basis");
     public static final ParseField CATEGORIZED_DOC_COUNT_FIELD = new ParseField("categorized_doc_count");
     public static final ParseField TOTAL_CATEGORY_COUNT_FIELD = new ParseField("total_category_count");
     public static final ParseField FREQUENT_CATEGORY_COUNT_FIELD = new ParseField("frequent_category_count");
@@ -79,6 +81,8 @@ public class ModelSizeStats implements ToXContentObject {
         PARSER.declareLong(Builder::setTotalOverFieldCount, TOTAL_OVER_FIELD_COUNT_FIELD);
         PARSER.declareLong(Builder::setTotalPartitionFieldCount, TOTAL_PARTITION_FIELD_COUNT_FIELD);
         PARSER.declareField(Builder::setMemoryStatus, p -> MemoryStatus.fromString(p.text()), MEMORY_STATUS_FIELD, ValueType.STRING);
+        PARSER.declareField(Builder::setAssignmentMemoryBasis,
+            p -> AssignmentMemoryBasis.fromString(p.text()), ASSIGNMENT_MEMORY_BASIS_FIELD, ValueType.STRING);
         PARSER.declareLong(Builder::setCategorizedDocCount, CATEGORIZED_DOC_COUNT_FIELD);
         PARSER.declareLong(Builder::setTotalCategoryCount, TOTAL_CATEGORY_COUNT_FIELD);
         PARSER.declareLong(Builder::setFrequentCategoryCount, FREQUENT_CATEGORY_COUNT_FIELD);
@@ -116,6 +120,29 @@ public class ModelSizeStats implements ToXContentObject {
         }
     }
 
+    /**
+     * Where will we get the memory requirement from when assigning this job to
+     * a node?  There are three possibilities:
+     * 1. The job's model_memory_limit
+     * 2. The current model memory, i.e. what's reported in model_bytes of this object
+     * 3. The peak model memory, i.e. what's reported in peak_model_bytes of this object
+     * The field storing this enum can also be <code>null</code>, which means the
+     * assignment code will decide on the fly - this was the old behaviour prior
+     * to 7.11.
+     */
+    public enum AssignmentMemoryBasis {
+        MODEL_MEMORY_LIMIT, CURRENT_MODEL_BYTES, PEAK_MODEL_BYTES;
+
+        public static AssignmentMemoryBasis fromString(String statusName) {
+            return valueOf(statusName.trim().toUpperCase(Locale.ROOT));
+        }
+
+        @Override
+        public String toString() {
+            return name().toLowerCase(Locale.ROOT);
+        }
+    }
+
     /**
      * The status of categorization for a job. OK is default, WARN
      * means that inappropriate numbers of categories are being found
@@ -143,6 +170,7 @@ public class ModelSizeStats implements ToXContentObject {
     private final long totalPartitionFieldCount;
     private final long bucketAllocationFailuresCount;
     private final MemoryStatus memoryStatus;
+    private final AssignmentMemoryBasis assignmentMemoryBasis;
     private final long categorizedDocCount;
     private final long totalCategoryCount;
     private final long frequentCategoryCount;
@@ -155,7 +183,8 @@ public class ModelSizeStats implements ToXContentObject {
 
     private ModelSizeStats(String jobId, long modelBytes, Long peakModelBytes, Long modelBytesExceeded, Long modelBytesMemoryLimit,
                            long totalByFieldCount, long totalOverFieldCount, long totalPartitionFieldCount,
-                           long bucketAllocationFailuresCount, MemoryStatus memoryStatus, long categorizedDocCount, long totalCategoryCount,
+                           long bucketAllocationFailuresCount, MemoryStatus memoryStatus,
+                           AssignmentMemoryBasis assignmentMemoryBasis, long categorizedDocCount, long totalCategoryCount,
                            long frequentCategoryCount, long rareCategoryCount, long deadCategoryCount, long failedCategoryCount,
                            CategorizationStatus categorizationStatus, Date timestamp, Date logTime) {
         this.jobId = jobId;
@@ -168,6 +197,7 @@ public class ModelSizeStats implements ToXContentObject {
         this.totalPartitionFieldCount = totalPartitionFieldCount;
         this.bucketAllocationFailuresCount = bucketAllocationFailuresCount;
         this.memoryStatus = memoryStatus;
+        this.assignmentMemoryBasis = assignmentMemoryBasis;
         this.categorizedDocCount = categorizedDocCount;
         this.totalCategoryCount = totalCategoryCount;
         this.frequentCategoryCount = frequentCategoryCount;
@@ -200,6 +230,9 @@ public class ModelSizeStats implements ToXContentObject {
         builder.field(TOTAL_PARTITION_FIELD_COUNT_FIELD.getPreferredName(), totalPartitionFieldCount);
         builder.field(BUCKET_ALLOCATION_FAILURES_COUNT_FIELD.getPreferredName(), bucketAllocationFailuresCount);
         builder.field(MEMORY_STATUS_FIELD.getPreferredName(), memoryStatus);
+        if (assignmentMemoryBasis != null) {
+            builder.field(ASSIGNMENT_MEMORY_BASIS_FIELD.getPreferredName(), assignmentMemoryBasis);
+        }
         builder.field(CATEGORIZED_DOC_COUNT_FIELD.getPreferredName(), categorizedDocCount);
         builder.field(TOTAL_CATEGORY_COUNT_FIELD.getPreferredName(), totalCategoryCount);
         builder.field(FREQUENT_CATEGORY_COUNT_FIELD.getPreferredName(), frequentCategoryCount);
@@ -256,6 +289,11 @@ public class ModelSizeStats implements ToXContentObject {
         return memoryStatus;
     }
 
+    @Nullable
+    public AssignmentMemoryBasis getAssignmentMemoryBasis() {
+        return assignmentMemoryBasis;
+    }
+
     public long getCategorizedDocCount() {
         return categorizedDocCount;
     }
@@ -306,8 +344,9 @@ public class ModelSizeStats implements ToXContentObject {
     public int hashCode() {
         return Objects.hash(
             jobId, modelBytes, peakModelBytes, modelBytesExceeded, modelBytesMemoryLimit, totalByFieldCount, totalOverFieldCount,
-            totalPartitionFieldCount, this.bucketAllocationFailuresCount, memoryStatus, categorizedDocCount, totalCategoryCount,
-            frequentCategoryCount, rareCategoryCount, deadCategoryCount, failedCategoryCount, categorizationStatus, timestamp, logTime);
+            totalPartitionFieldCount, this.bucketAllocationFailuresCount, memoryStatus, assignmentMemoryBasis, categorizedDocCount,
+            totalCategoryCount, frequentCategoryCount, rareCategoryCount, deadCategoryCount, failedCategoryCount, categorizationStatus,
+            timestamp, logTime);
     }
 
     /**
@@ -332,6 +371,7 @@ public class ModelSizeStats implements ToXContentObject {
             && this.totalOverFieldCount == that.totalOverFieldCount && this.totalPartitionFieldCount == that.totalPartitionFieldCount
             && this.bucketAllocationFailuresCount == that.bucketAllocationFailuresCount
             && Objects.equals(this.memoryStatus, that.memoryStatus)
+            && Objects.equals(this.assignmentMemoryBasis, that.assignmentMemoryBasis)
             && this.categorizedDocCount == that.categorizedDocCount
             && this.totalCategoryCount == that.totalCategoryCount
             && this.frequentCategoryCount == that.frequentCategoryCount
@@ -356,6 +396,7 @@ public class ModelSizeStats implements ToXContentObject {
         private long totalPartitionFieldCount;
         private long bucketAllocationFailuresCount;
         private MemoryStatus memoryStatus;
+        private AssignmentMemoryBasis assignmentMemoryBasis;
         private long categorizedDocCount;
         private long totalCategoryCount;
         private long frequentCategoryCount;
@@ -384,6 +425,7 @@ public class ModelSizeStats implements ToXContentObject {
             this.totalPartitionFieldCount = modelSizeStats.totalPartitionFieldCount;
             this.bucketAllocationFailuresCount = modelSizeStats.bucketAllocationFailuresCount;
             this.memoryStatus = modelSizeStats.memoryStatus;
+            this.assignmentMemoryBasis = modelSizeStats.assignmentMemoryBasis;
             this.categorizedDocCount = modelSizeStats.categorizedDocCount;
             this.totalCategoryCount = modelSizeStats.totalCategoryCount;
             this.frequentCategoryCount = modelSizeStats.frequentCategoryCount;
@@ -441,6 +483,11 @@ public class ModelSizeStats implements ToXContentObject {
             return this;
         }
 
+        public Builder setAssignmentMemoryBasis(AssignmentMemoryBasis assignmentMemoryBasis) {
+            this.assignmentMemoryBasis = assignmentMemoryBasis;
+            return this;
+        }
+
         public Builder setCategorizedDocCount(long categorizedDocCount) {
             this.categorizedDocCount = categorizedDocCount;
             return this;
@@ -490,8 +537,9 @@ public class ModelSizeStats implements ToXContentObject {
         public ModelSizeStats build() {
             return new ModelSizeStats(
                 jobId, modelBytes, peakModelBytes, modelBytesExceeded, modelBytesMemoryLimit, totalByFieldCount, totalOverFieldCount,
-                totalPartitionFieldCount, bucketAllocationFailuresCount, memoryStatus, categorizedDocCount, totalCategoryCount,
-                frequentCategoryCount, rareCategoryCount, deadCategoryCount, failedCategoryCount, categorizationStatus, timestamp, logTime);
+                totalPartitionFieldCount, bucketAllocationFailuresCount, memoryStatus, assignmentMemoryBasis, categorizedDocCount,
+                totalCategoryCount, frequentCategoryCount, rareCategoryCount, deadCategoryCount, failedCategoryCount, categorizationStatus,
+                timestamp, logTime);
         }
     }
 }

+ 126 - 31
client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningGetResultsIT.java

@@ -61,6 +61,7 @@ import static org.hamcrest.Matchers.greaterThanOrEqualTo;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.lessThan;
 import static org.hamcrest.Matchers.lessThanOrEqualTo;
+import static org.hamcrest.Matchers.nullValue;
 
 public class MachineLearningGetResultsIT extends ESRestHighLevelClientTestCase {
 
@@ -71,8 +72,8 @@ public class MachineLearningGetResultsIT extends ESRestHighLevelClientTestCase {
     // 2018-08-01T00:00:00Z
     private static final long START_TIME_EPOCH_MS = 1533081600000L;
 
-    private Stats bucketStats = new Stats();
-    private Stats recordStats = new Stats();
+    private final Stats bucketStats = new Stats();
+    private final Stats recordStats = new Stats();
 
     @Before
     public void createJobAndIndexResults() throws IOException {
@@ -149,10 +150,10 @@ public class MachineLearningGetResultsIT extends ESRestHighLevelClientTestCase {
     }
 
     private void addModelSnapshotIndexRequests(BulkRequest bulkRequest) {
+        // Index a number of model snapshots, one of which contains the new model_size_stats fields
+        // 'model_bytes_exceeded' and 'model_bytes_memory_limit' that were introduced in 7.2.0.
+        // We want to verify that we can parse the snapshots whether or not these fields are present.
         {
-            // Index a number of model snapshots, one of which contains the new model_size_stats fields
-            // 'model_bytes_exceeded' and 'model_bytes_memory_limit' that were introduced in 7.2.0.
-            // We want to verify that we can parse the snapshots whether or not these fields are present.
             IndexRequest indexRequest = new IndexRequest(RESULTS_INDEX);
             indexRequest.source("{\"job_id\":\"" + JOB_ID + "\", \"timestamp\":1541587919000, " +
                 "\"description\":\"State persisted due to job close at 2018-11-07T10:51:59+0000\", \"snapshot_id\":\"1541587919\"," +
@@ -164,6 +165,19 @@ public class MachineLearningGetResultsIT extends ESRestHighLevelClientTestCase {
                 " \"retain\":false }", XContentType.JSON);
             bulkRequest.add(indexRequest);
         }
+        // Also index one that contains 'memory_assignment_basis', which was added in 7.11
+        {
+            IndexRequest indexRequest = new IndexRequest(RESULTS_INDEX);
+            indexRequest.source("{\"job_id\":\"" + JOB_ID + "\", \"timestamp\":1541587929000, " +
+                "\"description\":\"State persisted due to job close at 2018-11-07T10:52:09+0000\", \"snapshot_id\":\"1541587929\"," +
+                "\"snapshot_doc_count\":1, \"model_size_stats\":{\"job_id\":\"" + JOB_ID + "\", \"result_type\":\"model_size_stats\"," +
+                "\"model_bytes\":51722, \"peak_model_bytes\":61322, \"model_bytes_exceeded\":10762, \"model_bytes_memory_limit\":40960," +
+                "\"total_by_field_count\":3, \"total_over_field_count\":0, \"total_partition_field_count\":2," +
+                "\"bucket_allocation_failures_count\":0, \"memory_status\":\"ok\", \"assignment_memory_basis\":\"model_memory_limit\"," +
+                " \"log_time\":1541587929000, \"timestamp\":1519930800000},\"latest_record_time_stamp\":1519931700000," +
+                "\"latest_result_time_stamp\":1519930800000, \"retain\":false }", XContentType.JSON);
+            bulkRequest.add(indexRequest);
+        }
         {
             IndexRequest indexRequest = new IndexRequest(RESULTS_INDEX);
             indexRequest.source("{\"job_id\":\"" + JOB_ID + "\", \"timestamp\":1541588919000, " +
@@ -214,8 +228,8 @@ public class MachineLearningGetResultsIT extends ESRestHighLevelClientTestCase {
             GetModelSnapshotsResponse response = execute(request, machineLearningClient::getModelSnapshots,
                 machineLearningClient::getModelSnapshotsAsync);
 
-            assertThat(response.count(), equalTo(3L));
-            assertThat(response.snapshots().size(), equalTo(3));
+            assertThat(response.count(), equalTo(4L));
+            assertThat(response.snapshots().size(), equalTo(4));
             assertThat(response.snapshots().get(0).getJobId(), equalTo(JOB_ID));
             assertThat(response.snapshots().get(0).getSnapshotId(), equalTo("1541587919"));
             assertThat(response.snapshots().get(0).getSnapshotDocCount(), equalTo(1));
@@ -236,35 +250,38 @@ public class MachineLearningGetResultsIT extends ESRestHighLevelClientTestCase {
             assertThat(response.snapshots().get(0).getModelSizeStats().getBucketAllocationFailuresCount(), equalTo(0L));
             assertThat(response.snapshots().get(0).getModelSizeStats().getMemoryStatus(),
                 equalTo(ModelSizeStats.MemoryStatus.fromString("ok")));
+            assertThat(response.snapshots().get(0).getModelSizeStats().getAssignmentMemoryBasis(), nullValue());
 
             assertThat(response.snapshots().get(1).getJobId(), equalTo(JOB_ID));
-            assertThat(response.snapshots().get(1).getSnapshotId(), equalTo("1541588919"));
+            assertThat(response.snapshots().get(1).getSnapshotId(), equalTo("1541587929"));
             assertThat(response.snapshots().get(1).getSnapshotDocCount(), equalTo(1));
             assertThat(response.snapshots().get(1).getDescription(), equalTo("State persisted due to job close at" +
-                " 2018-11-07T11:08:39+0000"));
+                " 2018-11-07T10:52:09+0000"));
             assertThat(response.snapshots().get(1).getSnapshotDocCount(), equalTo(1));
-            assertThat(response.snapshots().get(1).getTimestamp(), equalTo(new Date(1541588919000L)));
+            assertThat(response.snapshots().get(1).getTimestamp(), equalTo(new Date(1541587929000L)));
             assertThat(response.snapshots().get(1).getLatestRecordTimeStamp(), equalTo(new Date(1519931700000L)));
             assertThat(response.snapshots().get(1).getLatestResultTimeStamp(), equalTo(new Date(1519930800000L)));
             assertThat(response.snapshots().get(1).getModelSizeStats().getJobId(), equalTo(JOB_ID));
             assertThat(response.snapshots().get(1).getModelSizeStats().getModelBytes(), equalTo(51722L));
             assertThat(response.snapshots().get(1).getModelSizeStats().getPeakModelBytes(), equalTo(61322L));
-            assertThat(response.snapshots().get(1).getModelSizeStats().getModelBytesExceeded(), equalTo(null));
-            assertThat(response.snapshots().get(1).getModelSizeStats().getModelBytesMemoryLimit(), equalTo(null));
+            assertThat(response.snapshots().get(1).getModelSizeStats().getModelBytesExceeded(), equalTo(10762L));
+            assertThat(response.snapshots().get(1).getModelSizeStats().getModelBytesMemoryLimit(), equalTo(40960L));
             assertThat(response.snapshots().get(1).getModelSizeStats().getTotalByFieldCount(), equalTo(3L));
             assertThat(response.snapshots().get(1).getModelSizeStats().getTotalOverFieldCount(), equalTo(0L));
             assertThat(response.snapshots().get(1).getModelSizeStats().getTotalPartitionFieldCount(), equalTo(2L));
             assertThat(response.snapshots().get(1).getModelSizeStats().getBucketAllocationFailuresCount(), equalTo(0L));
             assertThat(response.snapshots().get(1).getModelSizeStats().getMemoryStatus(),
                 equalTo(ModelSizeStats.MemoryStatus.fromString("ok")));
+            assertThat(response.snapshots().get(1).getModelSizeStats().getAssignmentMemoryBasis(),
+                equalTo(ModelSizeStats.AssignmentMemoryBasis.MODEL_MEMORY_LIMIT));
 
             assertThat(response.snapshots().get(2).getJobId(), equalTo(JOB_ID));
-            assertThat(response.snapshots().get(2).getSnapshotId(), equalTo("1541589919"));
+            assertThat(response.snapshots().get(2).getSnapshotId(), equalTo("1541588919"));
             assertThat(response.snapshots().get(2).getSnapshotDocCount(), equalTo(1));
             assertThat(response.snapshots().get(2).getDescription(), equalTo("State persisted due to job close at" +
-                " 2018-11-07T11:25:19+0000"));
+                " 2018-11-07T11:08:39+0000"));
             assertThat(response.snapshots().get(2).getSnapshotDocCount(), equalTo(1));
-            assertThat(response.snapshots().get(2).getTimestamp(), equalTo(new Date(1541589919000L)));
+            assertThat(response.snapshots().get(2).getTimestamp(), equalTo(new Date(1541588919000L)));
             assertThat(response.snapshots().get(2).getLatestRecordTimeStamp(), equalTo(new Date(1519931700000L)));
             assertThat(response.snapshots().get(2).getLatestResultTimeStamp(), equalTo(new Date(1519930800000L)));
             assertThat(response.snapshots().get(2).getModelSizeStats().getJobId(), equalTo(JOB_ID));
@@ -278,6 +295,29 @@ public class MachineLearningGetResultsIT extends ESRestHighLevelClientTestCase {
             assertThat(response.snapshots().get(2).getModelSizeStats().getBucketAllocationFailuresCount(), equalTo(0L));
             assertThat(response.snapshots().get(2).getModelSizeStats().getMemoryStatus(),
                 equalTo(ModelSizeStats.MemoryStatus.fromString("ok")));
+            assertThat(response.snapshots().get(2).getModelSizeStats().getAssignmentMemoryBasis(), nullValue());
+
+            assertThat(response.snapshots().get(3).getJobId(), equalTo(JOB_ID));
+            assertThat(response.snapshots().get(3).getSnapshotId(), equalTo("1541589919"));
+            assertThat(response.snapshots().get(3).getSnapshotDocCount(), equalTo(1));
+            assertThat(response.snapshots().get(3).getDescription(), equalTo("State persisted due to job close at" +
+                " 2018-11-07T11:25:19+0000"));
+            assertThat(response.snapshots().get(3).getSnapshotDocCount(), equalTo(1));
+            assertThat(response.snapshots().get(3).getTimestamp(), equalTo(new Date(1541589919000L)));
+            assertThat(response.snapshots().get(3).getLatestRecordTimeStamp(), equalTo(new Date(1519931700000L)));
+            assertThat(response.snapshots().get(3).getLatestResultTimeStamp(), equalTo(new Date(1519930800000L)));
+            assertThat(response.snapshots().get(3).getModelSizeStats().getJobId(), equalTo(JOB_ID));
+            assertThat(response.snapshots().get(3).getModelSizeStats().getModelBytes(), equalTo(51722L));
+            assertThat(response.snapshots().get(3).getModelSizeStats().getPeakModelBytes(), equalTo(61322L));
+            assertThat(response.snapshots().get(3).getModelSizeStats().getModelBytesExceeded(), equalTo(null));
+            assertThat(response.snapshots().get(3).getModelSizeStats().getModelBytesMemoryLimit(), equalTo(null));
+            assertThat(response.snapshots().get(3).getModelSizeStats().getTotalByFieldCount(), equalTo(3L));
+            assertThat(response.snapshots().get(3).getModelSizeStats().getTotalOverFieldCount(), equalTo(0L));
+            assertThat(response.snapshots().get(3).getModelSizeStats().getTotalPartitionFieldCount(), equalTo(2L));
+            assertThat(response.snapshots().get(3).getModelSizeStats().getBucketAllocationFailuresCount(), equalTo(0L));
+            assertThat(response.snapshots().get(3).getModelSizeStats().getMemoryStatus(),
+                equalTo(ModelSizeStats.MemoryStatus.fromString("ok")));
+            assertThat(response.snapshots().get(3).getModelSizeStats().getAssignmentMemoryBasis(), nullValue());
         }
         {
             GetModelSnapshotsRequest request = new GetModelSnapshotsRequest(JOB_ID);
@@ -288,15 +328,37 @@ public class MachineLearningGetResultsIT extends ESRestHighLevelClientTestCase {
             GetModelSnapshotsResponse response = execute(request, machineLearningClient::getModelSnapshots,
                 machineLearningClient::getModelSnapshotsAsync);
 
-            assertThat(response.count(), equalTo(3L));
-            assertThat(response.snapshots().size(), equalTo(3));
+            assertThat(response.count(), equalTo(4L));
+            assertThat(response.snapshots().size(), equalTo(4));
+            assertThat(response.snapshots().get(3).getJobId(), equalTo(JOB_ID));
+            assertThat(response.snapshots().get(3).getSnapshotId(), equalTo("1541587919"));
+            assertThat(response.snapshots().get(3).getSnapshotDocCount(), equalTo(1));
+            assertThat(response.snapshots().get(3).getDescription(), equalTo("State persisted due to job close at" +
+                " 2018-11-07T10:51:59+0000"));
+            assertThat(response.snapshots().get(3).getSnapshotDocCount(), equalTo(1));
+            assertThat(response.snapshots().get(3).getTimestamp(), equalTo(new Date(1541587919000L)));
+            assertThat(response.snapshots().get(3).getLatestRecordTimeStamp(), equalTo(new Date(1519931700000L)));
+            assertThat(response.snapshots().get(3).getLatestResultTimeStamp(), equalTo(new Date(1519930800000L)));
+            assertThat(response.snapshots().get(3).getModelSizeStats().getJobId(), equalTo(JOB_ID));
+            assertThat(response.snapshots().get(3).getModelSizeStats().getModelBytes(), equalTo(51722L));
+            assertThat(response.snapshots().get(3).getModelSizeStats().getPeakModelBytes(), equalTo(61322L));
+            assertThat(response.snapshots().get(3).getModelSizeStats().getModelBytesExceeded(), equalTo(10762L));
+            assertThat(response.snapshots().get(3).getModelSizeStats().getModelBytesMemoryLimit(), equalTo(40960L));
+            assertThat(response.snapshots().get(3).getModelSizeStats().getTotalByFieldCount(), equalTo(3L));
+            assertThat(response.snapshots().get(3).getModelSizeStats().getTotalOverFieldCount(), equalTo(0L));
+            assertThat(response.snapshots().get(3).getModelSizeStats().getTotalPartitionFieldCount(), equalTo(2L));
+            assertThat(response.snapshots().get(3).getModelSizeStats().getBucketAllocationFailuresCount(), equalTo(0L));
+            assertThat(response.snapshots().get(3).getModelSizeStats().getMemoryStatus(),
+                equalTo(ModelSizeStats.MemoryStatus.fromString("ok")));
+            assertThat(response.snapshots().get(3).getModelSizeStats().getAssignmentMemoryBasis(), nullValue());
+
             assertThat(response.snapshots().get(2).getJobId(), equalTo(JOB_ID));
-            assertThat(response.snapshots().get(2).getSnapshotId(), equalTo("1541587919"));
+            assertThat(response.snapshots().get(2).getSnapshotId(), equalTo("1541587929"));
             assertThat(response.snapshots().get(2).getSnapshotDocCount(), equalTo(1));
             assertThat(response.snapshots().get(2).getDescription(), equalTo("State persisted due to job close at" +
-                " 2018-11-07T10:51:59+0000"));
+                " 2018-11-07T10:52:09+0000"));
             assertThat(response.snapshots().get(2).getSnapshotDocCount(), equalTo(1));
-            assertThat(response.snapshots().get(2).getTimestamp(), equalTo(new Date(1541587919000L)));
+            assertThat(response.snapshots().get(2).getTimestamp(), equalTo(new Date(1541587929000L)));
             assertThat(response.snapshots().get(2).getLatestRecordTimeStamp(), equalTo(new Date(1519931700000L)));
             assertThat(response.snapshots().get(2).getLatestResultTimeStamp(), equalTo(new Date(1519930800000L)));
             assertThat(response.snapshots().get(2).getModelSizeStats().getJobId(), equalTo(JOB_ID));
@@ -310,6 +372,8 @@ public class MachineLearningGetResultsIT extends ESRestHighLevelClientTestCase {
             assertThat(response.snapshots().get(2).getModelSizeStats().getBucketAllocationFailuresCount(), equalTo(0L));
             assertThat(response.snapshots().get(2).getModelSizeStats().getMemoryStatus(),
                 equalTo(ModelSizeStats.MemoryStatus.fromString("ok")));
+            assertThat(response.snapshots().get(2).getModelSizeStats().getAssignmentMemoryBasis(),
+                equalTo(ModelSizeStats.AssignmentMemoryBasis.MODEL_MEMORY_LIMIT));
 
             assertThat(response.snapshots().get(1).getJobId(), equalTo(JOB_ID));
             assertThat(response.snapshots().get(1).getSnapshotId(), equalTo("1541588919"));
@@ -331,6 +395,7 @@ public class MachineLearningGetResultsIT extends ESRestHighLevelClientTestCase {
             assertThat(response.snapshots().get(1).getModelSizeStats().getBucketAllocationFailuresCount(), equalTo(0L));
             assertThat(response.snapshots().get(1).getModelSizeStats().getMemoryStatus(),
                 equalTo(ModelSizeStats.MemoryStatus.fromString("ok")));
+            assertThat(response.snapshots().get(1).getModelSizeStats().getAssignmentMemoryBasis(), nullValue());
 
             assertThat(response.snapshots().get(0).getJobId(), equalTo(JOB_ID));
             assertThat(response.snapshots().get(0).getSnapshotId(), equalTo("1541589919"));
@@ -352,6 +417,7 @@ public class MachineLearningGetResultsIT extends ESRestHighLevelClientTestCase {
             assertThat(response.snapshots().get(0).getModelSizeStats().getBucketAllocationFailuresCount(), equalTo(0L));
             assertThat(response.snapshots().get(0).getModelSizeStats().getMemoryStatus(),
                 equalTo(ModelSizeStats.MemoryStatus.fromString("ok")));
+            assertThat(response.snapshots().get(0).getModelSizeStats().getAssignmentMemoryBasis(), nullValue());
         }
         {
             GetModelSnapshotsRequest request = new GetModelSnapshotsRequest(JOB_ID);
@@ -362,7 +428,7 @@ public class MachineLearningGetResultsIT extends ESRestHighLevelClientTestCase {
             GetModelSnapshotsResponse response = execute(request, machineLearningClient::getModelSnapshots,
                 machineLearningClient::getModelSnapshotsAsync);
 
-            assertThat(response.count(), equalTo(3L));
+            assertThat(response.count(), equalTo(4L));
             assertThat(response.snapshots().size(), equalTo(1));
             assertThat(response.snapshots().get(0).getJobId(), equalTo(JOB_ID));
             assertThat(response.snapshots().get(0).getSnapshotId(), equalTo("1541587919"));
@@ -384,17 +450,18 @@ public class MachineLearningGetResultsIT extends ESRestHighLevelClientTestCase {
             assertThat(response.snapshots().get(0).getModelSizeStats().getBucketAllocationFailuresCount(), equalTo(0L));
             assertThat(response.snapshots().get(0).getModelSizeStats().getMemoryStatus(),
                 equalTo(ModelSizeStats.MemoryStatus.fromString("ok")));
+            assertThat(response.snapshots().get(0).getModelSizeStats().getAssignmentMemoryBasis(), nullValue());
         }
         {
             GetModelSnapshotsRequest request = new GetModelSnapshotsRequest(JOB_ID);
             request.setSort("timestamp");
             request.setDesc(false);
-            request.setPageParams(new PageParams(1, 2));
+            request.setPageParams(new PageParams(2, 3));
 
             GetModelSnapshotsResponse response = execute(request, machineLearningClient::getModelSnapshots,
                 machineLearningClient::getModelSnapshotsAsync);
 
-            assertThat(response.count(), equalTo(3L));
+            assertThat(response.count(), equalTo(4L));
             assertThat(response.snapshots().size(), equalTo(2));
 
             assertThat(response.snapshots().get(0).getJobId(), equalTo(JOB_ID));
@@ -417,7 +484,7 @@ public class MachineLearningGetResultsIT extends ESRestHighLevelClientTestCase {
             assertThat(response.snapshots().get(0).getModelSizeStats().getBucketAllocationFailuresCount(), equalTo(0L));
             assertThat(response.snapshots().get(0).getModelSizeStats().getMemoryStatus(),
                 equalTo(ModelSizeStats.MemoryStatus.fromString("ok")));
-
+            assertThat(response.snapshots().get(0).getModelSizeStats().getAssignmentMemoryBasis(), nullValue());
 
             assertThat(response.snapshots().get(1).getJobId(), equalTo(JOB_ID));
             assertThat(response.snapshots().get(1).getSnapshotId(), equalTo("1541589919"));
@@ -439,6 +506,7 @@ public class MachineLearningGetResultsIT extends ESRestHighLevelClientTestCase {
             assertThat(response.snapshots().get(1).getModelSizeStats().getBucketAllocationFailuresCount(), equalTo(0L));
             assertThat(response.snapshots().get(1).getModelSizeStats().getMemoryStatus(),
                 equalTo(ModelSizeStats.MemoryStatus.fromString("ok")));
+            assertThat(response.snapshots().get(1).getModelSizeStats().getAssignmentMemoryBasis(), nullValue());
         }
         {
             GetModelSnapshotsRequest request = new GetModelSnapshotsRequest(JOB_ID);
@@ -470,6 +538,7 @@ public class MachineLearningGetResultsIT extends ESRestHighLevelClientTestCase {
             assertThat(response.snapshots().get(0).getModelSizeStats().getBucketAllocationFailuresCount(), equalTo(0L));
             assertThat(response.snapshots().get(0).getModelSizeStats().getMemoryStatus(),
                 equalTo(ModelSizeStats.MemoryStatus.fromString("ok")));
+            assertThat(response.snapshots().get(0).getModelSizeStats().getAssignmentMemoryBasis(), nullValue());
         }
         {
             GetModelSnapshotsRequest request = new GetModelSnapshotsRequest(JOB_ID);
@@ -491,8 +560,8 @@ public class MachineLearningGetResultsIT extends ESRestHighLevelClientTestCase {
             GetModelSnapshotsResponse response = execute(request, machineLearningClient::getModelSnapshots,
                 machineLearningClient::getModelSnapshotsAsync);
 
-            assertThat(response.count(), equalTo(2L));
-            assertThat(response.snapshots().size(), equalTo(2));
+            assertThat(response.count(), equalTo(3L));
+            assertThat(response.snapshots().size(), equalTo(3));
             assertThat(response.snapshots().get(0).getJobId(), equalTo(JOB_ID));
             assertThat(response.snapshots().get(0).getSnapshotId(), equalTo("1541587919"));
             assertThat(response.snapshots().get(0).getSnapshotDocCount(), equalTo(1));
@@ -513,27 +582,52 @@ public class MachineLearningGetResultsIT extends ESRestHighLevelClientTestCase {
             assertThat(response.snapshots().get(0).getModelSizeStats().getBucketAllocationFailuresCount(), equalTo(0L));
             assertThat(response.snapshots().get(0).getModelSizeStats().getMemoryStatus(),
                 equalTo(ModelSizeStats.MemoryStatus.fromString("ok")));
+            assertThat(response.snapshots().get(0).getModelSizeStats().getAssignmentMemoryBasis(), nullValue());
 
             assertThat(response.snapshots().get(1).getJobId(), equalTo(JOB_ID));
-            assertThat(response.snapshots().get(1).getSnapshotId(), equalTo("1541588919"));
+            assertThat(response.snapshots().get(1).getSnapshotId(), equalTo("1541587929"));
             assertThat(response.snapshots().get(1).getSnapshotDocCount(), equalTo(1));
             assertThat(response.snapshots().get(1).getDescription(), equalTo("State persisted due to job close at" +
-                " 2018-11-07T11:08:39+0000"));
+                " 2018-11-07T10:52:09+0000"));
             assertThat(response.snapshots().get(1).getSnapshotDocCount(), equalTo(1));
-            assertThat(response.snapshots().get(1).getTimestamp(), equalTo(new Date(1541588919000L)));
+            assertThat(response.snapshots().get(1).getTimestamp(), equalTo(new Date(1541587929000L)));
             assertThat(response.snapshots().get(1).getLatestRecordTimeStamp(), equalTo(new Date(1519931700000L)));
             assertThat(response.snapshots().get(1).getLatestResultTimeStamp(), equalTo(new Date(1519930800000L)));
             assertThat(response.snapshots().get(1).getModelSizeStats().getJobId(), equalTo(JOB_ID));
             assertThat(response.snapshots().get(1).getModelSizeStats().getModelBytes(), equalTo(51722L));
             assertThat(response.snapshots().get(1).getModelSizeStats().getPeakModelBytes(), equalTo(61322L));
-            assertThat(response.snapshots().get(1).getModelSizeStats().getModelBytesExceeded(), equalTo(null));
-            assertThat(response.snapshots().get(1).getModelSizeStats().getModelBytesMemoryLimit(), equalTo(null));
+            assertThat(response.snapshots().get(1).getModelSizeStats().getModelBytesExceeded(), equalTo(10762L));
+            assertThat(response.snapshots().get(1).getModelSizeStats().getModelBytesMemoryLimit(), equalTo(40960L));
             assertThat(response.snapshots().get(1).getModelSizeStats().getTotalByFieldCount(), equalTo(3L));
             assertThat(response.snapshots().get(1).getModelSizeStats().getTotalOverFieldCount(), equalTo(0L));
             assertThat(response.snapshots().get(1).getModelSizeStats().getTotalPartitionFieldCount(), equalTo(2L));
             assertThat(response.snapshots().get(1).getModelSizeStats().getBucketAllocationFailuresCount(), equalTo(0L));
             assertThat(response.snapshots().get(1).getModelSizeStats().getMemoryStatus(),
                 equalTo(ModelSizeStats.MemoryStatus.fromString("ok")));
+            assertThat(response.snapshots().get(1).getModelSizeStats().getAssignmentMemoryBasis(),
+                equalTo(ModelSizeStats.AssignmentMemoryBasis.MODEL_MEMORY_LIMIT));
+
+            assertThat(response.snapshots().get(2).getJobId(), equalTo(JOB_ID));
+            assertThat(response.snapshots().get(2).getSnapshotId(), equalTo("1541588919"));
+            assertThat(response.snapshots().get(2).getSnapshotDocCount(), equalTo(1));
+            assertThat(response.snapshots().get(2).getDescription(), equalTo("State persisted due to job close at" +
+                " 2018-11-07T11:08:39+0000"));
+            assertThat(response.snapshots().get(2).getSnapshotDocCount(), equalTo(1));
+            assertThat(response.snapshots().get(2).getTimestamp(), equalTo(new Date(1541588919000L)));
+            assertThat(response.snapshots().get(2).getLatestRecordTimeStamp(), equalTo(new Date(1519931700000L)));
+            assertThat(response.snapshots().get(2).getLatestResultTimeStamp(), equalTo(new Date(1519930800000L)));
+            assertThat(response.snapshots().get(2).getModelSizeStats().getJobId(), equalTo(JOB_ID));
+            assertThat(response.snapshots().get(2).getModelSizeStats().getModelBytes(), equalTo(51722L));
+            assertThat(response.snapshots().get(2).getModelSizeStats().getPeakModelBytes(), equalTo(61322L));
+            assertThat(response.snapshots().get(2).getModelSizeStats().getModelBytesExceeded(), equalTo(null));
+            assertThat(response.snapshots().get(2).getModelSizeStats().getModelBytesMemoryLimit(), equalTo(null));
+            assertThat(response.snapshots().get(2).getModelSizeStats().getTotalByFieldCount(), equalTo(3L));
+            assertThat(response.snapshots().get(2).getModelSizeStats().getTotalOverFieldCount(), equalTo(0L));
+            assertThat(response.snapshots().get(2).getModelSizeStats().getTotalPartitionFieldCount(), equalTo(2L));
+            assertThat(response.snapshots().get(2).getModelSizeStats().getBucketAllocationFailuresCount(), equalTo(0L));
+            assertThat(response.snapshots().get(2).getModelSizeStats().getMemoryStatus(),
+                equalTo(ModelSizeStats.MemoryStatus.fromString("ok")));
+            assertThat(response.snapshots().get(2).getModelSizeStats().getAssignmentMemoryBasis(), nullValue());
         }
         {
             GetModelSnapshotsRequest request = new GetModelSnapshotsRequest(JOB_ID);
@@ -566,6 +660,7 @@ public class MachineLearningGetResultsIT extends ESRestHighLevelClientTestCase {
             assertThat(response.snapshots().get(0).getModelSizeStats().getBucketAllocationFailuresCount(), equalTo(0L));
             assertThat(response.snapshots().get(0).getModelSizeStats().getMemoryStatus(),
                 equalTo(ModelSizeStats.MemoryStatus.fromString("ok")));
+            assertThat(response.snapshots().get(0).getModelSizeStats().getAssignmentMemoryBasis(), nullValue());
         }
     }
 

+ 5 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/ml/job/process/ModelSizeStatsTests.java

@@ -24,6 +24,7 @@ import org.elasticsearch.test.AbstractXContentTestCase;
 
 import java.util.Date;
 
+import static org.elasticsearch.client.ml.job.process.ModelSizeStats.AssignmentMemoryBasis;
 import static org.elasticsearch.client.ml.job.process.ModelSizeStats.CategorizationStatus;
 import static org.elasticsearch.client.ml.job.process.ModelSizeStats.MemoryStatus;
 
@@ -40,6 +41,7 @@ public class ModelSizeStatsTests extends AbstractXContentTestCase<ModelSizeStats
         assertEquals(0, stats.getTotalPartitionFieldCount());
         assertEquals(0, stats.getBucketAllocationFailuresCount());
         assertEquals(MemoryStatus.OK, stats.getMemoryStatus());
+        assertNull(stats.getAssignmentMemoryBasis());
         assertEquals(0, stats.getCategorizedDocCount());
         assertEquals(0, stats.getTotalCategoryCount());
         assertEquals(0, stats.getFrequentCategoryCount());
@@ -99,6 +101,9 @@ public class ModelSizeStatsTests extends AbstractXContentTestCase<ModelSizeStats
         if (randomBoolean()) {
             stats.setMemoryStatus(randomFrom(MemoryStatus.values()));
         }
+        if (randomBoolean()) {
+            stats.setAssignmentMemoryBasis(randomFrom(AssignmentMemoryBasis.values()));
+        }
         if (randomBoolean()) {
             stats.setCategorizedDocCount(randomNonNegativeLong());
         }

+ 4 - 0
docs/reference/ml/anomaly-detection/apis/get-job-stats.asciidoc

@@ -198,6 +198,10 @@ model.
 .Properties of `model_size_stats`
 [%collapsible%open]
 ====
+`assignment_memory_basis`:::
+(string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=assignment-memory-basis]
+
 `bucket_allocation_failures_count`:::
 (long)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=bucket-allocation-failures-count]

+ 17 - 0
docs/reference/ml/ml-shared.asciidoc

@@ -108,6 +108,23 @@ tag::assignment-explanation-dfanalytics[]
 Contains messages relating to the selection of a node.
 end::assignment-explanation-dfanalytics[]
 
+tag::assignment-memory-basis[]
+Where should the memory requirement used for deciding which node the job
+will run on come from? The possible values are:
++
+--
+* `model_memory_limit`: The job's memory requirement will be calculated on
+the basis that its model memory will grow to the `model_memory_limit`
+specified in the `analysis_limits` of its config.
+* `current_model_bytes`: The job's memory requirement will be calculated on
+the basis that its current model memory size is a good reflection of what
+it will be in the future.
+* `peak_model_bytes`: The job's memory requirement will be calculated on
+the basis that its peak model memory size is a good reflection of what
+the model size will be in the future.
+--
+end::assignment-memory-basis[]
+
 tag::background-persist-interval[]
 Advanced configuration option. The time between each periodic persistence of the
 model. The default value is a randomized value between 3 to 4 hours, which

+ 80 - 5
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/process/autodetect/state/ModelSizeStats.java

@@ -5,6 +5,8 @@
  */
 package org.elasticsearch.xpack.core.ml.job.process.autodetect.state;
 
+import org.elasticsearch.Version;
+import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
@@ -45,6 +47,7 @@ public class ModelSizeStats implements ToXContentObject, Writeable {
     public static final ParseField TOTAL_PARTITION_FIELD_COUNT_FIELD = new ParseField("total_partition_field_count");
     public static final ParseField BUCKET_ALLOCATION_FAILURES_COUNT_FIELD = new ParseField("bucket_allocation_failures_count");
     public static final ParseField MEMORY_STATUS_FIELD = new ParseField("memory_status");
+    public static final ParseField ASSIGNMENT_MEMORY_BASIS_FIELD = new ParseField("assignment_memory_basis");
     public static final ParseField CATEGORIZED_DOC_COUNT_FIELD = new ParseField("categorized_doc_count");
     public static final ParseField TOTAL_CATEGORY_COUNT_FIELD = new ParseField("total_category_count");
     public static final ParseField FREQUENT_CATEGORY_COUNT_FIELD = new ParseField("frequent_category_count");
@@ -73,6 +76,8 @@ public class ModelSizeStats implements ToXContentObject, Writeable {
         parser.declareLong(Builder::setTotalOverFieldCount, TOTAL_OVER_FIELD_COUNT_FIELD);
         parser.declareLong(Builder::setTotalPartitionFieldCount, TOTAL_PARTITION_FIELD_COUNT_FIELD);
         parser.declareField(Builder::setMemoryStatus, p -> MemoryStatus.fromString(p.text()), MEMORY_STATUS_FIELD, ValueType.STRING);
+        parser.declareField(Builder::setAssignmentMemoryBasis,
+            p -> AssignmentMemoryBasis.fromString(p.text()), ASSIGNMENT_MEMORY_BASIS_FIELD, ValueType.STRING);
         parser.declareLong(Builder::setCategorizedDocCount, CATEGORIZED_DOC_COUNT_FIELD);
         parser.declareLong(Builder::setTotalCategoryCount, TOTAL_CATEGORY_COUNT_FIELD);
         parser.declareLong(Builder::setFrequentCategoryCount, FREQUENT_CATEGORY_COUNT_FIELD);
@@ -117,6 +122,38 @@ public class ModelSizeStats implements ToXContentObject, Writeable {
         }
     }
 
+    /**
+     * Where will we get the memory requirement from when assigning this job to
+     * a node?  There are three possibilities:
+     * 1. The job's model_memory_limit
+     * 2. The current model memory, i.e. what's reported in model_bytes of this object
+     * 3. The peak model memory, i.e. what's reported in peak_model_bytes of this object
+     * The field storing this enum can also be <code>null</code>, which means the
+     * assignment code will decide on the fly - this was the old behaviour prior
+     * to 7.11.
+     */
+    public enum AssignmentMemoryBasis implements Writeable {
+        MODEL_MEMORY_LIMIT, CURRENT_MODEL_BYTES, PEAK_MODEL_BYTES;
+
+        public static AssignmentMemoryBasis fromString(String statusName) {
+            return valueOf(statusName.trim().toUpperCase(Locale.ROOT));
+        }
+
+        public static AssignmentMemoryBasis readFromStream(StreamInput in) throws IOException {
+            return in.readEnum(AssignmentMemoryBasis.class);
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeEnum(this);
+        }
+
+        @Override
+        public String toString() {
+            return name().toLowerCase(Locale.ROOT);
+        }
+    }
+
     private final String jobId;
     private final long modelBytes;
     private final Long peakModelBytes;
@@ -127,6 +164,7 @@ public class ModelSizeStats implements ToXContentObject, Writeable {
     private final long totalPartitionFieldCount;
     private final long bucketAllocationFailuresCount;
     private final MemoryStatus memoryStatus;
+    private final AssignmentMemoryBasis assignmentMemoryBasis;
     private final long categorizedDocCount;
     private final long totalCategoryCount;
     private final long frequentCategoryCount;
@@ -139,7 +177,8 @@ public class ModelSizeStats implements ToXContentObject, Writeable {
 
     private ModelSizeStats(String jobId, long modelBytes, Long peakModelBytes, Long modelBytesExceeded, Long modelBytesMemoryLimit,
                            long totalByFieldCount, long totalOverFieldCount, long totalPartitionFieldCount,
-                           long bucketAllocationFailuresCount, MemoryStatus memoryStatus, long categorizedDocCount, long totalCategoryCount,
+                           long bucketAllocationFailuresCount, MemoryStatus memoryStatus,
+                           AssignmentMemoryBasis assignmentMemoryBasis, long categorizedDocCount, long totalCategoryCount,
                            long frequentCategoryCount, long rareCategoryCount, long deadCategoryCount, long failedCategoryCount,
                            CategorizationStatus categorizationStatus, Date timestamp, Date logTime) {
         this.jobId = jobId;
@@ -152,6 +191,7 @@ public class ModelSizeStats implements ToXContentObject, Writeable {
         this.totalPartitionFieldCount = totalPartitionFieldCount;
         this.bucketAllocationFailuresCount = bucketAllocationFailuresCount;
         this.memoryStatus = memoryStatus;
+        this.assignmentMemoryBasis = assignmentMemoryBasis;
         this.categorizedDocCount = categorizedDocCount;
         this.totalCategoryCount = totalCategoryCount;
         this.frequentCategoryCount = frequentCategoryCount;
@@ -174,6 +214,15 @@ public class ModelSizeStats implements ToXContentObject, Writeable {
         totalPartitionFieldCount = in.readVLong();
         bucketAllocationFailuresCount = in.readVLong();
         memoryStatus = MemoryStatus.readFromStream(in);
+        if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
+            if (in.readBoolean()) {
+                assignmentMemoryBasis = AssignmentMemoryBasis.readFromStream(in);
+            } else {
+                assignmentMemoryBasis = null;
+            }
+        } else {
+            assignmentMemoryBasis = null;
+        }
         categorizedDocCount = in.readVLong();
         totalCategoryCount = in.readVLong();
         frequentCategoryCount = in.readVLong();
@@ -205,6 +254,14 @@ public class ModelSizeStats implements ToXContentObject, Writeable {
         out.writeVLong(totalPartitionFieldCount);
         out.writeVLong(bucketAllocationFailuresCount);
         memoryStatus.writeTo(out);
+        if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
+            if (assignmentMemoryBasis != null) {
+                out.writeBoolean(true);
+                assignmentMemoryBasis.writeTo(out);
+            } else {
+                out.writeBoolean(false);
+            }
+        }
         out.writeVLong(categorizedDocCount);
         out.writeVLong(totalCategoryCount);
         out.writeVLong(frequentCategoryCount);
@@ -246,6 +303,9 @@ public class ModelSizeStats implements ToXContentObject, Writeable {
         builder.field(TOTAL_PARTITION_FIELD_COUNT_FIELD.getPreferredName(), totalPartitionFieldCount);
         builder.field(BUCKET_ALLOCATION_FAILURES_COUNT_FIELD.getPreferredName(), bucketAllocationFailuresCount);
         builder.field(MEMORY_STATUS_FIELD.getPreferredName(), memoryStatus);
+        if (assignmentMemoryBasis != null) {
+            builder.field(ASSIGNMENT_MEMORY_BASIS_FIELD.getPreferredName(), assignmentMemoryBasis);
+        }
         builder.field(CATEGORIZED_DOC_COUNT_FIELD.getPreferredName(), categorizedDocCount);
         builder.field(TOTAL_CATEGORY_COUNT_FIELD.getPreferredName(), totalCategoryCount);
         builder.field(FREQUENT_CATEGORY_COUNT_FIELD.getPreferredName(), frequentCategoryCount);
@@ -301,6 +361,11 @@ public class ModelSizeStats implements ToXContentObject, Writeable {
         return memoryStatus;
     }
 
+    @Nullable
+    public AssignmentMemoryBasis getAssignmentMemoryBasis() {
+        return assignmentMemoryBasis;
+    }
+
     public long getCategorizedDocCount() {
         return categorizedDocCount;
     }
@@ -350,8 +415,9 @@ public class ModelSizeStats implements ToXContentObject, Writeable {
         // this.id excluded here as it is generated by the datastore
         return Objects.hash(
             jobId, modelBytes, peakModelBytes, modelBytesExceeded, modelBytesMemoryLimit, totalByFieldCount, totalOverFieldCount,
-            totalPartitionFieldCount, bucketAllocationFailuresCount, memoryStatus, categorizedDocCount, totalCategoryCount,
-            frequentCategoryCount, rareCategoryCount, deadCategoryCount, failedCategoryCount, categorizationStatus, timestamp, logTime);
+            totalPartitionFieldCount, bucketAllocationFailuresCount, memoryStatus, assignmentMemoryBasis, categorizedDocCount,
+            totalCategoryCount, frequentCategoryCount, rareCategoryCount, deadCategoryCount, failedCategoryCount, categorizationStatus,
+            timestamp, logTime);
     }
 
     /**
@@ -377,6 +443,7 @@ public class ModelSizeStats implements ToXContentObject, Writeable {
                 && this.totalOverFieldCount == that.totalOverFieldCount && this.totalPartitionFieldCount == that.totalPartitionFieldCount
                 && this.bucketAllocationFailuresCount == that.bucketAllocationFailuresCount
                 && Objects.equals(this.memoryStatus, that.memoryStatus)
+                && Objects.equals(this.assignmentMemoryBasis, that.assignmentMemoryBasis)
                 && Objects.equals(this.categorizedDocCount, that.categorizedDocCount)
                 && Objects.equals(this.totalCategoryCount, that.totalCategoryCount)
                 && Objects.equals(this.frequentCategoryCount, that.frequentCategoryCount)
@@ -401,6 +468,7 @@ public class ModelSizeStats implements ToXContentObject, Writeable {
         private long totalPartitionFieldCount;
         private long bucketAllocationFailuresCount;
         private MemoryStatus memoryStatus;
+        private AssignmentMemoryBasis assignmentMemoryBasis;
         private long categorizedDocCount;
         private long totalCategoryCount;
         private long frequentCategoryCount;
@@ -429,6 +497,7 @@ public class ModelSizeStats implements ToXContentObject, Writeable {
             this.totalPartitionFieldCount = modelSizeStats.totalPartitionFieldCount;
             this.bucketAllocationFailuresCount = modelSizeStats.bucketAllocationFailuresCount;
             this.memoryStatus = modelSizeStats.memoryStatus;
+            this.assignmentMemoryBasis = modelSizeStats.assignmentMemoryBasis;
             this.categorizedDocCount = modelSizeStats.categorizedDocCount;
             this.totalCategoryCount = modelSizeStats.totalCategoryCount;
             this.frequentCategoryCount = modelSizeStats.frequentCategoryCount;
@@ -486,6 +555,11 @@ public class ModelSizeStats implements ToXContentObject, Writeable {
             return this;
         }
 
+        public Builder setAssignmentMemoryBasis(AssignmentMemoryBasis assignmentMemoryBasis) {
+            this.assignmentMemoryBasis = assignmentMemoryBasis;
+            return this;
+        }
+
         public Builder setCategorizedDocCount(long categorizedDocCount) {
             this.categorizedDocCount = categorizedDocCount;
             return this;
@@ -535,8 +609,9 @@ public class ModelSizeStats implements ToXContentObject, Writeable {
         public ModelSizeStats build() {
             return new ModelSizeStats(
                 jobId, modelBytes, peakModelBytes, modelBytesExceeded, modelBytesMemoryLimit, totalByFieldCount, totalOverFieldCount,
-                totalPartitionFieldCount, bucketAllocationFailuresCount, memoryStatus, categorizedDocCount, totalCategoryCount,
-                frequentCategoryCount, rareCategoryCount, deadCategoryCount, failedCategoryCount, categorizationStatus, timestamp, logTime);
+                totalPartitionFieldCount, bucketAllocationFailuresCount, memoryStatus, assignmentMemoryBasis, categorizedDocCount,
+                totalCategoryCount, frequentCategoryCount, rareCategoryCount, deadCategoryCount, failedCategoryCount, categorizationStatus,
+                timestamp, logTime);
         }
     }
 }

+ 1 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java

@@ -179,6 +179,7 @@ public final class ReservedFieldNames {
             ModelSizeStats.TOTAL_PARTITION_FIELD_COUNT_FIELD.getPreferredName(),
             ModelSizeStats.BUCKET_ALLOCATION_FAILURES_COUNT_FIELD.getPreferredName(),
             ModelSizeStats.MEMORY_STATUS_FIELD.getPreferredName(),
+            ModelSizeStats.ASSIGNMENT_MEMORY_BASIS_FIELD.getPreferredName(),
             ModelSizeStats.LOG_TIME_FIELD.getPreferredName(),
 
             ModelSnapshot.DESCRIPTION.getPreferredName(),

+ 6 - 0
x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/anomalydetection/results_index_mappings.json

@@ -24,6 +24,9 @@
       "anomaly_score" : {
         "type" : "double"
       },
+      "assignment_memory_basis" : {
+        "type" : "keyword"
+      },
       "average_bucket_processing_time_ms" : {
         "type" : "double"
       },
@@ -350,6 +353,9 @@
       },
       "model_size_stats" : {
         "properties" : {
+          "assignment_memory_basis" : {
+            "type" : "keyword"
+          },
           "bucket_allocation_failures_count" : {
             "type" : "long"
           },

+ 4 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/job/process/autodetect/state/ModelSizeStatsTests.java

@@ -30,6 +30,7 @@ public class ModelSizeStatsTests extends AbstractSerializingTestCase<ModelSizeSt
         assertEquals(0, stats.getTotalPartitionFieldCount());
         assertEquals(0, stats.getBucketAllocationFailuresCount());
         assertEquals(MemoryStatus.OK, stats.getMemoryStatus());
+        assertNull(stats.getAssignmentMemoryBasis());
         assertEquals(0, stats.getCategorizedDocCount());
         assertEquals(0, stats.getTotalCategoryCount());
         assertEquals(0, stats.getFrequentCategoryCount());
@@ -95,6 +96,9 @@ public class ModelSizeStatsTests extends AbstractSerializingTestCase<ModelSizeSt
         if (randomBoolean()) {
             stats.setMemoryStatus(randomFrom(MemoryStatus.values()));
         }
+        if (randomBoolean()) {
+            stats.setAssignmentMemoryBasis(randomFrom(ModelSizeStats.AssignmentMemoryBasis.values()));
+        }
         if (randomBoolean()) {
             stats.setCategorizedDocCount(randomNonNegativeLong());
         }

+ 39 - 17
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsProvider.java

@@ -1197,6 +1197,9 @@ public class JobResultsProvider {
      * - Have low variability of model bytes in model size stats documents in the time period covered by the last
      *   <code>BUCKETS_FOR_ESTABLISHED_MEMORY_SIZE</code> buckets, which is defined as having a coefficient of variation
      *   of no more than <code>ESTABLISHED_MEMORY_CV_THRESHOLD</code>
+     * If necessary this calculation will be done by performing searches against the results index.  However, the
+     * calculation may have already been done in the C++ code, in which case the answer can just be read from the latest
+     * model size stats.
      * @param jobId the id of the job for which established memory usage is required
      * @param latestBucketTimestamp the latest bucket timestamp to be used for the calculation, if known, otherwise
      *                              <code>null</code>, implying the latest bucket that exists in the results index
@@ -1209,6 +1212,36 @@ public class JobResultsProvider {
     public void getEstablishedMemoryUsage(String jobId, Date latestBucketTimestamp, ModelSizeStats latestModelSizeStats,
                                           Consumer<Long> handler, Consumer<Exception> errorHandler) {
 
+        if (latestModelSizeStats != null) {
+            calculateEstablishedMemoryUsage(jobId, latestBucketTimestamp, latestModelSizeStats, handler, errorHandler);
+        } else {
+            modelSizeStats(jobId,
+                modelSizeStats -> calculateEstablishedMemoryUsage(jobId, latestBucketTimestamp, modelSizeStats, handler, errorHandler),
+                errorHandler);
+        }
+    }
+
+    void calculateEstablishedMemoryUsage(String jobId, Date latestBucketTimestamp, ModelSizeStats latestModelSizeStats,
+                                         Consumer<Long> handler, Consumer<Exception> errorHandler) {
+
+        assert latestModelSizeStats != null;
+
+        // There might be an easy short-circuit if the latest model size stats say which number to use
+        if (latestModelSizeStats.getAssignmentMemoryBasis() != null) {
+            switch (latestModelSizeStats.getAssignmentMemoryBasis()) {
+                case MODEL_MEMORY_LIMIT:
+                    handler.accept(0L);
+                    return;
+                case CURRENT_MODEL_BYTES:
+                    handler.accept(latestModelSizeStats.getModelBytes());
+                    return;
+                case PEAK_MODEL_BYTES:
+                    Long storedPeak = latestModelSizeStats.getPeakModelBytes();
+                    handler.accept((storedPeak != null) ? storedPeak : latestModelSizeStats.getModelBytes());
+                    return;
+            }
+        }
+
         String indexName = AnomalyDetectorsIndex.jobResultsAliasedName(jobId);
 
         // Step 2. Find the count, mean and standard deviation of memory usage over the time span of the last N bucket results,
@@ -1231,13 +1264,11 @@ public class JobResultsProvider {
                                     if (aggregations.size() == 1) {
                                         ExtendedStats extendedStats = (ExtendedStats) aggregations.get(0);
                                         long count = extendedStats.getCount();
-                                        if (count <= 0) {
-                                            // model size stats haven't changed in the last N buckets,
-                                            // so the latest (older) ones are established
-                                            handleLatestModelSizeStats(jobId, latestModelSizeStats, handler, errorHandler);
-                                        } else if (count == 1) {
-                                            // no need to do an extra search in the case of exactly one document being aggregated
-                                            handler.accept((long) extendedStats.getAvg());
+                                        if (count <= 1) {
+                                            // model size stats either haven't changed in the last N buckets,
+                                            // so the latest (older) ones are established, or have only changed
+                                            // once, so again there's no recent variation
+                                            handler.accept(latestModelSizeStats.getModelBytes());
                                         } else {
                                             double coefficientOfVaration = extendedStats.getStdDeviation() / extendedStats.getAvg();
                                             LOGGER.trace("[{}] Coefficient of variation [{}] when calculating established memory use",
@@ -1245,7 +1276,7 @@ public class JobResultsProvider {
                                             // is there sufficient stability in the latest model size stats readings?
                                             if (coefficientOfVaration <= ESTABLISHED_MEMORY_CV_THRESHOLD) {
                                                 // yes, so return the latest model size as established
-                                                handleLatestModelSizeStats(jobId, latestModelSizeStats, handler, errorHandler);
+                                                handler.accept(latestModelSizeStats.getModelBytes());
                                             } else {
                                                 // no - we don't have an established model size
                                                 handler.accept(0L);
@@ -1569,15 +1600,6 @@ public class JobResultsProvider {
         client::get);
     }
 
-    private void handleLatestModelSizeStats(String jobId, ModelSizeStats latestModelSizeStats, Consumer<Long> handler,
-                                            Consumer<Exception> errorHandler) {
-        if (latestModelSizeStats != null) {
-            handler.accept(latestModelSizeStats.getModelBytes());
-        } else {
-            modelSizeStats(jobId, modelSizeStats -> handler.accept(modelSizeStats.getModelBytes()), errorHandler);
-        }
-    }
-
     /**
      * Returns information needed to decide how to restart a job from a datafeed
      * @param jobId the job id