Browse Source

[ML] Add a model memory estimation endpoint for anomaly detection (#53507)

A new endpoint for estimating anomaly detection job
model memory requirements:

POST _ml/anomaly_detectors/estimate_model_memory

Closes #53219
David Roberts 5 years ago
parent
commit
8ee770560a

+ 12 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java

@@ -40,6 +40,7 @@ import org.elasticsearch.client.ml.DeleteForecastRequest;
 import org.elasticsearch.client.ml.DeleteJobRequest;
 import org.elasticsearch.client.ml.DeleteModelSnapshotRequest;
 import org.elasticsearch.client.ml.DeleteTrainedModelRequest;
+import org.elasticsearch.client.ml.EstimateModelMemoryRequest;
 import org.elasticsearch.client.ml.EvaluateDataFrameRequest;
 import org.elasticsearch.client.ml.ExplainDataFrameAnalyticsRequest;
 import org.elasticsearch.client.ml.FindFileStructureRequest;
@@ -593,6 +594,17 @@ final class MLRequestConverters {
         return new Request(HttpDelete.METHOD_NAME, endpoint);
     }
 
+    static Request estimateModelMemory(EstimateModelMemoryRequest estimateModelMemoryRequest) throws IOException {
+        String endpoint = new EndpointBuilder()
+            .addPathPartAsIs("_ml")
+            .addPathPartAsIs("anomaly_detectors")
+            .addPathPartAsIs("_estimate_model_memory")
+            .build();
+        Request request = new Request(HttpPost.METHOD_NAME, endpoint);
+        request.setEntity(createEntity(estimateModelMemoryRequest, REQUEST_BODY_CONTENT_TYPE));
+        return request;
+    }
+
     static Request putDataFrameAnalytics(PutDataFrameAnalyticsRequest putRequest) throws IOException {
         String endpoint = new EndpointBuilder()
             .addPathPartAsIs("_ml", "data_frame", "analytics")

+ 44 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java

@@ -23,6 +23,8 @@ import org.elasticsearch.action.support.master.AcknowledgedResponse;
 import org.elasticsearch.client.ml.CloseJobRequest;
 import org.elasticsearch.client.ml.CloseJobResponse;
 import org.elasticsearch.client.ml.DeleteTrainedModelRequest;
+import org.elasticsearch.client.ml.EstimateModelMemoryRequest;
+import org.elasticsearch.client.ml.EstimateModelMemoryResponse;
 import org.elasticsearch.client.ml.ExplainDataFrameAnalyticsRequest;
 import org.elasticsearch.client.ml.ExplainDataFrameAnalyticsResponse;
 import org.elasticsearch.client.ml.DeleteCalendarEventRequest;
@@ -1951,6 +1953,48 @@ public final class MachineLearningClient {
             Collections.emptySet());
     }
 
+    /**
+     * Estimate the model memory an analysis config is likely to need given supplied field cardinalities
+     * <p>
+     * For additional info
+     * see <a href="https://www.elastic.co/guide/en/elasticsearch/reference/current/ml-estimate-model-memory.html">Estimate Model Memory</a>
+     *
+     * @param request The {@link EstimateModelMemoryRequest}
+     * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
+     * @return {@link EstimateModelMemoryResponse} response object
+     */
+    public EstimateModelMemoryResponse estimateModelMemory(EstimateModelMemoryRequest request,
+                                                           RequestOptions options) throws IOException {
+        return restHighLevelClient.performRequestAndParseEntity(request,
+            MLRequestConverters::estimateModelMemory,
+            options,
+            EstimateModelMemoryResponse::fromXContent,
+            Collections.emptySet());
+    }
+
+    /**
+     * Estimate the model memory an analysis config is likely to need given supplied field cardinalities and notifies listener upon
+     * completion
+     * <p>
+     * For additional info
+     * see <a href="https://www.elastic.co/guide/en/elasticsearch/reference/current/ml-estimate-model-memory.html">Estimate Model Memory</a>
+     *
+     * @param request The {@link EstimateModelMemoryRequest}
+     * @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
+     * @param listener Listener to be notified upon request completion
+     * @return cancellable that may be used to cancel the request
+     */
+    public Cancellable estimateModelMemoryAsync(EstimateModelMemoryRequest request,
+                                                RequestOptions options,
+                                                ActionListener<EstimateModelMemoryResponse> listener) {
+        return restHighLevelClient.performRequestAsyncAndParseEntity(request,
+            MLRequestConverters::estimateModelMemory,
+            options,
+            EstimateModelMemoryResponse::fromXContent,
+            listener,
+            Collections.emptySet());
+    }
+
     /**
      * Creates a new Data Frame Analytics config
      * <p>

+ 110 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EstimateModelMemoryRequest.java

@@ -0,0 +1,110 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.client.ml;
+
+import org.elasticsearch.client.Validatable;
+import org.elasticsearch.client.ValidationException;
+import org.elasticsearch.client.ml.job.config.AnalysisConfig;
+import org.elasticsearch.common.xcontent.ToXContentObject;
+import org.elasticsearch.common.xcontent.XContentBuilder;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+
+/**
+ * Request to estimate the model memory an analysis config is likely to need given supplied field cardinalities.
+ */
+public class EstimateModelMemoryRequest implements Validatable, ToXContentObject {
+
+    public static final String ANALYSIS_CONFIG = "analysis_config";
+    public static final String OVERALL_CARDINALITY = "overall_cardinality";
+    public static final String MAX_BUCKET_CARDINALITY = "max_bucket_cardinality";
+
+    private final AnalysisConfig analysisConfig;
+    private Map<String, Long> overallCardinality = Collections.emptyMap();
+    private Map<String, Long> maxBucketCardinality = Collections.emptyMap();
+
+    @Override
+    public Optional<ValidationException> validate() {
+        return Optional.empty();
+    }
+
+    public EstimateModelMemoryRequest(AnalysisConfig analysisConfig) {
+        this.analysisConfig = Objects.requireNonNull(analysisConfig);
+    }
+
+    public AnalysisConfig getAnalysisConfig() {
+        return analysisConfig;
+    }
+
+    public Map<String, Long> getOverallCardinality() {
+        return overallCardinality;
+    }
+
+    public void setOverallCardinality(Map<String, Long> overallCardinality) {
+        this.overallCardinality = Collections.unmodifiableMap(overallCardinality);
+    }
+
+    public Map<String, Long> getMaxBucketCardinality() {
+        return maxBucketCardinality;
+    }
+
+    public void setMaxBucketCardinality(Map<String, Long> maxBucketCardinality) {
+        this.maxBucketCardinality = Collections.unmodifiableMap(maxBucketCardinality);
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+        builder.startObject();
+        builder.field(ANALYSIS_CONFIG, analysisConfig);
+        if (overallCardinality.isEmpty() == false) {
+            builder.field(OVERALL_CARDINALITY, overallCardinality);
+        }
+        if (maxBucketCardinality.isEmpty() == false) {
+            builder.field(MAX_BUCKET_CARDINALITY, maxBucketCardinality);
+        }
+        builder.endObject();
+        return builder;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(analysisConfig, overallCardinality, maxBucketCardinality);
+    }
+
+    @Override
+    public boolean equals(Object other) {
+        if (this == other) {
+            return true;
+        }
+
+        if (other == null || getClass() != other.getClass()) {
+            return false;
+        }
+
+        EstimateModelMemoryRequest that = (EstimateModelMemoryRequest) other;
+        return Objects.equals(analysisConfig, that.analysisConfig) &&
+            Objects.equals(overallCardinality, that.overallCardinality) &&
+            Objects.equals(maxBucketCardinality, that.maxBucketCardinality);
+    }
+}

+ 80 - 0
client/rest-high-level/src/main/java/org/elasticsearch/client/ml/EstimateModelMemoryResponse.java

@@ -0,0 +1,80 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.client.ml;
+
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.unit.ByteSizeValue;
+import org.elasticsearch.common.xcontent.ConstructingObjectParser;
+import org.elasticsearch.common.xcontent.XContentParser;
+
+import java.util.Objects;
+
+import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
+
+public class EstimateModelMemoryResponse {
+
+    public static final ParseField MODEL_MEMORY_ESTIMATE = new ParseField("model_memory_estimate");
+
+    static final ConstructingObjectParser<EstimateModelMemoryResponse, Void> PARSER =
+        new ConstructingObjectParser<>(
+            "estimate_model_memory",
+            true,
+            args -> new EstimateModelMemoryResponse((String) args[0]));
+
+    static {
+        PARSER.declareString(constructorArg(), MODEL_MEMORY_ESTIMATE);
+    }
+
+    public static EstimateModelMemoryResponse fromXContent(final XContentParser parser) {
+        return PARSER.apply(parser, null);
+    }
+
+    private final ByteSizeValue modelMemoryEstimate;
+
+    public EstimateModelMemoryResponse(String modelMemoryEstimate) {
+        this.modelMemoryEstimate = ByteSizeValue.parseBytesSizeValue(modelMemoryEstimate, MODEL_MEMORY_ESTIMATE.getPreferredName());
+    }
+
+    /**
+     * @return An estimate of the model memory the supplied analysis config is likely to need given the supplied field cardinalities.
+     */
+    public ByteSizeValue getModelMemoryEstimate() {
+        return modelMemoryEstimate;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+
+        if (this == o) {
+            return true;
+        }
+        if (o == null || getClass() != o.getClass()) {
+            return false;
+        }
+
+        EstimateModelMemoryResponse other = (EstimateModelMemoryResponse) o;
+        return Objects.equals(this.modelMemoryEstimate, other.modelMemoryEstimate);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(modelMemoryEstimate);
+    }
+}

+ 21 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java

@@ -36,6 +36,7 @@ import org.elasticsearch.client.ml.DeleteForecastRequest;
 import org.elasticsearch.client.ml.DeleteJobRequest;
 import org.elasticsearch.client.ml.DeleteModelSnapshotRequest;
 import org.elasticsearch.client.ml.DeleteTrainedModelRequest;
+import org.elasticsearch.client.ml.EstimateModelMemoryRequest;
 import org.elasticsearch.client.ml.EvaluateDataFrameRequest;
 import org.elasticsearch.client.ml.EvaluateDataFrameRequestTests;
 import org.elasticsearch.client.ml.ExplainDataFrameAnalyticsRequest;
@@ -107,6 +108,7 @@ import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.common.xcontent.ToXContent;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.common.xcontent.XContentType;
@@ -695,6 +697,25 @@ public class MLRequestConvertersTests extends ESTestCase {
         assertEquals("/_ml/calendars/" + calendarId + "/events/" + eventId, request.getEndpoint());
     }
 
+    public void testEstimateModelMemory() throws Exception {
+        String byFieldName = randomAlphaOfLength(10);
+        String influencerFieldName = randomAlphaOfLength(10);
+        AnalysisConfig analysisConfig = AnalysisConfig.builder(
+            Collections.singletonList(
+                Detector.builder().setFunction("count").setByFieldName(byFieldName).build()
+            )).setInfluencers(Collections.singletonList(influencerFieldName)).build();
+        EstimateModelMemoryRequest estimateModelMemoryRequest = new EstimateModelMemoryRequest(analysisConfig);
+        estimateModelMemoryRequest.setOverallCardinality(Collections.singletonMap(byFieldName, randomNonNegativeLong()));
+        estimateModelMemoryRequest.setMaxBucketCardinality(Collections.singletonMap(influencerFieldName, randomNonNegativeLong()));
+        Request request = MLRequestConverters.estimateModelMemory(estimateModelMemoryRequest);
+        assertEquals(HttpPost.METHOD_NAME, request.getMethod());
+        assertEquals("/_ml/anomaly_detectors/_estimate_model_memory", request.getEndpoint());
+
+        XContentBuilder builder = JsonXContent.contentBuilder();
+        builder = estimateModelMemoryRequest.toXContent(builder, ToXContent.EMPTY_PARAMS);
+        assertEquals(Strings.toString(builder), requestEntityToString(request));
+    }
+
     public void testPutDataFrameAnalytics() throws IOException {
         PutDataFrameAnalyticsRequest putRequest = new PutDataFrameAnalyticsRequest(randomDataFrameAnalyticsConfig());
         Request request = MLRequestConverters.putDataFrameAnalytics(putRequest);

+ 23 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

@@ -46,6 +46,8 @@ import org.elasticsearch.client.ml.DeleteJobRequest;
 import org.elasticsearch.client.ml.DeleteJobResponse;
 import org.elasticsearch.client.ml.DeleteModelSnapshotRequest;
 import org.elasticsearch.client.ml.DeleteTrainedModelRequest;
+import org.elasticsearch.client.ml.EstimateModelMemoryRequest;
+import org.elasticsearch.client.ml.EstimateModelMemoryResponse;
 import org.elasticsearch.client.ml.EvaluateDataFrameRequest;
 import org.elasticsearch.client.ml.EvaluateDataFrameResponse;
 import org.elasticsearch.client.ml.ExplainDataFrameAnalyticsRequest;
@@ -1244,6 +1246,27 @@ public class MachineLearningIT extends ESRestHighLevelClientTestCase {
         assertThat(remainingIds, not(hasItem(deletedEvent)));
     }
 
+    public void testEstimateModelMemory() throws Exception {
+        MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
+
+        String byFieldName = randomAlphaOfLength(10);
+        String influencerFieldName = randomAlphaOfLength(10);
+        AnalysisConfig analysisConfig = AnalysisConfig.builder(
+            Collections.singletonList(
+                Detector.builder().setFunction("count").setByFieldName(byFieldName).build()
+            )).setInfluencers(Collections.singletonList(influencerFieldName)).build();
+        EstimateModelMemoryRequest estimateModelMemoryRequest = new EstimateModelMemoryRequest(analysisConfig);
+        estimateModelMemoryRequest.setOverallCardinality(Collections.singletonMap(byFieldName, randomNonNegativeLong()));
+        estimateModelMemoryRequest.setMaxBucketCardinality(Collections.singletonMap(influencerFieldName, randomNonNegativeLong()));
+
+        EstimateModelMemoryResponse estimateModelMemoryResponse = execute(
+            estimateModelMemoryRequest,
+            machineLearningClient::estimateModelMemory, machineLearningClient::estimateModelMemoryAsync);
+
+        ByteSizeValue modelMemoryEstimate = estimateModelMemoryResponse.getModelMemoryEstimate();
+        assertThat(modelMemoryEstimate.getBytes(), greaterThanOrEqualTo(10000000L));
+    }
+
     public void testPutDataFrameAnalyticsConfig_GivenOutlierDetectionAnalysis() throws Exception {
         MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
         String configId = "test-put-df-analytics-outlier-detection";

+ 61 - 0
client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java

@@ -49,6 +49,8 @@ import org.elasticsearch.client.ml.DeleteJobRequest;
 import org.elasticsearch.client.ml.DeleteJobResponse;
 import org.elasticsearch.client.ml.DeleteModelSnapshotRequest;
 import org.elasticsearch.client.ml.DeleteTrainedModelRequest;
+import org.elasticsearch.client.ml.EstimateModelMemoryRequest;
+import org.elasticsearch.client.ml.EstimateModelMemoryResponse;
 import org.elasticsearch.client.ml.EvaluateDataFrameRequest;
 import org.elasticsearch.client.ml.EvaluateDataFrameResponse;
 import org.elasticsearch.client.ml.ExplainDataFrameAnalyticsRequest;
@@ -4133,6 +4135,65 @@ public class MlClientDocumentationIT extends ESRestHighLevelClientTestCase {
         }
     }
 
+    public void testEstimateModelMemory() throws Exception {
+        RestHighLevelClient client = highLevelClient();
+        {
+            // tag::estimate-model-memory-request
+            Detector.Builder detectorBuilder = new Detector.Builder()
+                .setFunction("count")
+                .setPartitionFieldName("status");
+            AnalysisConfig.Builder analysisConfigBuilder =
+                new AnalysisConfig.Builder(Collections.singletonList(detectorBuilder.build()))
+                .setBucketSpan(TimeValue.timeValueMinutes(10))
+                .setInfluencers(Collections.singletonList("src_ip"));
+            EstimateModelMemoryRequest request = new EstimateModelMemoryRequest(analysisConfigBuilder.build()); // <1>
+            request.setOverallCardinality(Collections.singletonMap("status", 50L));                             // <2>
+            request.setMaxBucketCardinality(Collections.singletonMap("src_ip", 30L));                           // <3>
+            // end::estimate-model-memory-request
+
+            // tag::estimate-model-memory-execute
+            EstimateModelMemoryResponse estimateModelMemoryResponse =
+                client.machineLearning().estimateModelMemory(request, RequestOptions.DEFAULT);
+            // end::estimate-model-memory-execute
+
+            // tag::estimate-model-memory-response
+            ByteSizeValue modelMemoryEstimate = estimateModelMemoryResponse.getModelMemoryEstimate(); // <1>
+            long estimateInBytes = modelMemoryEstimate.getBytes();
+            // end::estimate-model-memory-response
+            assertThat(estimateInBytes, greaterThan(10000000L));
+        }
+        {
+            AnalysisConfig analysisConfig =
+                AnalysisConfig.builder(Collections.singletonList(Detector.builder().setFunction("count").build())).build();
+            EstimateModelMemoryRequest request = new EstimateModelMemoryRequest(analysisConfig);
+
+            // tag::estimate-model-memory-execute-listener
+            ActionListener<EstimateModelMemoryResponse> listener = new ActionListener<EstimateModelMemoryResponse>() {
+                @Override
+                public void onResponse(EstimateModelMemoryResponse estimateModelMemoryResponse) {
+                    // <1>
+                }
+
+                @Override
+                public void onFailure(Exception e) {
+                    // <2>
+                }
+            };
+            // end::estimate-model-memory-execute-listener
+
+            // Replace the empty listener by a blocking listener in test
+            final CountDownLatch latch = new CountDownLatch(1);
+            listener = new LatchedActionListener<>(listener, latch);
+
+            // tag::estimate-model-memory-execute-async
+            client.machineLearning()
+                .estimateModelMemoryAsync(request, RequestOptions.DEFAULT, listener); // <1>
+            // end::estimate-model-memory-execute-async
+
+            assertTrue(latch.await(30L, TimeUnit.SECONDS));
+        }
+    }
+
     private String createFilter(RestHighLevelClient client) throws IOException {
         MlFilter.Builder filterBuilder = MlFilter.builder("my_safe_domains")
             .setDescription("A list of safe domains")

+ 42 - 0
docs/java-rest/high-level/ml/estimate-model-memory.asciidoc

@@ -0,0 +1,42 @@
+--
+:api: estimate-model-memory
+:request: EstimateModelMemoryRequest
+:response: EstimateModelMemoryResponse
+--
+[role="xpack"]
+[id="{upid}-{api}"]
+=== Estimate {anomaly-job} model memory API
+
+Estimate the model memory an analysis config is likely to need for
+the given cardinality of the fields it references.
+
+[id="{upid}-{api}-request"]
+==== Estimate {anomaly-job} model memory request
+
+A +{request}+ can be set up as follows:
+
+["source","java",subs="attributes,callouts,macros"]
+--------------------------------------------------
+include-tagged::{doc-tests-file}[{api}-request]
+--------------------------------------------------
+<1> Pass an `AnalysisConfig` to the constructor.
+<2> For any `by_field_name`, `over_field_name` or
+    `partition_field_name` fields referenced by the
+    detectors, supply overall cardinality estimates
+    in a `Map`.
+<3> For any `influencers`, supply a `Map` containing
+    estimates of the highest cardinality expected in
+    any single bucket.
+
+include::../execution.asciidoc[]
+
+[id="{upid}-{api}-response"]
+==== Estimate {anomaly-job} model memory response
+
+The returned +{response}+ contains the model memory estimate:
+
+["source","java",subs="attributes,callouts,macros"]
+--------------------------------------------------
+include-tagged::{doc-tests-file}[{api}-response]
+--------------------------------------------------
+<1> The model memory estimate.

+ 2 - 0
docs/java-rest/high-level/supported-apis.asciidoc

@@ -295,6 +295,7 @@ The Java High Level REST Client supports the following Machine Learning APIs:
 * <<{upid}-put-calendar-job>>
 * <<{upid}-delete-calendar-job>>
 * <<{upid}-delete-calendar>>
+* <<{upid}-estimate-model-memory>>
 * <<{upid}-get-data-frame-analytics>>
 * <<{upid}-get-data-frame-analytics-stats>>
 * <<{upid}-put-data-frame-analytics>>
@@ -351,6 +352,7 @@ include::ml/delete-calendar-event.asciidoc[]
 include::ml/put-calendar-job.asciidoc[]
 include::ml/delete-calendar-job.asciidoc[]
 include::ml/delete-calendar.asciidoc[]
+include::ml/estimate-model-memory.asciidoc[]
 include::ml/get-data-frame-analytics.asciidoc[]
 include::ml/get-data-frame-analytics-stats.asciidoc[]
 include::ml/put-data-frame-analytics.asciidoc[]

+ 87 - 0
docs/reference/ml/anomaly-detection/apis/estimate-model-memory.asciidoc

@@ -0,0 +1,87 @@
+[role="xpack"]
+[testenv="platinum"]
+[[ml-estimate-model-memory]]
+=== Estimate {anomaly-jobs} model memory API
+++++
+<titleabbrev>Estimate model memory</titleabbrev>
+++++
+
+Estimates the model memory an {anomaly-job} is likely to need based on analysis
+configuration details and cardinality estimates for the fields it references.
+
+[[ml-estimate-model-memory-request]]
+==== {api-request-title}
+
+`POST _ml/anomaly_detectors/_estimate_model_memory`
+
+[[ml-estimate-model-memory-prereqs]]
+==== {api-prereq-title}
+
+* If the {es} {security-features} are enabled, you must have `manage_ml` or
+`manage` cluster privileges to use this API. See
+<<security-privileges>>.
+
+[[ml-estimate-model-memory-request-body]]
+==== {api-request-body-title}
+
+`analysis_config`::
+(Required, object) For a list of the properties that you can specify in the
+`analysis_config` component of the body of this API, see <<put-analysisconfig>>.
+
+`max_bucket_cardinality`::
+(Optional, object) Estimates of the highest cardinality in a single bucket
+that will be observed for influencer fields over the time period that the job
+analyzes data. To produce a good answer, values must be provided for
+all influencer fields. It does not matter if values are provided for fields
+that are not listed as `influencers`. If there are no `influencers` then
+`max_bucket_cardinality` can be omitted from the request.
+
+`overall_cardinality`::
+(Optional, object) Estimates of the cardinality that will be observed for
+fields over the whole time period that the job analyzes data. To produce
+a good answer, values must be provided for fields referenced in the
+`by_field_name`, `over_field_name` and `partition_field_name` of any
+detectors. It does not matter if values are provided for other fields.
+If no detectors have a `by_field_name`, `over_field_name` or
+`partition_field_name` then `overall_cardinality` can be omitted
+from the request.
+
+[[ml-estimate-model-memory-example]]
+==== {api-examples-title}
+
+[source,console]
+--------------------------------------------------
+POST _ml/anomaly_detectors/_estimate_model_memory
+{
+    "analysis_config": {
+        "bucket_span": "5m",
+        "detectors": [
+          {
+            "function": "sum",
+            "field_name": "bytes",
+            "by_field_name": "status",
+            "partition_field_name": "app"
+          }
+        ],
+        "influencers": [ "source_ip", "dest_ip" ]
+    },
+    "overall_cardinality": {
+       "status": 10,
+       "app": 50
+    },
+    "max_bucket_cardinality": {
+       "source_ip": 300,
+       "dest_ip": 30
+    }
+}
+--------------------------------------------------
+// TEST[skip:needs-licence]
+
+The estimate returns the following result:
+
+[source,console-result]
+----
+{
+  "model_memory_estimate": "45mb"
+}
+----

+ 2 - 0
docs/reference/ml/anomaly-detection/apis/ml-api.asciidoc

@@ -118,6 +118,8 @@ include::delete-job.asciidoc[]
 include::delete-calendar-job.asciidoc[]
 include::delete-snapshot.asciidoc[]
 include::delete-expired-data.asciidoc[]
+//ESTIMATE
+include::estimate-model-memory.asciidoc[]
 //FIND
 include::find-file-structure.asciidoc[]
 //FLUSH

+ 57 - 23
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportEstimateModelMemoryAction.java

@@ -23,6 +23,17 @@ import java.util.HashSet;
 import java.util.Map;
 import java.util.Set;
 
+/**
+ * Calculates the estimated model memory requirement of an anomaly detection job
+ * from its analysis config and estimates of the cardinality of the various fields
+ * referenced in it.
+ *
+ * Answers are capped at <code>Long.MAX_VALUE</code> bytes, to avoid returning
+ * values with bigger units that cannot trivially be converted back to bytes.
+ * (In reality if the memory estimate is greater than <code>Long.MAX_VALUE</code>
+ * bytes then the job will be impossible to run successfully, so this is not a
+ * major limitation.)
+ */
 public class TransportEstimateModelMemoryAction
     extends HandledTransportAction<EstimateModelMemoryAction.Request, EstimateModelMemoryAction.Response> {
 
@@ -47,23 +58,24 @@ public class TransportEstimateModelMemoryAction
         Map<String, Long> overallCardinality = request.getOverallCardinality();
         Map<String, Long> maxBucketCardinality = request.getMaxBucketCardinality();
 
-        long answer = BASIC_REQUIREMENT.getBytes()
-            + calculateDetectorsRequirementBytes(analysisConfig, overallCardinality)
-            + calculateInfluencerRequirementBytes(analysisConfig, maxBucketCardinality)
-            + calculateCategorizationRequirementBytes(analysisConfig);
+        long answer = BASIC_REQUIREMENT.getBytes();
+        answer = addNonNegativeLongsWithMaxValueCap(answer, calculateDetectorsRequirementBytes(analysisConfig, overallCardinality));
+        answer = addNonNegativeLongsWithMaxValueCap(answer, calculateInfluencerRequirementBytes(analysisConfig, maxBucketCardinality));
+        answer = addNonNegativeLongsWithMaxValueCap(answer, calculateCategorizationRequirementBytes(analysisConfig));
 
         listener.onResponse(new EstimateModelMemoryAction.Response(roundUpToNextMb(answer)));
     }
 
     static long calculateDetectorsRequirementBytes(AnalysisConfig analysisConfig, Map<String, Long> overallCardinality) {
         return analysisConfig.getDetectors().stream().map(detector -> calculateDetectorRequirementBytes(detector, overallCardinality))
-            .reduce(0L, Long::sum);
+            .reduce(0L, TransportEstimateModelMemoryAction::addNonNegativeLongsWithMaxValueCap);
     }
 
     static long calculateDetectorRequirementBytes(Detector detector, Map<String, Long> overallCardinality) {
 
         long answer = 0;
 
+        // These values for detectors assume splitting is via a partition field
         switch (detector.getFunction()) {
             case COUNT:
             case LOW_COUNT:
@@ -71,7 +83,7 @@ public class TransportEstimateModelMemoryAction
             case NON_ZERO_COUNT:
             case LOW_NON_ZERO_COUNT:
             case HIGH_NON_ZERO_COUNT:
-                answer = 1; // TODO add realistic number
+                answer = new ByteSizeValue(32, ByteSizeUnit.KB).getBytes();
                 break;
             case DISTINCT_COUNT:
             case LOW_DISTINCT_COUNT:
@@ -88,7 +100,8 @@ public class TransportEstimateModelMemoryAction
                 answer = 1; // TODO add realistic number
                 break;
             case METRIC:
-                answer = 1; // TODO add realistic number
+                // metric analyses mean, min and max simultaneously, and uses about 2.5 times the memory of one of these
+                answer = new ByteSizeValue(160, ByteSizeUnit.KB).getBytes();
                 break;
             case MEAN:
             case LOW_MEAN:
@@ -104,18 +117,14 @@ public class TransportEstimateModelMemoryAction
             case NON_NULL_SUM:
             case LOW_NON_NULL_SUM:
             case HIGH_NON_NULL_SUM:
-                // 64 comes from https://github.com/elastic/kibana/issues/18722
-                answer = new ByteSizeValue(64, ByteSizeUnit.KB).getBytes();
-                break;
             case MEDIAN:
             case LOW_MEDIAN:
             case HIGH_MEDIAN:
-                answer = 1; // TODO add realistic number
-                break;
             case VARP:
             case LOW_VARP:
             case HIGH_VARP:
-                answer = 1; // TODO add realistic number
+                // 64 comes from https://github.com/elastic/kibana/issues/18722
+                answer = new ByteSizeValue(64, ByteSizeUnit.KB).getBytes();
                 break;
             case TIME_OF_DAY:
             case TIME_OF_WEEK:
@@ -130,19 +139,26 @@ public class TransportEstimateModelMemoryAction
 
         String byFieldName = detector.getByFieldName();
         if (byFieldName != null) {
-            answer *= cardinalityEstimate(Detector.BY_FIELD_NAME_FIELD.getPreferredName(), byFieldName, overallCardinality, true);
+            long cardinalityEstimate =
+                cardinalityEstimate(Detector.BY_FIELD_NAME_FIELD.getPreferredName(), byFieldName, overallCardinality, true);
+            // The memory cost of a by field is about 2/3rds that of a partition field
+            long multiplier = addNonNegativeLongsWithMaxValueCap(cardinalityEstimate, 2) / 3 * 2;
+            answer = multiplyNonNegativeLongsWithMaxValueCap(answer, multiplier);
         }
 
         String overFieldName = detector.getOverFieldName();
         if (overFieldName != null) {
-            cardinalityEstimate(Detector.OVER_FIELD_NAME_FIELD.getPreferredName(), overFieldName, overallCardinality, true);
-            // TODO - how should "over" field cardinality affect estimate?
+            long cardinalityEstimate =
+                cardinalityEstimate(Detector.OVER_FIELD_NAME_FIELD.getPreferredName(), overFieldName, overallCardinality, true);
+            // Over fields don't multiply the whole estimate, just add a small amount (estimate 512 bytes) per value
+            answer = addNonNegativeLongsWithMaxValueCap(answer, multiplyNonNegativeLongsWithMaxValueCap(cardinalityEstimate, 512));
         }
 
         String partitionFieldName = detector.getPartitionFieldName();
         if (partitionFieldName != null) {
-            answer *=
+            long multiplier =
                 cardinalityEstimate(Detector.PARTITION_FIELD_NAME_FIELD.getPreferredName(), partitionFieldName, overallCardinality, true);
+            answer = multiplyNonNegativeLongsWithMaxValueCap(answer, multiplier);
         }
 
         return answer;
@@ -156,10 +172,10 @@ public class TransportEstimateModelMemoryAction
             pureInfluencers.removeAll(detector.extractAnalysisFields());
         }
 
-        return pureInfluencers.stream()
-            .map(influencer -> cardinalityEstimate(AnalysisConfig.INFLUENCERS.getPreferredName(), influencer, maxBucketCardinality, false)
-                * BYTES_PER_INFLUENCER_VALUE)
-            .reduce(0L, Long::sum);
+        long totalInfluencerCardinality = pureInfluencers.stream()
+            .map(influencer -> cardinalityEstimate(AnalysisConfig.INFLUENCERS.getPreferredName(), influencer, maxBucketCardinality, false))
+            .reduce(0L, TransportEstimateModelMemoryAction::addNonNegativeLongsWithMaxValueCap);
+        return multiplyNonNegativeLongsWithMaxValueCap(BYTES_PER_INFLUENCER_VALUE, totalInfluencerCardinality);
     }
 
     static long calculateCategorizationRequirementBytes(AnalysisConfig analysisConfig) {
@@ -187,7 +203,25 @@ public class TransportEstimateModelMemoryAction
     }
 
     static ByteSizeValue roundUpToNextMb(long bytes) {
-        assert bytes >= 0;
-        return new ByteSizeValue((BYTES_IN_MB - 1 + bytes) / BYTES_IN_MB, ByteSizeUnit.MB);
+        assert bytes >= 0 : "negative bytes " + bytes;
+        return new ByteSizeValue(addNonNegativeLongsWithMaxValueCap(bytes, BYTES_IN_MB - 1) / BYTES_IN_MB, ByteSizeUnit.MB);
+    }
+
+    private static long addNonNegativeLongsWithMaxValueCap(long a, long b) {
+        assert a >= 0;
+        assert b >= 0;
+        if (Long.MAX_VALUE - a - b < 0) {
+            return Long.MAX_VALUE;
+        }
+        return a + b;
+    }
+
+    private static long multiplyNonNegativeLongsWithMaxValueCap(long a, long b) {
+        assert a >= 0;
+        assert b >= 0;
+        if (Long.MAX_VALUE / a < b) {
+            return Long.MAX_VALUE;
+        }
+        return a * b;
     }
 }

+ 6 - 2
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportEstimateModelMemoryActionTests.java

@@ -36,7 +36,7 @@ public class TransportEstimateModelMemoryActionTests extends ESTestCase {
 
         Detector withByField = createDetector(function, "field", "buy", null, null);
         assertThat(TransportEstimateModelMemoryAction.calculateDetectorRequirementBytes(withByField,
-            overallCardinality), is(200 * 65536L));
+            overallCardinality), is(134 * 65536L));
 
         Detector withPartitionField = createDetector(function, "field", null, null, "part");
         assertThat(TransportEstimateModelMemoryAction.calculateDetectorRequirementBytes(withPartitionField,
@@ -44,7 +44,7 @@ public class TransportEstimateModelMemoryActionTests extends ESTestCase {
 
         Detector withByAndPartitionFields = createDetector(function, "field", "buy", null, "part");
         assertThat(TransportEstimateModelMemoryAction.calculateDetectorRequirementBytes(withByAndPartitionFields,
-            overallCardinality), is(200 * 100 * 65536L));
+            overallCardinality), is(134 * 100 * 65536L));
     }
 
     public void testCalculateInfluencerRequirementBytes() {
@@ -98,6 +98,10 @@ public class TransportEstimateModelMemoryActionTests extends ESTestCase {
             equalTo(new ByteSizeValue(2, ByteSizeUnit.MB)));
         assertThat(TransportEstimateModelMemoryAction.roundUpToNextMb(2 * 1024 * 1024),
             equalTo(new ByteSizeValue(2, ByteSizeUnit.MB)));
+        // We don't round up at the extremes, to ensure that the resulting value can be represented as bytes in a long
+        // (At such extreme scale it won't be possible to actually run the analysis, so ease of use trumps precision)
+        assertThat(TransportEstimateModelMemoryAction.roundUpToNextMb(Long.MAX_VALUE - randomIntBetween(0, 1000000)),
+            equalTo(new ByteSizeValue(Long.MAX_VALUE / new ByteSizeValue(1, ByteSizeUnit.MB).getBytes() , ByteSizeUnit.MB)));
     }
 
     public static Detector createDetector(String function, String fieldName, String byFieldName,

+ 122 - 6
x-pack/plugin/src/test/resources/rest-api-spec/test/ml/estimate_model_memory.yml

@@ -12,7 +12,7 @@
               "airline": 50000
             }
           }
