소스 검색

[8.x] [ML] Inference duration and error metrics (#115876) (#118700)

* [ML] Inference duration and error metrics (#115876)

Add `es.inference.requests.time` metric around `infer` API.

As recommended by OTel spec, errors are determined by the
presence or absence of the `error.type` attribute in the metric.
"error.type" will be the http status code (as a string) if it is
available, otherwise it will be the name of the exception (e.g.
NullPointerException).

Additional notes:
- ApmInferenceStats is merged into InferenceStats. Originally we planned
  to have multiple implementations, but now we're only using APM.
- Request count is now always recorded, even when there are failures
  loading the endpoint configuration.
- Added a hook in streaming for cancel messages, so we can close the
  metrics when a user cancels the stream.

(cherry picked from commit 26870ef38d3c76b6d78897d8657cad77ce0bc35c)

* fixing switch with class issue

---------

Co-authored-by: Pat Whelan <pat.whelan@elastic.co>
Jonathan Buttner 10 달 전
부모
커밋
7eaf3805e6

+ 5 - 0
docs/changelog/115876.yaml

@@ -0,0 +1,5 @@
+pr: 115876
+summary: Inference duration and error metrics
+area: Machine Learning
+type: enhancement
+issues: []

+ 1 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

@@ -105,7 +105,6 @@ import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceE
 import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService;
 import org.elasticsearch.xpack.inference.services.mistral.MistralService;
 import org.elasticsearch.xpack.inference.services.openai.OpenAiService;
-import org.elasticsearch.xpack.inference.telemetry.ApmInferenceStats;
 import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
 
 import java.util.ArrayList;
@@ -240,7 +239,7 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP
         shardBulkInferenceActionFilter.set(actionFilter);
 
         var meterRegistry = services.telemetryProvider().getMeterRegistry();
-        var stats = new PluginComponentBinding<>(InferenceStats.class, ApmInferenceStats.create(meterRegistry));
+        var stats = new PluginComponentBinding<>(InferenceStats.class, InferenceStats.create(meterRegistry));
 
         return List.of(modelRegistry, registry, httpClientManager, stats);
     }

+ 102 - 24
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java

@@ -7,6 +7,8 @@
 
 package org.elasticsearch.xpack.inference.action;
 
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.ActionFilters;
@@ -14,6 +16,7 @@ import org.elasticsearch.action.support.HandledTransportAction;
 import org.elasticsearch.common.logging.DeprecationLogger;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
 import org.elasticsearch.common.xcontent.ChunkedToXContent;
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.inference.InferenceService;
 import org.elasticsearch.inference.InferenceServiceRegistry;
 import org.elasticsearch.inference.InferenceServiceResults;
@@ -26,20 +29,22 @@ import org.elasticsearch.tasks.Task;
 import org.elasticsearch.transport.TransportService;
 import org.elasticsearch.xpack.core.inference.action.InferenceAction;
 import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
+import org.elasticsearch.xpack.inference.common.DelegatingProcessor;
 import org.elasticsearch.xpack.inference.registry.ModelRegistry;
 import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
+import org.elasticsearch.xpack.inference.telemetry.InferenceTimer;
 
-import java.util.Set;
 import java.util.stream.Collectors;
 
 import static org.elasticsearch.core.Strings.format;
+import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes;
+import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes;
 
 public class TransportInferenceAction extends HandledTransportAction<InferenceAction.Request, InferenceAction.Response> {
+    private static final Logger log = LogManager.getLogger(TransportInferenceAction.class);
     private static final String STREAMING_INFERENCE_TASK_TYPE = "streaming_inference";
     private static final String STREAMING_TASK_ACTION = "xpack/inference/streaming_inference[n]";
 
-    private static final Set<Class<? extends InferenceService>> supportsStreaming = Set.of();
-
     private final ModelRegistry modelRegistry;
     private final InferenceServiceRegistry serviceRegistry;
     private final InferenceStats inferenceStats;
@@ -64,17 +69,22 @@ public class TransportInferenceAction extends HandledTransportAction<InferenceAc
 
     @Override
     protected void doExecute(Task task, InferenceAction.Request request, ActionListener<InferenceAction.Response> listener) {
+        var timer = InferenceTimer.start();
 
-        ActionListener<UnparsedModel> getModelListener = listener.delegateFailureAndWrap((delegate, unparsedModel) -> {
+        var getModelListener = ActionListener.wrap((UnparsedModel unparsedModel) -> {
             var service = serviceRegistry.getService(unparsedModel.service());
             if (service.isEmpty()) {
-                listener.onFailure(unknownServiceException(unparsedModel.service(), request.getInferenceEntityId()));
+                var e = unknownServiceException(unparsedModel.service(), request.getInferenceEntityId());
+                recordMetrics(unparsedModel, timer, e);
+                listener.onFailure(e);
                 return;
             }
 
             if (request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false) {
                 // not the wildcard task type and not the model task type
-                listener.onFailure(incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType()));
+                var e = incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType());
+                recordMetrics(unparsedModel, timer, e);
+                listener.onFailure(e);
                 return;
             }
 
@@ -85,20 +95,69 @@ public class TransportInferenceAction extends HandledTransportAction<InferenceAc
                     unparsedModel.settings(),
                     unparsedModel.secrets()
                 );
-            inferOnService(model, request, service.get(), delegate);
+            inferOnServiceWithMetrics(model, request, service.get(), timer, listener);
+        }, e -> {
+            try {
+                inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(e));
+            } catch (Exception metricsException) {
+                log.atDebug().withThrowable(metricsException).log("Failed to record metrics when the model is missing, dropping metrics");
+            }
+            listener.onFailure(e);
         });
 
         modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener);
     }
 
