|  | @@ -12,6 +12,7 @@ package org.elasticsearch.xpack.inference.services.elasticsearch;
 | 
	
		
			
				|  |  |  import org.apache.logging.log4j.Level;
 | 
	
		
			
				|  |  |  import org.elasticsearch.ElasticsearchStatusException;
 | 
	
		
			
				|  |  |  import org.elasticsearch.action.ActionListener;
 | 
	
		
			
				|  |  | +import org.elasticsearch.action.LatchedActionListener;
 | 
	
		
			
				|  |  |  import org.elasticsearch.action.support.PlainActionFuture;
 | 
	
		
			
				|  |  |  import org.elasticsearch.client.internal.Client;
 | 
	
		
			
				|  |  |  import org.elasticsearch.cluster.service.ClusterService;
 | 
	
	
		
			
				|  | @@ -65,6 +66,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfig
 | 
	
		
			
				|  |  |  import org.elasticsearch.xpack.inference.InferencePlugin;
 | 
	
		
			
				|  |  |  import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
 | 
	
		
			
				|  |  |  import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
 | 
	
		
			
				|  |  | +import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings;
 | 
	
		
			
				|  |  |  import org.elasticsearch.xpack.inference.services.ServiceFields;
 | 
	
		
			
				|  |  |  import org.junit.After;
 | 
	
		
			
				|  |  |  import org.junit.Before;
 | 
	
	
		
			
				|  | @@ -72,12 +74,14 @@ import org.mockito.ArgumentCaptor;
 | 
	
		
			
				|  |  |  import org.mockito.Mockito;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import java.util.ArrayList;
 | 
	
		
			
				|  |  | +import java.util.Arrays;
 | 
	
		
			
				|  |  |  import java.util.EnumSet;
 | 
	
		
			
				|  |  |  import java.util.HashMap;
 | 
	
		
			
				|  |  |  import java.util.List;
 | 
	
		
			
				|  |  |  import java.util.Map;
 | 
	
		
			
				|  |  |  import java.util.Optional;
 | 
	
		
			
				|  |  |  import java.util.Set;
 | 
	
		
			
				|  |  | +import java.util.concurrent.CountDownLatch;
 | 
	
		
			
				|  |  |  import java.util.concurrent.atomic.AtomicBoolean;
 | 
	
		
			
				|  |  |  import java.util.concurrent.atomic.AtomicInteger;
 | 
	
		
			
				|  |  |  import java.util.concurrent.atomic.AtomicReference;
 | 
	
	
		
			
				|  | @@ -832,16 +836,16 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    public void testChunkInfer_E5WithNullChunkingSettings() {
 | 
	
		
			
				|  |  | +    public void testChunkInfer_E5WithNullChunkingSettings() throws InterruptedException {
 | 
	
		
			
				|  |  |          testChunkInfer_e5(null);
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    public void testChunkInfer_E5ChunkingSettingsSet() {
 | 
	
		
			
				|  |  | +    public void testChunkInfer_E5ChunkingSettingsSet() throws InterruptedException {
 | 
	
		
			
				|  |  |          testChunkInfer_e5(ChunkingSettingsTests.createRandomChunkingSettings());
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @SuppressWarnings("unchecked")
 | 
	
		
			
				|  |  | -    private void testChunkInfer_e5(ChunkingSettings chunkingSettings) {
 | 
	
		
			
				|  |  | +    private void testChunkInfer_e5(ChunkingSettings chunkingSettings) throws InterruptedException {
 | 
	
		
			
				|  |  |          var mlTrainedModelResults = new ArrayList<InferenceResults>();
 | 
	
		
			
				|  |  |          mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
 | 
	
		
			
				|  |  |          mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
 | 
	
	
		
			
				|  | @@ -889,6 +893,9 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
 | 
	
		
			
				|  |  |              gotResults.set(true);
 | 
	
		
			
				|  |  |          }, ESTestCase::fail);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +        var latch = new CountDownLatch(1);
 | 
	
		
			
				|  |  | +        var latchedListener = new LatchedActionListener<>(resultsListener, latch);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          service.chunkedInfer(
 | 
	
		
			
				|  |  |              model,
 | 
	
		
			
				|  |  |              null,
 | 
	
	
		
			
				|  | @@ -897,22 +904,23 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
 | 
	
		
			
				|  |  |              InputType.SEARCH,
 | 
	
		
			
				|  |  |              new ChunkingOptions(null, null),
 | 
	
		
			
				|  |  |              InferenceAction.Request.DEFAULT_TIMEOUT,
 | 
	
		
			
				|  |  | -            ActionListener.runAfter(resultsListener, () -> terminate(threadPool))
 | 
	
		
			
				|  |  | +            latchedListener
 | 
	
		
			
				|  |  |          );
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +        latch.await();
 | 
	
		
			
				|  |  |          assertTrue("Listener not called", gotResults.get());
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    public void testChunkInfer_SparseWithNullChunkingSettings() {
 | 
	
		
			
				|  |  | +    public void testChunkInfer_SparseWithNullChunkingSettings() throws InterruptedException {
 | 
	
		
			
				|  |  |          testChunkInfer_Sparse(null);
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    public void testChunkInfer_SparseWithChunkingSettingsSet() {
 | 
	
		
			
				|  |  | +    public void testChunkInfer_SparseWithChunkingSettingsSet() throws InterruptedException {
 | 
	
		
			
				|  |  |          testChunkInfer_Sparse(ChunkingSettingsTests.createRandomChunkingSettings());
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @SuppressWarnings("unchecked")
 | 
	
		
			
				|  |  | -    private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) {
 | 
	
		
			
				|  |  | +    private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) throws InterruptedException {
 | 
	
		
			
				|  |  |          var mlTrainedModelResults = new ArrayList<InferenceResults>();
 | 
	
		
			
				|  |  |          mlTrainedModelResults.add(TextExpansionResultsTests.createRandomResults());
 | 
	
		
			
				|  |  |          mlTrainedModelResults.add(TextExpansionResultsTests.createRandomResults());
 | 
	
	
		
			
				|  | @@ -936,6 +944,7 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
 | 
	
		
			
				|  |  |          var service = createService(client);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |          var gotResults = new AtomicBoolean();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          var resultsListener = ActionListener.<List<ChunkedInferenceServiceResults>>wrap(chunkedResponse -> {
 | 
	
		
			
				|  |  |              assertThat(chunkedResponse, hasSize(2));
 | 
	
		
			
				|  |  |              assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedSparseEmbeddingResults.class));
 | 
	
	
		
			
				|  | @@ -955,6 +964,9 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
 | 
	
		
			
				|  |  |              gotResults.set(true);
 | 
	
		
			
				|  |  |          }, ESTestCase::fail);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +        var latch = new CountDownLatch(1);
 | 
	
		
			
				|  |  | +        var latchedListener = new LatchedActionListener<>(resultsListener, latch);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          service.chunkedInfer(
 | 
	
		
			
				|  |  |              model,
 | 
	
		
			
				|  |  |              null,
 | 
	
	
		
			
				|  | @@ -963,22 +975,23 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
 | 
	
		
			
				|  |  |              InputType.SEARCH,
 | 
	
		
			
				|  |  |              new ChunkingOptions(null, null),
 | 
	
		
			
				|  |  |              InferenceAction.Request.DEFAULT_TIMEOUT,
 | 
	
		
			
				|  |  | -            ActionListener.runAfter(resultsListener, () -> terminate(threadPool))
 | 
	
		
			
				|  |  | +            latchedListener
 | 
	
		
			
				|  |  |          );
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +        latch.await();
 | 
	
		
			
				|  |  |          assertTrue("Listener not called", gotResults.get());
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    public void testChunkInfer_ElserWithNullChunkingSettings() {
 | 
	
		
			
				|  |  | +    public void testChunkInfer_ElserWithNullChunkingSettings() throws InterruptedException {
 | 
	
		
			
				|  |  |          testChunkInfer_Elser(null);
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -    public void testChunkInfer_ElserWithChunkingSettingsSet() {
 | 
	
		
			
				|  |  | +    public void testChunkInfer_ElserWithChunkingSettingsSet() throws InterruptedException {
 | 
	
		
			
				|  |  |          testChunkInfer_Elser(ChunkingSettingsTests.createRandomChunkingSettings());
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @SuppressWarnings("unchecked")
 | 
	
		
			
				|  |  | -    private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) {
 | 
	
		
			
				|  |  | +    private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) throws InterruptedException {
 | 
	
		
			
				|  |  |          var mlTrainedModelResults = new ArrayList<InferenceResults>();
 | 
	
		
			
				|  |  |          mlTrainedModelResults.add(TextExpansionResultsTests.createRandomResults());
 | 
	
		
			
				|  |  |          mlTrainedModelResults.add(TextExpansionResultsTests.createRandomResults());
 | 
	
	
		
			
				|  | @@ -1022,6 +1035,9 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
 | 
	
		
			
				|  |  |              gotResults.set(true);
 | 
	
		
			
				|  |  |          }, ESTestCase::fail);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +        var latch = new CountDownLatch(1);
 | 
	
		
			
				|  |  | +        var latchedListener = new LatchedActionListener<>(resultsListener, latch);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          service.chunkedInfer(
 | 
	
		
			
				|  |  |              model,
 | 
	
		
			
				|  |  |              null,
 | 
	
	
		
			
				|  | @@ -1030,9 +1046,10 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
 | 
	
		
			
				|  |  |              InputType.SEARCH,
 | 
	
		
			
				|  |  |              new ChunkingOptions(null, null),
 | 
	
		
			
				|  |  |              InferenceAction.Request.DEFAULT_TIMEOUT,
 | 
	
		
			
				|  |  | -            ActionListener.runAfter(resultsListener, () -> terminate(threadPool))
 | 
	
		
			
				|  |  | +            latchedListener
 | 
	
		
			
				|  |  |          );
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +        latch.await();
 | 
	
		
			
				|  |  |          assertTrue("Listener not called", gotResults.get());
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -1093,7 +1110,7 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |      @SuppressWarnings("unchecked")
 | 
	
		
			
				|  |  | -    public void testChunkInfer_FailsBatch() {
 | 
	
		
			
				|  |  | +    public void testChunkInfer_FailsBatch() throws InterruptedException {
 | 
	
		
			
				|  |  |          var mlTrainedModelResults = new ArrayList<InferenceResults>();
 | 
	
		
			
				|  |  |          mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
 | 
	
		
			
				|  |  |          mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
 | 
	
	
		
			
				|  | @@ -1129,6 +1146,9 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
 | 
	
		
			
				|  |  |              gotResults.set(true);
 | 
	
		
			
				|  |  |          }, ESTestCase::fail);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +        var latch = new CountDownLatch(1);
 | 
	
		
			
				|  |  | +        var latchedListener = new LatchedActionListener<>(resultsListener, latch);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |          service.chunkedInfer(
 | 
	
		
			
				|  |  |              model,
 | 
	
		
			
				|  |  |              null,
 | 
	
	
		
			
				|  | @@ -1137,12 +1157,86 @@ public class ElasticsearchInternalServiceTests extends ESTestCase {
 | 
	
		
			
				|  |  |              InputType.SEARCH,
 | 
	
		
			
				|  |  |              new ChunkingOptions(null, null),
 | 
	
		
			
				|  |  |              InferenceAction.Request.DEFAULT_TIMEOUT,
 | 
	
		
			
				|  |  | -            ActionListener.runAfter(resultsListener, () -> terminate(threadPool))
 | 
	
		
			
				|  |  | +            latchedListener
 | 
	
		
			
				|  |  |          );
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +        latch.await();
 | 
	
		
			
				|  |  |          assertTrue("Listener not called", gotResults.get());
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    @SuppressWarnings("unchecked")
 | 
	
		
			
				|  |  | +    public void testChunkingLargeDocument() throws InterruptedException {
 | 
	
		
			
				|  |  | +        int numBatches = randomIntBetween(3, 6);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // how many response objects to return in each batch
 | 
	
		
			
				|  |  | +        int[] numResponsesPerBatch = new int[numBatches];
 | 
	
		
			
				|  |  | +        for (int i = 0; i < numBatches - 1; i++) {
 | 
	
		
			
				|  |  | +            numResponsesPerBatch[i] = ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE;
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +        numResponsesPerBatch[numBatches - 1] = randomIntBetween(1, ElasticsearchInternalService.EMBEDDING_MAX_BATCH_SIZE);
 | 
	
		
			
				|  |  | +        int numChunks = Arrays.stream(numResponsesPerBatch).sum();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // build a doc with enough words to make numChunks of chunks
 | 
	
		
			
				|  |  | +        int wordsPerChunk = 10;
 | 
	
		
			
				|  |  | +        int numWords = numChunks * wordsPerChunk;
 | 
	
		
			
				|  |  | +        var input = "word ".repeat(numWords);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        Client client = mock(Client.class);
 | 
	
		
			
				|  |  | +        when(client.threadPool()).thenReturn(threadPool);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // mock the inference response
 | 
	
		
			
				|  |  | +        doAnswer(invocationOnMock -> {
 | 
	
		
			
				|  |  | +            var request = (InferModelAction.Request) invocationOnMock.getArguments()[1];
 | 
	
		
			
				|  |  | +            var listener = (ActionListener<InferModelAction.Response>) invocationOnMock.getArguments()[2];
 | 
	
		
			
				|  |  | +            var mlTrainedModelResults = new ArrayList<InferenceResults>();
 | 
	
		
			
				|  |  | +            for (int i = 0; i < request.numberOfDocuments(); i++) {
 | 
	
		
			
				|  |  | +                mlTrainedModelResults.add(MlTextEmbeddingResultsTests.createRandomResults());
 | 
	
		
			
				|  |  | +            }
 | 
	
		
			
				|  |  | +            var response = new InferModelAction.Response(mlTrainedModelResults, "foo", true);
 | 
	
		
			
				|  |  | +            listener.onResponse(response);
 | 
	
		
			
				|  |  | +            return null;
 | 
	
		
			
				|  |  | +        }).when(client).execute(same(InferModelAction.INSTANCE), any(InferModelAction.Request.class), any(ActionListener.class));
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        var service = createService(client);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        var gotResults = new AtomicBoolean();
 | 
	
		
			
				|  |  | +        var resultsListener = ActionListener.<List<ChunkedInferenceServiceResults>>wrap(chunkedResponse -> {
 | 
	
		
			
				|  |  | +            assertThat(chunkedResponse, hasSize(1));
 | 
	
		
			
				|  |  | +            assertThat(chunkedResponse.get(0), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class));
 | 
	
		
			
				|  |  | +            var sparseResults = (InferenceChunkedTextEmbeddingFloatResults) chunkedResponse.get(0);
 | 
	
		
			
				|  |  | +            assertThat(sparseResults.chunks(), hasSize(numChunks));
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +            gotResults.set(true);
 | 
	
		
			
				|  |  | +        }, ESTestCase::fail);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // Create model using the word boundary chunker.
 | 
	
		
			
				|  |  | +        var model = new MultilingualE5SmallModel(
 | 
	
		
			
				|  |  | +            "foo",
 | 
	
		
			
				|  |  | +            TaskType.TEXT_EMBEDDING,
 | 
	
		
			
				|  |  | +            "e5",
 | 
	
		
			
				|  |  | +            new MultilingualE5SmallInternalServiceSettings(1, 1, "cross-platform", null),
 | 
	
		
			
				|  |  | +            new WordBoundaryChunkingSettings(wordsPerChunk, 0)
 | 
	
		
			
				|  |  | +        );
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        var latch = new CountDownLatch(1);
 | 
	
		
			
				|  |  | +        var latchedListener = new LatchedActionListener<>(resultsListener, latch);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        // For the given input we know how many requests will be made
 | 
	
		
			
				|  |  | +        service.chunkedInfer(
 | 
	
		
			
				|  |  | +            model,
 | 
	
		
			
				|  |  | +            null,
 | 
	
		
			
				|  |  | +            List.of(input),
 | 
	
		
			
				|  |  | +            Map.of(),
 | 
	
		
			
				|  |  | +            InputType.SEARCH,
 | 
	
		
			
				|  |  | +            new ChunkingOptions(null, null),
 | 
	
		
			
				|  |  | +            InferenceAction.Request.DEFAULT_TIMEOUT,
 | 
	
		
			
				|  |  | +            latchedListener
 | 
	
		
			
				|  |  | +        );
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        latch.await();
 | 
	
		
			
				|  |  | +        assertTrue("Listener not called with results", gotResults.get());
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      public void testParsePersistedConfig_Rerank() {
 | 
	
		
			
				|  |  |          // with task settings
 | 
	
		
			
				|  |  |          {
 |