-  - match: { model_memory_estimate: "3135mb" }
+  - match: { model_memory_estimate: "2094mb" }
 
 ---
 "Test by field also influencer":
@@ -32,7 +32,7 @@
               "airline": 500
             }
           }
-  - match: { model_memory_estimate: "3135mb" }
+  - match: { model_memory_estimate: "2094mb" }
 
 ---
 "Test by field with independent influencer":
@@ -52,7 +52,63 @@
               "country": 500
             }
           }
-  - match: { model_memory_estimate: "3140mb" }
+  - match: { model_memory_estimate: "2099mb" }
+
+---
+"Test over field":
+  - do:
+      ml.estimate_model_memory:
+        body: >
+          {
+            "analysis_config": {
+              "bucket_span": "1h",
+              "detectors": [{"function": "max", "field_name": "responsetime", "over_field_name": "airline"}]
+            },
+            "overall_cardinality": {
+              "airline": 50000
+            }
+          }
+  - match: { model_memory_estimate: "35mb" }
+
+---
+"Test over field also influencer":
+  - do:
+      ml.estimate_model_memory:
+        body: >
+          {
+            "analysis_config": {
+              "bucket_span": "1h",
+              "detectors": [{"function": "max", "field_name": "responsetime", "over_field_name": "airline"}],
+              "influencers": [ "airline" ]
+            },
+            "overall_cardinality": {
+              "airline": 50000
+            },
+            "max_bucket_cardinality": {
+              "airline": 500
+            }
+          }
+  - match: { model_memory_estimate: "35mb" }
+
+---
+"Test over field with independent influencer":
+  - do:
+      ml.estimate_model_memory:
+        body: >
+          {
+            "analysis_config": {
+              "bucket_span": "1h",
+              "detectors": [{"function": "max", "field_name": "responsetime", "over_field_name": "airline"}],
+              "influencers": [ "country" ]
+            },
+            "overall_cardinality": {
+              "airline": 50000
+            },
+            "max_bucket_cardinality": {
+              "country": 500
+            }
+          }
+  - match: { model_memory_estimate: "40mb" }
 
 ---
 "Test partition field":