-    private void inferOnService(
+    private void recordMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) {
+        try {
+            inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t));
+        } catch (Exception e) {
+            log.atDebug().withThrowable(e).log("Failed to record metrics with an unparsed model, dropping metrics");
+        }
+    }
+
+    private void inferOnServiceWithMetrics(
         Model model,
         InferenceAction.Request request,
         InferenceService service,
+        InferenceTimer timer,
         ActionListener<InferenceAction.Response> listener
+    ) {
+        inferenceStats.requestCount().incrementBy(1, modelAttributes(model));
+        inferOnService(model, request, service, ActionListener.wrap(inferenceResults -> {
+            if (request.isStreaming()) {
+                var taskProcessor = streamingTaskManager.<ChunkedToXContent>create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION);
+                inferenceResults.publisher().subscribe(taskProcessor);
+
+                var instrumentedStream = new PublisherWithMetrics(timer, model);
+                taskProcessor.subscribe(instrumentedStream);
+
+                listener.onResponse(new InferenceAction.Response(inferenceResults, instrumentedStream));
+            } else {
+                recordMetrics(model, timer, null);
+                listener.onResponse(new InferenceAction.Response(inferenceResults));
+            }
+        }, e -> {
+            recordMetrics(model, timer, e);
+            listener.onFailure(e);
+        }));
+    }
+
+    private void recordMetrics(Model model, InferenceTimer timer, @Nullable Throwable t) {
+        try {
+            inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t));
+        } catch (Exception e) {
+            log.atDebug().withThrowable(e).log("Failed to record metrics with a parsed model, dropping metrics");
+        }
+    }
+
+    private void inferOnService(
+        Model model,
+        InferenceAction.Request request,
+        InferenceService service,
+        ActionListener<InferenceServiceResults> listener
     ) {
         if (request.isStreaming() == false || service.canStream(request.getTaskType())) {
-            inferenceStats.incrementRequestCount(model);
             service.infer(
                 model,
                 request.getQuery(),
@@ -107,7 +166,7 @@ public class TransportInferenceAction extends HandledTransportAction<InferenceAc
                 request.getTaskSettings(),
                 request.getInputType(),
                 request.getInferenceTimeout(),
-                createListener(request, listener)
+                listener
             );
         } else {
             listener.onFailure(unsupportedStreamingTaskException(request, service));
@@ -135,20 +194,6 @@ public class TransportInferenceAction extends HandledTransportAction<InferenceAc
         }
     }
 
-    private ActionListener<InferenceServiceResults> createListener(
-        InferenceAction.Request request,
-        ActionListener<InferenceAction.Response> listener
-    ) {
-        if (request.isStreaming()) {
-            return listener.delegateFailureAndWrap((l, inferenceResults) -> {
-                var taskProcessor = streamingTaskManager.<ChunkedToXContent>create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION);
-                inferenceResults.publisher().subscribe(taskProcessor);
-                l.onResponse(new InferenceAction.Response(inferenceResults, taskProcessor));
-            });
-        }
-        return listener.delegateFailureAndWrap((l, inferenceResults) -> l.onResponse(new InferenceAction.Response(inferenceResults)));
-    }
-
     private static ElasticsearchStatusException unknownServiceException(String service, String inferenceId) {
         return new ElasticsearchStatusException("Unknown service [{}] for model [{}]. ", RestStatus.BAD_REQUEST, service, inferenceId);
     }
@@ -162,4 +207,37 @@ public class TransportInferenceAction extends HandledTransportAction<InferenceAc
         );
     }
 
+    private class PublisherWithMetrics extends DelegatingProcessor<ChunkedToXContent, ChunkedToXContent> {
+        private final InferenceTimer timer;
+        private final Model model;
+
+        private PublisherWithMetrics(InferenceTimer timer, Model model) {
+            this.timer = timer;
+            this.model = model;
+        }
+
+        @Override
+        protected void next(ChunkedToXContent item) {
+            downstream().onNext(item);
+        }
+
+        @Override
+        public void onError(Throwable throwable) {
+            recordMetrics(model, timer, throwable);
+            super.onError(throwable);
+        }
+
+        @Override
+        protected void onCancel() {
+            recordMetrics(model, timer, null);
+            super.onCancel();
+        }
+
+        @Override
+        public void onComplete() {
+            recordMetrics(model, timer, null);
+            super.onComplete();
+        }
+    }
+
 }

+ 3 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java

@@ -61,11 +61,14 @@ public abstract class DelegatingProcessor<T, R> implements Flow.Processor<T, R>
             public void cancel() {
                 if (isClosed.compareAndSet(false, true) && upstream != null) {
                     upstream.cancel();
+                    onCancel();
                 }
             }
         };
     }
 
