|  | @@ -39,7 +39,6 @@ import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest;
 | 
	
		
			
				|  |  |  import org.elasticsearch.xpack.core.inference.action.InferenceAction;
 | 
	
		
			
				|  |  |  import org.elasticsearch.xpack.inference.InferencePlugin;
 | 
	
		
			
				|  |  |  import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
 | 
	
		
			
				|  |  | -import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
 | 
	
		
			
				|  |  |  import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator;
 | 
	
		
			
				|  |  |  import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
 | 
	
		
			
				|  |  |  import org.elasticsearch.xpack.inference.registry.ModelRegistry;
 | 
	
	
		
			
				|  | @@ -297,8 +296,7 @@ public abstract class BaseTransportInferenceAction<Request extends BaseInference
 | 
	
		
			
				|  |  |                  );
 | 
	
		
			
				|  |  |                  inferenceResults.publisher().subscribe(taskProcessor);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -                var instrumentedStream = new PublisherWithMetrics(timer, model, request, localNodeId);
 | 
	
		
			
				|  |  | -                taskProcessor.subscribe(instrumentedStream);
 | 
	
		
			
				|  |  | +                var instrumentedStream = publisherWithMetrics(timer, model, request, localNodeId, taskProcessor);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |                  var streamErrorHandler = streamErrorHandler(instrumentedStream);
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -313,7 +311,52 @@ public abstract class BaseTransportInferenceAction<Request extends BaseInference
 | 
	
		
			
				|  |  |          }));
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    protected <T> Flow.Publisher<T> streamErrorHandler(Flow.Processor<T, T> upstream) {
 | 
	
		
			
				|  |  | +    private <T> Flow.Publisher<T> publisherWithMetrics(
 | 
	
		
			
				|  |  | +        InferenceTimer timer,
 | 
	
		
			
				|  |  | +        Model model,
 | 
	
		
			
				|  |  | +        Request request,
 | 
	
		
			
				|  |  | +        String localNodeId,
 | 
	
		
			
				|  |  | +        Flow.Processor<T, T> upstream
 | 
	
		
			
				|  |  | +    ) {
 | 
	
		
			
				|  |  | +        return downstream -> {
 | 
	
		
			
				|  |  | +            upstream.subscribe(new Flow.Subscriber<>() {
 | 
	
		
			
				|  |  | +                @Override
 | 
	
		
			
				|  |  | +                public void onSubscribe(Flow.Subscription subscription) {
 | 
	
		
			
				|  |  | +                    downstream.onSubscribe(new Flow.Subscription() {
 | 
	
		
			
				|  |  | +                        @Override
 | 
	
		
			
				|  |  | +                        public void request(long n) {
 | 
	
		
			
				|  |  | +                            subscription.request(n);
 | 
	
		
			
				|  |  | +                        }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                        @Override
 | 
	
		
			
				|  |  | +                        public void cancel() {
 | 
	
		
			
				|  |  | +                            recordRequestDurationMetrics(model, timer, request, localNodeId, null);
 | 
	
		
			
				|  |  | +                            subscription.cancel();
 | 
	
		
			
				|  |  | +                        }
 | 
	
		
			
				|  |  | +                    });
 | 
	
		
			
				|  |  | +                }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                @Override
 | 
	
		
			
				|  |  | +                public void onNext(T item) {
 | 
	
		
			
				|  |  | +                    downstream.onNext(item);
 | 
	
		
			
				|  |  | +                }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                @Override
 | 
	
		
			
				|  |  | +                public void onError(Throwable throwable) {
 | 
	
		
			
				|  |  | +                    recordRequestDurationMetrics(model, timer, request, localNodeId, throwable);
 | 
	
		
			
				|  |  | +                    downstream.onError(throwable);
 | 
	
		
			
				|  |  | +                }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +                @Override
 | 
	
		
			
				|  |  | +                public void onComplete() {
 | 
	
		
			
				|  |  | +                    recordRequestDurationMetrics(model, timer, request, localNodeId, null);
 | 
	
		
			
				|  |  | +                    downstream.onComplete();
 | 
	
		
			
				|  |  | +                }
 | 
	
		
			
				|  |  | +            });
 | 
	
		
			
				|  |  | +        };
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +    protected <T> Flow.Publisher<T> streamErrorHandler(Flow.Publisher<T> upstream) {
 | 
	
		
			
				|  |  |          return upstream;
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -386,44 +429,6 @@ public abstract class BaseTransportInferenceAction<Request extends BaseInference
 | 
	
		
			
				|  |  |          );
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    private class PublisherWithMetrics extends DelegatingProcessor<InferenceServiceResults.Result, InferenceServiceResults.Result> {
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        private final InferenceTimer timer;
 | 
	
		
			
				|  |  | -        private final Model model;
 | 
	
		
			
				|  |  | -        private final Request request;
 | 
	
		
			
				|  |  | -        private final String localNodeId;
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        private PublisherWithMetrics(InferenceTimer timer, Model model, Request request, String localNodeId) {
 | 
	
		
			
				|  |  | -            this.timer = timer;
 | 
	
		
			
				|  |  | -            this.model = model;
 | 
	
		
			
				|  |  | -            this.request = request;
 | 
	
		
			
				|  |  | -            this.localNodeId = localNodeId;
 | 
	
		
			
				|  |  | -        }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        @Override
 | 
	
		
			
				|  |  | -        protected void next(InferenceServiceResults.Result item) {
 | 
	
		
			
				|  |  | -            downstream().onNext(item);
 | 
	
		
			
				|  |  | -        }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        @Override
 | 
	
		
			
				|  |  | -        public void onError(Throwable throwable) {
 | 
	
		
			
				|  |  | -            recordRequestDurationMetrics(model, timer, request, localNodeId, throwable);
 | 
	
		
			
				|  |  | -            super.onError(throwable);
 | 
	
		
			
				|  |  | -        }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        @Override
 | 
	
		
			
				|  |  | -        protected void onCancel() {
 | 
	
		
			
				|  |  | -            recordRequestDurationMetrics(model, timer, request, localNodeId, null);
 | 
	
		
			
				|  |  | -            super.onCancel();
 | 
	
		
			
				|  |  | -        }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -        @Override
 | 
	
		
			
				|  |  | -        public void onComplete() {
 | 
	
		
			
				|  |  | -            recordRequestDurationMetrics(model, timer, request, localNodeId, null);
 | 
	
		
			
				|  |  | -            super.onComplete();
 | 
	
		
			
				|  |  | -        }
 | 
	
		
			
				|  |  | -    }
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  |      private record NodeRoutingDecision(boolean currentNodeShouldHandleRequest, DiscoveryNode targetNode) {
 | 
	
		
			
				|  |  |          static NodeRoutingDecision handleLocally() {
 | 
	
		
			
				|  |  |              return new NodeRoutingDecision(true, null);
 |