Browse Source

[ML] Preserve lost thread context in node inference action (#132973)

Fixes a memory leak if APM tracing is enabled as the lost context meant the trace
was never closed and the span never released
David Kyle 2 months ago
parent
commit
62c84a486b

+ 5 - 0
docs/changelog/132973.yaml

@@ -0,0 +1,5 @@
+pr: 132973
+summary: Preserve lost thread context in node inference action. A lost context causes a memory leak if APM tracing is enabled
+area: Machine Learning
+type: bug
+issues: []

+ 11 - 2
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferTrainedModelDeploymentAction.java

@@ -12,6 +12,7 @@ import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.FailedNodeException;
 import org.elasticsearch.action.TaskOperationFailure;
 import org.elasticsearch.action.support.ActionFilters;
+import org.elasticsearch.action.support.ContextPreservingActionListener;
 import org.elasticsearch.action.support.tasks.TransportTasksAction;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.util.concurrent.AtomicArray;
@@ -20,6 +21,7 @@ import org.elasticsearch.inference.InferenceResults;
 import org.elasticsearch.injection.guice.Inject;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.tasks.CancellableTask;
+import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
@@ -37,11 +39,14 @@ public class TransportInferTrainedModelDeploymentAction extends TransportTasksAc
     InferTrainedModelDeploymentAction.Response,
     InferTrainedModelDeploymentAction.Response> {
 
+    private final ThreadPool threadPool;
+
     @Inject
     public TransportInferTrainedModelDeploymentAction(
         ClusterService clusterService,
         TransportService transportService,
-        ActionFilters actionFilters
+        ActionFilters actionFilters,
+        ThreadPool threadPool
     ) {
         super(
             InferTrainedModelDeploymentAction.NAME,
@@ -52,6 +57,7 @@ public class TransportInferTrainedModelDeploymentAction extends TransportTasksAc
             InferTrainedModelDeploymentAction.Response::new,
             EsExecutors.DIRECT_EXECUTOR_SERVICE
         );
+        this.threadPool = threadPool;
     }
 
     @Override
@@ -99,6 +105,9 @@ public class TransportInferTrainedModelDeploymentAction extends TransportTasksAc
         // and return order the results to match the request order
         AtomicInteger count = new AtomicInteger();
         AtomicArray<InferenceResults> results = new AtomicArray<>(nlpInputs.size());
+
+        var contextPreservingListener = ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext());
+
         int slot = 0;
         for (var input : nlpInputs) {
             task.infer(
@@ -109,7 +118,7 @@ public class TransportInferTrainedModelDeploymentAction extends TransportTasksAc
                 request.getPrefixType(),
                 actionTask,
                 request.isChunkResults(),
-                orderedListener(count, results, slot++, nlpInputs.size(), listener)
+                orderedListener(count, results, slot++, nlpInputs.size(), contextPreservingListener)
             );
         }
     }