Browse Source

[ML] Test that AD job and model are not overallocating new node (#85283)

When a new node is added, if there are unassigned jobs and models,
we try to assign them based on the same cluster state. As the persistent
task service and the trained model allocation service are different,
they could decide to assign tasks disregarding one another.

This is currently not possible because they both add assignments by
cluster state updates. Thus whichever cluster state gets applied first
wins and the other will be rejected.

This commit adds a test to ensure there are no regressions of this
behavior.
Dimitris Athanasiou 3 năm trước cách đây
mục cha
commit
bfdd1d1f31
15 tập tin đã thay đổi với 427 bổ sung54 xóa
  1. 0 4
      x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/CategorizationAggregationIT.java
  2. 211 0
      x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/JobsAndModelsIT.java
  3. 0 11
      x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/MlDistributedFailureIT.java
  4. 0 5
      x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/MlNodeShutdownIT.java
  5. 2 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java
  6. 10 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterService.java
  7. 2 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java
  8. 95 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/BlackHolePyTorchProcess.java
  9. 4 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcess.java
  10. 43 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcess.java
  11. 1 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcessFactory.java
  12. 1 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java
  13. 2 28
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/BlackHoleAutodetectProcess.java
  14. 52 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/BlackHoleResultIterator.java
  15. 4 0
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/support/BaseMlIntegTestCase.java

+ 0 - 4
x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/CategorizationAggregationIT.java

@@ -98,10 +98,6 @@ public class CategorizationAggregationIT extends BaseMlIntegTestCase {
         assertThat(((Min) bucket.getAggregations().get("min")).value(), not(notANumber()));
     }
 
-    private void ensureStableCluster() {
-        ensureStableCluster(internalCluster().getNodeNames().length, TimeValue.timeValueSeconds(60));
-    }
-
     private void createSourceData() {
         client().admin().indices().prepareCreate(DATA_INDEX).setMapping("time", "type=date,format=epoch_millis", "msg", "type=text").get();
 

+ 211 - 0
x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/JobsAndModelsIT.java

@@ -0,0 +1,211 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.integration;
+
+import org.elasticsearch.action.index.IndexAction;
+import org.elasticsearch.action.index.IndexRequest;
+import org.elasticsearch.action.support.WriteRequest;
+import org.elasticsearch.cluster.node.DiscoveryNodeRole;
+import org.elasticsearch.common.bytes.BytesArray;
+import org.elasticsearch.common.unit.ByteSizeValue;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xcontent.json.JsonXContent;
+import org.elasticsearch.xpack.core.ml.action.CloseJobAction;
+import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction;
+import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
+import org.elasticsearch.xpack.core.ml.action.MlMemoryAction;
+import org.elasticsearch.xpack.core.ml.action.OpenJobAction;
+import org.elasticsearch.xpack.core.ml.action.PutJobAction;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
+import org.elasticsearch.xpack.core.ml.action.PutTrainedModelVocabularyAction;
+import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
+import org.elasticsearch.xpack.core.ml.action.StopTrainedModelDeploymentAction;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelType;
+import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocation;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
+import org.elasticsearch.xpack.core.ml.job.config.Job;
+import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
+import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
+import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase;
+
+import java.util.List;
+import java.util.Set;
+
+import static org.elasticsearch.test.NodeRoles.onlyRoles;
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.empty;
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.not;
+import static org.hamcrest.Matchers.nullValue;
+
+/**
+ * Tests that involve interactions of ML jobs that are persistent tasks
+ * and trained models.
+ */
+public class JobsAndModelsIT extends BaseMlIntegTestCase {
+
+    public void testCluster_GivenAnomalyDetectionJobAndTrainedModelDeployment_ShouldNotAllocateBothOnSameNode() throws Exception {
+        // This test starts 2 ML nodes and then starts an anomaly detection job and a
+        // trained model deployment that do not both fit in one node. We then proceed
+        // to stop both ML nodes and start a single ML node back up. We should see
+        // that both the job and the model cannot be allocated on that node.
+
+        internalCluster().ensureAtMostNumDataNodes(0);
+        logger.info("Starting dedicated master node...");
+        internalCluster().startMasterOnlyNode();
+        logger.info("Starting dedicated data node...");
+        internalCluster().startDataOnlyNode();
+        logger.info("Starting dedicated ml node...");
+        internalCluster().startNode(onlyRoles(Set.of(DiscoveryNodeRole.ML_ROLE)));
+        logger.info("Starting dedicated ml node...");
+        internalCluster().startNode(onlyRoles(Set.of(DiscoveryNodeRole.ML_ROLE)));
+        ensureStableCluster();
+
+        MlMemoryAction.Response memoryStats = client().execute(MlMemoryAction.INSTANCE, new MlMemoryAction.Request("ml:true")).actionGet();
+
+        long maxNativeBytesPerNode = 0;
+        for (MlMemoryAction.Response.MlMemoryStats stats : memoryStats.getNodes()) {
+            maxNativeBytesPerNode = stats.getMlMax().getBytes();
+        }
+
+        String jobId = "test-node-goes-down-while-running-job";
+        Job.Builder job = createJob(jobId, ByteSizeValue.ofBytes((long) (0.8 * maxNativeBytesPerNode)));
+
+        PutJobAction.Request putJobRequest = new PutJobAction.Request(job);
+        client().execute(PutJobAction.INSTANCE, putJobRequest).actionGet();
+        client().execute(OpenJobAction.INSTANCE, new OpenJobAction.Request(job.getId())).actionGet();
+
+        TrainedModelConfig model = TrainedModelConfig.builder()
+            .setModelId("test_model")
+            .setModelType(TrainedModelType.PYTORCH)
+            .setModelSize((long) (0.3 * maxNativeBytesPerNode))
+            .setInferenceConfig(new PassThroughConfig(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore()), null, null))
+            .setLocation(new IndexLocation(InferenceIndexConstants.nativeDefinitionStore()))
+            .build();
+
+        TrainedModelDefinitionDoc modelDefinitionDoc = new TrainedModelDefinitionDoc(
+            new BytesArray(""),
+            model.getModelId(),
+            0,
+            model.getModelSize(),
+            model.getModelSize(),
+            1,
+            true
+        );
+        try (XContentBuilder builder = JsonXContent.contentBuilder()) {
+            modelDefinitionDoc.toXContent(builder, null);
+            client().execute(
+                IndexAction.INSTANCE,
+                new IndexRequest(InferenceIndexConstants.nativeDefinitionStore()).source(builder)
+                    .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
+            ).actionGet();
+        }
+
+        client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(model, true)).actionGet();
+        client().execute(
+            PutTrainedModelVocabularyAction.INSTANCE,
+            new PutTrainedModelVocabularyAction.Request(
+                model.getModelId(),
+                List.of(
+                    "these",
+                    "are",
+                    "my",
+                    "words",
+                    BertTokenizer.SEPARATOR_TOKEN,
+                    BertTokenizer.CLASS_TOKEN,
+                    BertTokenizer.UNKNOWN_TOKEN,
+                    BertTokenizer.PAD_TOKEN
+                ),
+                List.of()
+            )
+        ).actionGet();
+
+        client().execute(StartTrainedModelDeploymentAction.INSTANCE, new StartTrainedModelDeploymentAction.Request(model.getModelId()))
+            .actionGet();
+
+        setMlIndicesDelayedNodeLeftTimeoutToZero();
+
+        String jobNode = client().execute(GetJobsStatsAction.INSTANCE, new GetJobsStatsAction.Request(job.getId()))
+            .actionGet()
+            .getResponse()
+            .results()
+            .get(0)
+            .getNode()
+            .getName();
+        String modelNode = client().execute(
+            GetTrainedModelsStatsAction.INSTANCE,
+            new GetTrainedModelsStatsAction.Request(model.getModelId())
+        ).actionGet().getResources().results().get(0).getDeploymentStats().getNodeStats().get(0).getNode().getName();
+
+        // Assert the job and model were assigned to different nodes as they would not fit in the same node
+        assertThat(jobNode, not(equalTo(modelNode)));
+
+        // Stop both ML nodes
+        logger.info("Stopping both ml nodes...");
+        assertThat(internalCluster().stopNode(jobNode), is(true));
+        assertThat(internalCluster().stopNode(modelNode), is(true));
+
+        // Wait for both the job and model to be unassigned
+        assertBusy(() -> {
+            GetJobsStatsAction.Response jobStats = client().execute(
+                GetJobsStatsAction.INSTANCE,
+                new GetJobsStatsAction.Request(job.getId())
+            ).actionGet();
+            assertThat(jobStats.getResponse().results().get(0).getNode(), is(nullValue()));
+        });
+        assertBusy(() -> {
+            GetTrainedModelsStatsAction.Response modelStats = client().execute(
+                GetTrainedModelsStatsAction.INSTANCE,
+                new GetTrainedModelsStatsAction.Request(model.getModelId())
+            ).actionGet();
+            assertThat(modelStats.getResources().results().get(0).getDeploymentStats().getNodeStats(), is(empty()));
+        });
+
+        // Start a new ML node
+        logger.info("Starting dedicated ml node...");
+        String lastMlNodeName = internalCluster().startNode(onlyRoles(Set.of(DiscoveryNodeRole.ML_ROLE)));
+        ensureStableCluster();
+
+        // Here we make the assumption that models are assigned before persistent tasks.
+        // The reason this holds follows. Allocation service is a plugin component listening to
+        // cluster states updates. Persistent tasks have executors that listen to cluster
+        // states. Plugin components get created before persistent task executors. Thus,
+        // the allocation service will be producing each cluster state updates first.
+        // As this assumption might be critical, the test should break if the assumption
+        // breaks to give us a warning about potential impact.
+
+        // Wait until the model is assigned
+        assertBusy(() -> {
+            GetTrainedModelsStatsAction.Response modelStatsResponse = client().execute(
+                GetTrainedModelsStatsAction.INSTANCE,
+                new GetTrainedModelsStatsAction.Request(model.getModelId())
+            ).actionGet();
+            GetTrainedModelsStatsAction.Response.TrainedModelStats modelStats = modelStatsResponse.getResources().results().get(0);
+            assertThat(modelStats.getDeploymentStats().getNodeStats().isEmpty(), is(false));
+            assertThat(modelStats.getDeploymentStats().getNodeStats().get(0).getNode().getName(), equalTo(lastMlNodeName));
+        });
+
+        // Check the job is unassigned due to insufficient memory
+        GetJobsStatsAction.Response jobStatsResponse = client().execute(
+            GetJobsStatsAction.INSTANCE,
+            new GetJobsStatsAction.Request(job.getId())
+        ).actionGet();
+        GetJobsStatsAction.Response.JobStats jobStats = jobStatsResponse.getResponse().results().get(0);
+        assertThat(jobStats.getNode(), is(nullValue()));
+        assertThat(jobStats.getAssignmentExplanation(), containsString("insufficient available memory"));
+
+        // Clean up
+        client().execute(CloseJobAction.INSTANCE, new CloseJobAction.Request(jobId).setForce(true)).actionGet();
+        client().execute(StopTrainedModelDeploymentAction.INSTANCE, new StopTrainedModelDeploymentAction.Request(model.getModelId()))
+            .actionGet();
+    }
+}