@@ -125,7 +181,7 @@
               "country": 600
             }
           }
-  - match: { model_memory_estimate: "150010mb" }
+  - match: { model_memory_estimate: "100060mb" }
 
 ---
 "Test by and partition fields also influencers":
@@ -147,7 +203,7 @@
               "country": 40
             }
           }
-  - match: { model_memory_estimate: "150010mb" }
+  - match: { model_memory_estimate: "100060mb" }
 
 ---
 "Test by and partition fields with independent influencer":
@@ -168,5 +224,65 @@
               "src_ip": 500
             }
           }
-  - match: { model_memory_estimate: "150015mb" }
+  - match: { model_memory_estimate: "100065mb" }
+
+---
+"Test over and partition field":
+  - do:
+      ml.estimate_model_memory:
+        body: >
+          {
+            "analysis_config": {
+              "bucket_span": "1h",
+              "detectors": [{"function": "max", "field_name": "responsetime", "over_field_name": "airline", "partition_field_name": "country"}]
+            },
+            "overall_cardinality": {
+              "airline": 4000,
+              "country": 600
+            }
+          }
+  - match: { model_memory_estimate: "1220mb" }
+
+---
+"Test over and partition fields also influencers":
+  - do:
+      ml.estimate_model_memory:
+        body: >
+          {
+            "analysis_config": {
+              "bucket_span": "1h",
+              "detectors": [{"function": "max", "field_name": "responsetime", "over_field_name": "airline", "partition_field_name": "country"}],
+              "influencers": [ "airline", "country" ]
+            },
+            "overall_cardinality": {
+              "airline": 4000,
+              "country": 600
+            },
+            "max_bucket_cardinality": {
+              "airline": 60,
+              "country": 40
+            }
+          }
+  - match: { model_memory_estimate: "1220mb" }
+
+---
+"Test over and partition fields with independent influencer":
+  - do:
+      ml.estimate_model_memory:
+        body: >
+          {
+            "analysis_config": {
+              "bucket_span": "1h",
+              "detectors": [{"function": "max", "field_name": "responsetime", "over_field_name": "airline", "partition_field_name": "country"}],
+              "influencers": [ "src_ip" ]
+            },
+            "overall_cardinality": {
+              "airline": 4000,
+              "country": 600
+            },
+            "max_bucket_cardinality": {
+              "src_ip": 500
+            }
+          }
+  - match: { model_memory_estimate: "1225mb" }