Browse Source

[ML] Refactor stream metrics (#125092)

Remove the use of DelegatingProcessor and replace it with an inline
processor.
Pat Whelan 7 months ago
parent
commit
b8db0eee96

+ 47 - 42
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java

@@ -39,7 +39,6 @@ import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest;
 import org.elasticsearch.xpack.core.inference.action.InferenceAction;
 import org.elasticsearch.xpack.inference.InferencePlugin;
 import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
-import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
 import org.elasticsearch.xpack.inference.common.InferenceServiceNodeLocalRateLimitCalculator;
 import org.elasticsearch.xpack.inference.common.InferenceServiceRateLimitCalculator;
 import org.elasticsearch.xpack.inference.registry.ModelRegistry;
@@ -297,8 +296,7 @@ public abstract class BaseTransportInferenceAction<Request extends BaseInference
                 );
                 inferenceResults.publisher().subscribe(taskProcessor);
 
-                var instrumentedStream = new PublisherWithMetrics(timer, model, request, localNodeId);
-                taskProcessor.subscribe(instrumentedStream);
+                var instrumentedStream = publisherWithMetrics(timer, model, request, localNodeId, taskProcessor);
 
                 var streamErrorHandler = streamErrorHandler(instrumentedStream);
 
@@ -313,7 +311,52 @@ public abstract class BaseTransportInferenceAction<Request extends BaseInference
         }));
     }
 
-    protected <T> Flow.Publisher<T> streamErrorHandler(Flow.Processor<T, T> upstream) {
+    private <T> Flow.Publisher<T> publisherWithMetrics(
+        InferenceTimer timer,
+        Model model,
+        Request request,
+        String localNodeId,
+        Flow.Processor<T, T> upstream
+    ) {
+        return downstream -> {
+            upstream.subscribe(new Flow.Subscriber<>() {
+                @Override
+                public void onSubscribe(Flow.Subscription subscription) {
+                    downstream.onSubscribe(new Flow.Subscription() {
+                        @Override
+                        public void request(long n) {
+                            subscription.request(n);
+                        }
+
+                        @Override
+                        public void cancel() {
+                            recordRequestDurationMetrics(model, timer, request, localNodeId, null);
+                            subscription.cancel();
+                        }
+                    });
+                }
+
+                @Override
+                public void onNext(T item) {
+                    downstream.onNext(item);
+                }
+
+                @Override
+                public void onError(Throwable throwable) {
+                    recordRequestDurationMetrics(model, timer, request, localNodeId, throwable);
+                    downstream.onError(throwable);
+                }
+
+                @Override
+                public void onComplete() {
+                    recordRequestDurationMetrics(model, timer, request, localNodeId, null);
+                    downstream.onComplete();
+                }
+            });
+        };
+    }
+
+    protected <T> Flow.Publisher<T> streamErrorHandler(Flow.Publisher<T> upstream) {
         return upstream;
     }
 
@@ -386,44 +429,6 @@ public abstract class BaseTransportInferenceAction<Request extends BaseInference
         );
     }
 
-    private class PublisherWithMetrics extends DelegatingProcessor<InferenceServiceResults.Result, InferenceServiceResults.Result> {
-
-        private final InferenceTimer timer;
-        private final Model model;
-        private final Request request;
-        private final String localNodeId;
-
-        private PublisherWithMetrics(InferenceTimer timer, Model model, Request request, String localNodeId) {
-            this.timer = timer;
-            this.model = model;
-            this.request = request;
-            this.localNodeId = localNodeId;
-        }
-
-        @Override
-        protected void next(InferenceServiceResults.Result item) {
-            downstream().onNext(item);
-        }
-
-        @Override
-        public void onError(Throwable throwable) {
-            recordRequestDurationMetrics(model, timer, request, localNodeId, throwable);
-            super.onError(throwable);
-        }
-
-        @Override
-        protected void onCancel() {
-            recordRequestDurationMetrics(model, timer, request, localNodeId, null);
-            super.onCancel();
-        }
-
-        @Override
-        public void onComplete() {
-            recordRequestDurationMetrics(model, timer, request, localNodeId, null);
-            super.onComplete();
-        }
-    }
-
     private record NodeRoutingDecision(boolean currentNodeShouldHandleRequest, DiscoveryNode targetNode) {
         static NodeRoutingDecision handleLocally() {
             return new NodeRoutingDecision(true, null);

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java

@@ -102,7 +102,7 @@ public class TransportUnifiedCompletionInferenceAction extends BaseTransportInfe
      * as {@link UnifiedChatCompletionException}.
      */
     @Override
-    protected <T> Flow.Publisher<T> streamErrorHandler(Flow.Processor<T, T> upstream) {
+    protected <T> Flow.Publisher<T> streamErrorHandler(Flow.Publisher<T> upstream) {
         return downstream -> {
             upstream.subscribe(new Flow.Subscriber<>() {
                 @Override

+ 3 - 6
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java

@@ -291,7 +291,7 @@ public abstract class BaseTransportInferenceActionTestCase<Request extends BaseI
     }
 
     public void testMetricsAfterStreamInferSuccess() {
-        mockStreamResponse(Flow.Subscriber::onComplete);
+        mockStreamResponse(Flow.Subscriber::onComplete).subscribe(mock());
         verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
             assertThat(attributes.get("service"), is(serviceId));
             assertThat(attributes.get("task_type"), is(taskType.toString()));
@@ -306,10 +306,7 @@ public abstract class BaseTransportInferenceActionTestCase<Request extends BaseI
     public void testMetricsAfterStreamInferFailure() {
         var expectedException = new IllegalStateException("hello");
         var expectedError = expectedException.getClass().getSimpleName();
-        mockStreamResponse(subscriber -> {
-            subscriber.subscribe(mock());
-            subscriber.onError(expectedException);
-        });
+        mockStreamResponse(subscriber -> subscriber.onError(expectedException)).subscribe(mock());
         verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
             assertThat(attributes.get("service"), is(serviceId));
             assertThat(attributes.get("task_type"), is(taskType.toString()));
@@ -388,7 +385,7 @@ public abstract class BaseTransportInferenceActionTestCase<Request extends BaseI
         assertThat(threadContext.getHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER), is(productUseCase));
     }
 
-    protected Flow.Publisher<InferenceServiceResults.Result> mockStreamResponse(Consumer<Flow.Processor<?, ?>> action) {
+    protected Flow.Publisher<InferenceServiceResults.Result> mockStreamResponse(Consumer<Flow.Subscriber<?>> action) {
         mockService(true, Set.of(), listener -> {
             Flow.Processor<ChunkedToXContent, ChunkedToXContent> taskProcessor = mock();
             doAnswer(innerAns -> {