|
@@ -0,0 +1,125 @@
|
|
|
+/*
|
|
|
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
|
|
+ * or more contributor license agreements. Licensed under the Elastic License
|
|
|
+ * 2.0; you may not use this file except in compliance with the Elastic License
|
|
|
+ * 2.0.
|
|
|
+ */
|
|
|
+
|
|
|
+package org.elasticsearch.xpack.inference.services.huggingface;
|
|
|
+
|
|
|
+import org.apache.http.HttpHeaders;
|
|
|
+import org.elasticsearch.action.support.PlainActionFuture;
|
|
|
+import org.elasticsearch.common.settings.Settings;
|
|
|
+import org.elasticsearch.core.TimeValue;
|
|
|
+import org.elasticsearch.inference.ChunkedInferenceServiceResults;
|
|
|
+import org.elasticsearch.inference.ChunkingOptions;
|
|
|
+import org.elasticsearch.inference.InputType;
|
|
|
+import org.elasticsearch.test.ESTestCase;
|
|
|
+import org.elasticsearch.test.http.MockResponse;
|
|
|
+import org.elasticsearch.test.http.MockWebServer;
|
|
|
+import org.elasticsearch.threadpool.ThreadPool;
|
|
|
+import org.elasticsearch.xcontent.XContentType;
|
|
|
+import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
|
|
+import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults;
|
|
|
+import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
|
|
+import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
|
|
+import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
|
|
+import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserModelTests;
|
|
|
+import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserService;
|
|
|
+import org.hamcrest.MatcherAssert;
|
|
|
+import org.hamcrest.Matchers;
|
|
|
+import org.junit.After;
|
|
|
+import org.junit.Before;
|
|
|
+
|
|
|
+import java.io.IOException;
|
|
|
+import java.util.HashMap;
|
|
|
+import java.util.List;
|
|
|
+import java.util.Map;
|
|
|
+import java.util.concurrent.TimeUnit;
|
|
|
+
|
|
|
+import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
|
|
+import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
|
|
+import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
|
|
+import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
|
|
+import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
|
|
+import static org.hamcrest.Matchers.equalTo;
|
|
|
+import static org.hamcrest.Matchers.hasSize;
|
|
|
+import static org.mockito.Mockito.mock;
|
|
|
+
|
|
|
+public class HuggingFaceElserServiceTests extends ESTestCase {
|
|
|
+
|
|
|
+ private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
|
|
|
+
|
|
|
+ private final MockWebServer webServer = new MockWebServer();
|
|
|
+ private ThreadPool threadPool;
|
|
|
+ private HttpClientManager clientManager;
|
|
|
+
|
|
|
+ @Before
|
|
|
+ public void init() throws Exception {
|
|
|
+ webServer.start();
|
|
|
+ threadPool = createThreadPool(inferenceUtilityPool());
|
|
|
+ clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
|
|
|
+ }
|
|
|
+
|
|
|
+ @After
|
|
|
+ public void shutdown() throws IOException {
|
|
|
+ clientManager.close();
|
|
|
+ terminate(threadPool);
|
|
|
+ webServer.close();
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testChunkedInfer_CallsInfer_Elser_ConvertsFloatResponse() throws IOException {
|
|
|
+ var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
|
|
+
|
|
|
+ try (var service = new HuggingFaceElserService(senderFactory, createWithEmptySettings(threadPool))) {
|
|
|
+
|
|
|
+ String responseJson = """
|
|
|
+ [
|
|
|
+ {
|
|
|
+ ".": 0.133155956864357
|
|
|
+ }
|
|
|
+ ]
|
|
|
+ """;
|
|
|
+ webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
|
|
+
|
|
|
+ var model = HuggingFaceElserModelTests.createModel(getUrl(webServer), "secret");
|
|
|
+ PlainActionFuture<List<ChunkedInferenceServiceResults>> listener = new PlainActionFuture<>();
|
|
|
+ service.chunkedInfer(
|
|
|
+ model,
|
|
|
+ List.of("abc"),
|
|
|
+ new HashMap<>(),
|
|
|
+ InputType.INGEST,
|
|
|
+ new ChunkingOptions(null, null),
|
|
|
+ InferenceAction.Request.DEFAULT_TIMEOUT,
|
|
|
+ listener
|
|
|
+ );
|
|
|
+
|
|
|
+ var result = listener.actionGet(TIMEOUT).get(0);
|
|
|
+
|
|
|
+ MatcherAssert.assertThat(
|
|
|
+ result.asMap(),
|
|
|
+ Matchers.is(
|
|
|
+ Map.of(
|
|
|
+ InferenceChunkedSparseEmbeddingResults.FIELD_NAME,
|
|
|
+ List.of(
|
|
|
+ Map.of(ChunkedNlpInferenceResults.TEXT, "abc", ChunkedNlpInferenceResults.INFERENCE, Map.of(".", 0.13315596f))
|
|
|
+ )
|
|
|
+ )
|
|
|
+ )
|
|
|
+ );
|
|
|
+
|
|
|
+ assertThat(webServer.requests(), hasSize(1));
|
|
|
+ assertNull(webServer.requests().get(0).getUri().getQuery());
|
|
|
+ assertThat(
|
|
|
+ webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE),
|
|
|
+ equalTo(XContentType.JSON.mediaTypeWithoutParameters())
|
|
|
+ );
|
|
|
+ assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
|
|
|
+
|
|
|
+ var requestMap = entityAsMap(webServer.requests().get(0).getBody());
|
|
|
+ assertThat(requestMap.size(), Matchers.is(1));
|
|
|
+ assertThat(requestMap.get("inputs"), Matchers.is(List.of("abc")));
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|