|
@@ -9,10 +9,15 @@ package org.elasticsearch.xpack.inference.action;
|
|
|
|
|
|
import org.elasticsearch.ElasticsearchStatusException;
|
|
|
import org.elasticsearch.action.ActionListener;
|
|
|
+import org.elasticsearch.action.ActionRequest;
|
|
|
+import org.elasticsearch.action.ActionResponse;
|
|
|
+import org.elasticsearch.action.ActionType;
|
|
|
import org.elasticsearch.action.support.ActionFilters;
|
|
|
+import org.elasticsearch.action.support.ContextPreservingActionListener;
|
|
|
import org.elasticsearch.action.support.HandledTransportAction;
|
|
|
import org.elasticsearch.client.internal.Client;
|
|
|
import org.elasticsearch.common.util.concurrent.EsExecutors;
|
|
|
+import org.elasticsearch.common.util.concurrent.ThreadContext;
|
|
|
import org.elasticsearch.common.xcontent.XContentHelper;
|
|
|
import org.elasticsearch.inference.TaskType;
|
|
|
import org.elasticsearch.inference.UnparsedModel;
|
|
@@ -30,7 +35,6 @@ import org.elasticsearch.xpack.inference.registry.ModelRegistry;
|
|
|
import java.io.IOException;
|
|
|
|
|
|
import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN;
|
|
|
-import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
|
|
|
|
|
|
public class TransportInferenceActionProxy extends HandledTransportAction<InferenceActionProxy.Request, InferenceAction.Response> {
|
|
|
private final ModelRegistry modelRegistry;
|
|
@@ -103,7 +107,7 @@ public class TransportInferenceActionProxy extends HandledTransportAction<Infere
|
|
|
);
|
|
|
}
|
|
|
|
|
|
- executeAsyncWithOrigin(client, INFERENCE_ORIGIN, UnifiedCompletionAction.INSTANCE, unifiedRequest, unifiedErrorFormatListener);
|
|
|
+ execute(UnifiedCompletionAction.INSTANCE, unifiedRequest, listener);
|
|
|
} catch (Exception e) {
|
|
|
unifiedErrorFormatListener.onFailure(e);
|
|
|
}
|
|
@@ -122,6 +126,19 @@ public class TransportInferenceActionProxy extends HandledTransportAction<Infere
|
|
|
inferenceActionRequestBuilder.setInferenceTimeout(request.getTimeout()).setStream(request.isStreaming());
|
|
|
}
|
|
|
|
|
|
- executeAsyncWithOrigin(client, INFERENCE_ORIGIN, InferenceAction.INSTANCE, inferenceActionRequestBuilder.build(), listener);
|
|
|
+ execute(InferenceAction.INSTANCE, inferenceActionRequestBuilder.build(), listener);
|
|
|
+ }
|
|
|
+
|
|
|
+ private <Request extends ActionRequest, Response extends ActionResponse> void execute(
|
|
|
+ ActionType<Response> action,
|
|
|
+ Request request,
|
|
|
+ ActionListener<Response> listener
|
|
|
+ ) {
|
|
|
+ var threadContext = client.threadPool().getThreadContext();
|
|
|
+ // stash the context so we clear the user's security headers, then restore and copy the response headers
|
|
|
+ var supplier = threadContext.newRestorableContext(true);
|
|
|
+ try (ThreadContext.StoredContext ignore = threadContext.stashWithOrigin(INFERENCE_ORIGIN)) {
|
|
|
+ client.execute(action, request, new ContextPreservingActionListener<>(supplier, listener));
|
|
|
+ }
|
|
|
}
|
|
|
}
|