+    protected void onCancel() {}
+
     @Override
     public void onSubscribe(Flow.Subscription subscription) {
         if (upstream != null) {

+ 0 - 49
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/ApmInferenceStats.java

@@ -1,49 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.inference.telemetry;
-
-import org.elasticsearch.inference.Model;
-import org.elasticsearch.telemetry.metric.LongCounter;
-import org.elasticsearch.telemetry.metric.MeterRegistry;
-
-import java.util.HashMap;
-import java.util.Objects;
-
-public class ApmInferenceStats implements InferenceStats {
-    private final LongCounter inferenceAPMRequestCounter;
-
-    public ApmInferenceStats(LongCounter inferenceAPMRequestCounter) {
-        this.inferenceAPMRequestCounter = Objects.requireNonNull(inferenceAPMRequestCounter);
-    }
-
-    @Override
-    public void incrementRequestCount(Model model) {
-        var service = model.getConfigurations().getService();
-        var taskType = model.getTaskType();
-        var modelId = model.getServiceSettings().modelId();
-
-        var attributes = new HashMap<String, Object>(5);
-        attributes.put("service", service);
-        attributes.put("task_type", taskType.toString());
-        if (modelId != null) {
-            attributes.put("model_id", modelId);
-        }
-
-        inferenceAPMRequestCounter.incrementBy(1, attributes);
-    }
-
-    public static ApmInferenceStats create(MeterRegistry meterRegistry) {
-        return new ApmInferenceStats(
-            meterRegistry.registerLongCounter(
-                "es.inference.requests.count.total",
-                "Inference API request counts for a particular service, task type, model ID",
-                "operations"
-            )
-        );
-    }
-}

+ 81 - 7
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceStats.java

@@ -7,15 +7,89 @@
 
 package org.elasticsearch.xpack.inference.telemetry;
 
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.UnparsedModel;
+import org.elasticsearch.telemetry.metric.LongCounter;
+import org.elasticsearch.telemetry.metric.LongHistogram;
+import org.elasticsearch.telemetry.metric.MeterRegistry;
 
-public interface InferenceStats {
+import java.util.Map;
+import java.util.Objects;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
 
-    /**
-     * Increment the counter for a particular value in a thread safe manner.
-     * @param model the model to increment request count for
-     */
-    void incrementRequestCount(Model model);
+import static java.util.Map.entry;
+import static java.util.stream.Stream.concat;
 
-    InferenceStats NOOP = model -> {};
+public record InferenceStats(LongCounter requestCount, LongHistogram inferenceDuration) {
+
+    public InferenceStats {
+        Objects.requireNonNull(requestCount);
+        Objects.requireNonNull(inferenceDuration);
+    }
+
+    public static InferenceStats create(MeterRegistry meterRegistry) {
+        return new InferenceStats(
+            meterRegistry.registerLongCounter(
+                "es.inference.requests.count.total",
+                "Inference API request counts for a particular service, task type, model ID",
+                "operations"
+            ),
+            meterRegistry.registerLongHistogram(
+                "es.inference.requests.time",
+                "Inference API request counts for a particular service, task type, model ID",
+                "ms"
+            )
+        );
+    }
+
+    public static Map<String, Object> modelAttributes(Model model) {
+        return toMap(modelAttributeEntries(model));
+    }
+
+    private static Stream<Map.Entry<String, Object>> modelAttributeEntries(Model model) {
+        var stream = Stream.<Map.Entry<String, Object>>builder()
+            .add(entry("service", model.getConfigurations().getService()))
+            .add(entry("task_type", model.getTaskType().toString()));
+        if (model.getServiceSettings().modelId() != null) {
+            stream.add(entry("model_id", model.getServiceSettings().modelId()));
+        }
+        return stream.build();
+    }
+
+    private static Map<String, Object> toMap(Stream<Map.Entry<String, Object>> stream) {
+        return stream.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
+    }
+
+    public static Map<String, Object> responseAttributes(Model model, @Nullable Throwable t) {
+        return toMap(concat(modelAttributeEntries(model), errorAttributes(t)));
+    }
+
+    public static Map<String, Object> responseAttributes(UnparsedModel model, @Nullable Throwable t) {
+        var unknownModelAttributes = Stream.<Map.Entry<String, Object>>builder()
+            .add(entry("service", model.service()))
+            .add(entry("task_type", model.taskType().toString()))
+            .build();
+
+        return toMap(concat(unknownModelAttributes, errorAttributes(t)));
+    }
+
+    public static Map<String, Object> responseAttributes(@Nullable Throwable t) {
+        return toMap(errorAttributes(t));
+    }
+
+    private static Stream<Map.Entry<String, Object>> errorAttributes(@Nullable Throwable t) {
+        if (t == null) {
+            return Stream.of(entry("status_code", 200));
+        } else if (t instanceof ElasticsearchStatusException ese) {
+            return Stream.<Map.Entry<String, Object>>builder()
+                .add(entry("status_code", ese.status().getStatus()))
+                .add(entry("error.type", String.valueOf(ese.status().getStatus())))
+                .build();
+        } else {
+            return Stream.of(entry("error.type", t.getClass().getSimpleName()));
+        }
+    }
 }

+ 33 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceTimer.java

@@ -0,0 +1,33 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.telemetry;
+
+import java.time.Clock;
+import java.time.Duration;
+import java.time.Instant;
+import java.util.Objects;
+
+public record InferenceTimer(Instant startTime, Clock clock) {
+
+    public InferenceTimer {
+        Objects.requireNonNull(startTime);
+        Objects.requireNonNull(clock);
+    }
+
+    public static InferenceTimer start() {
+        return start(Clock.systemUTC());
+    }
+
+    public static InferenceTimer start(Clock clock) {
+        return new InferenceTimer(clock.instant(), clock);
+    }
+
+    public long elapsedMillis() {
+        return Duration.between(startTime(), clock().instant()).toMillis();
+    }
+}

+ 354 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java

@@ -0,0 +1,354 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.action;
+
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.action.support.ActionFilters;
+import org.elasticsearch.common.xcontent.ChunkedToXContent;
+import org.elasticsearch.inference.InferenceService;
+import org.elasticsearch.inference.InferenceServiceRegistry;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.inference.UnparsedModel;
+import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.transport.TransportService;
+import org.elasticsearch.xpack.core.inference.action.InferenceAction;
+import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
+import org.elasticsearch.xpack.inference.registry.ModelRegistry;
+import org.elasticsearch.xpack.inference.telemetry.InferenceStats;
+import org.junit.Before;
+import org.mockito.ArgumentCaptor;
+
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.concurrent.Flow;
+import java.util.function.Consumer;
+
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.isA;
+import static org.hamcrest.Matchers.nullValue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyBoolean;
+import static org.mockito.ArgumentMatchers.anyLong;
+import static org.mockito.ArgumentMatchers.assertArg;
+import static org.mockito.ArgumentMatchers.same;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+public class TransportInferenceActionTests extends ESTestCase {
+    private static final String serviceId = "serviceId";
+    private static final TaskType taskType = TaskType.COMPLETION;
+    private static final String inferenceId = "inferenceEntityId";
+    private ModelRegistry modelRegistry;
+    private InferenceServiceRegistry serviceRegistry;
+    private InferenceStats inferenceStats;
+    private StreamingTaskManager streamingTaskManager;
+    private TransportInferenceAction action;
+
+    @Before
+    public void setUp() throws Exception {
+        super.setUp();
+        TransportService transportService = mock();
+        ActionFilters actionFilters = mock();
+        modelRegistry = mock();
+        serviceRegistry = mock();
+        inferenceStats = new InferenceStats(mock(), mock());
+        streamingTaskManager = mock();
+        action = new TransportInferenceAction(
+            transportService,
+            actionFilters,
+            modelRegistry,
+            serviceRegistry,
+            inferenceStats,
+            streamingTaskManager
+        );
+    }
+
+    public void testMetricsAfterModelRegistryError() {
+        var expectedException = new IllegalStateException("hello");
+        var expectedError = expectedException.getClass().getSimpleName();
+
+        doAnswer(ans -> {
+            ActionListener<?> listener = ans.getArgument(1);
+            listener.onFailure(expectedException);
+            return null;
+        }).when(modelRegistry).getModelWithSecrets(any(), any());
+
+        var listener = doExecute(taskType);
+        verify(listener).onFailure(same(expectedException));
+
+        verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
+            assertThat(attributes.get("service"), nullValue());
+            assertThat(attributes.get("task_type"), nullValue());
+            assertThat(attributes.get("model_id"), nullValue());
+            assertThat(attributes.get("status_code"), nullValue());
+            assertThat(attributes.get("error.type"), is(expectedError));
+        }));
+    }
+
+    private ActionListener<InferenceAction.Response> doExecute(TaskType taskType) {
+        return doExecute(taskType, false);
+    }
+
+    private ActionListener<InferenceAction.Response> doExecute(TaskType taskType, boolean stream) {
+        InferenceAction.Request request = mock();
+        when(request.getInferenceEntityId()).thenReturn(inferenceId);
+        when(request.getTaskType()).thenReturn(taskType);
+        when(request.isStreaming()).thenReturn(stream);
+        ActionListener<InferenceAction.Response> listener = mock();
+        action.doExecute(mock(), request, listener);
+        return listener;
+    }
+
+    public void testMetricsAfterMissingService() {
+        mockModelRegistry(taskType);
+
+        when(serviceRegistry.getService(any())).thenReturn(Optional.empty());
+
+        var listener = doExecute(taskType);
+
+        verify(listener).onFailure(assertArg(e -> {
+            assertThat(e, isA(ElasticsearchStatusException.class));
+            assertThat(e.getMessage(), is("Unknown service [" + serviceId + "] for model [" + inferenceId + "]. "));
+            assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST));
+        }));
+        verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
+            assertThat(attributes.get("service"), is(serviceId));
+            assertThat(attributes.get("task_type"), is(taskType.toString()));
+            assertThat(attributes.get("model_id"), nullValue());
+            assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus()));
+            assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus())));
+        }));
+    }
+
+    private void mockModelRegistry(TaskType expectedTaskType) {
+        var unparsedModel = new UnparsedModel(inferenceId, expectedTaskType, serviceId, Map.of(), Map.of());
+        doAnswer(ans -> {
+            ActionListener<UnparsedModel> listener = ans.getArgument(1);
+            listener.onResponse(unparsedModel);
+            return null;
+        }).when(modelRegistry).getModelWithSecrets(any(), any());
+    }
+
+    public void testMetricsAfterUnknownTaskType() {
+        var modelTaskType = TaskType.RERANK;
+        var requestTaskType = TaskType.SPARSE_EMBEDDING;
+        mockModelRegistry(modelTaskType);
+        when(serviceRegistry.getService(any())).thenReturn(Optional.of(mock()));
+
+        var listener = doExecute(requestTaskType);
+
+        verify(listener).onFailure(assertArg(e -> {
+            assertThat(e, isA(ElasticsearchStatusException.class));
+            assertThat(
+                e.getMessage(),
+                is(
+                    "Incompatible task_type, the requested type ["
+                        + requestTaskType
+                        + "] does not match the model type ["
+                        + modelTaskType
+                        + "]"
+                )
+            );
+            assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST));
+        }));
+        verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
+            assertThat(attributes.get("service"), is(serviceId));
+            assertThat(attributes.get("task_type"), is(modelTaskType.toString()));
+            assertThat(attributes.get("model_id"), nullValue());
+            assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus()));
+            assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus())));
+        }));
+    }
+
+    public void testMetricsAfterInferError() {
+        var expectedException = new IllegalStateException("hello");
+        var expectedError = expectedException.getClass().getSimpleName();
+        mockService(listener -> listener.onFailure(expectedException));
+
+        var listener = doExecute(taskType);
+
+        verify(listener).onFailure(same(expectedException));
+        verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
+            assertThat(attributes.get("service"), is(serviceId));
+            assertThat(attributes.get("task_type"), is(taskType.toString()));
+            assertThat(attributes.get("model_id"), nullValue());
+            assertThat(attributes.get("status_code"), nullValue());
+            assertThat(attributes.get("error.type"), is(expectedError));
+        }));
+    }
+
+    public void testMetricsAfterStreamUnsupported() {
+        var expectedStatus = RestStatus.METHOD_NOT_ALLOWED;
+        var expectedError = String.valueOf(expectedStatus.getStatus());
+        mockService(l -> {});
+
+        var listener = doExecute(taskType, true);
+
+        verify(listener).onFailure(assertArg(e -> {
+            assertThat(e, isA(ElasticsearchStatusException.class));
+            var ese = (ElasticsearchStatusException) e;
+            assertThat(ese.getMessage(), is("Streaming is not allowed for service [" + serviceId + "]."));
+            assertThat(ese.status(), is(expectedStatus));
+        }));
+        verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
+            assertThat(attributes.get("service"), is(serviceId));
+            assertThat(attributes.get("task_type"), is(taskType.toString()));
+            assertThat(attributes.get("model_id"), nullValue());
+            assertThat(attributes.get("status_code"), is(expectedStatus.getStatus()));
+            assertThat(attributes.get("error.type"), is(expectedError));
+        }));
+    }
+
+    public void testMetricsAfterInferSuccess() {
+        mockService(listener -> listener.onResponse(mock()));
+
+        var listener = doExecute(taskType);
+
+        verify(listener).onResponse(any());
+        verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
+            assertThat(attributes.get("service"), is(serviceId));
+            assertThat(attributes.get("task_type"), is(taskType.toString()));
+            assertThat(attributes.get("model_id"), nullValue());
+            assertThat(attributes.get("status_code"), is(200));
+            assertThat(attributes.get("error.type"), nullValue());
+        }));
+    }
+
+    public void testMetricsAfterStreamInferSuccess() {
+        mockStreamResponse(Flow.Subscriber::onComplete);
+        verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
+            assertThat(attributes.get("service"), is(serviceId));
+            assertThat(attributes.get("task_type"), is(taskType.toString()));
+            assertThat(attributes.get("model_id"), nullValue());
+            assertThat(attributes.get("status_code"), is(200));
+            assertThat(attributes.get("error.type"), nullValue());
+        }));
+    }
+
+    public void testMetricsAfterStreamInferFailure() {
+        var expectedException = new IllegalStateException("hello");
+        var expectedError = expectedException.getClass().getSimpleName();
+        mockStreamResponse(subscriber -> {
+            subscriber.subscribe(mock());
+            subscriber.onError(expectedException);
+        });
+        verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
+            assertThat(attributes.get("service"), is(serviceId));
+            assertThat(attributes.get("task_type"), is(taskType.toString()));
+            assertThat(attributes.get("model_id"), nullValue());
+            assertThat(attributes.get("status_code"), nullValue());
+            assertThat(attributes.get("error.type"), is(expectedError));
+        }));
+    }
+
+    public void testMetricsAfterStreamCancel() {
+        var response = mockStreamResponse(s -> s.onSubscribe(mock()));
+        response.subscribe(new Flow.Subscriber<>() {
+            @Override
+            public void onSubscribe(Flow.Subscription subscription) {
+                subscription.cancel();
+            }
+
+            @Override
+            public void onNext(ChunkedToXContent item) {
+
+            }
+
+            @Override
+            public void onError(Throwable throwable) {
+
+            }
+
+            @Override
+            public void onComplete() {
+
+            }
+        });
+
+        verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> {
+            assertThat(attributes.get("service"), is(serviceId));
+            assertThat(attributes.get("task_type"), is(taskType.toString()));
+            assertThat(attributes.get("model_id"), nullValue());
+            assertThat(attributes.get("status_code"), is(200));
+            assertThat(attributes.get("error.type"), nullValue());
+        }));
+    }
+
+    private Flow.Publisher<ChunkedToXContent> mockStreamResponse(Consumer<Flow.Processor<?, ?>> action) {
+        mockService(true, Set.of(), listener -> {
+            Flow.Processor<ChunkedToXContent, ChunkedToXContent> taskProcessor = mock();
+            doAnswer(innerAns -> {
+                action.accept(innerAns.getArgument(0));
+                return null;
+            }).when(taskProcessor).subscribe(any());
+            when(streamingTaskManager.<ChunkedToXContent>create(any(), any())).thenReturn(taskProcessor);
+            var inferenceServiceResults = mock(InferenceServiceResults.class);
+            when(inferenceServiceResults.publisher()).thenReturn(mock());
+            listener.onResponse(inferenceServiceResults);
+        });
+
+        var listener = doExecute(taskType, true);
+        var captor = ArgumentCaptor.forClass(InferenceAction.Response.class);
+        verify(listener).onResponse(captor.capture());
+        assertTrue(captor.getValue().isStreaming());
+        assertNotNull(captor.getValue().publisher());
+        return captor.getValue().publisher();
+    }
+
+    private void mockService(Consumer<ActionListener<InferenceServiceResults>> listenerAction) {
+        mockService(false, Set.of(), listenerAction);
+    }
+
+    private void mockService(
+        boolean stream,
+        Set<TaskType> supportedStreamingTasks,
+        Consumer<ActionListener<InferenceServiceResults>> listenerAction
+    ) {
+        InferenceService service = mock();
+        Model model = mockModel();
+        when(service.parsePersistedConfigWithSecrets(any(), any(), any(), any())).thenReturn(model);
+        when(service.name()).thenReturn(serviceId);
+
+        when(service.canStream(any())).thenReturn(stream);
+        when(service.supportedStreamingTasks()).thenReturn(supportedStreamingTasks);
+        doAnswer(ans -> {
+            listenerAction.accept(ans.getArgument(7));
+            return null;
+        }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any());
+        mockModelAndServiceRegistry(service);
+    }
+
+    private Model mockModel() {
+        Model model = mock();
+        ModelConfigurations modelConfigurations = mock();
+        when(modelConfigurations.getService()).thenReturn(serviceId);
+        when(model.getConfigurations()).thenReturn(modelConfigurations);
+        when(model.getTaskType()).thenReturn(taskType);
+        when(model.getServiceSettings()).thenReturn(mock());
+        return model;
+    }
+
+    private void mockModelAndServiceRegistry(InferenceService service) {
+        var unparsedModel = new UnparsedModel(inferenceId, taskType, serviceId, Map.of(), Map.of());
+        doAnswer(ans -> {
+            ActionListener<UnparsedModel> listener = ans.getArgument(1);
+            listener.onResponse(unparsedModel);
+            return null;
+        }).when(modelRegistry).getModelWithSecrets(any(), any());
+
+        when(serviceRegistry.getService(any())).thenReturn(Optional.of(service));
+    }
+}

