Ver código fonte

[ML] Robustness improvements on inference processor running allocated… (#79243)

Using a trained model deployment with an inference processor currently results
in timeouts and lots of documents that don't make it through to the model.

This commit improves that by:

  - reusing the process worker executor service that queues up operations
  for the process thus providing a buffer and avoiding blocking the ML thread pool
  - effectively removing inference timeout when run via an inference processor.
  Requests may not be timing out but if they take too long the queue will fill up
  thus returning 429 errors.
  - integrates pytorch models with the `xpack.ml.max_open_jobs` setting.
  - sets allocation state to `failed` on process crash.
Dimitris Athanasiou 4 anos atrás
pai
commit
d3835381cc
18 arquivos alterados com 292 adições e 152 exclusões
  1. 2 2
      x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TooManyJobsIT.java
  2. 2 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java
  3. 3 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java
  4. 21 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterService.java
  5. 13 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java
  6. 18 5
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java
  7. 4 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java
  8. 1 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/JobNodeSelector.java
  9. 1 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/NodeLoadDetector.java
  10. 134 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/ProcessWorkerExecutorService.java
  11. 8 6
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManager.java
  12. 5 110
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectWorkerExecutorService.java
  13. 2 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/JobModelSnapshotUpgrader.java
  14. 59 2
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterServiceTests.java
  15. 0 9
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManagerTests.java
  16. 6 4
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/JobNodeSelectorTests.java
  17. 1 1
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java
  18. 12 7
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/ProcessWorkerExecutorServiceTests.java

+ 2 - 2
x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TooManyJobsIT.java

@@ -176,8 +176,8 @@ public class TooManyJobsIT extends BaseMlIntegTestCase {
                             (expectedJobsAlreadyOpenOnNode * memoryFootprintPerJob) + "], estimated memory required for this job [" +
                             memoryFootprintPerJob + "].]"));
                 } else {
-                    assertTrue(detailedMessage, detailedMessage.endsWith("node is full. Number of opened jobs [" +
-                        maxNumberOfJobsPerNode + "], xpack.ml.max_open_jobs [" + maxNumberOfJobsPerNode + "].]"));
+                    assertTrue(detailedMessage, detailedMessage.endsWith("node is full. Number of opened jobs and allocated native " +
+                        "inference processes [" + maxNumberOfJobsPerNode + "], xpack.ml.max_open_jobs [" + maxNumberOfJobsPerNode + "].]"));
                 }
                 logger.info("good news everybody --> reached maximum number of allowed opened jobs, after trying to open the {}th job", i);
 

+ 2 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

@@ -531,6 +531,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
     // as the current node could be running in a cluster where some nodes are still using
     // that setting.  From 8.0.0 onwards we have the flexibility to increase it...
     private static final int MAX_MAX_OPEN_JOBS_PER_NODE = 512;
+    public static final int DEFAULT_MAX_OPEN_JOBS_PER_NODE = MAX_MAX_OPEN_JOBS_PER_NODE;
     // This setting is cluster-wide and can be set dynamically. However, prior to version 7.1 it was
     // a non-dynamic per-node setting. n a mixed version cluster containing 6.7 or 7.0 nodes those
     // older nodes will not react to the dynamic changes. Therefore, in such mixed version clusters
