Browse Source

[ML] Improve reporting of trained model size stats (#82000)

This improves reporting of trained model size in the response of the stats API.

In particular, it removes the `model_size_bytes` from the `deployment_stats` section and
replaces it with a top-level `model_size_stats` object that contains:

- `model_size_bytes`: the actual model size
- `required_native_memory_bytes`: the amount of memory required to load a model

In addition, these are now reported for PyTorch models regardless of their deployment state.
Dimitris Athanasiou 3 years ago
parent
commit
14a63ac115
14 changed files with 348 additions and 86 deletions
  1. 17 4
      docs/reference/ml/trained-models/apis/get-trained-models-stats.asciidoc
  2. 30 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java
  3. 13 9
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java
  4. 7 26
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationStats.java
  5. 70 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModelSizeStats.java
  6. 4 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java
  7. 0 2
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationStatsTests.java
  8. 28 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModelSizeStatsTests.java
  9. 34 19
      x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java
  10. 7 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearningUsageTransportAction.java
  11. 1 4
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.java
  12. 97 12
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java
  13. 40 4
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java
  14. 0 3
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsActionTests.java

+ 17 - 4
docs/reference/ml/trained-models/apis/get-trained-models-stats.asciidoc

@@ -116,10 +116,6 @@ The desired number of nodes for model allocation.
 (string)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
 
-`model_size`:::
-(<<byte-units,byte value>>)
-The size of the loaded model in bytes.
-
 `nodes`:::
 (array of objects)
 The deployment stats for each node that currently has the model allocated.
@@ -249,6 +245,23 @@ section in <<cluster-nodes-stats>>.
 (string)
 include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
 
+`model_size_stats`:::
+(object)
+A collection of model size stats fields.
++
+.Properties of model size stats
+[%collapsible%open]
+=====
+
+`model_size_bytes`:::
+(integer)
+The size of the model in bytes.
+
+`required_native_memory_bytes`:::
+(integer)
+The amount of memory required to load the model in bytes.
+=====
+
 `pipeline_count`:::
 (integer)
 The number of ingest pipelines that currently refer to the model.

+ 30 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java

@@ -23,6 +23,7 @@ import org.elasticsearch.xpack.core.action.util.QueryPage;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
 import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStats;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelSizeStats;
 
 import java.io.IOException;
 import java.util.ArrayList;
@@ -41,6 +42,7 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
     public static final String NAME = "cluster:monitor/xpack/ml/inference/stats/get";
 
     public static final ParseField MODEL_ID = new ParseField("model_id");
+    public static final ParseField MODEL_SIZE_STATS = new ParseField("model_size_stats");
     public static final ParseField PIPELINE_COUNT = new ParseField("pipeline_count");
     public static final ParseField INFERENCE_STATS = new ParseField("inference_stats");
     public static final ParseField DEPLOYMENT_STATS = new ParseField("deployment_stats");
@@ -77,6 +79,7 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
 
         public static class TrainedModelStats implements ToXContentObject, Writeable {
             private final String modelId;
+            private final TrainedModelSizeStats modelSizeStats;
             private final IngestStats ingestStats;
             private final InferenceStats inferenceStats;
             private final AllocationStats deploymentStats;
@@ -90,12 +93,14 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
 
             public TrainedModelStats(
                 String modelId,
+                TrainedModelSizeStats modelSizeStats,
                 IngestStats ingestStats,
                 int pipelineCount,
                 InferenceStats inferenceStats,
                 AllocationStats deploymentStats
             ) {
                 this.modelId = Objects.requireNonNull(modelId);
+                this.modelSizeStats = modelSizeStats;
                 this.ingestStats = ingestStats == null ? EMPTY_INGEST_STATS : ingestStats;
                 if (pipelineCount < 0) {
                     throw new ElasticsearchException("[{}] must be a greater than or equal to 0", PIPELINE_COUNT.getPreferredName());
@@ -107,6 +112,11 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
 
             public TrainedModelStats(StreamInput in) throws IOException {
                 modelId = in.readString();
+                if (in.getVersion().onOrAfter(Version.V_8_1_0)) {
+                    modelSizeStats = in.readOptionalWriteable(TrainedModelSizeStats::new);
+                } else {
+                    modelSizeStats = null;
+                }
                 ingestStats = new IngestStats(in);
                 pipelineCount = in.readVInt();
                 inferenceStats = in.readOptionalWriteable(InferenceStats::new);
@@ -121,6 +131,10 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
                 return modelId;
             }
 
+            public TrainedModelSizeStats getModelSizeStats() {
+                return modelSizeStats;
+            }
+
             public IngestStats getIngestStats() {
                 return ingestStats;
             }
@@ -141,6 +155,9 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
             public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
                 builder.startObject();
                 builder.field(MODEL_ID.getPreferredName(), modelId);
+                if (modelSizeStats != null) {
+                    builder.field(MODEL_SIZE_STATS.getPreferredName(), modelSizeStats);
+                }
                 builder.field(PIPELINE_COUNT.getPreferredName(), pipelineCount);
                 if (pipelineCount > 0) {
                     // Ingest stats is a fragment
@@ -159,6 +176,9 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
             @Override
             public void writeTo(StreamOutput out) throws IOException {
                 out.writeString(modelId);
+                if (out.getVersion().onOrAfter(Version.V_8_1_0)) {
+                    out.writeOptionalWriteable(modelSizeStats);
+                }
                 ingestStats.writeTo(out);
                 out.writeVInt(pipelineCount);
                 out.writeOptionalWriteable(inferenceStats);
@@ -169,7 +189,7 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
 
             @Override
             public int hashCode() {
-                return Objects.hash(modelId, ingestStats, pipelineCount, inferenceStats, deploymentStats);
+                return Objects.hash(modelId, modelSizeStats, ingestStats, pipelineCount, inferenceStats, deploymentStats);
             }
 
             @Override
@@ -182,6 +202,7 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
                 }
                 TrainedModelStats other = (TrainedModelStats) obj;
                 return Objects.equals(this.modelId, other.modelId)
+                    && Objects.equals(this.modelSizeStats, other.modelSizeStats)
                     && Objects.equals(this.ingestStats, other.ingestStats)
                     && Objects.equals(this.pipelineCount, other.pipelineCount)
                     && Objects.equals(this.deploymentStats, other.deploymentStats)
@@ -208,6 +229,7 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
 
             private long totalModelCount;
             private Map<String, Set<String>> expandedIdsWithAliases;
+            private Map<String, TrainedModelSizeStats> modelSizeStatsMap;
             private Map<String, IngestStats> ingestStatsMap;
             private Map<String, InferenceStats> inferenceStatsMap;
             private Map<String, AllocationStats> allocationStatsMap;
@@ -226,6 +248,11 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
                 return this.expandedIdsWithAliases;
             }
 
+            public Builder setModelSizeStatsByModelId(Map<String, TrainedModelSizeStats> modelSizeStatsByModelId) {
+                this.modelSizeStatsMap = modelSizeStatsByModelId;
+                return this;
+            }
+
             public Builder setIngestStatsByModelId(Map<String, IngestStats> ingestStatsByModelId) {
                 this.ingestStatsMap = ingestStatsByModelId;
                 return this;
@@ -244,12 +271,14 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
             public Response build() {
                 List<TrainedModelStats> trainedModelStats = new ArrayList<>(expandedIdsWithAliases.size());
                 expandedIdsWithAliases.keySet().forEach(id -> {
+                    TrainedModelSizeStats modelSizeStats = modelSizeStatsMap.get(id);
                     IngestStats ingestStats = ingestStatsMap.get(id);
                     InferenceStats inferenceStats = inferenceStatsMap.get(id);
                     AllocationStats allocationStats = allocationStatsMap.get(id);
                     trainedModelStats.add(
                         new TrainedModelStats(
                             id,
+                            modelSizeStats,
                             ingestStats,
                             ingestStats == null ? 0 : ingestStats.getPipelineStats().size(),
                             inferenceStats,

+ 13 - 9
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java

@@ -45,6 +45,13 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
 
     public static final TimeValue DEFAULT_TIMEOUT = new TimeValue(20, TimeUnit.SECONDS);
 
+    /**
+     * This has been found to be approximately 300MB on linux by manual testing.
+     * We also subtract 30MB that we always add as overhead (see MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD).
+     * TODO Check if it is substantially different in other platforms.
+     */
+    private static final ByteSizeValue MEMORY_OVERHEAD = ByteSizeValue.ofMb(270);
+
     public StartTrainedModelDeploymentAction() {
         super(NAME, CreateTrainedModelAllocationAction.Response::new);
     }
@@ -265,13 +272,6 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
             return PARSER.apply(parser, null);
         }
 
-        /**
-         * This has been found to be approximately 300MB on linux by manual testing.
-         * We also subtract 30MB that we always add as overhead (see MachineLearning.NATIVE_EXECUTABLE_CODE_OVERHEAD).
-         * TODO Check if it is substantially different in other platforms.
-         */
-        private static final ByteSizeValue MEMORY_OVERHEAD = ByteSizeValue.ofMb(270);
-
         private final String modelId;
         private final long modelBytes;
         // How many threads are used by the model during inference. Used to increase inference speed.
@@ -301,8 +301,7 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
         }
 
         public long estimateMemoryUsageBytes() {
-            // While loading the model in the process we need twice the model size.
-            return MEMORY_OVERHEAD.getBytes() + 2 * modelBytes;
+            return StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(modelBytes);
         }
 
         public Version getMinimalSupportedVersion() {
@@ -388,4 +387,9 @@ public class StartTrainedModelDeploymentAction extends ActionType<CreateTrainedM
             return false;
         }
     }
+
+    public static long estimateMemoryUsageBytes(long totalDefinitionLength) {
+        // While loading the model in the process we need twice the model size.
+        return MEMORY_OVERHEAD.getBytes() + 2 * totalDefinitionLength;
+    }
 }

+ 7 - 26
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationStats.java

@@ -221,8 +221,6 @@ public class AllocationStats implements ToXContentObject, Writeable {
     private AllocationStatus allocationStatus;
     private String reason;
     @Nullable
-    private final ByteSizeValue modelSize;
-    @Nullable
     private final Integer inferenceThreads;
     @Nullable
     private final Integer modelThreads;
@@ -233,7 +231,6 @@ public class AllocationStats implements ToXContentObject, Writeable {
 
     public AllocationStats(
         String modelId,
-        @Nullable ByteSizeValue modelSize,
         @Nullable Integer inferenceThreads,
         @Nullable Integer modelThreads,
         @Nullable Integer queueCapacity,
@@ -241,7 +238,6 @@ public class AllocationStats implements ToXContentObject, Writeable {
         List<AllocationStats.NodeStats> nodeStats
     ) {
         this.modelId = modelId;
-        this.modelSize = modelSize;
         this.inferenceThreads = inferenceThreads;
         this.modelThreads = modelThreads;
         this.queueCapacity = queueCapacity;
@@ -253,7 +249,9 @@ public class AllocationStats implements ToXContentObject, Writeable {
 
     public AllocationStats(StreamInput in) throws IOException {
         modelId = in.readString();
-        modelSize = in.readOptionalWriteable(ByteSizeValue::new);
+        if (in.getVersion().before(Version.V_8_1_0)) {
+            in.readOptionalWriteable(ByteSizeValue::new);
+        }
         inferenceThreads = in.readOptionalVInt();
         modelThreads = in.readOptionalVInt();
         queueCapacity = in.readOptionalVInt();
@@ -268,10 +266,6 @@ public class AllocationStats implements ToXContentObject, Writeable {
         return modelId;
     }
 
-    public ByteSizeValue getModelSize() {
-        return modelSize;
-    }
-
     @Nullable
     public Integer getInferenceThreads() {
         return inferenceThreads;
@@ -322,9 +316,6 @@ public class AllocationStats implements ToXContentObject, Writeable {
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startObject();
         builder.field("model_id", modelId);
-        if (modelSize != null) {
-            builder.humanReadableField("model_size_bytes", "model_size", modelSize);
-        }
         if (inferenceThreads != null) {
             builder.field(StartTrainedModelDeploymentAction.TaskParams.INFERENCE_THREADS.getPreferredName(), inferenceThreads);
         }
@@ -356,7 +347,9 @@ public class AllocationStats implements ToXContentObject, Writeable {
     @Override
     public void writeTo(StreamOutput out) throws IOException {
         out.writeString(modelId);
-        out.writeOptionalWriteable(modelSize);
+        if (out.getVersion().before(Version.V_8_1_0)) {
+            out.writeOptionalWriteable(null);
+        }
         out.writeOptionalVInt(inferenceThreads);
         out.writeOptionalVInt(modelThreads);
         out.writeOptionalVInt(queueCapacity);
@@ -373,7 +366,6 @@ public class AllocationStats implements ToXContentObject, Writeable {
         if (o == null || getClass() != o.getClass()) return false;
         AllocationStats that = (AllocationStats) o;
         return Objects.equals(modelId, that.modelId)
-            && Objects.equals(modelSize, that.modelSize)
             && Objects.equals(inferenceThreads, that.inferenceThreads)
             && Objects.equals(modelThreads, that.modelThreads)
             && Objects.equals(queueCapacity, that.queueCapacity)
@@ -386,17 +378,6 @@ public class AllocationStats implements ToXContentObject, Writeable {
 
     @Override
     public int hashCode() {
-        return Objects.hash(
-            modelId,
-            modelSize,
-            inferenceThreads,
-            modelThreads,
-            queueCapacity,
-            startTime,
-            nodeStats,
-            state,
-            reason,
-            allocationStatus
-        );
+        return Objects.hash(modelId, inferenceThreads, modelThreads, queueCapacity, startTime, nodeStats, state, reason, allocationStatus);
     }
 }

+ 70 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModelSizeStats.java

@@ -0,0 +1,70 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.unit.ByteSizeValue;
+import org.elasticsearch.xcontent.ParseField;
+import org.elasticsearch.xcontent.ToXContentObject;
+import org.elasticsearch.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.Objects;
+
+public class TrainedModelSizeStats implements ToXContentObject, Writeable {
+
+    private static final ParseField MODEL_SIZE_BYTES = new ParseField("model_size_bytes");
+    private static final ParseField REQUIRED_NATIVE_MEMORY_BYTES = new ParseField("required_native_memory_bytes");
+
+    private final long modelSizeBytes;
+    private final long requiredNativeMemoryBytes;
+
+    public TrainedModelSizeStats(long modelSizeBytes, long requiredNativeMemoryBytes) {
+        this.modelSizeBytes = modelSizeBytes;
+        this.requiredNativeMemoryBytes = requiredNativeMemoryBytes;
+    }
+
+    public TrainedModelSizeStats(StreamInput in) throws IOException {
+        modelSizeBytes = in.readLong();
+        requiredNativeMemoryBytes = in.readLong();
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeLong(modelSizeBytes);
+        out.writeLong(requiredNativeMemoryBytes);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.humanReadableField(MODEL_SIZE_BYTES.getPreferredName(), "model_size", ByteSizeValue.ofBytes(modelSizeBytes));
+        builder.humanReadableField(
+            REQUIRED_NATIVE_MEMORY_BYTES.getPreferredName(),
+            "required_native_memory",
+            ByteSizeValue.ofBytes(requiredNativeMemoryBytes)
+        );
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) return true;
+        if (o == null || getClass() != o.getClass()) return false;
+        TrainedModelSizeStats that = (TrainedModelSizeStats) o;
+        return modelSizeBytes == that.modelSizeBytes && requiredNativeMemoryBytes == that.requiredNativeMemoryBytes;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(modelSizeBytes, requiredNativeMemoryBytes);
+    }
+}

+ 4 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java

@@ -15,6 +15,7 @@ import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction.Respon
 import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStats;
 import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStatsTests;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStatsTests;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelSizeStatsTests;
 
 import java.util.List;
 import java.util.function.Function;
@@ -33,6 +34,7 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
             .map(
                 id -> new Response.TrainedModelStats(
                     id,
+                    randomBoolean() ? TrainedModelSizeStatsTests.createRandom() : null,
                     randomBoolean() ? randomIngestStats() : null,
                     randomIntBetween(0, 10),
                     randomBoolean() ? InferenceStatsTests.createTestInstance(id, null) : null,
@@ -81,6 +83,7 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
                         .map(
                             stats -> new Response.TrainedModelStats(
                                 stats.getModelId(),
+                                null,
                                 stats.getIngestStats(),
                                 stats.getPipelineCount(),
                                 stats.getInferenceStats(),
@@ -101,6 +104,7 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
                         .map(
                             stats -> new Response.TrainedModelStats(
                                 stats.getModelId(),
+                                null,
                                 stats.getIngestStats(),
                                 stats.getPipelineCount(),
                                 stats.getInferenceStats(),
@@ -108,7 +112,6 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer
                                     ? null
                                     : new AllocationStats(
                                         stats.getDeploymentStats().getModelId(),
-                                        stats.getDeploymentStats().getModelSize(),
                                         stats.getDeploymentStats().getInferenceThreads(),
                                         stats.getDeploymentStats().getModelThreads(),
                                         stats.getDeploymentStats().getQueueCapacity(),

+ 0 - 2
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/allocation/AllocationStatsTests.java

@@ -11,7 +11,6 @@ import org.elasticsearch.Version;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.transport.TransportAddress;
-import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.test.AbstractWireSerializingTestCase;
 
 import java.net.InetAddress;
@@ -44,7 +43,6 @@ public class AllocationStatsTests extends AbstractWireSerializingTestCase<Alloca
 
         return new AllocationStats(
             randomAlphaOfLength(5),
-            ByteSizeValue.ofBytes(randomNonNegativeLong()),
             randomBoolean() ? null : randomIntBetween(1, 8),
             randomBoolean() ? null : randomIntBetween(1, 8),
             randomBoolean() ? null : randomIntBetween(1, 10000),

+ 28 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModelSizeStatsTests.java

@@ -0,0 +1,28 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+
+public class TrainedModelSizeStatsTests extends AbstractWireSerializingTestCase<TrainedModelSizeStats> {
+
+    @Override
+    protected Writeable.Reader<TrainedModelSizeStats> instanceReader() {
+        return TrainedModelSizeStats::new;
+    }
+
+    @Override
+    protected TrainedModelSizeStats createTestInstance() {
+        return createRandom();
+    }
+
+    public static TrainedModelSizeStats createRandom() {
+        return new TrainedModelSizeStats(randomNonNegativeLong(), randomNonNegativeLong());
+    }
+}

+ 34 - 19
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java

@@ -13,6 +13,7 @@ import org.elasticsearch.client.Response;
 import org.elasticsearch.client.ResponseException;
 import org.elasticsearch.common.CheckedBiConsumer;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.common.xcontent.support.XContentMapValues;
 import org.elasticsearch.core.TimeValue;
@@ -241,24 +242,38 @@ public class PyTorchModelIT extends ESRestTestCase {
             assertThat(responseMap.toString(), statusState, is(not(nullValue())));
             assertThat(AllocationStatus.State.fromString(statusState), greaterThanOrEqualTo(state));
 
-            // starting models do not know their model size yet
-            if (state.isAnyOf(AllocationStatus.State.STARTED, AllocationStatus.State.FULLY_ALLOCATED)) {
-                Integer byteSize = (Integer) XContentMapValues.extractValue("deployment_stats.model_size_bytes", stats.get(0));
-                assertThat(responseMap.toString(), byteSize, is(not(nullValue())));
-                assertThat(byteSize, equalTo((int) RAW_MODEL_SIZE));
-
-                Response humanResponse = client().performRequest(new Request("GET", "/_ml/trained_models/" + modelId + "/_stats?human"));
-                var humanResponseMap = entityAsMap(humanResponse);
-                stats = (List<Map<String, Object>>) humanResponseMap.get("trained_model_stats");
-                assertThat(stats, hasSize(1));
-                String stringBytes = (String) XContentMapValues.extractValue("deployment_stats.model_size", stats.get(0));
-                assertThat(
-                    "stats response: " + responseMap + " human stats response" + humanResponseMap,
-                    stringBytes,
-                    is(not(nullValue()))
-                );
-                assertThat(stringBytes, equalTo("1.5kb"));
-            }
+            Integer byteSize = (Integer) XContentMapValues.extractValue("model_size_stats.model_size_bytes", stats.get(0));
+            assertThat(responseMap.toString(), byteSize, is(not(nullValue())));
+            assertThat(byteSize, equalTo((int) RAW_MODEL_SIZE));
+
+            Integer requiredNativeMemory = (Integer) XContentMapValues.extractValue(
+                "model_size_stats.required_native_memory_bytes",
+                stats.get(0)
+            );
+            assertThat(responseMap.toString(), requiredNativeMemory, is(not(nullValue())));
+            assertThat(requiredNativeMemory, equalTo((int) (ByteSizeValue.ofMb(270).getBytes() + 2 * RAW_MODEL_SIZE)));
+
+            Response humanResponse = client().performRequest(new Request("GET", "/_ml/trained_models/" + modelId + "/_stats?human"));
+            var humanResponseMap = entityAsMap(humanResponse);
+            stats = (List<Map<String, Object>>) humanResponseMap.get("trained_model_stats");
+            assertThat(stats, hasSize(1));
+            String stringModelSizeBytes = (String) XContentMapValues.extractValue("model_size_stats.model_size", stats.get(0));
+            assertThat(
+                "stats response: " + responseMap + " human stats response" + humanResponseMap,
+                stringModelSizeBytes,
+                is(not(nullValue()))
+            );
+            assertThat(stringModelSizeBytes, equalTo("1.5kb"));
+            String stringRequiredNativeMemory = (String) XContentMapValues.extractValue(
+                "model_size_stats.required_native_memory",
+                stats.get(0)
+            );
+            assertThat(
+                "stats response: " + responseMap + " human stats response" + humanResponseMap,
+                stringRequiredNativeMemory,
+                is(not(nullValue()))
+            );
+            assertThat(stringRequiredNativeMemory, equalTo("270mb"));
             stopDeployment(modelId);
         };
 
@@ -281,7 +296,7 @@ public class PyTorchModelIT extends ESRestTestCase {
         List<Map<String, Object>> stats = (List<Map<String, Object>>) entityAsMap(response).get("trained_model_stats");
         assertThat(stats, hasSize(1));
         assertThat(XContentMapValues.extractValue("deployment_stats.model_id", stats.get(0)), equalTo(modelA));
-        assertThat(XContentMapValues.extractValue("deployment_stats.model_size_bytes", stats.get(0)), equalTo((int) RAW_MODEL_SIZE));
+        assertThat(XContentMapValues.extractValue("model_size_stats.model_size_bytes", stats.get(0)), equalTo((int) RAW_MODEL_SIZE));
         List<Map<String, Object>> nodes = (List<Map<String, Object>>) XContentMapValues.extractValue(
             "deployment_stats.nodes",
             stats.get(0)

+ 7 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearningUsageTransportAction.java

@@ -46,12 +46,14 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
 import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
 import org.elasticsearch.xpack.core.ml.job.config.Job;
 import org.elasticsearch.xpack.core.ml.job.config.JobState;
 import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSizeStats;
 import org.elasticsearch.xpack.core.ml.stats.ForecastStats;
 import org.elasticsearch.xpack.core.ml.stats.StatsAccumulator;
+import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
 import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
 import org.elasticsearch.xpack.ml.job.JobManagerHolder;
 
@@ -368,11 +370,15 @@ public class MachineLearningUsageTransportAction extends XPackUsageFeatureTransp
     }
 
     private void addDeploymentStats(GetDeploymentStatsAction.Response response, Map<String, Object> inferenceUsage) {
+        TrainedModelAllocationMetadata trainedModelAllocationMetadata = TrainedModelAllocationMetadata.fromState(clusterService.state());
         StatsAccumulator modelSizes = new StatsAccumulator();
         double avgTimeSum = 0.0;
         StatsAccumulator nodeDistribution = new StatsAccumulator();
         for (var stats : response.getStats().results()) {
-            modelSizes.add(stats.getModelSize().getBytes());
+            TrainedModelAllocation allocation = trainedModelAllocationMetadata.getModelAllocation(stats.getModelId());
+            if (allocation != null) {
+                modelSizes.add(allocation.getTaskParams().getModelBytes());
+            }
             for (var nodeStats : stats.getNodeStats()) {
                 long nodeInferenceCount = nodeStats.getInferenceCount().orElse(0L);
                 avgTimeSum += nodeStats.getAvgInferenceTime().orElse(0.0) * nodeInferenceCount;

+ 1 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.java

@@ -18,7 +18,6 @@ import org.elasticsearch.cluster.node.DiscoveryNodes;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.inject.Inject;
-import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportService;
@@ -236,7 +235,6 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
                 updatedAllocationStats.add(
                     new AllocationStats(
                         stat.getModelId(),
-                        stat.getModelSize(),
                         stat.getInferenceThreads(),
                         stat.getModelThreads(),
                         stat.getQueueCapacity(),
@@ -270,7 +268,7 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
 
                 nodeStats.sort(Comparator.comparing(n -> n.getNode().getId()));
 
-                updatedAllocationStats.add(new AllocationStats(modelId, null, null, null, null, allocation.getStartTime(), nodeStats));
+                updatedAllocationStats.add(new AllocationStats(modelId, null, null, null, allocation.getStartTime(), nodeStats));
             }
         }
 
@@ -317,7 +315,6 @@ public class TransportGetDeploymentStatsAction extends TransportTasksAction<
         listener.onResponse(
             new AllocationStats(
                 task.getModelId(),
-                ByteSizeValue.ofBytes(task.getParams().getModelBytes()),
                 task.getParams().getInferenceThreads(),
                 task.getParams().getModelThreads(),
                 task.getParams().getQueueCapacity(),

+ 97 - 12
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java

@@ -12,27 +12,41 @@ import org.elasticsearch.action.admin.cluster.node.stats.NodeStats;
 import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsAction;
 import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequest;
 import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse;
+import org.elasticsearch.action.search.SearchAction;
+import org.elasticsearch.action.search.SearchRequest;
 import org.elasticsearch.action.support.ActionFilters;
 import org.elasticsearch.action.support.HandledTransportAction;
 import org.elasticsearch.client.Client;
 import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.document.DocumentField;
 import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.common.metrics.CounterMetric;
 import org.elasticsearch.common.util.set.Sets;
 import org.elasticsearch.core.Tuple;
+import org.elasticsearch.index.query.QueryBuilder;
+import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.ingest.IngestMetadata;
 import org.elasticsearch.ingest.IngestService;
 import org.elasticsearch.ingest.IngestStats;
 import org.elasticsearch.ingest.Pipeline;
+import org.elasticsearch.search.SearchHit;
+import org.elasticsearch.search.sort.SortOrder;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
+import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
 import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
+import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
 import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStats;
+import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModelSizeStats;
 import org.elasticsearch.xpack.ml.inference.ModelAliasMetadata;
 import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
+import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
 import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
 
 import java.util.ArrayList;
@@ -85,17 +99,17 @@ public class TransportGetTrainedModelsStatsAction extends HandledTransportAction
         final ModelAliasMetadata currentMetadata = ModelAliasMetadata.fromState(clusterService.state());
         GetTrainedModelsStatsAction.Response.Builder responseBuilder = new GetTrainedModelsStatsAction.Response.Builder();
 
-        ActionListener<GetDeploymentStatsAction.Response> getDeploymentStats = ActionListener.wrap(
-            deploymentStats -> listener.onResponse(
-                responseBuilder.setDeploymentStatsByModelId(
-                    deploymentStats.getStats()
-                        .results()
-                        .stream()
-                        .collect(Collectors.toMap(AllocationStats::getModelId, Function.identity()))
-                ).build()
-            ),
-            listener::onFailure
-        );
+        ActionListener<Map<String, TrainedModelSizeStats>> modelSizeStatsListener = ActionListener.wrap(modelSizeStatsByModelId -> {
+            responseBuilder.setModelSizeStatsByModelId(modelSizeStatsByModelId);
+            listener.onResponse(responseBuilder.build());
+        }, listener::onFailure);
+
+        ActionListener<GetDeploymentStatsAction.Response> deploymentStatsListener = ActionListener.wrap(deploymentStats -> {
+            responseBuilder.setDeploymentStatsByModelId(
+                deploymentStats.getStats().results().stream().collect(Collectors.toMap(AllocationStats::getModelId, Function.identity()))
+            );
+            modelSizeStats(responseBuilder.getExpandedIdsWithAliases(), request.isAllowNoResources(), modelSizeStatsListener);
+        }, listener::onFailure);
 
         ActionListener<List<InferenceStats>> inferenceStatsListener = ActionListener.wrap(inferenceStats -> {
             responseBuilder.setInferenceStatsByModelId(
@@ -106,7 +120,7 @@ public class TransportGetTrainedModelsStatsAction extends HandledTransportAction
                 ML_ORIGIN,
                 GetDeploymentStatsAction.INSTANCE,
                 new GetDeploymentStatsAction.Request(request.getResourceId()),
-                getDeploymentStats
+                deploymentStatsListener
             );
         }, listener::onFailure);
 
@@ -150,6 +164,77 @@ public class TransportGetTrainedModelsStatsAction extends HandledTransportAction
         );
     }
 
+    private void modelSizeStats(
+        Map<String, Set<String>> expandedIdsWithAliases,
+        boolean allowNoResources,
+        ActionListener<Map<String, TrainedModelSizeStats>> listener
+    ) {
+        ActionListener<List<TrainedModelConfig>> modelsListener = ActionListener.wrap(models -> {
+            final List<String> pytorchModelIds = models.stream()
+                .filter(m -> m.getModelType() == TrainedModelType.PYTORCH)
+                .map(TrainedModelConfig::getModelId)
+                .toList();
+            definitionLengths(pytorchModelIds, ActionListener.wrap(pytorchTotalDefinitionLengthsByModelId -> {
+                Map<String, TrainedModelSizeStats> modelSizeStatsByModelId = new HashMap<>();
+                for (TrainedModelConfig model : models) {
+                    if (model.getModelType() == TrainedModelType.PYTORCH) {
+                        long totalDefinitionLength = pytorchTotalDefinitionLengthsByModelId.getOrDefault(model.getModelId(), 0L);
+                        modelSizeStatsByModelId.put(
+                            model.getModelId(),
+                            new TrainedModelSizeStats(
+                                totalDefinitionLength,
+                                totalDefinitionLength > 0L
+                                    ? StartTrainedModelDeploymentAction.estimateMemoryUsageBytes(totalDefinitionLength)
+                                    : 0L
+                            )
+                        );
+                    } else {
+                        modelSizeStatsByModelId.put(model.getModelId(), new TrainedModelSizeStats(model.getModelSize(), 0));
+                    }
+                }
+                listener.onResponse(modelSizeStatsByModelId);
+            }, listener::onFailure));
+        }, listener::onFailure);
+
+        trainedModelProvider.getTrainedModels(
+            expandedIdsWithAliases,
+            GetTrainedModelsAction.Includes.empty(),
+            allowNoResources,
+            modelsListener
+        );
+    }
+
+    private void definitionLengths(List<String> modelIds, ActionListener<Map<String, Long>> listener) {
+        QueryBuilder query = QueryBuilders.boolQuery()
+            .filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelDefinitionDoc.NAME))
+            .filter(QueryBuilders.termsQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelIds))
+            .filter(QueryBuilders.termQuery(TrainedModelDefinitionDoc.DOC_NUM.getPreferredName(), 0));
+        SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN)
+            .setQuery(QueryBuilders.constantScoreQuery(query))
+            .setFetchSource(false)
+            .addDocValueField(TrainedModelConfig.MODEL_ID.getPreferredName())
+            .addDocValueField(TrainedModelDefinitionDoc.TOTAL_DEFINITION_LENGTH.getPreferredName())
+            // First find the latest index
+            .addSort("_index", SortOrder.DESC)
+            .request();
+
+        executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap(searchResponse -> {
+            Map<String, Long> totalDefinitionLengthByModelId = new HashMap<>();
+            for (SearchHit hit : searchResponse.getHits().getHits()) {
+                DocumentField modelIdField = hit.field(TrainedModelConfig.MODEL_ID.getPreferredName());
+                if (modelIdField != null && modelIdField.getValue()instanceof String modelId) {
+                    DocumentField totalDefinitionLengthField = hit.field(
+                        TrainedModelDefinitionDoc.TOTAL_DEFINITION_LENGTH.getPreferredName()
+                    );
+                    if (totalDefinitionLengthField != null && totalDefinitionLengthField.getValue()instanceof Long totalDefinitionLength) {
+                        totalDefinitionLengthByModelId.put(modelId, totalDefinitionLength);
+                    }
+                }
+            }
+            listener.onResponse(totalDefinitionLengthByModelId);
+        }, listener::onFailure));
+    }
+
     static Map<String, IngestStats> inferenceIngestStatsByModelId(
         NodesStatsResponse response,
         ModelAliasMetadata currentMetadata,

+ 40 - 4
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java

@@ -50,6 +50,7 @@ import org.elasticsearch.xpack.core.ml.action.GetDatafeedsStatsAction;
 import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
 import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction;
 import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
+import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig;
 import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
@@ -61,6 +62,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
 import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationState;
 import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStats;
 import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStatus;
+import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
@@ -73,6 +75,7 @@ import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.ModelSizeSta
 import org.elasticsearch.xpack.core.ml.stats.ForecastStats;
 import org.elasticsearch.xpack.core.ml.stats.ForecastStatsTests;
 import org.elasticsearch.xpack.core.watcher.support.xcontent.XContentSource;
+import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
 import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
 import org.elasticsearch.xpack.ml.job.JobManager;
 import org.elasticsearch.xpack.ml.job.JobManagerHolder;
@@ -340,12 +343,9 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
                 List.of(),
                 List.of(),
                 List.of(
-                    new AllocationStats("model_3", ByteSizeValue.ofMb(100), null, null, null, Instant.now(), List.of()).setState(
-                        AllocationState.STOPPING
-                    ),
+                    new AllocationStats("model_3", null, null, null, Instant.now(), List.of()).setState(AllocationState.STOPPING),
                     new AllocationStats(
                         "model_4",
-                        ByteSizeValue.ofMb(200),
                         2,
                         2,
                         1000,
@@ -378,6 +378,42 @@ public class MachineLearningInfoTransportActionTests extends ESTestCase {
             )
         );
 
+        final ClusterState cs = ClusterState.builder(new ClusterName("_name"))
+            .metadata(
+                Metadata.builder()
+                    .putCustom(
+                        TrainedModelAllocationMetadata.NAME,
+                        TrainedModelAllocationMetadata.Builder.empty()
+                            .addNewAllocation(
+                                "model_3",
+                                TrainedModelAllocation.Builder.empty(
+                                    new StartTrainedModelDeploymentAction.TaskParams(
+                                        "model_3",
+                                        ByteSizeValue.ofMb(100).getBytes(),
+                                        1,
+                                        1,
+                                        1024
+                                    )
+                                )
+                            )
+                            .addNewAllocation(
+                                "model_4",
+                                TrainedModelAllocation.Builder.empty(
+                                    new StartTrainedModelDeploymentAction.TaskParams(
+                                        "model_4",
+                                        ByteSizeValue.ofMb(200).getBytes(),
+                                        1,
+                                        1,
+                                        1024
+                                    )
+                                ).addNewRoutingEntry("foo").addNewRoutingEntry("bar")
+                            )
+                            .build()
+                    )
+            )
+            .build();
+        when(clusterService.state()).thenReturn(cs);
+
         var usageAction = newUsageAction(settings.build());
         PlainActionFuture<XPackUsageFeatureResponse> future = new PlainActionFuture<>();
         usageAction.masterOperation(null, null, ClusterState.EMPTY_STATE, future);

+ 0 - 3
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsActionTests.java

@@ -11,7 +11,6 @@ import org.elasticsearch.Version;
 import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.node.DiscoveryNodes;
 import org.elasticsearch.common.transport.TransportAddress;
-import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction;
 import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsActionResponseTests;
@@ -80,7 +79,6 @@ public class TransportGetDeploymentStatsActionTests extends ESTestCase {
 
         var model1 = new AllocationStats(
             "model1",
-            ByteSizeValue.ofBytes(randomNonNegativeLong()),
             randomBoolean() ? null : randomIntBetween(1, 8),
             randomBoolean() ? null : randomIntBetween(1, 8),
             randomBoolean() ? null : randomIntBetween(1, 10000),
@@ -116,7 +114,6 @@ public class TransportGetDeploymentStatsActionTests extends ESTestCase {
 
         var model1 = new AllocationStats(
             "model1",
-            ByteSizeValue.ofBytes(randomNonNegativeLong()),
             randomBoolean() ? null : randomIntBetween(1, 8),
             randomBoolean() ? null : randomIntBetween(1, 8),
             randomBoolean() ? null : randomIntBetween(1, 10000),