|  | @@ -11,19 +11,37 @@ package org.elasticsearch.xpack.inference.services.elser;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import org.elasticsearch.action.ActionListener;
 | 
	
		
			
				|  |  |  import org.elasticsearch.client.internal.Client;
 | 
	
		
			
				|  |  | +import org.elasticsearch.inference.ChunkedInferenceServiceResults;
 | 
	
		
			
				|  |  | +import org.elasticsearch.inference.ChunkingOptions;
 | 
	
		
			
				|  |  | +import org.elasticsearch.inference.InferenceResults;
 | 
	
		
			
				|  |  |  import org.elasticsearch.inference.InferenceServiceExtension;
 | 
	
		
			
				|  |  | +import org.elasticsearch.inference.InputType;
 | 
	
		
			
				|  |  |  import org.elasticsearch.inference.Model;
 | 
	
		
			
				|  |  |  import org.elasticsearch.inference.ModelConfigurations;
 | 
	
		
			
				|  |  |  import org.elasticsearch.inference.TaskType;
 | 
	
		
			
				|  |  |  import org.elasticsearch.test.ESTestCase;
 | 
	
		
			
				|  |  | +import org.elasticsearch.threadpool.TestThreadPool;
 | 
	
		
			
				|  |  | +import org.elasticsearch.threadpool.ThreadPool;
 | 
	
		
			
				|  |  | +import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults;
 | 
	
		
			
				|  |  | +import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
 | 
	
		
			
				|  |  | +import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResultsTests;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +import java.util.ArrayList;
 | 
	
		
			
				|  |  |  import java.util.Collections;
 | 
	
		
			
				|  |  |  import java.util.HashMap;
 | 
	
		
			
				|  |  | +import java.util.List;
 | 
	
		
			
				|  |  |  import java.util.Map;
 | 
	
		
			
				|  |  |  import java.util.Set;
 | 
	
		
			
				|  |  | +import java.util.concurrent.atomic.AtomicBoolean;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import static org.hamcrest.Matchers.containsString;
 | 
	
		
			
				|  |  | +import static org.hamcrest.Matchers.hasSize;
 | 
	
		
			
				|  |  | +import static org.hamcrest.Matchers.instanceOf;
 | 
	
		
			
				|  |  | +import static org.mockito.ArgumentMatchers.any;
 | 
	
		
			
				|  |  | +import static org.mockito.ArgumentMatchers.same;
 | 
	
		
			
				|  |  | +import static org.mockito.Mockito.doAnswer;
 | 
	
		
			
				|  |  |  import static org.mockito.Mockito.mock;
 | 
	
		
			
				|  |  | +import static org.mockito.Mockito.when;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  public class ElserInternalServiceTests extends ESTestCase {
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -307,6 +325,69 @@ public class ElserInternalServiceTests extends ESTestCase {
 | 
	
		
			
				|  |  |          }
 | 
	
		
			
				|  |  |      }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +    @SuppressWarnings("unchecked")
 | 
	
		
			
				|  |  | +    public void testChunkInfer() {
 | 
	
		
			
				|  |  | +        var mlTrainedModelResults = new ArrayList<InferenceResults>();
 | 
	
		
			
				|  |  | +        mlTrainedModelResults.add(ChunkedTextExpansionResultsTests.createRandomResults());
 | 
	
		
			
				|  |  | +        mlTrainedModelResults.add(ChunkedTextExpansionResultsTests.createRandomResults());
 | 
	
		
			
				|  |  | +        var response = new InferTrainedModelDeploymentAction.Response(mlTrainedModelResults);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        ThreadPool threadpool = new TestThreadPool("test");
 | 
	
		
			
				|  |  | +        Client client = mock(Client.class);
 | 
	
		
			
				|  |  | +        when(client.threadPool()).thenReturn(threadpool);
 | 
	
		
			
				|  |  | +        doAnswer(invocationOnMock -> {
 | 
	
		
			
				|  |  | +            var listener = (ActionListener<InferTrainedModelDeploymentAction.Response>) invocationOnMock.getArguments()[2];
 | 
	
		
			
				|  |  | +            listener.onResponse(response);
 | 
	
		
			
				|  |  | +            return null;
 | 
	
		
			
				|  |  | +        }).when(client)
 | 
	
		
			
				|  |  | +            .execute(
 | 
	
		
			
				|  |  | +                same(InferTrainedModelDeploymentAction.INSTANCE),
 | 
	
		
			
				|  |  | +                any(InferTrainedModelDeploymentAction.Request.class),
 | 
	
		
			
				|  |  | +                any(ActionListener.class)
 | 
	
		
			
				|  |  | +            );
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        var model = new ElserInternalModel(
 | 
	
		
			
				|  |  | +            "foo",
 | 
	
		
			
				|  |  | +            TaskType.SPARSE_EMBEDDING,
 | 
	
		
			
				|  |  | +            "elser",
 | 
	
		
			
				|  |  | +            new ElserInternalServiceSettings(1, 1, "elser"),
 | 
	
		
			
				|  |  | +            new ElserMlNodeTaskSettings()
 | 
	
		
			
				|  |  | +        );
 | 
	
		
			
				|  |  | +        var service = createService(client);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        var gotResults = new AtomicBoolean();
 | 
	
		
			
				|  |  | +        var resultsListener = ActionListener.<List<ChunkedInferenceServiceResults>>wrap(chunkedResponse -> {
 | 
	
		
			
				|  |  | +            assertThat(chunkedResponse, hasSize(2));
 | 
	
		
			
				|  |  | +            assertThat(chunkedResponse.get(0), instanceOf(ChunkedSparseEmbeddingResults.class));
 | 
	
		
			
				|  |  | +            var result1 = (ChunkedSparseEmbeddingResults) chunkedResponse.get(0);
 | 
	
		
			
				|  |  | +            assertEquals(
 | 
	
		
			
				|  |  | +                ((org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResults) mlTrainedModelResults.get(0)).getChunks(),
 | 
	
		
			
				|  |  | +                result1.getChunkedResults()
 | 
	
		
			
				|  |  | +            );
 | 
	
		
			
				|  |  | +            assertThat(chunkedResponse.get(1), instanceOf(ChunkedSparseEmbeddingResults.class));
 | 
	
		
			
				|  |  | +            var result2 = (ChunkedSparseEmbeddingResults) chunkedResponse.get(1);
 | 
	
		
			
				|  |  | +            assertEquals(
 | 
	
		
			
				|  |  | +                ((org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResults) mlTrainedModelResults.get(1)).getChunks(),
 | 
	
		
			
				|  |  | +                result2.getChunkedResults()
 | 
	
		
			
				|  |  | +            );
 | 
	
		
			
				|  |  | +            gotResults.set(true);
 | 
	
		
			
				|  |  | +        }, ESTestCase::fail);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        service.chunkedInfer(
 | 
	
		
			
				|  |  | +            model,
 | 
	
		
			
				|  |  | +            List.of("foo", "bar"),
 | 
	
		
			
				|  |  | +            Map.of(),
 | 
	
		
			
				|  |  | +            InputType.SEARCH,
 | 
	
		
			
				|  |  | +            new ChunkingOptions(null, null),
 | 
	
		
			
				|  |  | +            ActionListener.runAfter(resultsListener, () -> terminate(threadpool))
 | 
	
		
			
				|  |  | +        );
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +        if (gotResults.get() == false) {
 | 
	
		
			
				|  |  | +            terminate(threadpool);
 | 
	
		
			
				|  |  | +        }
 | 
	
		
			
				|  |  | +        assertTrue("Listener not called", gotResults.get());
 | 
	
		
			
				|  |  | +    }
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |      private ElserInternalService createService(Client client) {
 | 
	
		
			
				|  |  |          var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client);
 | 
	
		
			
				|  |  |          return new ElserInternalService(context);
 |