浏览代码

[ML] Directly call Inference API from Proxy (#127342)

In order to propagate response headers back from the proxied actions, we
are directly calling the Transport actions via the NodeClient.
Pat Whelan 5 月之前
父节点
当前提交
1bce4d6e66

+ 20 - 3
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java

@@ -9,10 +9,15 @@ package org.elasticsearch.xpack.inference.action;
 
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.ActionRequest;
+import org.elasticsearch.action.ActionResponse;
+import org.elasticsearch.action.ActionType;
 import org.elasticsearch.action.support.ActionFilters;
+import org.elasticsearch.action.support.ContextPreservingActionListener;
 import org.elasticsearch.action.support.HandledTransportAction;
 import org.elasticsearch.client.internal.Client;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
+import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.common.xcontent.XContentHelper;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.inference.UnparsedModel;
@@ -30,7 +35,6 @@ import org.elasticsearch.xpack.inference.registry.ModelRegistry;
 import java.io.IOException;
 
 import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN;
-import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
 
 public class TransportInferenceActionProxy extends HandledTransportAction<InferenceActionProxy.Request, InferenceAction.Response> {
     private final ModelRegistry modelRegistry;
@@ -103,7 +107,7 @@ public class TransportInferenceActionProxy extends HandledTransportAction<Infere
                 );
             }
 
-            executeAsyncWithOrigin(client, INFERENCE_ORIGIN, UnifiedCompletionAction.INSTANCE, unifiedRequest, unifiedErrorFormatListener);
+            execute(UnifiedCompletionAction.INSTANCE, unifiedRequest, listener);
         } catch (Exception e) {
             unifiedErrorFormatListener.onFailure(e);
         }
@@ -122,6 +126,19 @@ public class TransportInferenceActionProxy extends HandledTransportAction<Infere
             inferenceActionRequestBuilder.setInferenceTimeout(request.getTimeout()).setStream(request.isStreaming());
         }
 
-        executeAsyncWithOrigin(client, INFERENCE_ORIGIN, InferenceAction.INSTANCE, inferenceActionRequestBuilder.build(), listener);
+        execute(InferenceAction.INSTANCE, inferenceActionRequestBuilder.build(), listener);
+    }
+
+    private <Request extends ActionRequest, Response extends ActionResponse> void execute(
+        ActionType<Response> action,
+        Request request,
+        ActionListener<Response> listener
+    ) {
+        var threadContext = client.threadPool().getThreadContext();
+        // stash the context so we clear the user's security headers, then restore and copy the response headers
+        var supplier = threadContext.newRestorableContext(true);
+        try (ThreadContext.StoredContext ignore = threadContext.stashWithOrigin(INFERENCE_ORIGIN)) {
+            client.execute(action, request, new ContextPreservingActionListener<>(supplier, listener));
+        }
     }
 }