@@ -538,7 +539,7 @@ public class MachineLearning extends Plugin implements SystemIndexPlugin,
     public static final Setting<Integer> MAX_OPEN_JOBS_PER_NODE =
             Setting.intSetting(
                 "xpack.ml.max_open_jobs",
-                MAX_MAX_OPEN_JOBS_PER_NODE,
+                DEFAULT_MAX_OPEN_JOBS_PER_NODE,
                 1,
                 MAX_MAX_OPEN_JOBS_PER_NODE,
                 Property.Dynamic,

+ 3 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java

@@ -14,6 +14,7 @@ import org.elasticsearch.action.support.HandledTransportAction;
 import org.elasticsearch.client.Client;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.inject.Inject;
+import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.license.LicenseUtils;
 import org.elasticsearch.license.XPackLicenseState;
 import org.elasticsearch.rest.RestStatus;
@@ -162,7 +163,8 @@ public class TransportInternalInferModelAction extends HandledTransportAction<Re
         executeAsyncWithOrigin(client,
             ML_ORIGIN,
             InferTrainedModelDeploymentAction.INSTANCE,
-            new InferTrainedModelDeploymentAction.Request(modelId, inferenceConfigUpdate, Collections.singletonList(doc), null),
+            new InferTrainedModelDeploymentAction.Request(modelId, inferenceConfigUpdate, Collections.singletonList(doc),
+                TimeValue.MAX_VALUE),
             ActionListener.wrap(
                 r -> listener.onResponse(r.getResults()),
                 e -> {

+ 21 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterService.java

@@ -57,12 +57,14 @@ public class TrainedModelAllocationClusterService implements ClusterStateListene
     private final NodeLoadDetector nodeLoadDetector;
     private volatile int maxMemoryPercentage;
     private volatile boolean useAuto;
+    private volatile int maxOpenJobs;
 
     public TrainedModelAllocationClusterService(Settings settings, ClusterService clusterService, NodeLoadDetector nodeLoadDetector) {
         this.clusterService = clusterService;
         this.nodeLoadDetector = nodeLoadDetector;
         this.maxMemoryPercentage = MachineLearning.MAX_MACHINE_MEMORY_PERCENT.get(settings);
         this.useAuto = MachineLearning.USE_AUTO_MACHINE_MEMORY_PERCENT.get(settings);
+        this.maxOpenJobs = MachineLearning.MAX_OPEN_JOBS_PER_NODE.get(settings);
         // Only nodes that can possibly be master nodes really need this service running
         if (DiscoveryNode.isMasterNode(settings)) {
             clusterService.addListener(this);
@@ -70,6 +72,7 @@ public class TrainedModelAllocationClusterService implements ClusterStateListene
                 .addSettingsUpdateConsumer(MachineLearning.MAX_MACHINE_MEMORY_PERCENT, this::setMaxMemoryPercentage);
             clusterService.getClusterSettings()
                 .addSettingsUpdateConsumer(MachineLearning.USE_AUTO_MACHINE_MEMORY_PERCENT, this::setUseAuto);
+            clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.MAX_OPEN_JOBS_PER_NODE, this::setMaxOpenJobs);
         }
     }
 
@@ -81,6 +84,10 @@ public class TrainedModelAllocationClusterService implements ClusterStateListene
         this.useAuto = useAuto;
     }
 
+    private void setMaxOpenJobs(int maxOpenJobs) {
+        this.maxOpenJobs = maxOpenJobs;
+    }
+
     @Override
     public void clusterChanged(ClusterChangedEvent event) {
         if (event.state().blocks().hasGlobalBlock(GatewayService.STATE_NOT_RECOVERED_BLOCK)) {
@@ -437,7 +444,7 @@ public class TrainedModelAllocationClusterService implements ClusterStateListene
     }
 
     Optional<String> nodeHasCapacity(ClusterState state, StartTrainedModelDeploymentAction.TaskParams params, DiscoveryNode node) {
-        NodeLoad load = nodeLoadDetector.detectNodeLoad(state, true, node, Integer.MAX_VALUE, maxMemoryPercentage, useAuto);
+        NodeLoad load = nodeLoadDetector.detectNodeLoad(state, true, node, maxOpenJobs, maxMemoryPercentage, useAuto);
         return handleNodeLoad(load, node.getId(), params);
     }
 
@@ -455,7 +462,7 @@ public class TrainedModelAllocationClusterService implements ClusterStateListene
             builder.build(),
             true,
             node,
-            Integer.MAX_VALUE,
+            maxOpenJobs,
             maxMemoryPercentage,
             useAuto
         );
@@ -467,6 +474,18 @@ public class TrainedModelAllocationClusterService implements ClusterStateListene
             logger.warn("[{}] failed to calculate current node load with error [{}]", params.getModelId(), nodeId);
             return Optional.of(load.getError());
         }
+        if (load.remainingJobs() == 0) {
+            return Optional.of(
+                ParameterizedMessage.format(
+                    "This node is full. Number of opened jobs and allocated native inference processes [{}], {} [{}].",
+                    new Object[] {
+                        load.getNumAssignedJobs(),
+                        MachineLearning.MAX_OPEN_JOBS_PER_NODE.getKey(),
+                        maxOpenJobs
+                    }
+                )
+            );
+        }
         if (load.getFreeMemory() < params.estimateMemoryUsageBytes()) {
             return Optional.of(
                 ParameterizedMessage.format(

+ 13 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationNodeService.java

@@ -434,4 +434,17 @@ public class TrainedModelAllocationNodeService implements ClusterStateListener {
             ActionListener.wrap(r -> stopTask.run(), e -> stopTask.run())
         );
     }
+
+    public void failAllocation(TrainedModelDeploymentTask task, String reason) {
+        updateStoredState(
+            task.getModelId(),
+            new RoutingStateAndReason(RoutingState.FAILED, reason),
+            ActionListener.wrap(r -> logger.debug(
+                    new ParameterizedMessage("[{}] Successfully updating allocation state to [{}] with reason [{}]",
+                        task.getModelId(), RoutingState.FAILED, reason))
+            , e -> logger.error(new ParameterizedMessage("[{}] Error while updating allocation state to [{}] with reason [{}]",
+                        task.getModelId(), RoutingState.FAILED, reason), e)
+            )
+        );
+    }
 }

+ 18 - 5
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

@@ -19,16 +19,16 @@ import org.elasticsearch.action.search.SearchRequest;
 import org.elasticsearch.client.Client;
 import org.elasticsearch.common.util.concurrent.AbstractRunnable;
 import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
-import org.elasticsearch.xcontent.NamedXContentRegistry;
-import org.elasticsearch.xcontent.XContentFactory;
-import org.elasticsearch.xcontent.XContentParser;
-import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.index.query.IdsQueryBuilder;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.search.SearchHit;
 import org.elasticsearch.threadpool.Scheduler;
 import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xcontent.NamedXContentRegistry;
+import org.elasticsearch.xcontent.XContentFactory;
+import org.elasticsearch.xcontent.XContentParser;
+import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
@@ -48,6 +48,7 @@ import org.elasticsearch.xpack.ml.inference.pytorch.process.NativePyTorchProcess
 import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchProcessFactory;
 import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchResultProcessor;
 import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchStateStreamer;
+import org.elasticsearch.xpack.ml.job.process.ProcessWorkerExecutorService;
 
 import java.io.IOException;
 import java.io.InputStream;
@@ -231,7 +232,12 @@ public class DeploymentManager {
         }
 
         final long requestId = requestIdCounter.getAndIncrement();
-        executorServiceForProcess.execute(new InferenceAction(requestId, timeout, processContext, config, doc, threadPool, listener));
+        InferenceAction inferenceAction = new InferenceAction(requestId, timeout, processContext, config, doc, threadPool, listener);
+        try {
+            processContext.executorService.execute(inferenceAction);
+        } catch (Exception e) {
+            inferenceAction.onFailure(e);
+        }
     }
 
     static class InferenceAction extends AbstractRunnable {
@@ -380,11 +386,13 @@ public class DeploymentManager {
         private final SetOnce<TrainedModelInput> modelInput = new SetOnce<>();
         private final PyTorchResultProcessor resultProcessor;
         private final PyTorchStateStreamer stateStreamer;
+        private final ProcessWorkerExecutorService executorService;
 
         ProcessContext(TrainedModelDeploymentTask task, ExecutorService executorService) {
             this.task = Objects.requireNonNull(task);
             resultProcessor = new PyTorchResultProcessor(task.getModelId());
             this.stateStreamer = new PyTorchStateStreamer(client, executorService, xContentRegistry);
+            this.executorService = new ProcessWorkerExecutorService(threadPool.getThreadContext(), "pytorch_inference", 1024);
         }
 
         PyTorchResultProcessor getResultProcessor() {
@@ -393,10 +401,12 @@ public class DeploymentManager {
 
         synchronized void startProcess() {
             process.set(pyTorchProcessFactory.createProcess(task, executorServiceForProcess, onProcessCrash()));
+            executorServiceForProcess.submit(executorService::start);
         }
 
         synchronized void stopProcess() {
             resultProcessor.stop();
+            executorService.shutdown();
             if (process.get() == null) {
                 return;
             }
@@ -412,7 +422,10 @@ public class DeploymentManager {
         private Consumer<String> onProcessCrash() {
             return reason -> {
                 logger.error("[{}] process crashed due to reason [{}]", task.getModelId(), reason);
+                resultProcessor.stop();
+                executorService.shutdown();
                 processContextByAllocation.remove(task.getId());
+                task.setFailed("process crashed due to reason [" + reason + "]");
             };
         }
 

+ 4 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java

@@ -120,4 +120,8 @@ public class TrainedModelDeploymentTask extends CancellableTask implements Start
     public Optional<ModelStats> modelStats() {
         return trainedModelAllocationNodeService.modelStats(this);
     }
+
+    public void setFailed(String reason) {
+        trainedModelAllocationNodeService.failAllocation(this, reason);
+    }
 }

+ 1 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/JobNodeSelector.java

@@ -201,7 +201,7 @@ public class JobNodeSelector {
             if (currentLoad.remainingJobs() == 0) {
                 reason = createReason(jobId,
                     nodeNameAndMlAttributes(node),
-                    "This node is full. Number of opened jobs [{}], {} [{}].",
+                    "This node is full. Number of opened jobs and allocated native inference processes [{}], {} [{}].",
                     currentLoad.getNumAssignedJobs(),
                     MAX_OPEN_JOBS_PER_NODE.getKey(),
                     maxNumberOfOpenJobs);

+ 1 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/NodeLoadDetector.java

@@ -129,6 +129,7 @@ public class NodeLoadDetector {
                     .map(RoutingStateAndReason::getState)
                     .orElse(RoutingState.STOPPED)
                     .consumesMemory()) {
+                    nodeLoad.incNumAssignedJobs();
                     nodeLoad.incAssignedJobMemory(allocation.getTaskParams().estimateMemoryUsageBytes());
                 }
             }

+ 134 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/ProcessWorkerExecutorService.java

@@ -0,0 +1,134 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+package org.elasticsearch.xpack.ml.job.process;
+
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.apache.logging.log4j.message.ParameterizedMessage;
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.common.util.concurrent.AbstractRunnable;
+import org.elasticsearch.common.util.concurrent.EsExecutors;
+import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
+import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.core.SuppressForbidden;
+import org.elasticsearch.rest.RestStatus;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Objects;
+import java.util.concurrent.AbstractExecutorService;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+/*
+ * Native ML processes can only handle a single operation at a time. In order to guarantee that, all
+ * operations are initially added to a queue and a worker thread from an ML threadpool will process each
+ * operation at a time.
+ */
+public class ProcessWorkerExecutorService extends AbstractExecutorService {
+
+    private static final Logger logger = LogManager.getLogger(ProcessWorkerExecutorService.class);
+
+    private final ThreadContext contextHolder;
+    private final String processName;
+    private final CountDownLatch awaitTermination = new CountDownLatch(1);
+    private final BlockingQueue<Runnable> queue;
+
+    private volatile boolean running = true;
+
+    /**
+     * @param contextHolder the thread context holder
+     * @param processName the name of the process to be used in logging
+     * @param queueSize the size of the queue holding operations. If an operation is added
+     *                  for execution when the queue is full a 429 error is thrown.
+     */
+    @SuppressForbidden(reason = "properly rethrowing errors, see EsExecutors.rethrowErrors")
+    public ProcessWorkerExecutorService(ThreadContext contextHolder, String processName, int queueSize) {
+        this.contextHolder = Objects.requireNonNull(contextHolder);
+        this.processName = Objects.requireNonNull(processName);
+        this.queue = new LinkedBlockingQueue<>(queueSize);
+    }
+
+    @Override
+    public void shutdown() {
+        running = false;
+    }
+
+    @Override
+    public List<Runnable> shutdownNow() {
+        throw new UnsupportedOperationException("not supported");
+    }
+
+    @Override
+    public boolean isShutdown() {
+        return running == false;
+    }
+
+    @Override
+    public boolean isTerminated() {
+        return awaitTermination.getCount() == 0;
+    }
+
+    @Override
+    public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
+        return awaitTermination.await(timeout, unit);
+    }
+
+    @Override
+    public synchronized void execute(Runnable command) {
+        if (isShutdown()) {
+            EsRejectedExecutionException rejected = new EsRejectedExecutionException(processName + " worker service has shutdown", true);
+            if (command instanceof AbstractRunnable) {
+                ((AbstractRunnable) command).onRejection(rejected);
+            } else {
+                throw rejected;
+            }
+        }
+
+        boolean added = queue.offer(contextHolder.preserveContext(command));
+        if (added == false) {
+            throw new ElasticsearchStatusException("Unable to execute on [{}] as queue is full", RestStatus.TOO_MANY_REQUESTS, processName);
+        }
+    }
+
+    public void start() {
+        try {
+            while (running) {
+                Runnable runnable = queue.poll(500, TimeUnit.MILLISECONDS);
+                if (runnable != null) {
+                    try {
+                        runnable.run();
+                    } catch (Exception e) {
+                        logger.error(() -> new ParameterizedMessage("error handling process [{}] operation", processName), e);
+                    }
+                    EsExecutors.rethrowErrors(contextHolder.unwrap(runnable));
+                }
+            }
+
+            synchronized (this) {
+                // if shutdown with tasks pending notify the handlers
+                if (queue.isEmpty() == false) {
+                    List<Runnable> notExecuted = new ArrayList<>();
+                    queue.drainTo(notExecuted);
+
+                    String msg = "unable to process as " + processName + " worker service has shutdown";
+                    for (Runnable runnable : notExecuted) {
+                        if (runnable instanceof AbstractRunnable) {
+                            ((AbstractRunnable) runnable).onRejection( new EsRejectedExecutionException(msg, true));
+                        }
+                    }
+                }
+            }
+        } catch (InterruptedException e) {
+            Thread.currentThread().interrupt();
+        } finally {
+            awaitTermination.countDown();
+        }
+    }
+}

+ 8 - 6
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManager.java

@@ -7,6 +7,7 @@
 package org.elasticsearch.xpack.ml.job.process.autodetect;
 
 import joptsimple.internal.Strings;
+
 import org.apache.logging.log4j.LogManager;
 import org.apache.logging.log4j.Logger;
 import org.apache.logging.log4j.message.ParameterizedMessage;
@@ -19,22 +20,22 @@ import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.ClusterStateListener;
 import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
 import org.elasticsearch.cluster.service.ClusterService;
-import org.elasticsearch.core.CheckedConsumer;
-import org.elasticsearch.core.Tuple;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.unit.ByteSizeValue;
-import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.common.util.concurrent.AbstractRunnable;
 import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
-import org.elasticsearch.xcontent.NamedXContentRegistry;
 import org.elasticsearch.common.xcontent.XContentElasticsearchExtension;
-import org.elasticsearch.xcontent.XContentType;
+import org.elasticsearch.core.CheckedConsumer;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.core.Tuple;
 import org.elasticsearch.core.internal.io.IOUtils;
 import org.elasticsearch.index.analysis.AnalysisRegistry;
 import org.elasticsearch.indices.InvalidAliasNameException;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xcontent.NamedXContentRegistry;
+import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xpack.core.action.util.PageParams;
 import org.elasticsearch.xpack.core.action.util.QueryPage;
 import org.elasticsearch.xpack.core.ml.MlMetadata;
@@ -65,6 +66,7 @@ import org.elasticsearch.xpack.ml.job.persistence.JobResultsProvider;
 import org.elasticsearch.xpack.ml.job.persistence.ScheduledEventsQueryBuilder;
 import org.elasticsearch.xpack.ml.job.persistence.StateStreamer;
 import org.elasticsearch.xpack.ml.job.process.DataCountsReporter;
+import org.elasticsearch.xpack.ml.job.process.ProcessWorkerExecutorService;
 import org.elasticsearch.xpack.ml.job.process.autodetect.output.AutodetectResultProcessor;
 import org.elasticsearch.xpack.ml.job.process.autodetect.params.AutodetectParams;
 import org.elasticsearch.xpack.ml.job.process.autodetect.params.DataLoadParams;
@@ -960,7 +962,7 @@ public class AutodetectProcessManager implements ClusterStateListener {
     }
 
     ExecutorService createAutodetectExecutorService(ExecutorService executorService) {
-        AutodetectWorkerExecutorService autodetectWorkerExecutor = new AutodetectWorkerExecutorService(threadPool.getThreadContext());
+        ProcessWorkerExecutorService autodetectWorkerExecutor = new AutodetectWorkerExecutorService(threadPool.getThreadContext());
         executorService.submit(autodetectWorkerExecutor::start);
         return autodetectWorkerExecutor;
     }

+ 5 - 110
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectWorkerExecutorService.java

@@ -4,120 +4,15 @@
  * 2.0; you may not use this file except in compliance with the Elastic License
  * 2.0.
  */
+
 package org.elasticsearch.xpack.ml.job.process.autodetect;
 
-import org.apache.logging.log4j.LogManager;
-import org.apache.logging.log4j.Logger;
-import org.elasticsearch.ElasticsearchStatusException;
-import org.elasticsearch.core.SuppressForbidden;
-import org.elasticsearch.common.util.concurrent.AbstractRunnable;
-import org.elasticsearch.common.util.concurrent.EsExecutors;
-import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
-import org.elasticsearch.rest.RestStatus;
-
-import java.util.ArrayList;
-import java.util.List;
-import java.util.concurrent.AbstractExecutorService;
-import java.util.concurrent.BlockingQueue;
-import java.util.concurrent.CountDownLatch;
-import java.util.concurrent.LinkedBlockingQueue;
-import java.util.concurrent.TimeUnit;
-
-/*
- * The autodetect native process can only handle a single operation at a time. In order to guarantee that, all
- * operations are initially added to a queue and a worker thread from ml autodetect threadpool will process each
- * operation at a time.
- */
-class AutodetectWorkerExecutorService extends AbstractExecutorService {
-
-    private static final Logger logger = LogManager.getLogger(AutodetectWorkerExecutorService.class);
-
-    private final ThreadContext contextHolder;
-    private final CountDownLatch awaitTermination = new CountDownLatch(1);
-    private final BlockingQueue<Runnable> queue = new LinkedBlockingQueue<>(100);
-
-    private volatile boolean running = true;
-
-    @SuppressForbidden(reason = "properly rethrowing errors, see EsExecutors.rethrowErrors")
-    AutodetectWorkerExecutorService(ThreadContext contextHolder) {
-        this.contextHolder = contextHolder;
-    }
-
-    @Override
-    public void shutdown() {
-        running = false;
-    }
-
-    @Override
-    public List<Runnable> shutdownNow() {
-        throw new UnsupportedOperationException("not supported");
-    }
-
-    @Override
-    public boolean isShutdown() {
-        return running == false;
-    }
-
-    @Override
-    public boolean isTerminated() {
-        return awaitTermination.getCount() == 0;
-    }
-
-    @Override
-    public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
-        return awaitTermination.await(timeout, unit);
-    }
-
-    @Override
-    public synchronized void execute(Runnable command) {
-        if (isShutdown()) {
-            EsRejectedExecutionException rejected = new EsRejectedExecutionException("autodetect worker service has shutdown", true);
-            if (command instanceof AbstractRunnable) {
-                ((AbstractRunnable) command).onRejection(rejected);
-            } else {
-                throw rejected;
-            }
-        }
-
-        boolean added = queue.offer(contextHolder.preserveContext(command));
-        if (added == false) {
-            throw new ElasticsearchStatusException("Unable to submit operation", RestStatus.TOO_MANY_REQUESTS);
-        }
-    }
-
-    void start() {
-        try {
-            while (running) {
-                Runnable runnable = queue.poll(500, TimeUnit.MILLISECONDS);
-                if (runnable != null) {
-                    try {
-                        runnable.run();
-                    } catch (Exception e) {
-                        logger.error("error handling job operation", e);
-                    }
-                    EsExecutors.rethrowErrors(contextHolder.unwrap(runnable));
-                }
-            }
+import org.elasticsearch.xpack.ml.job.process.ProcessWorkerExecutorService;
 
-            synchronized (this) {
-                // if shutdown with tasks pending notify the handlers
-                if (queue.isEmpty() == false) {
-                    List<Runnable> notExecuted = new ArrayList<>();
-                    queue.drainTo(notExecuted);
+public class AutodetectWorkerExecutorService extends ProcessWorkerExecutorService {
 
-                    for (Runnable runnable : notExecuted) {
-                        if (runnable instanceof AbstractRunnable) {
-                            ((AbstractRunnable) runnable).onRejection(
-                                new EsRejectedExecutionException("unable to process as autodetect worker service has shutdown", true));
-                        }
-                    }
-                }
-            }
-        } catch (InterruptedException e) {
-            Thread.currentThread().interrupt();
-        } finally {
-            awaitTermination.countDown();
-        }
+    public AutodetectWorkerExecutorService(ThreadContext contextHolder) {
+        super(contextHolder, "autodetect", 100);
     }
 }

+ 2 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/JobModelSnapshotUpgrader.java

@@ -31,6 +31,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.MachineLearning;
 import org.elasticsearch.xpack.ml.job.persistence.JobResultsPersister;
 import org.elasticsearch.xpack.ml.job.persistence.StateStreamer;
+import org.elasticsearch.xpack.ml.job.process.ProcessWorkerExecutorService;
 import org.elasticsearch.xpack.ml.job.process.autodetect.output.JobSnapshotUpgraderResultProcessor;
 import org.elasticsearch.xpack.ml.job.process.autodetect.params.AutodetectParams;
 import org.elasticsearch.xpack.ml.job.process.autodetect.params.FlushJobParams;
@@ -117,7 +118,7 @@ public final class JobModelSnapshotUpgrader {
             snapshotId,
             jobResultsPersister,
             process);
-        AutodetectWorkerExecutorService autodetectWorkerExecutor;
+        ProcessWorkerExecutorService autodetectWorkerExecutor;
         try (ThreadContext.StoredContext ignore = threadPool.getThreadContext().stashContext()) {
             autodetectWorkerExecutor = new AutodetectWorkerExecutorService(threadPool.getThreadContext());
             autodetectExecutorService.submit(autodetectWorkerExecutor::start);

+ 59 - 2
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/allocation/TrainedModelAllocationClusterServiceTests.java

@@ -26,6 +26,7 @@ import org.elasticsearch.common.settings.ClusterSettings;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.common.util.set.Sets;
+import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.xpack.core.ml.MlMetadata;
 import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
@@ -36,6 +37,7 @@ import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReaso
 import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
 import org.elasticsearch.xpack.ml.MachineLearning;
 import org.elasticsearch.xpack.ml.job.NodeLoadDetector;
+import org.elasticsearch.xpack.ml.job.task.OpenJobPersistentTasksExecutorTests;
 import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
 import org.junit.Before;
 
@@ -45,6 +47,7 @@ import java.util.function.Function;
 
 import static org.hamcrest.Matchers.allOf;
 import static org.hamcrest.Matchers.anEmptyMap;
+import static org.hamcrest.Matchers.contains;
 import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.hasKey;
@@ -65,7 +68,11 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
         clusterService = mock(ClusterService.class);
         ClusterSettings clusterSettings = new ClusterSettings(
             Settings.EMPTY,
-            Sets.newHashSet(MachineLearning.MAX_MACHINE_MEMORY_PERCENT, MachineLearning.USE_AUTO_MACHINE_MEMORY_PERCENT)
+            Sets.newHashSet(
+                MachineLearning.MAX_MACHINE_MEMORY_PERCENT,
+                MachineLearning.USE_AUTO_MACHINE_MEMORY_PERCENT,
+                MachineLearning.MAX_OPEN_JOBS_PER_NODE
+            )
         );
         when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
         MlMemoryTracker memoryTracker = mock(MlMemoryTracker.class);
@@ -447,6 +454,57 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
         );
     }
 
+    public void testAddRemoveAllocationNodes_GivenNodeThatReachedMaxOpenJobs() {
+
+        PersistentTasksCustomMetadata.Builder tasksBuilder = PersistentTasksCustomMetadata.builder();
+        for (int i = 0; i < MachineLearning.DEFAULT_MAX_OPEN_JOBS_PER_NODE; i++) {
+            OpenJobPersistentTasksExecutorTests.addJobTask("job_id_" + i, "ml-node-full-load", null, tasksBuilder);
+        }
+        PersistentTasksCustomMetadata persistentTasks = tasksBuilder.build();
+
+        ClusterState currentState = ClusterState.builder(new ClusterName("testAddRemoveAllocationNodes"))
+            .nodes(
+                DiscoveryNodes.builder()
+                    .add(buildNode("ml-node-full-load", true, ByteSizeValue.ofGb(4).getBytes()))
+                    .add(buildNode("ml-node-no-load", true, ByteSizeValue.ofGb(4).getBytes()))
+                    .build()
+            )
+            .metadata(
+                Metadata.builder()
+                    .putCustom(
+                        TrainedModelAllocationMetadata.NAME,
+                        TrainedModelAllocationMetadata.Builder.empty()
+                            .addNewAllocation(
+                                "model-1",
+                                TrainedModelAllocation.Builder.empty(newParams("model-1", 10_000))
+                                    .addNewRoutingEntry("ml-node-no-load")
+                                    .updateExistingRoutingEntry("ml-node-no-load", started())
+                            )
+                            .build()
+                    )
+                    .putCustom(
+                        PersistentTasksCustomMetadata.TYPE,
+                        persistentTasks
+                    )
+            )
+            .build();
+        TrainedModelAllocationClusterService trainedModelAllocationClusterService = createClusterService();
+
+        ClusterState modified = trainedModelAllocationClusterService.addRemoveAllocationNodes(currentState);
+        TrainedModelAllocationMetadata trainedModelAllocationMetadata = TrainedModelAllocationMetadata.fromState(modified);
+        assertThat(trainedModelAllocationMetadata.modelAllocations().keySet(), contains("model-1"));
+
+        assertThat(trainedModelAllocationMetadata.getModelAllocation("model-1").getNodeRoutingTable().keySet(), hasSize(1));
+        assertThat(trainedModelAllocationMetadata.getModelAllocation("model-1").getNodeRoutingTable().keySet(),
+            contains("ml-node-no-load"));
+        assertThat(trainedModelAllocationMetadata.getModelAllocation("model-1").getNodeRoutingTable().get("ml-node-no-load").getState(),
+            equalTo(RoutingState.STARTED));
+
+        TrainedModelAllocation allocation = trainedModelAllocationMetadata.getModelAllocation("model-1");
+        assertThat(allocation.getReason().get(), equalTo("Not allocating on node [ml-node-full-load]." +
+            " Reason: This node is full. Number of opened jobs and allocated native inference processes [512], " +
+            "xpack.ml.max_open_jobs [512]."));
+    }
 
     public void testShouldAllocateModels() {
         String model1 = "model-1";
@@ -867,7 +925,6 @@ public class TrainedModelAllocationClusterServiceTests extends ESTestCase {
             MapBuilder.<String, String>newMapBuilder()
                 .put(MachineLearning.MACHINE_MEMORY_NODE_ATTR, String.valueOf(nativeMemory))
                 .put(MachineLearning.MAX_JVM_SIZE_NODE_ATTR, String.valueOf(10))
-                .put(MachineLearning.MAX_OPEN_JOBS_NODE_ATTR, String.valueOf(10))
                 .map(),
             isML ? DiscoveryNodeRole.roles() : Set.of(DiscoveryNodeRole.DATA_ROLE, DiscoveryNodeRole.MASTER_ROLE),
             version

+ 0 - 9
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManagerTests.java

@@ -8,7 +8,6 @@
 package org.elasticsearch.xpack.ml.inference.deployment;
 
 import org.elasticsearch.action.ActionListener;
-import org.elasticsearch.client.Client;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.threadpool.ScalingExecutorBuilder;
@@ -17,7 +16,6 @@ import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig;
-import org.elasticsearch.xpack.ml.inference.pytorch.process.NativePyTorchProcess;
 import org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchResultProcessor;
 import org.junit.After;
 import org.junit.Before;
@@ -32,7 +30,6 @@ import static org.mockito.Mockito.when;
 
 public class DeploymentManagerTests extends ESTestCase {
 
-    private DeploymentManager deploymentManager;
     private ThreadPool tp;
 
     @Before
@@ -42,12 +39,6 @@ public class DeploymentManagerTests extends ESTestCase {
             new ScalingExecutorBuilder(UTILITY_THREAD_POOL_NAME,1, 4, TimeValue.timeValueMinutes(10), "xpack.ml.utility_thread_pool"),
             new ScalingExecutorBuilder(JOB_COMMS_THREAD_POOL_NAME,1, 4, TimeValue.timeValueMinutes(10), "xpack.ml.job_comms_thread_pool")
         );
-        deploymentManager = new DeploymentManager(
-            mock(Client.class),
-            xContentRegistry(),
-            tp,
-            (task, executorService, onProcessCrash) -> mock(NativePyTorchProcess.class)
-        );
     }
 
     @After

+ 6 - 4
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/JobNodeSelectorTests.java

@@ -123,8 +123,9 @@ public class JobNodeSelectorTests extends ESTestCase {
             MAX_JOB_BYTES,
             false);
         assertNull(result.getExecutorNode());
-        assertThat(result.getExplanation(), containsString("node is full. Number of opened jobs ["
-            + maxRunningJobsPerNode + "], xpack.ml.max_open_jobs [" + maxRunningJobsPerNode + "]"));
+        assertThat(result.getExplanation(), containsString(
+            "node is full. Number of opened jobs and allocated native inference processes [" + maxRunningJobsPerNode
+                + "], xpack.ml.max_open_jobs [" + maxRunningJobsPerNode + "]"));
     }
 
     public void testSelectLeastLoadedMlNodeForDataFrameAnalyticsJob_maxCapacityCountLimiting() {
@@ -153,8 +154,9 @@ public class JobNodeSelectorTests extends ESTestCase {
             MAX_JOB_BYTES,
             false);
         assertNull(result.getExecutorNode());
-        assertThat(result.getExplanation(), containsString("node is full. Number of opened jobs ["
-            + maxRunningJobsPerNode + "], xpack.ml.max_open_jobs [" + maxRunningJobsPerNode + "]"));
+        assertThat(result.getExplanation(), containsString(
+            "node is full. Number of opened jobs and allocated native inference processes [" + maxRunningJobsPerNode
+                + "], xpack.ml.max_open_jobs [" + maxRunningJobsPerNode + "]"));
     }
 
     public void testSelectLeastLoadedMlNodeForAnomalyDetectorJob_maxCapacityMemoryLimiting() {

+ 1 - 1
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java

@@ -131,7 +131,7 @@ public class NodeLoadDetectorTests extends ESTestCase {
         load = nodeLoadDetector.detectNodeLoad(cs, true, nodes.get("_node_id4"), 5, 30, false);
         assertThat(load.getAssignedJobMemory(), equalTo(429916160L));
         assertThat(load.getNumAllocatingJobs(), equalTo(0L));
-        assertThat(load.getNumAssignedJobs(), equalTo(1L));
+        assertThat(load.getNumAssignedJobs(), equalTo(2L));
         assertThat(load.getMaxJobs(), equalTo(5));
         assertThat(load.getMaxMlMemory(), equalTo(0L));
     }

+ 12 - 7
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectWorkerExecutorServiceTests.java → x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/ProcessWorkerExecutorServiceTests.java

@@ -4,12 +4,10 @@
  * 2.0; you may not use this file except in compliance with the Elastic License
  * 2.0.
  */
-package org.elasticsearch.xpack.ml.job.process.autodetect;
+package org.elasticsearch.xpack.ml.job.process;
 
-import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.concurrent.AbstractRunnable;
 import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
-import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.threadpool.TestThreadPool;
 import org.elasticsearch.threadpool.ThreadPool;
@@ -22,7 +20,10 @@ import java.util.concurrent.atomic.AtomicInteger;
 
 import static org.hamcrest.Matchers.containsString;
 
-public class AutodetectWorkerExecutorServiceTests extends ESTestCase {
+public class ProcessWorkerExecutorServiceTests extends ESTestCase {
+
+    private static final String TEST_PROCESS = "test";
+    private static final int QUEUE_SIZE = 100;
 
     private ThreadPool threadPool = new TestThreadPool("AutodetectWorkerExecutorServiceTests");
 
@@ -32,7 +33,7 @@ public class AutodetectWorkerExecutorServiceTests extends ESTestCase {
     }
 
     public void testAutodetectWorkerExecutorService_SubmitAfterShutdown() {
-        AutodetectWorkerExecutorService executor = new AutodetectWorkerExecutorService(new ThreadContext(Settings.EMPTY));
+        ProcessWorkerExecutorService executor = createExecutorService();
 
         threadPool.generic().execute(() -> executor.start());
         executor.shutdown();
@@ -40,7 +41,7 @@ public class AutodetectWorkerExecutorServiceTests extends ESTestCase {
     }
 
     public void testAutodetectWorkerExecutorService_TasksNotExecutedCallHandlerOnShutdown() throws Exception {
-        AutodetectWorkerExecutorService executor = new AutodetectWorkerExecutorService(new ThreadContext(Settings.EMPTY));
+        ProcessWorkerExecutorService executor = createExecutorService();
 
         CountDownLatch latch = new CountDownLatch(1);
 
@@ -85,7 +86,7 @@ public class AutodetectWorkerExecutorServiceTests extends ESTestCase {
     }
 
     public void testAutodetectWorkerExecutorServiceDoesNotSwallowErrors() {
-        AutodetectWorkerExecutorService executor = new AutodetectWorkerExecutorService(threadPool.getThreadContext());
+        ProcessWorkerExecutorService executor = createExecutorService();
         if (randomBoolean()) {
             executor.submit(() -> {
                 throw new Error("future error");
@@ -98,4 +99,8 @@ public class AutodetectWorkerExecutorServiceTests extends ESTestCase {
         Error e = expectThrows(Error.class, () -> executor.start());
         assertThat(e.getMessage(), containsString("future error"));
     }
+
+    private ProcessWorkerExecutorService createExecutorService() {
+        return new ProcessWorkerExecutorService(threadPool.getThreadContext(), TEST_PROCESS, QUEUE_SIZE);
+    }
 }