+ 0 - 11
x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/MlDistributedFailureIT.java

@@ -789,17 +789,6 @@ public class MlDistributedFailureIT extends BaseMlIntegTestCase {
         }, 30, TimeUnit.SECONDS);
     }
 
-    private void waitForJobClosed(String jobId) throws Exception {
-        assertBusy(() -> {
-            JobStats jobStats = getJobStats(jobId);
-            assertEquals(jobStats.getState(), JobState.CLOSED);
-        }, 30, TimeUnit.SECONDS);
-    }
-
-    private void ensureStableCluster() {
-        ensureStableCluster(internalCluster().getNodeNames().length, TimeValue.timeValueSeconds(60));
-    }
-
     private void indexModelSnapshotFromCurrentJobStats(String jobId) throws IOException {
         JobStats jobStats = getJobStats(jobId);
         DataCounts dataCounts = jobStats.getDataCounts();

+ 0 - 5
x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/MlNodeShutdownIT.java

@@ -11,7 +11,6 @@ import org.apache.lucene.util.SetOnce;
 import org.elasticsearch.cluster.metadata.Metadata;
 import org.elasticsearch.cluster.metadata.SingleNodeShutdownMetadata;
 import org.elasticsearch.common.unit.ByteSizeValue;
-import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.xpack.core.action.util.QueryPage;
 import org.elasticsearch.xpack.core.ml.action.CloseJobAction;
 import org.elasticsearch.xpack.core.ml.action.GetJobsStatsAction;
@@ -255,10 +254,6 @@ public class MlNodeShutdownIT extends BaseMlIntegTestCase {
         client().execute(StartDatafeedAction.INSTANCE, startDatafeedRequest).get();
     }
 
-    private void ensureStableCluster() {
-        ensureStableCluster(internalCluster().getNodeNames().length, TimeValue.timeValueSeconds(60));
-    }
-
     private void createSourceData() {
         client().admin().indices().prepareCreate("data").setMapping("time", "type=date").get();
         long numDocs = randomIntBetween(50, 100);

+ 2 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

@@ -308,6 +308,7 @@ import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
 import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService;
 import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider;
 import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
+import org.elasticsearch.xpack.ml.inference.pytorch.process.BlackHolePyTorchProcess;
 import org.elasticsearch.xpack.ml.inference.pytorch.process.NativePyTorchProcessFactory;
 import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcessFactory;
 import org.elasticsearch.xpack.ml.job.JobManager;
@@ -890,7 +891,7 @@ public class MachineLearning extends Plugin
             normalizerProcessFactory = (jobId, quantilesState, bucketSpan, executorService) -> new MultiplyingNormalizerProcess(1.0);
             analyticsProcessFactory = (jobId, analyticsProcessConfig, hasState, executorService, onProcessCrash) -> null;
             memoryEstimationProcessFactory = (jobId, analyticsProcessConfig, hasState, executorService, onProcessCrash) -> null;
-            pyTorchProcessFactory = (task, executorService, onProcessCrash) -> null;
+            pyTorchProcessFactory = (task, executorService, onProcessCrash) -> new BlackHolePyTorchProcess();
         }
         NormalizerFactory normalizerFactory = new NormalizerFactory(
             normalizerProcessFactory,

+ 10 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterService.java

@@ -108,6 +108,16 @@ public class TrainedModelAllocationClusterService implements ClusterStateListene
                     // If the event indicates there were nodes added/removed, this method only looks at the current state and has
                     // no previous knowledge of existing nodes. Consequently, if a model was manually removed (task-kill) from a node
                     // it may get re-allocated to that node when another node is added/removed...
+
+                    // As this produces a cluster state update task, we are certain that if the persistent
+                    // task framework results in assigning some ML tasks on that same cluster state change
+                    // we do not end up over-allocating a node. Both this service and the persistant task service
+                    // will produce a cluster state update but the one that gets applied first wins. The other
+                    // update will be rejected and we will retry to assign getting a correct update on available memory
+                    // on each node.
+                    // Also, note that as this service is a returned as a component of the ML plugin,
+                    // and components are created before persistent task executors, we will try to allocate
+                    // trained models before we try to assign ML persistent tasks.
                     return addRemoveAllocationNodes(currentState);
                 }
 

+ 2 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

@@ -46,7 +46,7 @@ import org.elasticsearch.xpack.ml.MachineLearning;
 import org.elasticsearch.xpack.ml.inference.nlp.NlpTask;
 import org.elasticsearch.xpack.ml.inference.nlp.Vocabulary;
 import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
-import org.elasticsearch.xpack.ml.inference.pytorch.process.NativePyTorchProcess;
+import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcess;
 import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcessFactory;
 import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchResultProcessor;
 import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchStateStreamer;
@@ -447,7 +447,7 @@ public class DeploymentManager {
     class ProcessContext {
 
         private final TrainedModelDeploymentTask task;
-        private final SetOnce<NativePyTorchProcess> process = new SetOnce<>();
+        private final SetOnce<PyTorchProcess> process = new SetOnce<>();
         private final SetOnce<NlpTask.Processor> nlpTaskProcessor = new SetOnce<>();
         private final SetOnce<TrainedModelInput> modelInput = new SetOnce<>();
         private final PyTorchResultProcessor resultProcessor;

+ 95 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/BlackHolePyTorchProcess.java

@@ -0,0 +1,95 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.pytorch.process;
+
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchResult;
+import org.elasticsearch.xpack.ml.process.BlackHoleResultIterator;
+
+import java.io.IOException;
+import java.time.ZonedDateTime;
+import java.util.Iterator;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingDeque;
+
+public class BlackHolePyTorchProcess implements PyTorchProcess {
+
+    private final ZonedDateTime startTime;
+    private volatile boolean running = true;
+    private final BlockingQueue<PyTorchResult> results = new LinkedBlockingDeque<>();
+
+    public BlackHolePyTorchProcess() {
+        startTime = ZonedDateTime.now();
+    }
+
+    @Override
+    public void loadModel(String modelId, String index, PyTorchStateStreamer stateStreamer, ActionListener<Boolean> listener) {
+        listener.onResponse(true);
+    }
+
+    @Override
+    public Iterator<PyTorchResult> readResults() {
+        return new BlackHoleResultIterator<>(results, () -> running);
+    }
+
+    @Override
+    public void writeInferenceRequest(BytesReference jsonRequest) throws IOException {}
+
+    @Override
+    public boolean isReady() {
+        return true;
+    }
+
+    @Override
+    public void writeRecord(String[] record) throws IOException {}
+
+    @Override
+    public void persistState() throws IOException {}
+
+    @Override
+    public void persistState(long snapshotTimestampMs, String snapshotId, String snapshotDescription) throws IOException {}
+
+    @Override
+    public void flushStream() throws IOException {}
+
+    @Override
+    public void kill(boolean awaitCompletion) throws IOException {
+        running = false;
+    }
+
+    @Override
+    public ZonedDateTime getProcessStartTime() {
+        return startTime;
+    }
+
+    @Override
+    public boolean isProcessAlive() {
+        return running;
+    }
+
+    @Override
+    public boolean isProcessAliveAfterWaiting() {
+        try {
+            Thread.sleep(45);
+        } catch (InterruptedException e) {
+            Thread.currentThread().interrupt();
+        }
+        return running;
+    }
+
+    @Override
+    public String readError() {
+        return "";
+    }
+
+    @Override
+    public void close() throws IOException {
+        running = false;
+    }
+}

+ 4 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/NativePyTorchProcess.java

@@ -22,7 +22,7 @@ import java.util.Iterator;
 import java.util.List;
 import java.util.function.Consumer;
 
-public class NativePyTorchProcess extends AbstractNativeProcess {
+public class NativePyTorchProcess extends AbstractNativeProcess implements PyTorchProcess {
 
     private static final String NAME = "pytorch_inference";
 
@@ -55,14 +55,17 @@ public class NativePyTorchProcess extends AbstractNativeProcess {
         throw new UnsupportedOperationException();
     }
 
+    @Override
     public void loadModel(String modelId, String index, PyTorchStateStreamer stateStreamer, ActionListener<Boolean> listener) {
         stateStreamer.writeStateToStream(modelId, index, processRestoreStream(), listener);
     }
 
+    @Override
     public Iterator<PyTorchResult> readResults() {
         return resultsParser.parseResults(processOutStream());
     }
 
+    @Override
     public void writeInferenceRequest(BytesReference jsonRequest) throws IOException {
         processInStream().write(jsonRequest.array(), jsonRequest.arrayOffset(), jsonRequest.length());
         processInStream().write('\n');

+ 43 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcess.java

@@ -0,0 +1,43 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.pytorch.process;
+
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.common.bytes.BytesReference;
+import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchResult;
+import org.elasticsearch.xpack.ml.process.NativeProcess;
+
+import java.io.IOException;
+import java.util.Iterator;
+
+/**
+ * Interface representing the native C++ pytorch process
+ */
+public interface PyTorchProcess extends NativeProcess {
+
+    /**
+     * Load the model into the process
+     * @param modelId the model id
+     * @param index the index where the model is stored
+     * @param stateStreamer the pytorch state streamer
+     * @param listener a listener that gets notified when the loading has completed
+     */
+    void loadModel(String modelId, String index, PyTorchStateStreamer stateStreamer, ActionListener<Boolean> listener);
+
+    /**
+     * @return stream of pytorch results
+     */
+    Iterator<PyTorchResult> readResults();
+
+    /**
+     * Writes an inference request to the process
+     * @param jsonRequest the inference request as json
+     * @throws IOException If writing the request fails
+     */
+    void writeInferenceRequest(BytesReference jsonRequest) throws IOException;
+}

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchProcessFactory.java

@@ -14,5 +14,5 @@ import java.util.function.Consumer;
 
 public interface PyTorchProcessFactory {
 
-    NativePyTorchProcess createProcess(TrainedModelDeploymentTask task, ExecutorService executorService, Consumer<String> onProcessCrash);
+    PyTorchProcess createProcess(TrainedModelDeploymentTask task, ExecutorService executorService, Consumer<String> onProcessCrash);
 }

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java

@@ -87,7 +87,7 @@ public class PyTorchResultProcessor {
         pendingResults.remove(requestId);
     }
 
-    public void process(NativePyTorchProcess process) {
+    public void process(PyTorchProcess process) {
         try {
             Iterator<PyTorchResult> iterator = process.readResults();
             while (iterator.hasNext()) {

+ 2 - 28
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/BlackHoleAutodetectProcess.java

@@ -20,6 +20,7 @@ import org.elasticsearch.xpack.ml.job.process.autodetect.params.DataLoadParams;
 import org.elasticsearch.xpack.ml.job.process.autodetect.params.FlushJobParams;
 import org.elasticsearch.xpack.ml.job.process.autodetect.params.ForecastParams;
 import org.elasticsearch.xpack.ml.job.results.AutodetectResult;
+import org.elasticsearch.xpack.ml.process.BlackHoleResultIterator;
 
 import java.time.ZonedDateTime;
 import java.util.Arrays;
@@ -29,7 +30,6 @@ import java.util.List;
 import java.util.Objects;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.LinkedBlockingDeque;
-import java.util.concurrent.TimeUnit;
 import java.util.function.Consumer;
 
 /**
@@ -161,33 +161,7 @@ public class BlackHoleAutodetectProcess implements AutodetectProcess {
 
     @Override
     public Iterator<AutodetectResult> readAutodetectResults() {
-        // Create a custom iterator here, because LinkedBlockingDeque iterator and stream are not blocking when empty:
-        return new Iterator<AutodetectResult>() {
-
-            AutodetectResult result;
-
-            @Override
-            public boolean hasNext() {
-                try {
-                    while (open) {
-                        result = results.poll(100, TimeUnit.MILLISECONDS);
-                        if (result != null) {
-                            return true;
-                        }
-                    }
-                    result = results.poll();
-                    return result != null;
-                } catch (InterruptedException e) {
-                    Thread.currentThread().interrupt();
-                    return false;
-                }
-            }
-
-            @Override
-            public AutodetectResult next() {
-                return result;
-            }
-        };
+        return new BlackHoleResultIterator<>(results, () -> open);
     }
 
     @Override

+ 52 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/BlackHoleResultIterator.java

@@ -0,0 +1,52 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.process;
+
+import java.util.Iterator;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Supplier;
+
+/**
+ * A custom iterator that blocks even when there are no results as
+ * a {@link java.util.concurrent.LinkedBlockingDeque} iterator and stream aren't.
+ * @param <T> the result type
+ */
+public class BlackHoleResultIterator<T> implements Iterator<T> {
+
+    private final BlockingQueue<T> results;
+    private final Supplier<Boolean> isRunning;
+    private volatile T latestResult;
+
+    public BlackHoleResultIterator(BlockingQueue<T> results, Supplier<Boolean> isRunning) {
+        this.results = results;
+        this.isRunning = isRunning;
+    }
+
+    @Override
+    public boolean hasNext() {
+        try {
+            while (isRunning.get()) {
+                latestResult = results.poll(100, TimeUnit.MILLISECONDS);
+                if (latestResult != null) {
+                    return true;
+                }
+            }
+            latestResult = results.poll();
+            return latestResult != null;
+        } catch (InterruptedException e) {
+            Thread.currentThread().interrupt();
+            return false;
+        }
+    }
+
+    @Override
+    public T next() {
+        return latestResult;
+    }
+}

+ 4 - 0
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/support/BaseMlIntegTestCase.java

@@ -517,6 +517,10 @@ public abstract class BaseMlIntegTestCase extends ESIntegTestCase {
             .actionGet();
     }
 
+    protected void ensureStableCluster() {
+        ensureStableCluster(internalCluster().getNodeNames().length, TimeValue.timeValueSeconds(60));
+    }
+
     public static class MockPainlessScriptEngine extends MockScriptEngine {
 
         public static final String NAME = "painless";