+ 0 - 69
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/ApmInferenceStatsTests.java

@@ -1,69 +0,0 @@
-/*
- * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
- * or more contributor license agreements. Licensed under the Elastic License
- * 2.0; you may not use this file except in compliance with the Elastic License
- * 2.0.
- */
-
-package org.elasticsearch.xpack.inference.telemetry;
-
-import org.elasticsearch.inference.Model;
-import org.elasticsearch.inference.ModelConfigurations;
-import org.elasticsearch.inference.ServiceSettings;
-import org.elasticsearch.inference.TaskType;
-import org.elasticsearch.telemetry.metric.LongCounter;
-import org.elasticsearch.telemetry.metric.MeterRegistry;
-import org.elasticsearch.test.ESTestCase;
-
-import java.util.Map;
-
-import static org.mockito.ArgumentMatchers.eq;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
-
-public class ApmInferenceStatsTests extends ESTestCase {
-
-    public void testRecordWithModel() {
-        var longCounter = mock(LongCounter.class);
-
-        var stats = new ApmInferenceStats(longCounter);
-
-        stats.incrementRequestCount(model("service", TaskType.ANY, "modelId"));
-
-        verify(longCounter).incrementBy(
-            eq(1L),
-            eq(Map.of("service", "service", "task_type", TaskType.ANY.toString(), "model_id", "modelId"))
-        );
-    }
-
-    public void testRecordWithoutModel() {
-        var longCounter = mock(LongCounter.class);
-
-        var stats = new ApmInferenceStats(longCounter);
-
-        stats.incrementRequestCount(model("service", TaskType.ANY, null));
-
-        verify(longCounter).incrementBy(eq(1L), eq(Map.of("service", "service", "task_type", TaskType.ANY.toString())));
-    }
-
-    public void testCreation() {
-        assertNotNull(ApmInferenceStats.create(MeterRegistry.NOOP));
-    }
-
-    private Model model(String service, TaskType taskType, String modelId) {
-        var configuration = mock(ModelConfigurations.class);
-        when(configuration.getService()).thenReturn(service);
-        var settings = mock(ServiceSettings.class);
-        if (modelId != null) {
-            when(settings.modelId()).thenReturn(modelId);
-        }
-
-        var model = mock(Model.class);
-        when(model.getTaskType()).thenReturn(taskType);
-        when(model.getConfigurations()).thenReturn(configuration);
-        when(model.getServiceSettings()).thenReturn(settings);
-
-        return model;
-    }
-}

