Browse Source

[ML] add new trained model deployment cache clear API (#89074)

This adds a new `_ml/trained_models/<model_id>/deployment/cache/_clear` API. This will clear the inference cache on every node where the model is allocated.
Benjamin Trent 3 years ago
parent
commit
d588d456f0
19 changed files with 638 additions and 28 deletions
  1. 5 0
      docs/changelog/89074.yaml
  2. 57 0
      docs/reference/ml/trained-models/apis/clear-trained-model-deployment-cache.asciidoc
  3. 2 0
      docs/reference/ml/trained-models/apis/index.asciidoc
  4. 31 0
      rest-api-spec/src/main/resources/rest-api-spec/api/ml.clear_trained_model_deployment_cache.json
  5. 102 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/ClearDeploymentCacheAction.java
  6. 23 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/ClearDeploymentCacheActionRequestTests.java
  7. 23 0
      x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/ClearDeploymentCacheActionResponseTests.java
  8. 5 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java
  9. 97 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportClearDeploymentCacheAction.java
  10. 6 2
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java
  11. 18 17
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/AbstractControlMessagePyTorchAction.java
  12. 43 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/ClearCacheControlMessagePytorchAction.java
  13. 22 1
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java
  14. 49 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/ThreadSettingsControlMessagePytorchAction.java
  15. 4 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java
  16. 47 0
      x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestClearDeploymentCacheAction.java
  17. 17 8
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/ThreadSettingsControlMessagePytorchActionTests.java
  18. 1 0
      x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java
  19. 86 0
      x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml

+ 5 - 0
docs/changelog/89074.yaml

@@ -0,0 +1,5 @@
+pr: 89074
+summary: Add new trained model deployment cache clear API
+area: Machine Learning
+type: enhancement
+issues: []

+ 57 - 0
docs/reference/ml/trained-models/apis/clear-trained-model-deployment-cache.asciidoc

@@ -0,0 +1,57 @@
+[role="xpack"]
+[[clear-trained-model-deployment-cache]]
+= Clear trained model deployment cache API
+[subs="attributes"]
+++++
+<titleabbrev>Clear trained model deployment cache</titleabbrev>
+++++
+
+Clears a trained model deployment cache on all nodes where the trained model is assigned.
+
+preview::[]
+
+[[clear-trained-model-deployment-cache-request]]
+== {api-request-title}
+
+`POST _ml/trained_models/<model_id>/deployment/cache/_clear`
+
+[[clear-trained-model-deployment-cache-prereq]]
+== {api-prereq-title}
+
+Requires the `manage_ml` cluster privilege. This privilege is included in the
+`machine_learning_admin` built-in role.
+
+[[clear-trained-model-deployment-cache-desc]]
+== {api-description-title}
+
+A trained model deployment may have an inference cache enabled. As requests are handled by each allocated node,
+their responses may be cached on that individual node. Calling this API clears the caches without restarting the
+deployment.
+
+[[clear-trained-model-deployment-cache-path-params]]
+== {api-path-parms-title}
+
+`<model_id>`::
+(Required, string)
+include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
+
+[[clear-trained-model-deployment-cache-example]]
+== {api-examples-title}
+
+The following example clears the cache for the new deployment for the
+`elastic__distilbert-base-uncased-finetuned-conll03-english` trained model:
+
+[source,console]
+--------------------------------------------------
+POST _ml/trained_models/elastic__distilbert-base-uncased-finetuned-conll03-english/deployment/cache/_clear
+--------------------------------------------------
+// TEST[skip:TBD]
+
+The API returns the following results:
+
+[source,console-result]
+----
+{
+   "cleared": true
+}
+----

+ 2 - 0
docs/reference/ml/trained-models/apis/index.asciidoc

@@ -12,6 +12,8 @@ include::get-trained-models.asciidoc[leveloffset=+2]
 include::get-trained-models-stats.asciidoc[leveloffset=+2]
 //INFER
 include::infer-trained-model.asciidoc[leveloffset=+2][leveloffset=+2]
+//UPDATE
+include::clear-trained-model-deployment-cache.asciidoc[leveloffset=+2]
 //START/STOP
 include::start-trained-model-deployment.asciidoc[leveloffset=+2]
 include::stop-trained-model-deployment.asciidoc[leveloffset=+2]

+ 31 - 0
rest-api-spec/src/main/resources/rest-api-spec/api/ml.clear_trained_model_deployment_cache.json

@@ -0,0 +1,31 @@
+{
+  "ml.clear_trained_model_deployment_cache":{
+    "documentation":{
+      "url":"https://www.elastic.co/guide/en/elasticsearch/reference/master/clear-trained-model-deployment-cache.html",
+      "description":"Clear the cached results from a trained model deployment"
+    },
+    "stability":"experimental",
+    "visibility":"public",
+    "headers":{
+      "accept": [ "application/json"],
+      "content_type": ["application/json"]
+    },
+    "url":{
+      "paths":[
+        {
+          "path":"/_ml/trained_models/{model_id}/deployment/cache/_clear",
+          "methods":[
+            "POST"
+          ],
+          "parts":{
+            "model_id":{
+              "type":"string",
+              "description":"The unique identifier of the trained model.",
+              "required":true
+            }
+          }
+        }
+      ]
+    }
+  }
+}

+ 102 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/ClearDeploymentCacheAction.java

@@ -0,0 +1,102 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.action;
+
+import org.elasticsearch.action.ActionType;
+import org.elasticsearch.action.support.tasks.BaseTasksRequest;
+import org.elasticsearch.action.support.tasks.BaseTasksResponse;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.tasks.Task;
+import org.elasticsearch.xcontent.ToXContent;
+import org.elasticsearch.xcontent.ToXContentObject;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.Objects;
+
+public class ClearDeploymentCacheAction extends ActionType<ClearDeploymentCacheAction.Response> {
+    public static final ClearDeploymentCacheAction INSTANCE = new ClearDeploymentCacheAction();
+    public static final String NAME = "cluster:admin/xpack/ml/trained_models/deployment/clear_cache";
+
+    private ClearDeploymentCacheAction() {
+        super(NAME, Response::new);
+    }
+
+    public static class Request extends BaseTasksRequest<Request> {
+        private final String deploymentId;
+
+        public Request(String deploymentId) {
+            this.deploymentId = ExceptionsHelper.requireNonNull(deploymentId, "deployment_id");
+        }
+
+        public Request(StreamInput in) throws IOException {
+            super(in);
+            this.deploymentId = in.readString();
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            super.writeTo(out);
+            out.writeString(deploymentId);
+        }
+
+        public String getDeploymentId() {
+            return deploymentId;
+        }
+
+        @Override
+        public boolean match(Task task) {
+            return StartTrainedModelDeploymentAction.TaskMatcher.match(task, deploymentId);
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+            Request request = (Request) o;
+            return Objects.equals(deploymentId, request.deploymentId);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(deploymentId);
+        }
+    }
+
+    public static class Response extends BaseTasksResponse implements ToXContentObject {
+
+        private final boolean cleared;
+
+        public Response(boolean cleared) {
+            super(Collections.emptyList(), Collections.emptyList());
+            this.cleared = cleared;
+        }
+
+        public Response(StreamInput in) throws IOException {
+            super(in);
+            this.cleared = in.readBoolean();
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            super.writeTo(out);
+            out.writeBoolean(cleared);
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
+            builder.startObject();
+            builder.field("cleared", cleared);
+            builder.endObject();
+            return builder;
+        }
+    }
+}

+ 23 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/ClearDeploymentCacheActionRequestTests.java

@@ -0,0 +1,23 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.action;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+
+public class ClearDeploymentCacheActionRequestTests extends AbstractWireSerializingTestCase<ClearDeploymentCacheAction.Request> {
+    @Override
+    protected Writeable.Reader<ClearDeploymentCacheAction.Request> instanceReader() {
+        return ClearDeploymentCacheAction.Request::new;
+    }
+
+    @Override
+    protected ClearDeploymentCacheAction.Request createTestInstance() {
+        return new ClearDeploymentCacheAction.Request(randomAlphaOfLength(5));
+    }
+}

+ 23 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/ClearDeploymentCacheActionResponseTests.java

@@ -0,0 +1,23 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.ml.action;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+
+public class ClearDeploymentCacheActionResponseTests extends AbstractWireSerializingTestCase<ClearDeploymentCacheAction.Response> {
+    @Override
+    protected Writeable.Reader<ClearDeploymentCacheAction.Response> instanceReader() {
+        return ClearDeploymentCacheAction.Response::new;
+    }
+
+    @Override
+    protected ClearDeploymentCacheAction.Response createTestInstance() {
+        return new ClearDeploymentCacheAction.Response(randomBoolean());
+    }
+}

+ 5 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

@@ -92,6 +92,7 @@ import org.elasticsearch.xpack.core.ml.MlMetadata;
 import org.elasticsearch.xpack.core.ml.MlStatsIndex;
 import org.elasticsearch.xpack.core.ml.MlTasks;
 import org.elasticsearch.xpack.core.ml.action.CancelJobModelSnapshotUpgradeAction;
+import org.elasticsearch.xpack.core.ml.action.ClearDeploymentCacheAction;
 import org.elasticsearch.xpack.core.ml.action.CloseJobAction;
 import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
 import org.elasticsearch.xpack.core.ml.action.DeleteCalendarAction;
@@ -189,6 +190,7 @@ import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeTaskS
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.core.template.TemplateUtils;
 import org.elasticsearch.xpack.ml.action.TransportCancelJobModelSnapshotUpgradeAction;
+import org.elasticsearch.xpack.ml.action.TransportClearDeploymentCacheAction;
 import org.elasticsearch.xpack.ml.action.TransportCloseJobAction;
 import org.elasticsearch.xpack.ml.action.TransportCreateTrainedModelAssignmentAction;
 import org.elasticsearch.xpack.ml.action.TransportDeleteCalendarAction;
@@ -391,6 +393,7 @@ import org.elasticsearch.xpack.ml.rest.filter.RestDeleteFilterAction;
 import org.elasticsearch.xpack.ml.rest.filter.RestGetFiltersAction;
 import org.elasticsearch.xpack.ml.rest.filter.RestPutFilterAction;
 import org.elasticsearch.xpack.ml.rest.filter.RestUpdateFilterAction;
+import org.elasticsearch.xpack.ml.rest.inference.RestClearDeploymentCacheAction;
 import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAction;
 import org.elasticsearch.xpack.ml.rest.inference.RestDeleteTrainedModelAliasAction;
 import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsAction;
@@ -1254,6 +1257,7 @@ public class MachineLearning extends Plugin
             new RestPutTrainedModelDefinitionPartAction(),
             new RestPutTrainedModelVocabularyAction(),
             new RestInferTrainedModelAction(),
+            new RestClearDeploymentCacheAction(),
             // CAT Handlers
             new RestCatJobsAction(),
             new RestCatTrainedModelsAction(),
@@ -1358,6 +1362,7 @@ public class MachineLearning extends Plugin
                 UpdateTrainedModelAssignmentRoutingInfoAction.INSTANCE,
                 TransportUpdateTrainedModelAssignmentStateAction.class
             ),
