|
@@ -12,6 +12,7 @@ import org.elasticsearch.action.ActionListener;
|
|
|
import org.elasticsearch.action.FailedNodeException;
|
|
|
import org.elasticsearch.action.TaskOperationFailure;
|
|
|
import org.elasticsearch.action.support.ActionFilters;
|
|
|
+import org.elasticsearch.action.support.ContextPreservingActionListener;
|
|
|
import org.elasticsearch.action.support.tasks.TransportTasksAction;
|
|
|
import org.elasticsearch.cluster.service.ClusterService;
|
|
|
import org.elasticsearch.common.util.concurrent.AtomicArray;
|
|
@@ -20,6 +21,7 @@ import org.elasticsearch.inference.InferenceResults;
|
|
|
import org.elasticsearch.injection.guice.Inject;
|
|
|
import org.elasticsearch.rest.RestStatus;
|
|
|
import org.elasticsearch.tasks.CancellableTask;
|
|
|
+import org.elasticsearch.threadpool.ThreadPool;
|
|
|
import org.elasticsearch.transport.TransportService;
|
|
|
import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
|
|
|
import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
|
|
@@ -37,11 +39,14 @@ public class TransportInferTrainedModelDeploymentAction extends TransportTasksAc
|
|
|
InferTrainedModelDeploymentAction.Response,
|
|
|
InferTrainedModelDeploymentAction.Response> {
|
|
|
|
|
|
+ private final ThreadPool threadPool;
|
|
|
+
|
|
|
@Inject
|
|
|
public TransportInferTrainedModelDeploymentAction(
|
|
|
ClusterService clusterService,
|
|
|
TransportService transportService,
|
|
|
- ActionFilters actionFilters
|
|
|
+ ActionFilters actionFilters,
|
|
|
+ ThreadPool threadPool
|
|
|
) {
|
|
|
super(
|
|
|
InferTrainedModelDeploymentAction.NAME,
|
|
@@ -52,6 +57,7 @@ public class TransportInferTrainedModelDeploymentAction extends TransportTasksAc
|
|
|
InferTrainedModelDeploymentAction.Response::new,
|
|
|
EsExecutors.DIRECT_EXECUTOR_SERVICE
|
|
|
);
|
|
|
+ this.threadPool = threadPool;
|
|
|
}
|
|
|
|
|
|
@Override
|
|
@@ -99,6 +105,9 @@ public class TransportInferTrainedModelDeploymentAction extends TransportTasksAc
|
|
|
// and return order the results to match the request order
|
|
|
AtomicInteger count = new AtomicInteger();
|
|
|
AtomicArray<InferenceResults> results = new AtomicArray<>(nlpInputs.size());
|
|
|
+
|
|
|
+ var contextPreservingListener = ContextPreservingActionListener.wrapPreservingContext(listener, threadPool.getThreadContext());
|
|
|
+
|
|
|
int slot = 0;
|
|
|
for (var input : nlpInputs) {
|
|
|
task.infer(
|
|
@@ -109,7 +118,7 @@ public class TransportInferTrainedModelDeploymentAction extends TransportTasksAc
|
|
|
request.getPrefixType(),
|
|
|
actionTask,
|
|
|
request.isChunkResults(),
|
|
|
- orderedListener(count, results, slot++, nlpInputs.size(), listener)
|
|
|
+ orderedListener(count, results, slot++, nlpInputs.size(), contextPreservingListener)
|
|
|
);
|
|
|
}
|
|
|
}
|