+ 217 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/InferenceStatsTests.java

@@ -0,0 +1,217 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.telemetry;
+
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.inference.Model;
+import org.elasticsearch.inference.ModelConfigurations;
+import org.elasticsearch.inference.ServiceSettings;
+import org.elasticsearch.inference.TaskType;
+import org.elasticsearch.inference.UnparsedModel;
+import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.telemetry.metric.LongCounter;
+import org.elasticsearch.telemetry.metric.LongHistogram;
+import org.elasticsearch.telemetry.metric.MeterRegistry;
+import org.elasticsearch.test.ESTestCase;
+
+import java.util.Map;
+
+import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes;
+import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.nullValue;
+import static org.mockito.ArgumentMatchers.assertArg;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+public class InferenceStatsTests extends ESTestCase {
+
+    public void testRecordWithModel() {
+        var longCounter = mock(LongCounter.class);
+        var stats = new InferenceStats(longCounter, mock());
+
+        stats.requestCount().incrementBy(1, modelAttributes(model("service", TaskType.ANY, "modelId")));
+
+        verify(longCounter).incrementBy(
+            eq(1L),
+            eq(Map.of("service", "service", "task_type", TaskType.ANY.toString(), "model_id", "modelId"))
+        );
+    }
+
+    public void testRecordWithoutModel() {
+        var longCounter = mock(LongCounter.class);
+        var stats = new InferenceStats(longCounter, mock());
+
+        stats.requestCount().incrementBy(1, modelAttributes(model("service", TaskType.ANY, null)));
+
+        verify(longCounter).incrementBy(eq(1L), eq(Map.of("service", "service", "task_type", TaskType.ANY.toString())));
+    }
+
+    public void testCreation() {
+        assertNotNull(InferenceStats.create(MeterRegistry.NOOP));
+    }
+
+    public void testRecordDurationWithoutError() {
+        var expectedLong = randomLong();
+        var histogramCounter = mock(LongHistogram.class);
+        var stats = new InferenceStats(mock(), histogramCounter);
+
+        stats.inferenceDuration().record(expectedLong, responseAttributes(model("service", TaskType.ANY, "modelId"), null));
+
+        verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> {
+            assertThat(attributes.get("service"), is("service"));
+            assertThat(attributes.get("task_type"), is(TaskType.ANY.toString()));
+            assertThat(attributes.get("model_id"), is("modelId"));
+            assertThat(attributes.get("status_code"), is(200));
+            assertThat(attributes.get("error.type"), nullValue());
+        }));
+    }
+
+    /**
+     * "If response status code was sent or received and status indicates an error according to HTTP span status definition,
+     * error.type SHOULD be set to the status code number (represented as a string)"
+     * - https://opentelemetry.io/docs/specs/semconv/http/http-metrics/
+     */
+    public void testRecordDurationWithElasticsearchStatusException() {
+        var expectedLong = randomLong();
+        var histogramCounter = mock(LongHistogram.class);
+        var stats = new InferenceStats(mock(), histogramCounter);
+        var statusCode = RestStatus.BAD_REQUEST;
+        var exception = new ElasticsearchStatusException("hello", statusCode);
+        var expectedError = String.valueOf(statusCode.getStatus());
+
+        stats.inferenceDuration().record(expectedLong, responseAttributes(model("service", TaskType.ANY, "modelId"), exception));
+
+        verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> {
+            assertThat(attributes.get("service"), is("service"));
+            assertThat(attributes.get("task_type"), is(TaskType.ANY.toString()));
+            assertThat(attributes.get("model_id"), is("modelId"));
+            assertThat(attributes.get("status_code"), is(statusCode.getStatus()));
+            assertThat(attributes.get("error.type"), is(expectedError));
+        }));
+    }
+
+    /**
+     * "If the request fails with an error before response status code was sent or received,
+     * error.type SHOULD be set to exception type"
+     * - https://opentelemetry.io/docs/specs/semconv/http/http-metrics/
+     */
+    public void testRecordDurationWithOtherException() {
+        var expectedLong = randomLong();
+        var histogramCounter = mock(LongHistogram.class);
+        var stats = new InferenceStats(mock(), histogramCounter);
+        var exception = new IllegalStateException("ahh");
+        var expectedError = exception.getClass().getSimpleName();
+
+        stats.inferenceDuration().record(expectedLong, responseAttributes(model("service", TaskType.ANY, "modelId"), exception));
+
+        verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> {
+            assertThat(attributes.get("service"), is("service"));
+            assertThat(attributes.get("task_type"), is(TaskType.ANY.toString()));
+            assertThat(attributes.get("model_id"), is("modelId"));
+            assertThat(attributes.get("status_code"), nullValue());
+            assertThat(attributes.get("error.type"), is(expectedError));
+        }));
+    }
+
+    public void testRecordDurationWithUnparsedModelAndElasticsearchStatusException() {
+        var expectedLong = randomLong();
+        var histogramCounter = mock(LongHistogram.class);
+        var stats = new InferenceStats(mock(), histogramCounter);
+        var statusCode = RestStatus.BAD_REQUEST;
+        var exception = new ElasticsearchStatusException("hello", statusCode);
+        var expectedError = String.valueOf(statusCode.getStatus());
+
+        var unparsedModel = new UnparsedModel("inferenceEntityId", TaskType.ANY, "service", Map.of(), Map.of());
+
+        stats.inferenceDuration().record(expectedLong, responseAttributes(unparsedModel, exception));
+
+        verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> {
+            assertThat(attributes.get("service"), is("service"));
+            assertThat(attributes.get("task_type"), is(TaskType.ANY.toString()));
+            assertThat(attributes.get("model_id"), nullValue());
+            assertThat(attributes.get("status_code"), is(statusCode.getStatus()));
+            assertThat(attributes.get("error.type"), is(expectedError));
+        }));
+    }
+
+    public void testRecordDurationWithUnparsedModelAndOtherException() {
+        var expectedLong = randomLong();
+        var histogramCounter = mock(LongHistogram.class);
+        var stats = new InferenceStats(mock(), histogramCounter);
+        var exception = new IllegalStateException("ahh");
+        var expectedError = exception.getClass().getSimpleName();
+
+        var unparsedModel = new UnparsedModel("inferenceEntityId", TaskType.ANY, "service", Map.of(), Map.of());
+
+        stats.inferenceDuration().record(expectedLong, responseAttributes(unparsedModel, exception));
+
+        verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> {
+            assertThat(attributes.get("service"), is("service"));
+            assertThat(attributes.get("task_type"), is(TaskType.ANY.toString()));
+            assertThat(attributes.get("model_id"), nullValue());
+            assertThat(attributes.get("status_code"), nullValue());
+            assertThat(attributes.get("error.type"), is(expectedError));
+        }));
+    }
+
+    public void testRecordDurationWithUnknownModelAndElasticsearchStatusException() {
+        var expectedLong = randomLong();
+        var histogramCounter = mock(LongHistogram.class);
+        var stats = new InferenceStats(mock(), histogramCounter);
+        var statusCode = RestStatus.BAD_REQUEST;
+        var exception = new ElasticsearchStatusException("hello", statusCode);
+        var expectedError = String.valueOf(statusCode.getStatus());
+
+        stats.inferenceDuration().record(expectedLong, responseAttributes(exception));
+
+        verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> {
+            assertThat(attributes.get("service"), nullValue());
+            assertThat(attributes.get("task_type"), nullValue());
+            assertThat(attributes.get("model_id"), nullValue());
+            assertThat(attributes.get("status_code"), is(statusCode.getStatus()));
+            assertThat(attributes.get("error.type"), is(expectedError));
+        }));
+    }
+
+    public void testRecordDurationWithUnknownModelAndOtherException() {
+        var expectedLong = randomLong();
+        var histogramCounter = mock(LongHistogram.class);
+        var stats = new InferenceStats(mock(), histogramCounter);
+        var exception = new IllegalStateException("ahh");
+        var expectedError = exception.getClass().getSimpleName();
+
+        stats.inferenceDuration().record(expectedLong, responseAttributes(exception));
+
+        verify(histogramCounter).record(eq(expectedLong), assertArg(attributes -> {
+            assertThat(attributes.get("service"), nullValue());
+            assertThat(attributes.get("task_type"), nullValue());
+            assertThat(attributes.get("model_id"), nullValue());
+            assertThat(attributes.get("status_code"), nullValue());
+            assertThat(attributes.get("error.type"), is(expectedError));
+        }));
+    }
+
+    private Model model(String service, TaskType taskType, String modelId) {
+        var configuration = mock(ModelConfigurations.class);
+        when(configuration.getService()).thenReturn(service);
+        var settings = mock(ServiceSettings.class);
+        if (modelId != null) {
+            when(settings.modelId()).thenReturn(modelId);
+        }
+
+        var model = mock(Model.class);
+        when(model.getTaskType()).thenReturn(taskType);
+        when(model.getConfigurations()).thenReturn(configuration);
+        when(model.getServiceSettings()).thenReturn(settings);
+
+        return model;
+    }
+}

+ 32 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/InferenceTimerTests.java

@@ -0,0 +1,32 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.telemetry;
+
+import org.elasticsearch.test.ESTestCase;
+
+import java.time.Clock;
+import java.time.Instant;
+import java.time.temporal.ChronoUnit;
+
+import static org.hamcrest.Matchers.is;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class InferenceTimerTests extends ESTestCase {
+
+    public void testElapsedMillis() {
+        var expectedDuration = randomLongBetween(10, 300);
+
+        var startTime = Instant.now();
+        var clock = mock(Clock.class);
+        when(clock.instant()).thenReturn(startTime).thenReturn(startTime.plus(expectedDuration, ChronoUnit.MILLIS));
+        var timer = InferenceTimer.start(clock);
+
+        assertThat(expectedDuration, is(timer.elapsedMillis()));
+    }
+}