浏览代码

[ML] Handle data frame analytics state spreading over multiple docs (#62564)

When state persistence was first implemented for data frame analytics
we had the assumption that state would always fit in a single document.
However this is not the case any more.

This commit adds handling of state that spreads over multiple documents.
Dimitris Athanasiou 5 年之前
父节点
当前提交
2723a119ff
共有 19 个文件被更改,包括 117 次插入73 次删除
  1. 4 4
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java
  2. 2 2
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java
  3. 1 1
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java
  4. 4 4
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java
  5. 1 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java
  6. 1 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java
  7. 1 1
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java
  8. 2 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java
  9. 36 8
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteDataFrameAnalyticsAction.java
  10. 4 3
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcess.java
  11. 2 4
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessFactory.java
  12. 19 23
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java
  13. 1 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/MemoryUsageEstimationProcessManager.java
  14. 29 4
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcess.java
  15. 2 4
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcessFactory.java
  16. 2 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeMemoryUsageEstimationProcess.java
  17. 1 3
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeMemoryUsageEstimationProcessFactory.java
  18. 4 4
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java
  19. 1 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/MemoryUsageEstimationProcessManagerTests.java

+ 4 - 4
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java

@@ -55,7 +55,7 @@ public class Classification implements DataFrameAnalysis {
     public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
     public static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors");
 
-    private static final String STATE_DOC_ID_SUFFIX = "_classification_state#1";
+    private static final String STATE_DOC_ID_INFIX = "_classification_state#";
 
     private static final String NUM_CLASSES = "num_classes";
 
@@ -411,8 +411,8 @@ public class Classification implements DataFrameAnalysis {
     }
 
     @Override
-    public String getStateDocId(String jobId) {
-        return jobId + STATE_DOC_ID_SUFFIX;
+    public String getStateDocIdPrefix(String jobId) {
+        return jobId + STATE_DOC_ID_INFIX;
     }
 
     @Override
@@ -437,7 +437,7 @@ public class Classification implements DataFrameAnalysis {
     }
 
     public static String extractJobIdFromStateDoc(String stateDocId) {
-        int suffixIndex = stateDocId.lastIndexOf(STATE_DOC_ID_SUFFIX);
+        int suffixIndex = stateDocId.lastIndexOf(STATE_DOC_ID_INFIX);
         return suffixIndex <= 0 ? null : stateDocId.substring(0, suffixIndex);
     }
 

+ 2 - 2
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java

@@ -63,9 +63,9 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
     boolean persistsState();
 
     /**
-     * Returns the document id for the analysis state
+     * Returns the document id prefix for the analysis state
      */
-    String getStateDocId(String jobId);
+    String getStateDocIdPrefix(String jobId);
 
     /**
      * Returns the progress phases the analysis goes through in order

+ 1 - 1
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java

@@ -264,7 +264,7 @@ public class OutlierDetection implements DataFrameAnalysis {
     }
 
     @Override
-    public String getStateDocId(String jobId) {
+    public String getStateDocIdPrefix(String jobId) {
         throw new UnsupportedOperationException("Outlier detection does not support state");
     }
 

+ 4 - 4
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java

@@ -51,7 +51,7 @@ public class Regression implements DataFrameAnalysis {
     public static final ParseField LOSS_FUNCTION_PARAMETER = new ParseField("loss_function_parameter");
     public static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors");
 
-    private static final String STATE_DOC_ID_SUFFIX = "_regression_state#1";
+    private static final String STATE_DOC_ID_INFIX = "_regression_state#";
 
     private static final ConstructingObjectParser<Regression, Void> LENIENT_PARSER = createParser(true);
     private static final ConstructingObjectParser<Regression, Void> STRICT_PARSER = createParser(false);
@@ -305,8 +305,8 @@ public class Regression implements DataFrameAnalysis {
     }
 
     @Override
-    public String getStateDocId(String jobId) {
-        return jobId + STATE_DOC_ID_SUFFIX;
+    public String getStateDocIdPrefix(String jobId) {
+        return jobId + STATE_DOC_ID_INFIX;
     }
 
     @Override
@@ -328,7 +328,7 @@ public class Regression implements DataFrameAnalysis {
     }
 
     public static String extractJobIdFromStateDoc(String stateDocId) {
-        int suffixIndex = stateDocId.lastIndexOf(STATE_DOC_ID_SUFFIX);
+        int suffixIndex = stateDocId.lastIndexOf(STATE_DOC_ID_INFIX);
         return suffixIndex <= 0 ? null : stateDocId.substring(0, suffixIndex);
     }
 

+ 1 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java

@@ -446,7 +446,7 @@ public class ClassificationTests extends AbstractBWCSerializationTestCase<Classi
         Classification classification = createRandom();
         assertThat(classification.persistsState(), is(true));
         String randomId = randomAlphaOfLength(10);
-        assertThat(classification.getStateDocId(randomId), equalTo(randomId + "_classification_state#1"));
+        assertThat(classification.getStateDocIdPrefix(randomId), equalTo(randomId + "_classification_state#"));
     }
 
     public void testExtractJobIdFromStateDoc() {

+ 1 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java

@@ -120,7 +120,7 @@ public class OutlierDetectionTests extends AbstractBWCSerializationTestCase<Outl
     public void testGetStateDocId() {
         OutlierDetection outlierDetection = createRandom();
         assertThat(outlierDetection.persistsState(), is(false));
-        expectThrows(UnsupportedOperationException.class, () -> outlierDetection.getStateDocId("foo"));
+        expectThrows(UnsupportedOperationException.class, () -> outlierDetection.getStateDocIdPrefix("foo"));
     }
 
     public void testInferenceConfig() {

+ 1 - 1
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java

@@ -331,7 +331,7 @@ public class RegressionTests extends AbstractBWCSerializationTestCase<Regression
         Regression regression = createRandom();
         assertThat(regression.persistsState(), is(true));
         String randomId = randomAlphaOfLength(10);
-        assertThat(regression.getStateDocId(randomId), equalTo(randomId + "_regression_state#1"));
+        assertThat(regression.getStateDocIdPrefix(randomId), equalTo(randomId + "_regression_state#"));
     }
 
     public void testExtractJobIdFromStateDoc() {

+ 2 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

@@ -660,8 +660,8 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
                     new BlackHoleAutodetectProcess(job.getId(), onProcessCrash);
             // factor of 1.0 makes renormalization a no-op
             normalizerProcessFactory = (jobId, quantilesState, bucketSpan, executorService) -> new MultiplyingNormalizerProcess(1.0);
-            analyticsProcessFactory = (jobId, analyticsProcessConfig, state, executorService, onProcessCrash) -> null;
-            memoryEstimationProcessFactory = (jobId, analyticsProcessConfig, state, executorService, onProcessCrash) -> null;
+            analyticsProcessFactory = (jobId, analyticsProcessConfig, hasState, executorService, onProcessCrash) -> null;
+            memoryEstimationProcessFactory = (jobId, analyticsProcessConfig, hasState, executorService, onProcessCrash) -> null;
         }
         NormalizerFactory normalizerFactory = new NormalizerFactory(normalizerProcessFactory,
                 threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME));

+ 36 - 8
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteDataFrameAnalyticsAction.java

@@ -28,6 +28,7 @@ import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.unit.TimeValue;
+import org.elasticsearch.index.query.IdsQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.index.reindex.AbstractBulkByScrollRequest;
@@ -57,8 +58,6 @@ import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
 import org.elasticsearch.xpack.ml.utils.MlIndicesUtils;
 
 import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
 import java.util.Objects;
 
 import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
@@ -242,17 +241,46 @@ public class TransportDeleteDataFrameAnalyticsAction
                              DataFrameAnalyticsConfig config,
                              TimeValue timeout,
                              ActionListener<BulkByScrollResponse> listener) {
-        List<String> ids = new ArrayList<>();
-        ids.add(StoredProgress.documentId(config.getId()));
-        if (config.getAnalysis().persistsState()) {
-            ids.add(config.getAnalysis().getStateDocId(config.getId()));
+        ActionListener<Boolean> deleteModelStateListener = ActionListener.wrap(
+            r -> executeDeleteByQuery(
+                    parentTaskClient,
+                    AnomalyDetectorsIndex.jobStateIndexPattern(),
+                    QueryBuilders.idsQuery().addIds(StoredProgress.documentId(config.getId())),
+                    timeout,
+                    listener
+                )
+            , listener::onFailure
+        );
+
+        deleteModelState(parentTaskClient, config, timeout, 1, deleteModelStateListener);
+    }
+
+    private void deleteModelState(ParentTaskAssigningClient parentTaskClient,
+                                  DataFrameAnalyticsConfig config,
+                                  TimeValue timeout,
+                                  int docNum,
+                                  ActionListener<Boolean> listener) {
+        if (config.getAnalysis().persistsState() == false) {
+            listener.onResponse(true);
+            return;
         }
+
+        IdsQueryBuilder query = QueryBuilders.idsQuery().addIds(config.getAnalysis().getStateDocIdPrefix(config.getId()) + docNum);
         executeDeleteByQuery(
             parentTaskClient,
             AnomalyDetectorsIndex.jobStateIndexPattern(),
-            QueryBuilders.idsQuery().addIds(ids.toArray(String[]::new)),
+            query,
             timeout,
-            listener
+            ActionListener.wrap(
+                response -> {
+                    if (response.getDeleted() > 0) {
+                        deleteModelState(parentTaskClient, config, timeout, docNum + 1, listener);
+                        return;
+                    }
+                    listener.onResponse(true);
+                },
+                listener::onFailure
+            )
         );
     }
 

+ 4 - 3
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcess.java

@@ -5,7 +5,7 @@
  */
 package org.elasticsearch.xpack.ml.dataframe.process;
 
-import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.client.Client;
 import org.elasticsearch.xpack.ml.process.NativeProcess;
 
 import java.io.IOException;
@@ -41,7 +41,8 @@ public interface AnalyticsProcess<ProcessResult> extends NativeProcess {
 
     /**
      * Restores the model state from a previously persisted one
-     * @param state the state to restore
+     * @param client the client to use for fetching the state documents
+     * @param stateDocIdPrefix the prefix of ids of the state documents
      */
-    void restoreState(BytesReference state) throws IOException;
+    void restoreState(Client client, String stateDocIdPrefix) throws IOException;
 }

+ 2 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessFactory.java

@@ -5,8 +5,6 @@
  */
 package org.elasticsearch.xpack.ml.dataframe.process;
 
-import org.elasticsearch.common.Nullable;
-import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
 
 import java.util.concurrent.ExecutorService;
@@ -19,12 +17,12 @@ public interface AnalyticsProcessFactory<ProcessResult> {
      *
      * @param config                 The data frame analytics config
      * @param analyticsProcessConfig The process configuration
-     * @param state                  The state document to restore from if there is one available
+     * @param hasState               Whether there is state to restore from
      * @param executorService        Executor service used to start the async tasks a job needs to operate the analytical process
      * @param onProcessCrash         Callback to execute if the process stops unexpectedly
      * @return The process
      */
     AnalyticsProcess<ProcessResult> createAnalyticsProcess(DataFrameAnalyticsConfig config, AnalyticsProcessConfig analyticsProcessConfig,
-                                                           @Nullable BytesReference state, ExecutorService executorService,
+                                                           boolean hasState, ExecutorService executorService,
                                                            Consumer<String> onProcessCrash);
 }

+ 19 - 23
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java

@@ -15,13 +15,10 @@ import org.elasticsearch.action.search.SearchResponse;
 import org.elasticsearch.action.support.IndicesOptions;
 import org.elasticsearch.client.Client;
 import org.elasticsearch.client.ParentTaskAssigningClient;
-import org.elasticsearch.common.Nullable;
 import org.elasticsearch.common.Strings;
-import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.index.query.QueryBuilders;
-import org.elasticsearch.search.SearchHit;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xpack.core.ClientHelper;
 import org.elasticsearch.xpack.core.ml.MlStatsIndex;
@@ -139,11 +136,11 @@ public class AnalyticsProcessManager {
             }
 
             // Fetch existing model state (if any)
-            BytesReference state = getModelState(config);
+            final boolean hasState = hasModelState(config);
 
             boolean isProcessStarted;
             try {
-                isProcessStarted = processContext.startProcess(dataExtractorFactory, task, state);
+                isProcessStarted = processContext.startProcess(dataExtractorFactory, task, hasState);
             } catch (Exception e) {
                 processContext.stop();
                 task.setFailed(processContext.getFailureReason() == null ?
@@ -153,7 +150,7 @@ public class AnalyticsProcessManager {
 
             if (isProcessStarted) {
                 executorServiceForProcess.execute(() -> processContext.resultProcessor.get().process(processContext.process.get()));
-                executorServiceForProcess.execute(() -> processData(task, processContext, state));
+                executorServiceForProcess.execute(() -> processData(task, processContext, hasState));
             } else {
                 processContextByAllocation.remove(task.getAllocationId());
                 auditor.info(config.getId(), Messages.DATA_FRAME_ANALYTICS_AUDIT_FINISHED_ANALYSIS);
@@ -162,23 +159,22 @@ public class AnalyticsProcessManager {
         });
     }
 
-    @Nullable
-    private BytesReference getModelState(DataFrameAnalyticsConfig config) {
+    private boolean hasModelState(DataFrameAnalyticsConfig config) {
         if (config.getAnalysis().persistsState() == false) {
-            return null;
+            return false;
         }
 
         try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) {
             SearchResponse searchResponse = client.prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())
                 .setSize(1)
-                .setQuery(QueryBuilders.idsQuery().addIds(config.getAnalysis().getStateDocId(config.getId())))
+                .setFetchSource(false)
+                .setQuery(QueryBuilders.idsQuery().addIds(config.getAnalysis().getStateDocIdPrefix(config.getId()) + "1"))
                 .get();
-            SearchHit[] hits = searchResponse.getHits().getHits();
-            return hits.length == 0 ? null : hits[0].getSourceRef();
+            return searchResponse.getHits().getHits().length == 1;
         }
     }
 
-    private void processData(DataFrameAnalyticsTask task, ProcessContext processContext, BytesReference state) {
+    private void processData(DataFrameAnalyticsTask task, ProcessContext processContext, boolean hasState) {
         LOGGER.info("[{}] Started loading data", processContext.config.getId());
         auditor.info(processContext.config.getId(), Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_STARTED_LOADING_DATA));
 
@@ -193,7 +189,7 @@ public class AnalyticsProcessManager {
             process.writeEndOfDataMessage();
             process.flushStream();
 
-            restoreState(task, config, state, process);
+            restoreState(task, config, process, hasState);
 
             LOGGER.info("[{}] Started analyzing", processContext.config.getId());
             auditor.info(processContext.config.getId(), Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_STARTED_ANALYZING));
@@ -297,14 +293,14 @@ public class AnalyticsProcessManager {
         process.writeRecord(headerRecord);
     }
 
-    private void restoreState(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, @Nullable BytesReference state,
-                              AnalyticsProcess<AnalyticsResult> process) {
+    private void restoreState(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config, AnalyticsProcess<AnalyticsResult> process,
+                              boolean hasState) {
         if (config.getAnalysis().persistsState() == false) {
             LOGGER.debug("[{}] Analysis does not support state", config.getId());
             return;
         }
 
-        if (state == null) {
+        if (hasState == false) {
             LOGGER.debug("[{}] No model state available to restore", config.getId());
             return;
         }
@@ -313,7 +309,7 @@ public class AnalyticsProcessManager {
         auditor.info(config.getId(), Messages.DATA_FRAME_ANALYTICS_AUDIT_RESTORING_STATE);
 
         try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) {
-            process.restoreState(state);
+            process.restoreState(client, config.getAnalysis().getStateDocIdPrefix(config.getId()));
         } catch (Exception e) {
             LOGGER.error(new ParameterizedMessage("[{}] Failed to restore state", process.getConfig().jobId()), e);
             task.setFailed(ExceptionsHelper.serverError("Failed to restore state: " + e.getMessage()));
@@ -321,9 +317,9 @@ public class AnalyticsProcessManager {
     }
 
     private AnalyticsProcess<AnalyticsResult> createProcess(DataFrameAnalyticsTask task, DataFrameAnalyticsConfig config,
-                                                            AnalyticsProcessConfig analyticsProcessConfig, @Nullable BytesReference state) {
-        AnalyticsProcess<AnalyticsResult> process =
-            processFactory.createAnalyticsProcess(config, analyticsProcessConfig, state, executorServiceForProcess, onProcessCrash(task));
+                                                            AnalyticsProcessConfig analyticsProcessConfig, boolean hasState) {
+        AnalyticsProcess<AnalyticsResult> process = processFactory.createAnalyticsProcess(
+            config, analyticsProcessConfig, hasState, executorServiceForProcess, onProcessCrash(task));
         if (process.isProcessAlive() == false) {
             throw ExceptionsHelper.serverError("Failed to start data frame analytics process");
         }
@@ -467,7 +463,7 @@ public class AnalyticsProcessManager {
          */
         synchronized boolean startProcess(DataFrameDataExtractorFactory dataExtractorFactory,
                                           DataFrameAnalyticsTask task,
-                                          @Nullable BytesReference state) {
+                                          boolean hasState) {
             if (task.isStopping()) {
                 // The job was stopped before we started the process so no need to start it
                 return false;
@@ -483,7 +479,7 @@ public class AnalyticsProcessManager {
                 LOGGER.info("[{}] no data found to analyze. Will not start analytics native process.", config.getId());
                 return false;
             }
-            process.set(createProcess(task, config, analyticsProcessConfig, state));
+            process.set(createProcess(task, config, analyticsProcessConfig, hasState));
             resultProcessor.set(createResultProcessor(task, dataExtractorFactory));
             return true;
         }

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/MemoryUsageEstimationProcessManager.java

@@ -85,7 +85,7 @@ public class MemoryUsageEstimationProcessManager {
             processFactory.createAnalyticsProcess(
                 config,
                 processConfig,
-                null,
+                false,
                 executorServiceForProcess,
                 // The handler passed here will never be called as AbstractNativeProcess.detectCrash method returns early when
                 // (processInStream == null) which is the case for MemoryUsageEstimationProcess.

+ 29 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcess.java

@@ -5,8 +5,15 @@
  */
 package org.elasticsearch.xpack.ml.dataframe.process;
 
-import org.elasticsearch.common.bytes.BytesReference;
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.apache.logging.log4j.message.ParameterizedMessage;
+import org.elasticsearch.action.search.SearchResponse;
+import org.elasticsearch.client.Client;
 import org.elasticsearch.common.xcontent.NamedXContentRegistry;
+import org.elasticsearch.index.query.QueryBuilders;
+import org.elasticsearch.search.SearchHit;
+import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
 import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
 import org.elasticsearch.xpack.ml.process.NativeController;
 import org.elasticsearch.xpack.ml.process.ProcessPipes;
@@ -22,6 +29,8 @@ import java.util.function.Consumer;
 
 public class NativeAnalyticsProcess extends AbstractNativeAnalyticsProcess<AnalyticsResult> {
 
+    private static final Logger logger = LogManager.getLogger(NativeAnalyticsProcess.class);
+
     private static final String NAME = "analytics";
 
     private final AnalyticsProcessConfig config;
@@ -56,10 +65,26 @@ public class NativeAnalyticsProcess extends AbstractNativeAnalyticsProcess<Analy
     }
 
     @Override
-    public void restoreState(BytesReference state) throws IOException {
-        Objects.requireNonNull(state);
+    public void restoreState(Client client, String stateDocIdPrefix) throws IOException {
+        Objects.requireNonNull(stateDocIdPrefix);
         try (OutputStream restoreStream = processRestoreStream()) {
-            StateToProcessWriterHelper.writeStateToStream(state, restoreStream);
+            int docNum = 0;
+            while (true) {
+                if (isProcessKilled()) {
+                    return;
+                }
+
+                // We fetch the documents one at a time because all together they can amount to too much memory
+                SearchResponse stateResponse = client.prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())
+                    .setSize(1)
+                    .setQuery(QueryBuilders.idsQuery().addIds(stateDocIdPrefix + ++docNum)).get();
+                if (stateResponse.getHits().getHits().length == 0) {
+                    break;
+                }
+                SearchHit stateDoc = stateResponse.getHits().getAt(0);
+                logger.debug(() -> new ParameterizedMessage("[{}] Restoring state document [{}]", config.jobId(), stateDoc.getId()));
+                StateToProcessWriterHelper.writeStateToStream(stateDoc.getSourceRef(), restoreStream);
+            }
         }
     }
 }

+ 2 - 4
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeAnalyticsProcessFactory.java

@@ -8,8 +8,6 @@ package org.elasticsearch.xpack.ml.dataframe.process;
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.elasticsearch.cluster.service.ClusterService;
-import org.elasticsearch.common.Nullable;
-import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
 import org.elasticsearch.common.xcontent.NamedXContentRegistry;
@@ -70,12 +68,12 @@ public class NativeAnalyticsProcessFactory implements AnalyticsProcessFactory<An
 
     @Override
     public NativeAnalyticsProcess createAnalyticsProcess(DataFrameAnalyticsConfig config, AnalyticsProcessConfig analyticsProcessConfig,
-                                                         @Nullable BytesReference state, ExecutorService executorService,
+                                                         boolean hasState, ExecutorService executorService,
                                                          Consumer<String> onProcessCrash) {
         String jobId = config.getId();
         List<Path> filesToDelete = new ArrayList<>();
         ProcessPipes processPipes = new ProcessPipes(env, NAMED_PIPE_HELPER, AnalyticsBuilder.ANALYTICS, jobId,
-                false, true, true, state != null, config.getAnalysis().persistsState());
+                false, true, true, hasState, config.getAnalysis().persistsState());
 
         // The extra 2 are for the checksum and the control field
         int numberOfFields = analyticsProcessConfig.cols() + 2;

+ 2 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeMemoryUsageEstimationProcess.java

@@ -5,7 +5,7 @@
  */
 package org.elasticsearch.xpack.ml.dataframe.process;
 
-import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.client.Client;
 import org.elasticsearch.common.xcontent.NamedXContentRegistry;
 import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult;
 import org.elasticsearch.xpack.ml.process.NativeController;
@@ -33,7 +33,7 @@ public class NativeMemoryUsageEstimationProcess extends AbstractNativeAnalyticsP
     }
 
     @Override
-    public void restoreState(BytesReference state) {
+    public void restoreState(Client client, String stateDocIdPrefix) {
         throw new UnsupportedOperationException();
     }
 }

+ 1 - 3
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/NativeMemoryUsageEstimationProcessFactory.java

@@ -8,8 +8,6 @@ package org.elasticsearch.xpack.ml.dataframe.process;
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.elasticsearch.cluster.service.ClusterService;
-import org.elasticsearch.common.Nullable;
-import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.unit.TimeValue;
 import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
 import org.elasticsearch.core.internal.io.IOUtils;
@@ -60,7 +58,7 @@ public class NativeMemoryUsageEstimationProcessFactory implements AnalyticsProce
     public NativeMemoryUsageEstimationProcess createAnalyticsProcess(
             DataFrameAnalyticsConfig config,
             AnalyticsProcessConfig analyticsProcessConfig,
-            @Nullable BytesReference state,
+            boolean hasState,
             ExecutorService executorService,
             Consumer<String> onProcessCrash) {
         List<Path> filesToDelete = new ArrayList<>();

+ 4 - 4
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java

@@ -92,7 +92,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
         when(process.isProcessAlive()).thenReturn(true);
         when(process.readAnalyticsResults()).thenReturn(List.of(PROCESS_RESULT).iterator());
         processFactory = mock(AnalyticsProcessFactory.class);
-        when(processFactory.createAnalyticsProcess(any(), any(), any(), any(), any())).thenReturn(process);
+        when(processFactory.createAnalyticsProcess(any(), any(), anyBoolean(), any(), any())).thenReturn(process);
         auditor = mock(DataFrameAnalyticsAuditor.class);
         trainedModelProvider = mock(TrainedModelProvider.class);
 
@@ -226,7 +226,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
 
         AnalyticsProcessManager.ProcessContext processContext = processManager.new ProcessContext(dataFrameAnalyticsConfig);
         processContext.stop();
-        assertThat(processContext.startProcess(dataExtractorFactory, task, null), is(false));
+        assertThat(processContext.startProcess(dataExtractorFactory, task, false), is(false));
 
         InOrder inOrder = inOrder(dataExtractor, process, task);
         inOrder.verify(task).isStopping();
@@ -237,7 +237,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
         when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(0, NUM_COLS));
 
         AnalyticsProcessManager.ProcessContext processContext = processManager.new ProcessContext(dataFrameAnalyticsConfig);
-        assertThat(processContext.startProcess(dataExtractorFactory, task, null), is(false));
+        assertThat(processContext.startProcess(dataExtractorFactory, task, false), is(false));
 
         InOrder inOrder = inOrder(dataExtractor, process, task);
         inOrder.verify(task).isStopping();
@@ -248,7 +248,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
 
     public void testProcessContext_StartAndStop() throws Exception {
         AnalyticsProcessManager.ProcessContext processContext = processManager.new ProcessContext(dataFrameAnalyticsConfig);
-        assertThat(processContext.startProcess(dataExtractorFactory, task, null), is(true));
+        assertThat(processContext.startProcess(dataExtractorFactory, task, false), is(true));
         processContext.stop();
 
         InOrder inOrder = inOrder(dataExtractor, process, task);

+ 1 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/MemoryUsageEstimationProcessManagerTests.java

@@ -66,7 +66,7 @@ public class MemoryUsageEstimationProcessManagerTests extends ESTestCase {
         process = mock(AnalyticsProcess.class);
         when(process.readAnalyticsResults()).thenReturn(List.of(PROCESS_RESULT).iterator());
         processFactory = mock(AnalyticsProcessFactory.class);
-        when(processFactory.createAnalyticsProcess(any(), any(), any(), any(), any())).thenReturn(process);
+        when(processFactory.createAnalyticsProcess(any(), any(), anyBoolean(), any(), any())).thenReturn(process);
         dataExtractor = mock(DataFrameDataExtractor.class);
         when(dataExtractor.collectDataSummary()).thenReturn(new DataFrameDataExtractor.DataSummary(NUM_ROWS, NUM_COLS));
         dataExtractorFactory = mock(DataFrameDataExtractorFactory.class);