|
@@ -39,7 +39,6 @@ import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest;
|
|
|
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
|
|
import org.elasticsearch.xpack.inference.InferencePlugin;
|
|
|
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
|
|
|
-import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
|
|
|
import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator;
|
|
|
import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
|
|
|
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
|
|
@@ -297,8 +296,7 @@ public abstract class BaseTransportInferenceAction<Request extends BaseInference
|
|
|
);
|
|
|
inferenceResults.publisher().subscribe(taskProcessor);
|
|
|
|
|
|
- var instrumentedStream = new PublisherWithMetrics(timer, model, request, localNodeId);
|
|
|
- taskProcessor.subscribe(instrumentedStream);
|
|
|
+ var instrumentedStream = publisherWithMetrics(timer, model, request, localNodeId, taskProcessor);
|
|
|
|
|
|
var streamErrorHandler = streamErrorHandler(instrumentedStream);
|
|
|
|
|
@@ -313,7 +311,52 @@ public abstract class BaseTransportInferenceAction<Request extends BaseInference
|
|
|
}));
|
|
|
}
|
|
|
|
|
|
- protected <T> Flow.Publisher<T> streamErrorHandler(Flow.Processor<T, T> upstream) {
|
|
|
+ private <T> Flow.Publisher<T> publisherWithMetrics(
|
|
|
+ InferenceTimer timer,
|
|
|
+ Model model,
|
|
|
+ Request request,
|
|
|
+ String localNodeId,
|
|
|
+ Flow.Processor<T, T> upstream
|
|
|
+ ) {
|
|
|
+ return downstream -> {
|
|
|
+ upstream.subscribe(new Flow.Subscriber<>() {
|
|
|
+ @Override
|
|
|
+ public void onSubscribe(Flow.Subscription subscription) {
|
|
|
+ downstream.onSubscribe(new Flow.Subscription() {
|
|
|
+ @Override
|
|
|
+ public void request(long n) {
|
|
|
+ subscription.request(n);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void cancel() {
|
|
|
+ recordRequestDurationMetrics(model, timer, request, localNodeId, null);
|
|
|
+ subscription.cancel();
|
|
|
+ }
|
|
|
+ });
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void onNext(T item) {
|
|
|
+ downstream.onNext(item);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void onError(Throwable throwable) {
|
|
|
+ recordRequestDurationMetrics(model, timer, request, localNodeId, throwable);
|
|
|
+ downstream.onError(throwable);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void onComplete() {
|
|
|
+ recordRequestDurationMetrics(model, timer, request, localNodeId, null);
|
|
|
+ downstream.onComplete();
|
|
|
+ }
|
|
|
+ });
|
|
|
+ };
|
|
|
+ }
|
|
|
+
|
|
|
+ protected <T> Flow.Publisher<T> streamErrorHandler(Flow.Publisher<T> upstream) {
|
|
|
return upstream;
|
|
|
}
|
|
|
|
|
@@ -386,44 +429,6 @@ public abstract class BaseTransportInferenceAction<Request extends BaseInference
|
|
|
);
|
|
|
}
|
|
|
|
|
|
- private class PublisherWithMetrics extends DelegatingProcessor<InferenceServiceResults.Result, InferenceServiceResults.Result> {
|
|
|
-
|
|
|
- private final InferenceTimer timer;
|
|
|
- private final Model model;
|
|
|
- private final Request request;
|
|
|
- private final String localNodeId;
|
|
|
-
|
|
|
- private PublisherWithMetrics(InferenceTimer timer, Model model, Request request, String localNodeId) {
|
|
|
- this.timer = timer;
|
|
|
- this.model = model;
|
|
|
- this.request = request;
|
|
|
- this.localNodeId = localNodeId;
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- protected void next(InferenceServiceResults.Result item) {
|
|
|
- downstream().onNext(item);
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public void onError(Throwable throwable) {
|
|
|
- recordRequestDurationMetrics(model, timer, request, localNodeId, throwable);
|
|
|
- super.onError(throwable);
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- protected void onCancel() {
|
|
|
- recordRequestDurationMetrics(model, timer, request, localNodeId, null);
|
|
|
- super.onCancel();
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public void onComplete() {
|
|
|
- recordRequestDurationMetrics(model, timer, request, localNodeId, null);
|
|
|
- super.onComplete();
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
private record NodeRoutingDecision(boolean currentNodeShouldHandleRequest, DiscoveryNode targetNode) {
|
|
|
static NodeRoutingDecision handleLocally() {
|
|
|
return new NodeRoutingDecision(true, null);
|