|
|
@@ -11,19 +11,37 @@ package org.elasticsearch.xpack.inference.services.elser;
|
|
|
|
|
|
import org.elasticsearch.action.ActionListener;
|
|
|
import org.elasticsearch.client.internal.Client;
|
|
|
+import org.elasticsearch.inference.ChunkedInferenceServiceResults;
|
|
|
+import org.elasticsearch.inference.ChunkingOptions;
|
|
|
+import org.elasticsearch.inference.InferenceResults;
|
|
|
import org.elasticsearch.inference.InferenceServiceExtension;
|
|
|
+import org.elasticsearch.inference.InputType;
|
|
|
import org.elasticsearch.inference.Model;
|
|
|
import org.elasticsearch.inference.ModelConfigurations;
|
|
|
import org.elasticsearch.inference.TaskType;
|
|
|
import org.elasticsearch.test.ESTestCase;
|
|
|
+import org.elasticsearch.threadpool.TestThreadPool;
|
|
|
+import org.elasticsearch.threadpool.ThreadPool;
|
|
|
+import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults;
|
|
|
+import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
|
|
|
+import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResultsTests;
|
|
|
|
|
|
+import java.util.ArrayList;
|
|
|
import java.util.Collections;
|
|
|
import java.util.HashMap;
|
|
|
+import java.util.List;
|
|
|
import java.util.Map;
|
|
|
import java.util.Set;
|
|
|
+import java.util.concurrent.atomic.AtomicBoolean;
|
|
|
|
|
|
import static org.hamcrest.Matchers.containsString;
|
|
|
+import static org.hamcrest.Matchers.hasSize;
|
|
|
+import static org.hamcrest.Matchers.instanceOf;
|
|
|
+import static org.mockito.ArgumentMatchers.any;
|
|
|
+import static org.mockito.ArgumentMatchers.same;
|
|
|
+import static org.mockito.Mockito.doAnswer;
|
|
|
import static org.mockito.Mockito.mock;
|
|
|
+import static org.mockito.Mockito.when;
|
|
|
|
|
|
public class ElserInternalServiceTests extends ESTestCase {
|
|
|
|
|
|
@@ -307,6 +325,69 @@ public class ElserInternalServiceTests extends ESTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
+ public void testChunkInfer() {
|
|
|
+ var mlTrainedModelResults = new ArrayList<InferenceResults>();
|
|
|
+ mlTrainedModelResults.add(ChunkedTextExpansionResultsTests.createRandomResults());
|
|
|
+ mlTrainedModelResults.add(ChunkedTextExpansionResultsTests.createRandomResults());
|
|
|
+ var response = new InferTrainedModelDeploymentAction.Response(mlTrainedModelResults);
|
|
|
+
|
|
|
+ ThreadPool threadpool = new TestThreadPool("test");
|
|
|
+ Client client = mock(Client.class);
|
|
|
+ when(client.threadPool()).thenReturn(threadpool);
|
|
|
+ doAnswer(invocationOnMock -> {
|
|
|
+ var listener = (ActionListener<InferTrainedModelDeploymentAction.Response>) invocationOnMock.getArguments()[2];
|
|
|
+ listener.onResponse(response);
|
|
|
+ return null;
|
|
|
+ }).when(client)
|
|
|
+ .execute(
|
|
|
+ same(InferTrainedModelDeploymentAction.INSTANCE),
|
|
|
+ any(InferTrainedModelDeploymentAction.Request.class),
|
|
|
+ any(ActionListener.class)
|
|
|
+ );
|
|
|
+
|
|
|
+ var model = new ElserInternalModel(
|
|
|
+ "foo",
|
|
|
+ TaskType.SPARSE_EMBEDDING,
|
|
|
+ "elser",
|
|
|
+ new ElserInternalServiceSettings(1, 1, "elser"),
|
|
|
+ new ElserMlNodeTaskSettings()
|
|
|
+ );
|
|
|
+ var service = createService(client);
|
|
|
+
|
|
|
+ var gotResults = new AtomicBoolean();
|
|
|
+ var resultsListener = ActionListener.<List<ChunkedInferenceServiceResults>>wrap(chunkedResponse -> {
|
|
|
+ assertThat(chunkedResponse, hasSize(2));
|
|
|
+ assertThat(chunkedResponse.get(0), instanceOf(ChunkedSparseEmbeddingResults.class));
|
|
|
+ var result1 = (ChunkedSparseEmbeddingResults) chunkedResponse.get(0);
|
|
|
+ assertEquals(
|
|
|
+ ((org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResults) mlTrainedModelResults.get(0)).getChunks(),
|
|
|
+ result1.getChunkedResults()
|
|
|
+ );
|
|
|
+ assertThat(chunkedResponse.get(1), instanceOf(ChunkedSparseEmbeddingResults.class));
|
|
|
+ var result2 = (ChunkedSparseEmbeddingResults) chunkedResponse.get(1);
|
|
|
+ assertEquals(
|
|
|
+ ((org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResults) mlTrainedModelResults.get(1)).getChunks(),
|
|
|
+ result2.getChunkedResults()
|
|
|
+ );
|
|
|
+ gotResults.set(true);
|
|
|
+ }, ESTestCase::fail);
|
|
|
+
|
|
|
+ service.chunkedInfer(
|
|
|
+ model,
|
|
|
+ List.of("foo", "bar"),
|
|
|
+ Map.of(),
|
|
|
+ InputType.SEARCH,
|
|
|
+ new ChunkingOptions(null, null),
|
|
|
+ ActionListener.runAfter(resultsListener, () -> terminate(threadpool))
|
|
|
+ );
|
|
|
+
|
|
|
+ if (gotResults.get() == false) {
|
|
|
+ terminate(threadpool);
|
|
|
+ }
|
|
|
+ assertTrue("Listener not called", gotResults.get());
|
|
|
+ }
|
|
|
+
|
|
|
private ElserInternalService createService(Client client) {
|
|
|
var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client);
|
|
|
return new ElserInternalService(context);
|