+            new ActionHandler<>(ClearDeploymentCacheAction.INSTANCE, TransportClearDeploymentCacheAction.class),
             usageAction,
             infoAction
         );

+ 97 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportClearDeploymentCacheAction.java

@@ -0,0 +1,97 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.action;
+
+import org.elasticsearch.ResourceNotFoundException;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.FailedNodeException;
+import org.elasticsearch.action.TaskOperationFailure;
+import org.elasticsearch.action.support.ActionFilters;
+import org.elasticsearch.action.support.tasks.TransportTasksAction;
+import org.elasticsearch.cluster.ClusterState;
+import org.elasticsearch.cluster.service.ClusterService;
+import org.elasticsearch.common.inject.Inject;
+import org.elasticsearch.tasks.Task;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.transport.TransportService;
+import org.elasticsearch.xpack.core.ml.action.ClearDeploymentCacheAction;
+import org.elasticsearch.xpack.core.ml.action.ClearDeploymentCacheAction.Request;
+import org.elasticsearch.xpack.core.ml.action.ClearDeploymentCacheAction.Response;
+import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
+import org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentMetadata;
+import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask;
+
+import java.util.List;
+import java.util.Map;
+
+import static org.elasticsearch.ExceptionsHelper.convertToElastic;
+
+public class TransportClearDeploymentCacheAction extends TransportTasksAction<TrainedModelDeploymentTask, Request, Response, Response> {
+
+    @Inject
+    public TransportClearDeploymentCacheAction(
+        TransportService transportService,
+        ActionFilters actionFilters,
+        ClusterService clusterService
+    ) {
+        super(
+            ClearDeploymentCacheAction.NAME,
+            clusterService,
+            transportService,
+            actionFilters,
+            Request::new,
+            Response::new,
+            Response::new,
+            ThreadPool.Names.SAME
+        );
+    }
+
+    @Override
+    protected Response newResponse(
+        Request request,
+        List<Response> taskResponse,
+        List<TaskOperationFailure> taskOperationFailures,
+        List<FailedNodeException> failedNodeExceptions
+    ) {
+        if (taskOperationFailures.isEmpty() == false) {
+            throw convertToElastic(taskOperationFailures.get(0).getCause());
+        } else if (failedNodeExceptions.isEmpty() == false) {
+            throw convertToElastic(failedNodeExceptions.get(0));
+        }
+        return new Response(true);
+    }
+
+    @Override
+    protected void doExecute(Task task, Request request, ActionListener<Response> listener) {
+        final ClusterState clusterState = clusterService.state();
+        final TrainedModelAssignmentMetadata assignment = TrainedModelAssignmentMetadata.fromState(clusterState);
+        TrainedModelAssignment trainedModelAssignment = assignment.getModelAssignment(request.getDeploymentId());
+        if (trainedModelAssignment == null) {
+            listener.onFailure(new ResourceNotFoundException("assignment for model with id [{}] not found", request.getDeploymentId()));
+            return;
+        }
+        String[] nodes = trainedModelAssignment.getNodeRoutingTable()
+            .entrySet()
+            .stream()
+            .filter(entry -> entry.getValue().isRoutable())
+            .map(Map.Entry::getKey)
+            .toArray(String[]::new);
+
+        if (nodes.length == 0) {
+            listener.onResponse(new Response(true));
+            return;
+        }
+        request.setNodes(nodes);
+        super.doExecute(task, request, listener);
+    }
+
+    @Override
+    protected void taskOperation(Task actionTask, Request request, TrainedModelDeploymentTask task, ActionListener<Response> listener) {
+        task.clearCache(ActionListener.wrap(r -> listener.onResponse(new Response(true)), listener::onFailure));
+    }
+}

