|
@@ -0,0 +1,121 @@
|
|
|
+/*
|
|
|
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
|
|
+ * or more contributor license agreements. Licensed under the Elastic License
|
|
|
+ * 2.0; you may not use this file except in compliance with the Elastic License
|
|
|
+ * 2.0.
|
|
|
+ */
|
|
|
+
|
|
|
+package org.elasticsearch.xpack.inference.action;
|
|
|
+
|
|
|
+import org.elasticsearch.action.ActionListener;
|
|
|
+import org.elasticsearch.action.support.ActionFilters;
|
|
|
+import org.elasticsearch.action.support.PlainActionFuture;
|
|
|
+import org.elasticsearch.client.internal.Client;
|
|
|
+import org.elasticsearch.cluster.ClusterState;
|
|
|
+import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
|
|
|
+import org.elasticsearch.cluster.service.ClusterService;
|
|
|
+import org.elasticsearch.common.io.stream.BytesStreamOutput;
|
|
|
+import org.elasticsearch.inference.ModelConfigurations;
|
|
|
+import org.elasticsearch.inference.ServiceSettings;
|
|
|
+import org.elasticsearch.inference.TaskType;
|
|
|
+import org.elasticsearch.protocol.xpack.XPackUsageRequest;
|
|
|
+import org.elasticsearch.tasks.Task;
|
|
|
+import org.elasticsearch.test.ESTestCase;
|
|
|
+import org.elasticsearch.test.MockUtils;
|
|
|
+import org.elasticsearch.threadpool.TestThreadPool;
|
|
|
+import org.elasticsearch.threadpool.ThreadPool;
|
|
|
+import org.elasticsearch.transport.TransportService;
|
|
|
+import org.elasticsearch.xcontent.ToXContent;
|
|
|
+import org.elasticsearch.xcontent.XContentBuilder;
|
|
|
+import org.elasticsearch.xcontent.XContentFactory;
|
|
|
+import org.elasticsearch.xpack.core.XPackFeatureSet;
|
|
|
+import org.elasticsearch.xpack.core.XPackField;
|
|
|
+import org.elasticsearch.xpack.core.action.XPackUsageFeatureResponse;
|
|
|
+import org.elasticsearch.xpack.core.inference.InferenceFeatureSetUsage;
|
|
|
+import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
|
|
|
+import org.elasticsearch.xpack.core.watcher.support.xcontent.XContentSource;
|
|
|
+import org.junit.After;
|
|
|
+import org.junit.Before;
|
|
|
+
|
|
|
+import java.util.List;
|
|
|
+
|
|
|
+import static org.hamcrest.Matchers.hasSize;
|
|
|
+import static org.hamcrest.core.Is.is;
|
|
|
+import static org.mockito.ArgumentMatchers.any;
|
|
|
+import static org.mockito.Mockito.doAnswer;
|
|
|
+import static org.mockito.Mockito.mock;
|
|
|
+import static org.mockito.Mockito.when;
|
|
|
+
|
|
|
+public class TransportInferenceUsageActionTests extends ESTestCase {
|
|
|
+
|
|
|
+ private Client client;
|
|
|
+ private TransportInferenceUsageAction action;
|
|
|
+
|
|
|
+ @Before
|
|
|
+ public void init() {
|
|
|
+ client = mock(Client.class);
|
|
|
+ ThreadPool threadPool = new TestThreadPool("test");
|
|
|
+ when(client.threadPool()).thenReturn(threadPool);
|
|
|
+
|
|
|
+ TransportService transportService = MockUtils.setupTransportServiceWithThreadpoolExecutor(mock(ThreadPool.class));
|
|
|
+
|
|
|
+ action = new TransportInferenceUsageAction(
|
|
|
+ transportService,
|
|
|
+ mock(ClusterService.class),
|
|
|
+ mock(ThreadPool.class),
|
|
|
+ mock(ActionFilters.class),
|
|
|
+ mock(IndexNameExpressionResolver.class),
|
|
|
+ client
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
+ @After
|
|
|
+ public void close() {
|
|
|
+ client.threadPool().shutdown();
|
|
|
+ }
|
|
|
+
|
|
|
+ public void test() throws Exception {
|
|
|
+ doAnswer(invocation -> {
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
+ var listener = (ActionListener<GetInferenceModelAction.Response>) invocation.getArguments()[2];
|
|
|
+ listener.onResponse(
|
|
|
+ new GetInferenceModelAction.Response(
|
|
|
+ List.of(
|
|
|
+ new ModelConfigurations("model-001", TaskType.TEXT_EMBEDDING, "openai", mock(ServiceSettings.class)),
|
|
|
+ new ModelConfigurations("model-002", TaskType.TEXT_EMBEDDING, "openai", mock(ServiceSettings.class)),
|
|
|
+ new ModelConfigurations("model-003", TaskType.SPARSE_EMBEDDING, "hugging_face_elser", mock(ServiceSettings.class)),
|
|
|
+ new ModelConfigurations("model-004", TaskType.TEXT_EMBEDDING, "openai", mock(ServiceSettings.class)),
|
|
|
+ new ModelConfigurations("model-005", TaskType.SPARSE_EMBEDDING, "openai", mock(ServiceSettings.class)),
|
|
|
+ new ModelConfigurations("model-006", TaskType.SPARSE_EMBEDDING, "hugging_face_elser", mock(ServiceSettings.class))
|
|
|
+ )
|
|
|
+ )
|
|
|
+ );
|
|
|
+ return Void.TYPE;
|
|
|
+ }).when(client).execute(any(GetInferenceModelAction.class), any(), any());
|
|
|
+
|
|
|
+ PlainActionFuture<XPackUsageFeatureResponse> future = new PlainActionFuture<>();
|
|
|
+ action.masterOperation(mock(Task.class), mock(XPackUsageRequest.class), mock(ClusterState.class), future);
|
|
|
+
|
|
|
+ BytesStreamOutput out = new BytesStreamOutput();
|
|
|
+ future.get().getUsage().writeTo(out);
|
|
|
+ XPackFeatureSet.Usage usage = new InferenceFeatureSetUsage(out.bytes().streamInput());
|
|
|
+
|
|
|
+ assertThat(usage.name(), is(XPackField.INFERENCE));
|
|
|
+ assertTrue(usage.enabled());
|
|
|
+ assertTrue(usage.available());
|
|
|
+
|
|
|
+ XContentBuilder builder = XContentFactory.jsonBuilder();
|
|
|
+ usage.toXContent(builder, ToXContent.EMPTY_PARAMS);
|
|
|
+ XContentSource source = new XContentSource(builder);
|
|
|
+ assertThat(source.getValue("models"), hasSize(3));
|
|
|
+ assertThat(source.getValue("models.0.service"), is("hugging_face_elser"));
|
|
|
+ assertThat(source.getValue("models.0.task_type"), is("SPARSE_EMBEDDING"));
|
|
|
+ assertThat(source.getValue("models.0.count"), is(2));
|
|
|
+ assertThat(source.getValue("models.1.service"), is("openai"));
|
|
|
+ assertThat(source.getValue("models.1.task_type"), is("SPARSE_EMBEDDING"));
|
|
|
+ assertThat(source.getValue("models.1.count"), is(1));
|
|
|
+ assertThat(source.getValue("models.2.service"), is("openai"));
|
|
|
+ assertThat(source.getValue("models.2.task_type"), is("TEXT_EMBEDDING"));
|
|
|
+ assertThat(source.getValue("models.2.count"), is(3));
|
|
|
+ }
|
|
|
+}
|