|
@@ -12,6 +12,7 @@ package org.elasticsearch.xpack.inference.services.elasticsearch;
|
|
|
import org.apache.logging.log4j.Level;
|
|
|
import org.elasticsearch.ElasticsearchStatusException;
|
|
|
import org.elasticsearch.action.ActionListener;
|
|
|
+import org.elasticsearch.action.LatchedActionListener;
|
|
|
import org.elasticsearch.action.support.PlainActionFuture;
|
|
|
import org.elasticsearch.client.internal.Client;
|
|
|
import org.elasticsearch.cluster.service.ClusterService;
|
|
@@ -65,6 +66,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfig
|
|
|
import org.elasticsearch.xpack.inference.InferencePlugin;
|
|
|
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
|
|
|
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
|
|
|
+import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings;
|
|
|
import org.elasticsearch.xpack.inference.services.ServiceFields;
|
|
|
import org.junit.After;
|
|
|
import org.junit.Before;
|
|
@@ -72,12 +74,14 @@ import org.mockito.ArgumentCaptor;
|
|
|
import org.mockito.Mockito;
|
|
|
|
|
|
import java.util.ArrayList;
|
|
|
+import java.util.Arrays;
|
|
|
import java.util.EnumSet;
|
|
|
import java.util.HashMap;
|
|
|
import java.util.List;
|
|
|
import java.util.Map;
|
|
|
import java.util.Optional;
|
|
|
import java.util.Set;
|
|
|
+import java.util.concurrent.CountDownLatch;
|
|
|
import java.util.concurrent.atomic.AtomicBoolean;
|
|
|
import java.util.concurrent.atomic.AtomicInteger;
|
|
|
import java.util.concurrent.atomic.AtomicReference;
|
|
@@ -832,16 +836,16 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- public void testChunkInfer_E5WithNullChunkingSettings() {
|
|
|
+ public void testChunkInfer_E5WithNullChunkingSettings() throws InterruptedException {
|
|
|
testChunkInfer_e5(null);
|
|
|
}
|
|
|
|
|
|
- public void testChunkInfer_E5ChunkingSettingsSet() {
|
|
|
+ public void testChunkInfer_E5ChunkingSettingsSet() throws InterruptedException {
|
|
|
testChunkInfer_e5(ChunkingSettingsTests.createRandomChunkingSettings());
|
|
|
}
|
|
|
|
|
|
@SuppressWarnings("unchecked")
|
|
|
- private void testChunkInfer_e5(ChunkingSettings chunkingSettings) {
|
|
|
+ private void testChunkInfer_e5(ChunkingSettings chunkingSettings) throws InterruptedException {
|
|
|
var mlTrainedModelResults = new ArrayList<InferenceResults>();
|
|
|
mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
|
|
|
mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
|
|
@@ -889,6 +893,9 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
|
|
|
gotResults.set(true);
|
|
|
}, ESTestCase::fail);
|
|
|
|
|
|
+ var latch = new CountDownLatch(1);
|
|
|
+ var latchedListener = new LatchedActionListener<>(resultsListener, latch);
|
|
|
+
|
|
|
service.chunkedInfer(
|
|
|
model,
|
|
|
null,
|
|
@@ -897,22 +904,23 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
|
|
|
InputType.SEARCH,
|
|
|
new ChunkingOptions(null, null),
|
|
|
InferenceAction.Request.DEFAULT_TIMEOUT,
|
|
|
- ActionListener.runAfter(resultsListener, () -> terminate(threadPool))
|
|
|
+ latchedListener
|
|
|
);
|
|
|
|
|
|
+ latch.await();
|
|
|
assertTrue("Listener not called", gotResults.get());
|
|
|
}
|
|
|
|
|
|
- public void testChunkInfer_SparseWithNullChunkingSettings() {
|
|
|
+ public void testChunkInfer_SparseWithNullChunkingSettings() throws InterruptedException {
|
|
|
testChunkInfer_Sparse(null);
|
|
|
}
|
|
|
|
|
|
- public void testChunkInfer_SparseWithChunkingSettingsSet() {
|
|
|
+ public void testChunkInfer_SparseWithChunkingSettingsSet() throws InterruptedException {
|
|
|
testChunkInfer_Sparse(ChunkingSettingsTests.createRandomChunkingSettings());
|
|
|
}
|
|
|
|
|
|
@SuppressWarnings("unchecked")
|
|
|
- private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) {
|
|
|
+ private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) throws InterruptedException {
|
|
|
var mlTrainedModelResults = new ArrayList<InferenceResults>();
|
|
|
mlTrainedModelResults.add(TextExpansionResultsTests.createRandomResults());
|
|
|
mlTrainedModelResults.add(TextExpansionResultsTests.createRandomResults());
|
|
@@ -936,6 +944,7 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
|
|
|
var service = createService(client);
|
|
|
|
|
|
var gotResults = new AtomicBoolean();
|
|
|
+
|
|
|
var resultsListener = ActionListener.<List<ChunkedInferenceServiceResults>>wrap(chunkedResponse -> {
|
|
|
assertThat(chunkedResponse, hasSize(2));
|
|
|
assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedSparseEmbeddingResults.class));
|
|
@@ -955,6 +964,9 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
|
|
|
gotResults.set(true);
|
|
|
}, ESTestCase::fail);
|
|
|
|
|
|
+ var latch = new CountDownLatch(1);
|
|
|
+ var latchedListener = new LatchedActionListener<>(resultsListener, latch);
|
|
|
+
|
|
|
service.chunkedInfer(
|
|
|
model,
|
|
|
null,
|
|
@@ -963,22 +975,23 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
|
|
|
InputType.SEARCH,
|
|
|
new ChunkingOptions(null, null),
|
|
|
InferenceAction.Request.DEFAULT_TIMEOUT,
|
|
|
- ActionListener.runAfter(resultsListener, () -> terminate(threadPool))
|
|
|
+ latchedListener
|
|
|
);
|
|
|
|
|
|
+ latch.await();
|
|
|
assertTrue("Listener not called", gotResults.get());
|
|
|
}
|
|
|
|
|
|
- public void testChunkInfer_ElserWithNullChunkingSettings() {
|
|
|
+ public void testChunkInfer_ElserWithNullChunkingSettings() throws InterruptedException {
|
|
|
testChunkInfer_Elser(null);
|
|
|
}
|
|
|
|
|
|
- public void testChunkInfer_ElserWithChunkingSettingsSet() {
|
|
|
+ public void testChunkInfer_ElserWithChunkingSettingsSet() throws InterruptedException {
|
|
|
testChunkInfer_Elser(ChunkingSettingsTests.createRandomChunkingSettings());
|
|
|
}
|
|
|
|
|
|
@SuppressWarnings("unchecked")
|
|
|
- private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) {
|
|
|
+ private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) throws InterruptedException {
|
|
|
var mlTrainedModelResults = new ArrayList<InferenceResults>();
|
|
|
mlTrainedModelResults.add(TextExpansionResultsTests.createRandomResults());
|
|
|
mlTrainedModelResults.add(TextExpansionResultsTests.createRandomResults());
|
|
@@ -1022,6 +1035,9 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
|
|
|
gotResults.set(true);
|
|
|
}, ESTestCase::fail);
|
|
|
|
|
|
+ var latch = new CountDownLatch(1);
|
|
|
+ var latchedListener = new LatchedActionListener<>(resultsListener, latch);
|
|
|
+
|
|
|
service.chunkedInfer(
|
|
|
model,
|
|
|
null,
|
|
@@ -1030,9 +1046,10 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
|
|
|
InputType.SEARCH,
|
|
|
new ChunkingOptions(null, null),
|
|
|
InferenceAction.Request.DEFAULT_TIMEOUT,
|
|
|
- ActionListener.runAfter(resultsListener, () -> terminate(threadPool))
|
|
|
+ latchedListener
|
|
|
);
|
|
|
|
|
|
+ latch.await();
|
|
|
assertTrue("Listener not called", gotResults.get());
|
|
|
}
|
|
|
|
|
@@ -1093,7 +1110,7 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
|
|
|
}
|
|
|
|
|
|
@SuppressWarnings("unchecked")
|
|
|
- public void testChunkInfer_FailsBatch() {
|
|
|
+ public void testChunkInfer_FailsBatch() throws InterruptedException {
|
|
|
var mlTrainedModelResults = new ArrayList<InferenceResults>();
|
|
|
mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
|
|
|
mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
|
|
@@ -1129,6 +1146,9 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
|
|
|
gotResults.set(true);
|
|
|
}, ESTestCase::fail);
|
|
|
|
|
|
+ var latch = new CountDownLatch(1);
|
|
|
+ var latchedListener = new LatchedActionListener<>(resultsListener, latch);
|
|
|
+
|
|
|
service.chunkedInfer(
|
|
|
model,
|
|
|
null,
|
|
@@ -1137,12 +1157,86 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
|
|
|
InputType.SEARCH,
|
|
|
new ChunkingOptions(null, null),
|
|
|
InferenceAction.Request.DEFAULT_TIMEOUT,
|
|
|
- ActionListener.runAfter(resultsListener, () -> terminate(threadPool))
|
|
|
+ latchedListener
|
|
|
);
|
|
|
|
|
|
+ latch.await();
|
|
|
assertTrue("Listener not called", gotResults.get());
|
|
|
}
|
|
|
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
+ public void testChunkingLargeDocument() throws InterruptedException {
|
|
|
+ int numBatches = randomIntBetween(3, 6);
|
|
|
+
|
|
|
+ // how many response objects to return in each batch
|
|
|
+ int[] numResponsesPerBatch = new int[numBatches];
|
|
|
+ for (int i = 0; i < numBatches - 1; i++) {
|
|
|
+ numResponsesPerBatch[i] = ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE;
|
|
|
+ }
|
|
|
+ numResponsesPerBatch[numBatches - 1] = randomIntBetween(1, ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE);
|
|
|
+ int numChunks = Arrays.stream(numResponsesPerBatch).sum();
|
|
|
+
|
|
|
+ // build a doc with enough words to make numChunks of chunks
|
|
|
+ int wordsPerChunk = 10;
|
|
|
+ int numWords = numChunks * wordsPerChunk;
|
|
|
+ var input = "word ".repeat(numWords);
|
|
|
+
|
|
|
+ Client client = mock(Client.class);
|
|
|
+ when(client.threadPool()).thenReturn(threadPool);
|
|
|
+
|
|
|
+ // mock the inference response
|
|
|
+ doAnswer(invocationOnMock -> {
|
|
|
+ var request = (InferModelAction.Request) invocationOnMock.getArguments()[1];
|
|
|
+ var listener = (ActionListener<InferModelAction.Response>) invocationOnMock.getArguments()[2];
|
|
|
+ var mlTrainedModelResults = new ArrayList<InferenceResults>();
|
|
|
+ for (int i = 0; i < request.numberOfDocuments(); i++) {
|
|
|
+ mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
|
|
|
+ }
|
|
|
+ var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true);
|
|
|
+ listener.onResponse(response);
|
|
|
+ return null;
|
|
|
+ }).when(client).execute(same(InferModelAction.INSTANCE), any(InferModelAction.Request.class), any(ActionListener.class));
|
|
|
+
|
|
|
+ var service = createService(client);
|
|
|
+
|
|
|
+ var gotResults = new AtomicBoolean();
|
|
|
+ var resultsListener = ActionListener.<List<ChunkedInferenceServiceResults>>wrap(chunkedResponse -> {
|
|
|
+ assertThat(chunkedResponse, hasSize(1));
|
|
|
+ assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class));
|
|
|
+ var sparseResults = (InferenceChunkedTextEmbeddingFloatResults) chunkedResponse.get(0);
|
|
|
+ assertThat(sparseResults.chunks(), hasSize(numChunks));
|
|
|
+
|
|
|
+ gotResults.set(true);
|
|
|
+ }, ESTestCase::fail);
|
|
|
+
|
|
|
+ // Create model using the word boundary chunker.
|
|
|
+ var model = new MultilingualE5SmallModel(
|
|
|
+ "foo",
|
|
|
+ TaskType.TEXT_EMBEDDING,
|
|
|
+ "e5",
|
|
|
+ new MultilingualE5SmallInternalServiceSettings(1, 1, "cross-platform", null),
|
|
|
+ new WordBoundaryChunkingSettings(wordsPerChunk, 0)
|
|
|
+ );
|
|
|
+
|
|
|
+ var latch = new CountDownLatch(1);
|
|
|
+ var latchedListener = new LatchedActionListener<>(resultsListener, latch);
|
|
|
+
|
|
|
+ // For the given input we know how many requests will be made
|
|
|
+ service.chunkedInfer(
|
|
|
+ model,
|
|
|
+ null,
|
|
|
+ List.of(input),
|
|
|
+ Map.of(),
|
|
|
+ InputType.SEARCH,
|
|
|
+ new ChunkingOptions(null, null),
|
|
|
+ InferenceAction.Request.DEFAULT_TIMEOUT,
|
|
|
+ latchedListener
|
|
|
+ );
|
|
|
+
|
|
|
+ latch.await();
|
|
|
+ assertTrue("Listener not called with results", gotResults.get());
|
|
|
+ }
|
|
|
+
|
|
|
public void testParsePersistedConfig_Rerank() {
|
|
|
// with task settings
|
|
|
{
|