+ 6 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java

@@ -66,7 +66,7 @@ public class TrainedModelAssignmentNodeService implements ClusterStateListener {
     private static final String NODE_NO_LONGER_REFERENCED = "node no longer referenced in model routing table";
     private static final String ASSIGNMENT_NO_LONGER_EXISTS = "model assignment no longer exists";
     private static final TimeValue MODEL_LOADING_CHECK_INTERVAL = TimeValue.timeValueSeconds(1);
-    private static final TimeValue UPDATE_NUMBER_OF_ALLOCATIONS_TIMEOUT = TimeValue.timeValueSeconds(60);
+    private static final TimeValue CONTROL_MESSAGE_TIMEOUT = TimeValue.timeValueSeconds(60);
     private static final Logger logger = LogManager.getLogger(TrainedModelAssignmentNodeService.class);
     private final TrainedModelAssignmentService trainedModelAssignmentService;
     private final DeploymentManager deploymentManager;
@@ -286,6 +286,10 @@ public class TrainedModelAssignmentNodeService implements ClusterStateListener {
         return deploymentManager.getStats(task);
     }
 
+    public void clearCache(TrainedModelDeploymentTask task, ActionListener<AcknowledgedResponse> listener) {
+        deploymentManager.clearCache(task, CONTROL_MESSAGE_TIMEOUT, listener);
+    }
+
     private TaskAwareRequest taskAwareRequest(StartTrainedModelDeploymentAction.TaskParams params) {
         final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = this;
         return new TaskAwareRequest() {
@@ -419,7 +423,7 @@ public class TrainedModelAssignmentNodeService implements ClusterStateListener {
             deploymentManager.updateNumAllocations(
                 task,
                 assignment.getNodeRoutingTable().get(nodeId).getTargetAllocations(),
-                UPDATE_NUMBER_OF_ALLOCATIONS_TIMEOUT,
+                CONTROL_MESSAGE_TIMEOUT,
                 ActionListener.wrap(threadSettings -> {
                     logger.debug("[{}] Updated number of allocations to [{}]", assignment.getModelId(), threadSettings.numAllocations());
                     task.updateNumberOfAllocations(threadSettings.numAllocations());

+ 18 - 17
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/ControlMessagePyTorchAction.java → x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/AbstractControlMessagePyTorchAction.java

@@ -17,35 +17,37 @@ import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentFactory;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchResult;
-import org.elasticsearch.xpack.ml.inference.pytorch.results.ThreadSettings;
 
 import java.io.IOException;
 
 import static org.elasticsearch.core.Strings.format;
 
-class ControlMessagePyTorchAction extends AbstractPyTorchAction<ThreadSettings> {
+abstract class AbstractControlMessagePyTorchAction<T> extends AbstractPyTorchAction<T> {
 
     private static final Logger logger = LogManager.getLogger(InferencePyTorchAction.class);
 
-    private final int numAllocationThreads;
-
-    private enum ControlMessageTypes {
-        AllocationThreads
+    enum ControlMessageTypes {
+        AllocationThreads,
+        ClearCache
     };
 
-    ControlMessagePyTorchAction(
+    AbstractControlMessagePyTorchAction(
         String modelId,
         long requestId,
-        int numAllocationThreads,
         TimeValue timeout,
         DeploymentManager.ProcessContext processContext,
         ThreadPool threadPool,
-        ActionListener<ThreadSettings> listener
+        ActionListener<T> listener
     ) {
         super(modelId, requestId, timeout, processContext, threadPool, listener);
-        this.numAllocationThreads = numAllocationThreads;
     }
 
+    abstract int controlOrdinal();
+
+    abstract void writeMessage(XContentBuilder builder) throws IOException;
+
+    abstract T getResult(PyTorchResult result);
+
     @Override
     protected void doRun() throws Exception {
         if (isNotified()) {
@@ -56,7 +58,7 @@ class ControlMessagePyTorchAction extends AbstractPyTorchAction<ThreadSettings>
 
         final String requestIdStr = String.valueOf(getRequestId());
         try {
-            var message = buildControlMessage(requestIdStr, numAllocationThreads);
+            var message = buildControlMessage(requestIdStr);
 
             getProcessContext().getResultProcessor()
                 .registerRequest(requestIdStr, ActionListener.wrap(this::processResponse, this::onFailure));
@@ -70,24 +72,23 @@ class ControlMessagePyTorchAction extends AbstractPyTorchAction<ThreadSettings>
         }
     }
 
-    public static BytesReference buildControlMessage(String requestId, int numAllocationThreads) throws IOException {
+    final BytesReference buildControlMessage(String requestId) throws IOException {
         XContentBuilder builder = XContentFactory.jsonBuilder();
         builder.startObject();
         builder.field("request_id", requestId);
-        builder.field("control", ControlMessageTypes.AllocationThreads.ordinal());
-        builder.field("num_allocations", numAllocationThreads);
+        builder.field("control", controlOrdinal());
+        writeMessage(builder);
         builder.endObject();
-
         // BytesReference.bytes closes the builder
         return BytesReference.bytes(builder);
     }
 
-    public void processResponse(PyTorchResult result) {
+    private void processResponse(PyTorchResult result) {
         if (result.isError()) {
             onFailure(result.errorResult().error());
             return;
         }
-        onSuccess(result.threadSettings());
+        onSuccess(getResult(result));
     }
 
     @Override

+ 43 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/ClearCacheControlMessagePytorchAction.java

@@ -0,0 +1,43 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.deployment;
+
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchResult;
+
+public class ClearCacheControlMessagePytorchAction extends AbstractControlMessagePyTorchAction<Boolean> {
+
+    ClearCacheControlMessagePytorchAction(
+        String modelId,
+        long requestId,
+        TimeValue timeout,
+        DeploymentManager.ProcessContext processContext,
+        ThreadPool threadPool,
+        ActionListener<Boolean> listener
+    ) {
+        super(modelId, requestId, timeout, processContext, threadPool, listener);
+    }
+
+    @Override
+    int controlOrdinal() {
+        return ControlMessageTypes.ClearCache.ordinal();
+    }
+
+    @Override
+    void writeMessage(XContentBuilder builder) {
+        // Nothing is written
+    }
+
+    @Override
+    Boolean getResult(PyTorchResult result) {
+        return true;
+    }
+}

+ 22 - 1
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

@@ -14,6 +14,7 @@ import org.elasticsearch.ResourceNotFoundException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.search.SearchAction;
 import org.elasticsearch.action.search.SearchRequest;
+import org.elasticsearch.action.support.master.AcknowledgedResponse;
 import org.elasticsearch.client.internal.Client;
 import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
 import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
@@ -296,7 +297,7 @@ public class DeploymentManager {
         }
 
         final long requestId = requestIdCounter.getAndIncrement();
-        ControlMessagePyTorchAction controlMessageAction = new ControlMessagePyTorchAction(
+        ThreadSettingsControlMessagePytorchAction controlMessageAction = new ThreadSettingsControlMessagePytorchAction(
             task.getModelId(),
             requestId,
             numAllocationThreads,
@@ -309,6 +310,26 @@ public class DeploymentManager {
         executePyTorchAction(processContext, PriorityProcessWorkerExecutorService.RequestPriority.HIGHEST, controlMessageAction);
     }
 
+    public void clearCache(TrainedModelDeploymentTask task, TimeValue timeout, ActionListener<AcknowledgedResponse> listener) {
+        var processContext = getProcessContext(task, listener::onFailure);
+        if (processContext == null) {
+            // error reporting handled in the call to getProcessContext
+            return;
+        }
+
+        final long requestId = requestIdCounter.getAndIncrement();
+        ClearCacheControlMessagePytorchAction controlMessageAction = new ClearCacheControlMessagePytorchAction(
+            task.getModelId(),
+            requestId,
+            timeout,
+            processContext,
+            threadPool,
+            ActionListener.wrap(b -> listener.onResponse(AcknowledgedResponse.TRUE), listener::onFailure)
+        );
+
+        executePyTorchAction(processContext, PriorityProcessWorkerExecutorService.RequestPriority.HIGHEST, controlMessageAction);
+    }
+
     public void executePyTorchAction(
         ProcessContext processContext,
         PriorityProcessWorkerExecutorService.RequestPriority priority,

+ 49 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/ThreadSettingsControlMessagePytorchAction.java

@@ -0,0 +1,49 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.inference.deployment;
+
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchResult;
+import org.elasticsearch.xpack.ml.inference.pytorch.results.ThreadSettings;
+
+import java.io.IOException;
+
+public class ThreadSettingsControlMessagePytorchAction extends AbstractControlMessagePyTorchAction<ThreadSettings> {
+    private final int numAllocationThreads;
+
+    ThreadSettingsControlMessagePytorchAction(
+        String modelId,
+        long requestId,
+        int numAllocationThreads,
+        TimeValue timeout,
+        DeploymentManager.ProcessContext processContext,
+        ThreadPool threadPool,
+        ActionListener<ThreadSettings> listener
+    ) {
+        super(modelId, requestId, timeout, processContext, threadPool, listener);
+        this.numAllocationThreads = numAllocationThreads;
+    }
+
+    @Override
+    int controlOrdinal() {
+        return ControlMessageTypes.AllocationThreads.ordinal();
+    }
+
+    @Override
+    void writeMessage(XContentBuilder builder) throws IOException {
+        builder.field("num_allocations", numAllocationThreads);
+    }
+
+    @Override
+    ThreadSettings getResult(PyTorchResult result) {
+        return result.threadSettings();
+    }
+}

+ 4 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java

@@ -167,6 +167,10 @@ public class TrainedModelDeploymentTask extends CancellableTask implements Start
         return trainedModelAssignmentNodeService.modelStats(this);
     }
 
+    public void clearCache(ActionListener<AcknowledgedResponse> listener) {
+        trainedModelAssignmentNodeService.clearCache(this, listener);
+    }
+
     public void setFailed(String reason) {
         failed = true;
         trainedModelAssignmentNodeService.failAssignment(this, reason);

+ 47 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestClearDeploymentCacheAction.java

@@ -0,0 +1,47 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.ml.rest.inference;
+
+import org.elasticsearch.client.internal.node.NodeClient;
+import org.elasticsearch.rest.BaseRestHandler;
+import org.elasticsearch.rest.RestRequest;
+import org.elasticsearch.rest.action.RestToXContentListener;
+import org.elasticsearch.xpack.core.ml.action.ClearDeploymentCacheAction;
+import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+
+import static org.elasticsearch.rest.RestRequest.Method.POST;
+import static org.elasticsearch.xpack.ml.MachineLearning.BASE_PATH;
+
+public class RestClearDeploymentCacheAction extends BaseRestHandler {
+
+    @Override
+    public String getName() {
+        return "xpack_ml_clear_deployment_cache_action";
+    }
+
+    @Override
+    public List<Route> routes() {
+        return Collections.singletonList(
+            new Route(POST, BASE_PATH + "trained_models/{" + TrainedModelConfig.MODEL_ID.getPreferredName() + "}/deployment/cache/_clear")
+        );
+    }
+
+    @Override
+    protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
+        String modelId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName());
+        return channel -> client.execute(
+            ClearDeploymentCacheAction.INSTANCE,
+            new ClearDeploymentCacheAction.Request(modelId),
+            new RestToXContentListener<>(channel)
+        );
+    }
+}

+ 17 - 8
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/ControlMessagePyTorchActionTests.java → x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/deployment/ThreadSettingsControlMessagePytorchActionTests.java

@@ -32,13 +32,22 @@ import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
-public class ControlMessagePyTorchActionTests extends ESTestCase {
-
-    private ThreadPool tp;
+public class ThreadSettingsControlMessagePytorchActionTests extends ESTestCase {
 
     public void testBuildControlMessage() throws IOException {
-        var message = ControlMessagePyTorchAction.buildControlMessage("foo", 4);
-
+        DeploymentManager.ProcessContext processContext = mock(DeploymentManager.ProcessContext.class);
+        ThreadPool tp = mock(ThreadPool.class);
+        @SuppressWarnings("unchecked")
+        ThreadSettingsControlMessagePytorchAction action = new ThreadSettingsControlMessagePytorchAction(
+            "model_id",
+            1,
+            4,
+            TimeValue.MINUS_ONE,
+            processContext,
+            tp,
+            ActionListener.NOOP
+        );
+        var message = action.buildControlMessage("foo");
         assertEquals("{\"request_id\":\"foo\",\"control\":0,\"num_allocations\":4}", message.utf8ToString());
     }
 
@@ -56,7 +65,7 @@ public class ControlMessagePyTorchActionTests extends ESTestCase {
 
         {
             ActionListener<ThreadSettings> listener = mock(ActionListener.class);
-            ControlMessagePyTorchAction action = new ControlMessagePyTorchAction(
+            ThreadSettingsControlMessagePytorchAction action = new ThreadSettingsControlMessagePytorchAction(
                 "test-model",
                 1,
                 1,
@@ -75,7 +84,7 @@ public class ControlMessagePyTorchActionTests extends ESTestCase {
         }
         {
             ActionListener<ThreadSettings> listener = mock(ActionListener.class);
-            ControlMessagePyTorchAction action = new ControlMessagePyTorchAction(
+            ThreadSettingsControlMessagePytorchAction action = new ThreadSettingsControlMessagePytorchAction(
                 "test-model",
                 1,
                 1,
@@ -114,7 +123,7 @@ public class ControlMessagePyTorchActionTests extends ESTestCase {
         ArgumentCaptor<BytesReference> messageCapture = ArgumentCaptor.forClass(BytesReference.class);
         doNothing().when(pp).writeInferenceRequest(messageCapture.capture());
 
-        ControlMessagePyTorchAction action = new ControlMessagePyTorchAction(
+        ThreadSettingsControlMessagePytorchAction action = new ThreadSettingsControlMessagePytorchAction(
             "test-model",
             1,
             1,

+ 1 - 0
x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java

@@ -160,6 +160,7 @@ public class Constants {
         "cluster:admin/xpack/ml/job/update",
         "cluster:admin/xpack/ml/job/validate",
         "cluster:admin/xpack/ml/job/validate/detector",
+        "cluster:admin/xpack/ml/trained_models/deployment/clear_cache",
         "cluster:admin/xpack/ml/trained_models/deployment/start",
         "cluster:admin/xpack/ml/trained_models/deployment/stop",
         "cluster:admin/xpack/ml/trained_models/part/put",

+ 86 - 0
x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml

@@ -130,6 +130,92 @@ setup:
   - match: { trained_model_stats.0.deployment_stats.nodes.0.inference_count: 3 }
   - match: { trained_model_stats.0.deployment_stats.nodes.0.inference_cache_hit_count: 1 }
 
+  - do:
+      ml.stop_trained_model_deployment:
+        model_id: test_model
+  - match: { stopped: true }
+---
+"Test clear deployment cache":
+  - skip:
+      features: allowed_warnings
+
+  - do:
+      ml.start_trained_model_deployment:
+        model_id: test_model
+        cache_size: 10kb
+        wait_for: started
+  - match: {assignment.assignment_state: started}
+  - match: {assignment.task_parameters.model_id: test_model}
+  - match: {assignment.task_parameters.cache_size: 10kb}
+
+  - do:
+      allowed_warnings:
+        - '[POST /_ml/trained_models/{model_id}/deployment/_infer] is deprecated! Use [POST /_ml/trained_models/{model_id}/_infer] instead.'
+      ml.infer_trained_model:
+        model_id: "test_model"
+        body: >
+          {
+            "docs": [
+              { "input": "words" }
+            ]
+          }
+
+  - do:
+      allowed_warnings:
+        - '[POST /_ml/trained_models/{model_id}/deployment/_infer] is deprecated! Use [POST /_ml/trained_models/{model_id}/_infer] instead.'
+      ml.infer_trained_model:
+        model_id: "test_model"
+        body: >
+          {
+            "docs": [
+              { "input": "are" }
+            ]
+          }
+
+  - do:
+      allowed_warnings:
+        - '[POST /_ml/trained_models/{model_id}/deployment/_infer] is deprecated! Use [POST /_ml/trained_models/{model_id}/_infer] instead.'
+      ml.infer_trained_model:
+        model_id: "test_model"
+        body: >
+          {
+            "docs": [
+              { "input": "words" }
+            ]
+          }
+
+  - do:
+      ml.get_trained_models_stats:
+        model_id: "test_model"
+  - match: { count: 1 }
+  - match: { trained_model_stats.0.deployment_stats.nodes.0.inference_count: 3 }
+  - match: { trained_model_stats.0.deployment_stats.nodes.0.inference_cache_hit_count: 1 }
+
+
+  - do:
+      ml.clear_trained_model_deployment_cache:
+        model_id: test_model
+  - match: { cleared: true }
+
+  - do:
+      allowed_warnings:
+        - '[POST /_ml/trained_models/{model_id}/deployment/_infer] is deprecated! Use [POST /_ml/trained_models/{model_id}/_infer] instead.'
+      ml.infer_trained_model:
+        model_id: "test_model"
+        body: >
+          {
+            "docs": [
+              { "input": "words" }
+            ]
+          }
+
+  - do:
+      ml.get_trained_models_stats:
+        model_id: "test_model"
+  - match: { count: 1 }
+  - match: { trained_model_stats.0.deployment_stats.nodes.0.inference_count: 4 }
+  - match: { trained_model_stats.0.deployment_stats.nodes.0.inference_cache_hit_count: 1 }
+
   - do:
       ml.stop_trained_model_deployment:
         model_id: test_model