瀏覽代碼

[ML] Create default word based chunker (#107303)

WordBoundaryChunker uses ICU4J to split text at word boundaries
creating chunks from long inputs. The chunksize and overlap 
parameters are measured in words. The chunk text is then processed 
in batches depending on the inference services supported batch size.
David Kyle 1 年之前
父節點
當前提交
ecc406edfc
共有 18 個文件被更改,包括 1190 次插入157 次删除
  1. 5 0
      docs/changelog/107303.yaml
  2. 116 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedTextEmbeddingFloatResults.java
  3. 2 0
      x-pack/plugin/inference/build.gradle
  4. 33 0
      x-pack/plugin/inference/licenses/icu4j-LICENSE.txt
  5. 3 0
      x-pack/plugin/inference/licenses/icu4j-NOTICE.txt
  6. 1 0
      x-pack/plugin/inference/src/main/java/module-info.java
  7. 8 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java
  8. 264 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/EmbeddingRequestChunker.java
  9. 111 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/WordBoundaryChunker.java
  10. 12 24
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java
  11. 5 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceFields.java
  12. 14 4
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java
  13. 5 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceFields.java
  14. 278 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/EmbeddingRequestChunkerTests.java
  15. 221 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/WordBoundaryChunkerTests.java
  16. 56 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ChunkedTextEmbeddingFloatResultsTests.java
  17. 27 105
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
  18. 29 24
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java

+ 5 - 0
docs/changelog/107303.yaml

@@ -0,0 +1,5 @@
+pr: 107303
+summary: Create default word based chunker
+area: Machine Learning
+type: feature
+issues: []

+ 116 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedTextEmbeddingFloatResults.java

@@ -0,0 +1,116 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.core.inference.results;
+
+import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.inference.ChunkedInferenceServiceResults;
+import org.elasticsearch.inference.InferenceResults;
+import org.elasticsearch.xcontent.ToXContent;
+import org.elasticsearch.xcontent.ToXContentObject;
+import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+public record ChunkedTextEmbeddingFloatResults(List<EmbeddingChunk> chunks) implements ChunkedInferenceServiceResults {
+
+    public static final String NAME = "chunked_text_embedding_service_float_results";
+    public static final String FIELD_NAME = "text_embedding_float_chunk";
+
+    public ChunkedTextEmbeddingFloatResults(StreamInput in) throws IOException {
+        this(in.readCollectionAsList(EmbeddingChunk::new));
+    }
+
+    @Override
+    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
+        // TODO add isTruncated flag
+        builder.startArray(FIELD_NAME);
+        for (var embedding : chunks) {
+            embedding.toXContent(builder, params);
+        }
+        builder.endArray();
+        return builder;
+    }
+
+    @Override
+    public void writeTo(StreamOutput out) throws IOException {
+        out.writeCollection(chunks);
+    }
+
+    @Override
+    public List<? extends InferenceResults> transformToCoordinationFormat() {
+        throw new UnsupportedOperationException("Chunked results are not returned in the coordinated action");
+    }
+
+    @Override
+    public List<? extends InferenceResults> transformToLegacyFormat() {
+        throw new UnsupportedOperationException("Chunked results are not returned in the legacy format");
+    }
+
+    @Override
+    public Map<String, Object> asMap() {
+        return Map.of(FIELD_NAME, chunks.stream().map(EmbeddingChunk::asMap).collect(Collectors.toList()));
+    }
+
+    @Override
+    public String getWriteableName() {
+        return NAME;
+    }
+
+    public List<EmbeddingChunk> getChunks() {
+        return chunks;
+    }
+
+    public record EmbeddingChunk(String matchedText, List<Float> embedding) implements Writeable, ToXContentObject {
+
+        public EmbeddingChunk(StreamInput in) throws IOException {
+            this(in.readString(), in.readCollectionAsImmutableList(StreamInput::readFloat));
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeString(matchedText);
+            out.writeCollection(embedding, StreamOutput::writeFloat);
+        }
+
+        @Override
+        public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
+            builder.startObject();
+            builder.field(ChunkedNlpInferenceResults.TEXT, matchedText);
+
+            builder.startArray(ChunkedNlpInferenceResults.INFERENCE);
+            for (Float value : embedding) {
+                builder.value(value);
+            }
+            builder.endArray();
+
+            builder.endObject();
+            return builder;
+        }
+
+        public Map<String, Object> asMap() {
+            var map = new HashMap<String, Object>();
+            map.put(ChunkedNlpInferenceResults.TEXT, matchedText);
+            map.put(ChunkedNlpInferenceResults.INFERENCE, embedding);
+            return map;
+        }
+
+        @Override
+        public String toString() {
+            return Strings.toString(this);
+        }
+    }
+
+}

+ 2 - 0
x-pack/plugin/inference/build.gradle

@@ -24,4 +24,6 @@ dependencies {
   compileOnly project(path: xpackModule('core'))
   testImplementation(testArtifact(project(xpackModule('core'))))
   testImplementation project(':modules:reindex')
+
+  api "com.ibm.icu:icu4j:${versions.icu4j}"
 }

+ 33 - 0
x-pack/plugin/inference/licenses/icu4j-LICENSE.txt

@@ -0,0 +1,33 @@
+ICU License - ICU 1.8.1 and later
+
+COPYRIGHT AND PERMISSION NOTICE
+
+Copyright (c) 1995-2012 International Business Machines Corporation and others
+
+All rights reserved.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, and/or sell copies of the
+Software, and to permit persons to whom the Software is furnished to do so,
+provided that the above copyright notice(s) and this permission notice appear
+in all copies of the Software and that both the above copyright notice(s) and
+this permission notice appear in supporting documentation.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF THIRD PARTY RIGHTS.
+IN NO EVENT SHALL THE COPYRIGHT HOLDER OR HOLDERS INCLUDED IN THIS NOTICE BE
+LIABLE FOR ANY CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, OR
+ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER
+IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+Except as contained in this notice, the name of a copyright holder shall not
+be used in advertising or otherwise to promote the sale, use or other
+dealings in this Software without prior written authorization of the
+copyright holder.
+
+All trademarks and registered trademarks mentioned herein are the property of
+their respective owners.

+ 3 - 0
x-pack/plugin/inference/licenses/icu4j-NOTICE.txt

@@ -0,0 +1,3 @@
+ICU4J, (under lucene/analysis/icu) is licensed under an MIT style license
+(modules/analysis/icu/lib/icu4j-LICENSE-BSD_LIKE.txt) and Copyright (c) 1995-2012
+International Business Machines Corporation and others

+ 1 - 0
x-pack/plugin/inference/src/main/java/module-info.java

@@ -17,6 +17,7 @@ module org.elasticsearch.inference {
     requires org.apache.httpcomponents.httpasyncclient;
     requires org.apache.httpcomponents.httpcore.nio;
     requires org.apache.lucene.core;
+    requires com.ibm.icu;
 
     exports org.elasticsearch.xpack.inference.action;
     exports org.elasticsearch.xpack.inference.registry;

+ 8 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

@@ -17,6 +17,7 @@ import org.elasticsearch.inference.TaskSettings;
 import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
 import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults;
 import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingByteResults;
+import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingFloatResults;
 import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults;
 import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
 import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults;
@@ -105,6 +106,13 @@ public class InferenceNamedWriteablesProvider {
                 ChunkedTextEmbeddingResults::new
             )
         );
+        namedWriteables.add(
+            new NamedWriteableRegistry.Entry(
+                InferenceServiceResults.class,
+                ChunkedTextEmbeddingFloatResults.NAME,
+                ChunkedTextEmbeddingFloatResults::new
+            )
+        );
         namedWriteables.add(
             new NamedWriteableRegistry.Entry(
                 InferenceServiceResults.class,

+ 264 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/EmbeddingRequestChunker.java

@@ -0,0 +1,264 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.common;
+
+import org.elasticsearch.ElasticsearchStatusException;
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.common.util.concurrent.AtomicArray;
+import org.elasticsearch.inference.ChunkedInferenceServiceResults;
+import org.elasticsearch.inference.InferenceServiceResults;
+import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Collectors;
+
+/**
+ * This class creates optimally sized batches of input strings
+ * for batched processing splitting long strings into smaller
+ * chunks. Multiple inputs may be fit into a single batch or
+ * a single large input that has been chunked may spread over
+ * multiple batches.
+ *
+ * The final aspect it to gather the responses from the batch
+ * processing and map the results back to the original element
+ * in the input list.
+ */
+public class EmbeddingRequestChunker {
+
+    public static final int DEFAULT_WORDS_PER_CHUNK = 250;
+    public static final int DEFAULT_CHUNK_OVERLAP = 100;
+
+    private final List<BatchRequest> batchedRequests = new ArrayList<>();
+    private final AtomicInteger resultCount = new AtomicInteger();
+    private final int maxNumberOfInputsPerBatch;
+    private final int wordsPerChunk;
+    private final int chunkOverlap;
+
+    private List<List<String>> chunkedInputs;
+    private List<AtomicArray<List<TextEmbeddingResults.Embedding>>> results;
+    private AtomicArray<ErrorChunkedInferenceResults> errors;
+    private ActionListener<List<ChunkedInferenceServiceResults>> finalListener;
+
+    public EmbeddingRequestChunker(List<String> inputs, int maxNumberOfInputsPerBatch) {
+        this.maxNumberOfInputsPerBatch = maxNumberOfInputsPerBatch;
+        this.wordsPerChunk = DEFAULT_WORDS_PER_CHUNK;
+        this.chunkOverlap = DEFAULT_CHUNK_OVERLAP;
+        splitIntoBatchedRequests(inputs);
+    }
+
+    public EmbeddingRequestChunker(List<String> inputs, int maxNumberOfInputsPerBatch, int wordsPerChunk, int chunkOverlap) {
+        this.maxNumberOfInputsPerBatch = maxNumberOfInputsPerBatch;
+        this.wordsPerChunk = wordsPerChunk;
+        this.chunkOverlap = chunkOverlap;
+        splitIntoBatchedRequests(inputs);
+    }
+
+    private void splitIntoBatchedRequests(List<String> inputs) {
+        var chunker = new WordBoundaryChunker();
+        chunkedInputs = new ArrayList<>(inputs.size());
+        results = new ArrayList<>(inputs.size());
+        errors = new AtomicArray<>(inputs.size());
+
+        for (int i = 0; i < inputs.size(); i++) {
+            var chunks = chunker.chunk(inputs.get(i), wordsPerChunk, chunkOverlap);
+            int numberOfSubBatches = addToBatches(chunks, i);
+            // size the results array with the expected number of request/responses
+            results.add(new AtomicArray<>(numberOfSubBatches));
+            chunkedInputs.add(chunks);
+        }
+    }
+
+    private int addToBatches(List<String> chunks, int inputIndex) {
+        BatchRequest lastBatch;
+        if (batchedRequests.isEmpty()) {
+            lastBatch = new BatchRequest(new ArrayList<>());
+            batchedRequests.add(lastBatch);
+        } else {
+            lastBatch = batchedRequests.get(batchedRequests.size() - 1);
+        }
+
+        int freeSpace = maxNumberOfInputsPerBatch - lastBatch.size();
+        assert freeSpace >= 0;
+
+        // chunks may span multiple batches,
+        // the chunkIndex keeps them ordered.
+        int chunkIndex = 0;
+
+        if (freeSpace > 0) {
+            // use any free space in the previous batch before creating new batches
+            int toAdd = Math.min(freeSpace, chunks.size());
+            lastBatch.addSubBatch(new SubBatch(chunks.subList(0, toAdd), new SubBatchPositionsAndCount(inputIndex, chunkIndex++, toAdd)));
+        }
+
+        int start = freeSpace;
+        while (start < chunks.size()) {
+            int toAdd = Math.min(maxNumberOfInputsPerBatch, chunks.size() - start);
+            var batch = new BatchRequest(new ArrayList<>());
+            batch.addSubBatch(
+                new SubBatch(chunks.subList(start, start + toAdd), new SubBatchPositionsAndCount(inputIndex, chunkIndex++, toAdd))
+            );
+            batchedRequests.add(batch);
+            start += toAdd;
+        }
+
+        return chunkIndex;
+    }
+
+    /**
+     * Returns a list of batched inputs and a ActionListener for each batch.
+     * @param finalListener The listener to call once all the batches are processed
+     * @return Batches and listeners
+     */
+    public List<BatchRequestAndListener> batchRequestsWithListeners(ActionListener<List<ChunkedInferenceServiceResults>> finalListener) {
+        this.finalListener = finalListener;
+
+        int numberOfRequests = batchedRequests.size();
+
+        var requests = new ArrayList<BatchRequestAndListener>(numberOfRequests);
+        for (var batch : batchedRequests) {
+            requests.add(
+                new BatchRequestAndListener(
+                    batch,
+                    new DebatchingListener(
+                        batch.subBatches().stream().map(SubBatch::positions).collect(Collectors.toList()),
+                        numberOfRequests
+                    )
+                )
+            );
+        }
+
+        return requests;
+    }
+
+    /**
+     * A grouping listener that calls the final listener only when
+     * all responses have been received.
+     * Long inputs that were split into chunks are reassembled and
+     * returned as a single chunked response.
+     * The listener knows where in the results array to insert the
+     * response so that order is preserved.
+     */
+    private class DebatchingListener implements ActionListener<InferenceServiceResults> {
+
+        private final List<SubBatchPositionsAndCount> positions;
+        private final int totalNumberOfRequests;
+
+        DebatchingListener(List<SubBatchPositionsAndCount> positions, int totalNumberOfRequests) {
+            this.positions = positions;
+            this.totalNumberOfRequests = totalNumberOfRequests;
+        }
+
+        @Override
+        public void onResponse(InferenceServiceResults inferenceServiceResults) {
+            if (inferenceServiceResults instanceof TextEmbeddingResults textEmbeddingResults) { // TODO byte embeddings
+                int numRequests = positions.stream().mapToInt(SubBatchPositionsAndCount::embeddingCount).sum();
+                if (numRequests != textEmbeddingResults.embeddings().size()) {
+                    onFailure(
+                        new ElasticsearchStatusException(
+                            "Error the number of embedding responses [{}] does not equal the number of " + "requests [{}]",
+                            RestStatus.BAD_REQUEST,
+                            textEmbeddingResults.embeddings().size(),
+                            numRequests
+                        )
+                    );
+                    return;
+                }
+
+                int start = 0;
+                for (var pos : positions) {
+                    results.get(pos.inputIndex())
+                        .setOnce(pos.chunkIndex(), textEmbeddingResults.embeddings().subList(start, start + pos.embeddingCount()));
+                    start += pos.embeddingCount();
+                }
+            }
+
+            if (resultCount.incrementAndGet() == totalNumberOfRequests) {
+                sendResponse();
+            }
+        }
+
+        @Override
+        public void onFailure(Exception e) {
+            var errorResult = new ErrorChunkedInferenceResults(e);
+            for (var pos : positions) {
+                errors.setOnce(pos.inputIndex(), errorResult);
+            }
+
+            if (resultCount.incrementAndGet() == totalNumberOfRequests) {
+                sendResponse();
+            }
+        }
+
+        private void sendResponse() {
+            var response = new ArrayList<ChunkedInferenceServiceResults>(chunkedInputs.size());
+            for (int i = 0; i < chunkedInputs.size(); i++) {
+                if (errors.get(i) != null) {
+                    response.add(errors.get(i));
+                } else {
+                    response.add(merge(chunkedInputs.get(i), results.get(i)));
+                }
+            }
+
+            finalListener.onResponse(response);
+        }
+
+        private ChunkedTextEmbeddingFloatResults merge(
+            List<String> chunks,
+            AtomicArray<List<TextEmbeddingResults.Embedding>> debatchedResults
+        ) {
+            var all = new ArrayList<TextEmbeddingResults.Embedding>();
+            for (int i = 0; i < debatchedResults.length(); i++) {
+                var subBatch = debatchedResults.get(i);
+                all.addAll(subBatch);
+            }
+
+            assert chunks.size() == all.size();
+
+            var embeddingChunks = new ArrayList<ChunkedTextEmbeddingFloatResults.EmbeddingChunk>();
+            for (int i = 0; i < chunks.size(); i++) {
+                embeddingChunks.add(new ChunkedTextEmbeddingFloatResults.EmbeddingChunk(chunks.get(i), all.get(i).values()));
+            }
+
+            return new ChunkedTextEmbeddingFloatResults(embeddingChunks);
+        }
+    }
+
+    public record BatchRequest(List<SubBatch> subBatches) {
+        public int size() {
+            return subBatches.stream().mapToInt(SubBatch::size).sum();
+        }
+
+        public void addSubBatch(SubBatch sb) {
+            subBatches.add(sb);
+        }
+
+        public List<String> inputs() {
+            return subBatches.stream().flatMap(s -> s.requests().stream()).collect(Collectors.toList());
+        }
+    }
+
+    public record BatchRequestAndListener(BatchRequest batch, ActionListener<InferenceServiceResults> listener) {
+
+    }
+
+    /**
+     * Used for mapping batched requests back to the original input
+     */
+    record SubBatchPositionsAndCount(int inputIndex, int chunkIndex, int embeddingCount) {}
+
+    record SubBatch(List<String> requests, SubBatchPositionsAndCount positions) {
+        public int size() {
+            return requests.size();
+        }
+    }
+}

+ 111 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/WordBoundaryChunker.java

@@ -0,0 +1,111 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.common;
+
+import com.ibm.icu.text.BreakIterator;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Locale;
+
+/**
+ * Breaks text into smaller strings or chunks on Word boundaries.
+ * Whitespace is preserved and included in the start of the
+ * following chunk not the end of the chunk. If the chunk ends
+ * on a punctuation mark the punctuation is included in the
+ * next chunk.
+ *
+ * The overlap value must be > (chunkSize /2) to avoid the
+ * complexity of tracking the start positions of multiple
+ * chunks within the chunk.
+ */
+public class WordBoundaryChunker {
+
+    private BreakIterator wordIterator;
+
+    public WordBoundaryChunker() {
+        wordIterator = BreakIterator.getWordInstance(Locale.ROOT);
+    }
+
+    /**
+     * Break the input text into small chunks as dictated
+     * by the chunking parameters
+     * @param input Text to chunk
+     * @param chunkSize The number of words in each chunk
+     * @param overlap The number of words to overlap each chunk.
+     *                Can be 0 but must be non-negative.
+     * @return List of chunked text
+     */
+    public List<String> chunk(String input, int chunkSize, int overlap) {
+        if (overlap > 0 && overlap > chunkSize / 2) {
+            throw new IllegalArgumentException(
+                "Invalid chunking parameters, overlap ["
+                    + overlap
+                    + "] must be < chunk size / 2 ["
+                    + chunkSize
+                    + " / 2 = "
+                    + chunkSize / 2
+                    + "]"
+            );
+        }
+
+        if (overlap < 0) {
+            throw new IllegalArgumentException("Invalid chunking parameters, overlap [" + overlap + "] must be >= 0");
+        }
+
+        if (input.isEmpty()) {
+            return List.of("");
+        }
+
+        var chunks = new ArrayList<String>();
+
+        // This position in the chunk is where the next overlapping chunk will start
+        final int chunkSizeLessOverlap = chunkSize - overlap;
+        // includes the count of words from the overlap portion in the previous chunk
+        int wordsInChunkCountIncludingOverlap = 0;
+        int nextWindowStart = 0;
+        int windowStart = 0;
+        int wordsSinceStartWindowWasMarked = 0;
+
+        wordIterator.setText(input);
+        int boundary = wordIterator.next();
+
+        while (boundary != BreakIterator.DONE) {
+            if (wordIterator.getRuleStatus() != BreakIterator.WORD_NONE) {
+                wordsInChunkCountIncludingOverlap++;
+                wordsSinceStartWindowWasMarked++;
+
+                if (wordsInChunkCountIncludingOverlap >= chunkSize) {
+                    chunks.add(input.substring(windowStart, boundary));
+                    wordsInChunkCountIncludingOverlap = overlap;
+
+                    if (overlap == 0) {
+                        nextWindowStart = boundary;
+                    }
+
+                    windowStart = nextWindowStart;
+                }
+
+                if (wordsSinceStartWindowWasMarked == chunkSizeLessOverlap) {
+                    nextWindowStart = boundary;
+                    wordsSinceStartWindowWasMarked = 0;
+                }
+            }
+            boundary = wordIterator.next();
+        }
+
+        // Get the last chunk that was shorter than the required chunk size
+        // if it ends on a boundary than the count should equal overlap in which case
+        // we can ignore it, unless this is the first chunk in which case we want to add it
+        if (wordsInChunkCountIncludingOverlap > overlap || chunks.isEmpty()) {
+            chunks.add(input.substring(windowStart));
+        }
+
+        return chunks;
+    }
+}

+ 12 - 24
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java

@@ -23,12 +23,7 @@ import org.elasticsearch.inference.ModelSecrets;
 import org.elasticsearch.inference.SimilarityMeasure;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.rest.RestStatus;
-import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults;
-import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
-import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
+import org.elasticsearch.xpack.inference.common.EmbeddingRequestChunker;
 import org.elasticsearch.xpack.inference.external.action.cohere.CohereActionCreator;
 import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
@@ -45,12 +40,12 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 
-import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
+import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields.EMBEDDING_MAX_BATCH_SIZE;
 
 public class CohereService extends SenderService {
     public static final String NAME = "cohere";
@@ -229,25 +224,18 @@ public class CohereService extends SenderService {
         TimeValue timeout,
         ActionListener<List<ChunkedInferenceServiceResults>> listener
     ) {
-        ActionListener<InferenceServiceResults> inferListener = listener.delegateFailureAndWrap(
-            (delegate, response) -> delegate.onResponse(translateToChunkedResults(input, response))
-        );
+        if (model instanceof CohereModel == false) {
+            listener.onFailure(createInvalidModelException(model));
+            return;
+        }
 
-        doInfer(model, input, taskSettings, inputType, timeout, inferListener);
-    }
+        CohereModel cohereModel = (CohereModel) model;
+        var actionCreator = new CohereActionCreator(getSender(), getServiceComponents());
 
-    private static List<ChunkedInferenceServiceResults> translateToChunkedResults(
-        List<String> inputs,
-        InferenceServiceResults inferenceResults
-    ) {
-        if (inferenceResults instanceof TextEmbeddingResults textEmbeddingResults) {
-            return ChunkedTextEmbeddingResults.of(inputs, textEmbeddingResults);
-        } else if (inferenceResults instanceof TextEmbeddingByteResults textEmbeddingByteResults) {
-            return ChunkedTextEmbeddingByteResults.of(inputs, textEmbeddingByteResults);
-        } else if (inferenceResults instanceof ErrorInferenceResults error) {
-            return List.of(new ErrorChunkedInferenceResults(error.getException()));
-        } else {
-            throw createInvalidChunkedResultException(inferenceResults.getWriteableName());
+        var batchedRequests = new EmbeddingRequestChunker(input, EMBEDDING_MAX_BATCH_SIZE).batchRequestsWithListeners(listener);
+        for (var request : batchedRequests) {
+            var action = cohereModel.accept(actionCreator, taskSettings, inputType);
+            action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());
         }
     }
 

+ 5 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceFields.java

@@ -9,4 +9,9 @@ package org.elasticsearch.xpack.inference.services.cohere;
 
 public class CohereServiceFields {
     public static final String TRUNCATE = "truncate";
+
+    /**
+     * Taken from https://docs.cohere.com/reference/embed
+     */
+    static final int EMBEDDING_MAX_BATCH_SIZE = 96;
 }

+ 14 - 4
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java

@@ -28,6 +28,7 @@ import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResult
 import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
 import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
 import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults;
+import org.elasticsearch.xpack.inference.common.EmbeddingRequestChunker;
 import org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionCreator;
 import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
@@ -50,6 +51,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersi
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
+import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.EMBEDDING_MAX_BATCH_SIZE;
 
 public class OpenAiService extends SenderService {
     public static final String NAME = "openai";
@@ -232,11 +234,19 @@ public class OpenAiService extends SenderService {
         TimeValue timeout,
         ActionListener<List<ChunkedInferenceServiceResults>> listener
     ) {
-        ActionListener<InferenceServiceResults> inferListener = listener.delegateFailureAndWrap(
-            (delegate, response) -> delegate.onResponse(translateToChunkedResults(input, response))
-        );
+        if (model instanceof OpenAiModel == false) {
+            listener.onFailure(createInvalidModelException(model));
+            return;
+        }
 
-        doInfer(model, input, taskSettings, inputType, timeout, inferListener);
+        OpenAiModel openAiModel = (OpenAiModel) model;
+        var actionCreator = new OpenAiActionCreator(getSender(), getServiceComponents());
+
+        var batchedRequests = new EmbeddingRequestChunker(input, EMBEDDING_MAX_BATCH_SIZE).batchRequestsWithListeners(listener);
+        for (var request : batchedRequests) {
+            var action = openAiModel.accept(actionCreator, taskSettings);
+            action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());
+        }
     }
 
     private static List<ChunkedInferenceServiceResults> translateToChunkedResults(

+ 5 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceFields.java

@@ -13,4 +13,9 @@ public class OpenAiServiceFields {
 
     public static final String ORGANIZATION = "organization_id";
 
+    /**
+     * Taken from https://platform.openai.com/docs/api-reference/embeddings/create
+     */
+    static final int EMBEDDING_MAX_BATCH_SIZE = 2048;
+
 }

+ 278 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/EmbeddingRequestChunkerTests.java

@@ -0,0 +1,278 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.common;
+
+import org.elasticsearch.action.ActionListener;
+import org.elasticsearch.inference.ChunkedInferenceServiceResults;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingFloatResults;
+import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults;
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicReference;
+
+import static org.hamcrest.Matchers.contains;
+import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.hamcrest.Matchers.startsWith;
+
+public class EmbeddingRequestChunkerTests extends ESTestCase {
+
+    public void testShortInputsAreSingleBatch() {
+        String input = "one chunk";
+
+        var batches = new EmbeddingRequestChunker(List.of(input), 100, 100, 10).batchRequestsWithListeners(testListener());
+        assertThat(batches, hasSize(1));
+        assertThat(batches.get(0).batch().inputs(), contains(input));
+    }
+
+    public void testMultipleShortInputsAreSingleBatch() {
+        List<String> inputs = List.of("1st small", "2nd small", "3rd small");
+
+        var batches = new EmbeddingRequestChunker(inputs, 100, 100, 10).batchRequestsWithListeners(testListener());
+        assertThat(batches, hasSize(1));
+        assertEquals(batches.get(0).batch().inputs(), inputs);
+        var subBatches = batches.get(0).batch().subBatches();
+        for (int i = 0; i < inputs.size(); i++) {
+            var subBatch = subBatches.get(i);
+            assertThat(subBatch.requests(), contains(inputs.get(i)));
+            assertEquals(0, subBatch.positions().chunkIndex());
+            assertEquals(i, subBatch.positions().inputIndex());
+            assertEquals(1, subBatch.positions().embeddingCount());
+        }
+    }
+
+    public void testManyInputsMakeManyBatches() {
+        int maxNumInputsPerBatch = 10;
+        int numInputs = maxNumInputsPerBatch * 3 + 1; // requires 4 batches
+        var inputs = new ArrayList<String>();
+        //
+        for (int i = 0; i < numInputs; i++) {
+            inputs.add("input " + i);
+        }
+
+        var batches = new EmbeddingRequestChunker(inputs, maxNumInputsPerBatch, 100, 10).batchRequestsWithListeners(testListener());
+        assertThat(batches, hasSize(4));
+        assertThat(batches.get(0).batch().inputs(), hasSize(maxNumInputsPerBatch));
+        assertThat(batches.get(1).batch().inputs(), hasSize(maxNumInputsPerBatch));
+        assertThat(batches.get(2).batch().inputs(), hasSize(maxNumInputsPerBatch));
+        assertThat(batches.get(3).batch().inputs(), hasSize(1));
+
+        assertEquals("input 0", batches.get(0).batch().inputs().get(0));
+        assertEquals("input 9", batches.get(0).batch().inputs().get(9));
+        assertThat(
+            batches.get(1).batch().inputs(),
+            contains("input 10", "input 11", "input 12", "input 13", "input 14", "input 15", "input 16", "input 17", "input 18", "input 19")
+        );
+        assertEquals("input 20", batches.get(2).batch().inputs().get(0));
+        assertEquals("input 29", batches.get(2).batch().inputs().get(9));
+        assertThat(batches.get(3).batch().inputs(), contains("input 30"));
+
+        int inputIndex = 0;
+        var subBatches = batches.get(0).batch().subBatches();
+        for (int i = 0; i < batches.size(); i++) {
+            var subBatch = subBatches.get(i);
+            assertThat(subBatch.requests(), contains(inputs.get(i)));
+            assertEquals(0, subBatch.positions().chunkIndex());
+            assertEquals(inputIndex, subBatch.positions().inputIndex());
+            assertEquals(1, subBatch.positions().embeddingCount());
+            inputIndex++;
+        }
+    }
+
+    public void testLongInputChunkedOverMultipleBatches() {
+        int batchSize = 5;
+        int chunkSize = 20;
+        int overlap = 0;
+        // passage will be chunked into batchSize + 1 parts
+        // and spread over 2 batch requests
+        int numberOfWordsInPassage = (chunkSize * batchSize) + 5;
+
+        var passageBuilder = new StringBuilder();
+        for (int i = 0; i < numberOfWordsInPassage; i++) {
+            passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace
+        }
+
+        List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small");
+
+        var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(testListener());
+        assertThat(batches, hasSize(2));
+        {
+            var batch = batches.get(0).batch();
+            assertThat(batch.inputs(), hasSize(batchSize));
+            assertEquals(batchSize, batch.size());
+            assertThat(batch.subBatches(), hasSize(2));
+            {
+                var subBatch = batch.subBatches().get(0);
+                assertEquals(0, subBatch.positions().inputIndex());
+                assertEquals(0, subBatch.positions().chunkIndex());
+                assertEquals(1, subBatch.positions().embeddingCount());
+                assertThat(subBatch.requests(), contains("1st small"));
+            }
+            {
+                var subBatch = batch.subBatches().get(1);
+                assertEquals(1, subBatch.positions().inputIndex()); // 2nd input
+                assertEquals(0, subBatch.positions().chunkIndex());  // 1st part of the 2nd input
+                assertEquals(4, subBatch.positions().embeddingCount()); // 4 chunks
+                assertThat(subBatch.requests().get(0), startsWith("passage_input0 "));
+                assertThat(subBatch.requests().get(1), startsWith(" passage_input20 "));
+                assertThat(subBatch.requests().get(2), startsWith(" passage_input40 "));
+                assertThat(subBatch.requests().get(3), startsWith(" passage_input60 "));
+            }
+        }
+        {
+            var batch = batches.get(1).batch();
+            assertThat(batch.inputs(), hasSize(4));
+            assertEquals(4, batch.size());
+            assertThat(batch.subBatches(), hasSize(3));
+            {
+                var subBatch = batch.subBatches().get(0);
+                assertEquals(1, subBatch.positions().inputIndex()); // 2nd input
+                assertEquals(1, subBatch.positions().chunkIndex()); // 2nd part of the 2nd input
+                assertEquals(2, subBatch.positions().embeddingCount());
+                assertThat(subBatch.requests().get(0), startsWith(" passage_input80 "));
+                assertThat(subBatch.requests().get(1), startsWith(" passage_input100 "));
+            }
+            {
+                var subBatch = batch.subBatches().get(1);
+                assertEquals(2, subBatch.positions().inputIndex()); // 3rd input
+                assertEquals(0, subBatch.positions().chunkIndex());  // 1st and only part
+                assertEquals(1, subBatch.positions().embeddingCount()); // 1 chunk
+                assertThat(subBatch.requests(), contains("2nd small"));
+            }
+            {
+                var subBatch = batch.subBatches().get(2);
+                assertEquals(3, subBatch.positions().inputIndex());  // 4th input
+                assertEquals(0, subBatch.positions().chunkIndex());  // 1st and only part
+                assertEquals(1, subBatch.positions().embeddingCount()); // 1 chunk
+                assertThat(subBatch.requests(), contains("3rd small"));
+            }
+        }
+    }
+
+    public void testMergingListener() {
+        int batchSize = 5;
+        int chunkSize = 20;
+        int overlap = 0;
+        // passage will be chunked into batchSize + 1 parts
+        // and spread over 2 batch requests
+        int numberOfWordsInPassage = (chunkSize * batchSize) + 5;
+
+        var passageBuilder = new StringBuilder();
+        for (int i = 0; i < numberOfWordsInPassage; i++) {
+            passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace
+        }
+        List<String> inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small");
+
+        var finalListener = testListener();
+        var batches = new EmbeddingRequestChunker(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener);
+        assertThat(batches, hasSize(2));
+
+        // 4 inputs in 2 batches
+        {
+            var embeddings = new ArrayList<TextEmbeddingResults.Embedding>();
+            for (int i = 0; i < batchSize; i++) {
+                embeddings.add(new TextEmbeddingResults.Embedding(List.of(randomFloat())));
+            }
+            batches.get(0).listener().onResponse(new TextEmbeddingResults(embeddings));
+        }
+        {
+            var embeddings = new ArrayList<TextEmbeddingResults.Embedding>();
+            for (int i = 0; i < 4; i++) { // 4 requests in the 2nd batch
+                embeddings.add(new TextEmbeddingResults.Embedding(List.of(randomFloat())));
+            }
+            batches.get(1).listener().onResponse(new TextEmbeddingResults(embeddings));
+        }
+
+        assertNotNull(finalListener.results);
+        assertThat(finalListener.results, hasSize(4));
+        {
+            var chunkedResult = finalListener.results.get(0);
+            assertThat(chunkedResult, instanceOf(ChunkedTextEmbeddingFloatResults.class));
+            var chunkedFloatResult = (ChunkedTextEmbeddingFloatResults) chunkedResult;
+            assertThat(chunkedFloatResult.chunks(), hasSize(1));
+            assertEquals("1st small", chunkedFloatResult.chunks().get(0).matchedText());
+        }
+        {
+            // this is the large input split in multiple chunks
+            var chunkedResult = finalListener.results.get(1);
+            assertThat(chunkedResult, instanceOf(ChunkedTextEmbeddingFloatResults.class));
+            var chunkedFloatResult = (ChunkedTextEmbeddingFloatResults) chunkedResult;
+            assertThat(chunkedFloatResult.chunks(), hasSize(6));
+            assertThat(chunkedFloatResult.chunks().get(0).matchedText(), startsWith("passage_input0 "));
+            assertThat(chunkedFloatResult.chunks().get(1).matchedText(), startsWith(" passage_input20 "));
+            assertThat(chunkedFloatResult.chunks().get(2).matchedText(), startsWith(" passage_input40 "));
+            assertThat(chunkedFloatResult.chunks().get(3).matchedText(), startsWith(" passage_input60 "));
+            assertThat(chunkedFloatResult.chunks().get(4).matchedText(), startsWith(" passage_input80 "));
+            assertThat(chunkedFloatResult.chunks().get(5).matchedText(), startsWith(" passage_input100 "));
+        }
+        {
+            var chunkedResult = finalListener.results.get(2);
+            assertThat(chunkedResult, instanceOf(ChunkedTextEmbeddingFloatResults.class));
+            var chunkedFloatResult = (ChunkedTextEmbeddingFloatResults) chunkedResult;
+            assertThat(chunkedFloatResult.chunks(), hasSize(1));
+            assertEquals("2nd small", chunkedFloatResult.chunks().get(0).matchedText());
+        }
+        {
+            var chunkedResult = finalListener.results.get(3);
+            assertThat(chunkedResult, instanceOf(ChunkedTextEmbeddingFloatResults.class));
+            var chunkedFloatResult = (ChunkedTextEmbeddingFloatResults) chunkedResult;
+            assertThat(chunkedFloatResult.chunks(), hasSize(1));
+            assertEquals("3rd small", chunkedFloatResult.chunks().get(0).matchedText());
+        }
+    }
+
+    public void testListenerErrorsWithWrongNumberOfResponses() {
+        List<String> inputs = List.of("1st small", "2nd small", "3rd small");
+
+        var failureMessage = new AtomicReference<String>();
+        var listener = new ActionListener<List<ChunkedInferenceServiceResults>>() {
+
+            @Override
+            public void onResponse(List<ChunkedInferenceServiceResults> chunkedInferenceServiceResults) {
+                assertThat(chunkedInferenceServiceResults.get(0), instanceOf(ErrorChunkedInferenceResults.class));
+                var error = (ErrorChunkedInferenceResults) chunkedInferenceServiceResults.get(0);
+                failureMessage.set(error.getException().getMessage());
+            }
+
+            @Override
+            public void onFailure(Exception e) {
+                fail("expected a response with an error");
+            }
+        };
+
+        var batches = new EmbeddingRequestChunker(inputs, 10, 100, 0).batchRequestsWithListeners(listener);
+        assertThat(batches, hasSize(1));
+
+        var embeddings = new ArrayList<TextEmbeddingResults.Embedding>();
+        embeddings.add(new TextEmbeddingResults.Embedding(List.of(randomFloat())));
+        embeddings.add(new TextEmbeddingResults.Embedding(List.of(randomFloat())));
+        batches.get(0).listener().onResponse(new TextEmbeddingResults(embeddings));
+        assertEquals("Error the number of embedding responses [2] does not equal the number of requests [3]", failureMessage.get());
+    }
+
+    private ChunkedResultsListener testListener() {
+        return new ChunkedResultsListener();
+    }
+
+    private static class ChunkedResultsListener implements ActionListener<List<ChunkedInferenceServiceResults>> {
+        List<ChunkedInferenceServiceResults> results;
+
+        @Override
+        public void onResponse(List<ChunkedInferenceServiceResults> chunkedInferenceServiceResults) {
+            this.results = chunkedInferenceServiceResults;
+        }
+
+        @Override
+        public void onFailure(Exception e) {
+            fail(e.getMessage());
+        }
+    }
+}

+ 221 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/WordBoundaryChunkerTests.java

@@ -0,0 +1,221 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.common;
+
+import org.elasticsearch.test.ESTestCase;
+
+import java.util.List;
+
+import static org.hamcrest.Matchers.contains;
+import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.hasSize;
+
+public class WordBoundaryChunkerTests extends ESTestCase {
+
+    private final String TEST_TEXT = "Word segmentation is the problem of dividing a string of written language into its component words.\n"
+        + "In English and many other languages using some form of the Latin alphabet, the space is a good approximation of a word divider "
+        + "(word delimiter), although this concept has limits because of the variability with which languages emically regard collocations "
+        + "and compounds. Many English compound nouns are variably written (for example, ice box = ice-box = icebox; pig sty = pig-sty = "
+        + "pigsty) with a corresponding variation in whether speakers think of them as noun phrases or single nouns; there are trends in "
+        + "how norms are set, such as that open compounds often tend eventually to solidify by widespread convention, but variation remains"
+        + " systemic. In contrast, German compound nouns show less orthographic variation, with solidification being a stronger norm.";
+
+    private final String[] MULTI_LINGUAL = new String[] {
+        "Građevne strukture Mesa Verde dokaz su akumuliranog znanja i vještina koje su se stoljećima prenosile generacijama civilizacije"
+            + " Anasazi. Vrhunce svojih dosega ostvarili su u 12. i 13. stoljeću, kada su sagrađene danas najpoznatije građevine na "
+            + "liticama. Zidali su obrađenim pješčenjakom, tvrđim kamenom oblikovanim do veličine štruce kruha. Kao žbuku između ciglā "
+            + "stavljali su glinu razmočenu vodom. Tim su materijalom gradili prostorije veličine do 6 četvornih metara. U potkrovljima "
+            + "su skladištili žitarice i druge plodine, dok su kive - ceremonijalne prostorije - gradili ispred soba, ali ukopane u zemlju,"
+            + " nešto poput današnjih podruma. Kiva je bila vrhunski dizajnirana prostorija okruglog oblika s prostorom za vatru zimi te s"
+            + " dovodom hladnog zraka za klimatizaciju ljeti. U zidane konstrukcije stavljali su i lokalno posječena stabla, što današnjim"
+            + " arheolozima pomaže u preciznom datiranju nastanka pojedine građevine metodom dendrokronologije. Ta stabla pridonose i"
+            + " teoriji o mogućem konačnom slomu ondašnjeg društva. Nakon što su, tijekom nekoliko stoljeća, šume do kraja srušene, a "
+            + "njihova obnova zbog sušne klime traje i po 200 godina, nije proteklo puno vremena do konačnog urušavanja civilizacije, "
+            + "koja se, na svojem vrhuncu osjećala nepobjedivom. 90 % sagrađenih naseobina ispod stijena ima do deset prostorija. ⅓ od "
+            + "ukupnog broja sagrađenih kuća ima jednu ili dvije kamene prostorije",
+        "Histoarysk wie in acre in stik lân dat 40 roeden (oftewol 1 furlong of ⅛ myl of 660 foet) lang wie, en 4 roeden (of 66 foet) "
+            + "breed. Men is fan tinken dat dat likernôch de grûnmjitte wie dy't men mei in jok oksen yn ien dei beploegje koe.",
+        "創業当初の「太平洋化学工業社」から1959年太平洋化学工業株式会社へ、1987年には太平洋化学㈱に社名を変更。 1990年以降、海外拠点を増やし本格的な国際進出を始動。"
+            + " 創業者がつくりあげた化粧品会社を世界企業へと成長させるべく2002年3月英文社名AMOREPACIFICに改めた。",
+        "۱۔ ھن شق جي مطابق قادياني گروھ يا لاھوري گروھ جي ڪنھن رڪن کي جيڪو پاڻ کي 'احمدي' يا ڪنھن ٻي نالي سان پڪاري جي لاءِ ممنوع قرار "
+            + "ڏنو ويو آھي تہ ھو (الف) ڳالھائي، لکي يا ڪنھن ٻي طريقي سان ڪنھن خليفي يا آنحضور ﷺ جي ڪنھن صحابي کان علاوہڍه ڪنھن کي امير"
+            + " المومنين يا"
+            + " خليفہ المومنين يا خليفہ المسلمين يا صحابی يا رضي الله عنه چئي۔ (ب) آنحضور ﷺ جي گھروارين کان علاوه ڪنھن کي ام المومنين "
+            + "چئي۔ (ج) آنحضور ﷺ جي خاندان جي اھل بيت کان علاوہڍه ڪنھن کي اھل بيت چئي۔ (د) پنھنجي عبادت گاھ کي مسجد چئي۔" };
+
+    public void testSingleSplit() {
+        var chunker = new WordBoundaryChunker();
+        var chunks = chunker.chunk(TEST_TEXT, 10_000, 0);
+        assertThat(chunks, hasSize(1));
+        assertEquals(TEST_TEXT, chunks.get(0));
+    }
+
+    public void testChunkSizeOnWhiteSpaceNoOverlap() {
+        var numWhiteSpaceSeparatedWords = TEST_TEXT.split("\\s+").length;
+        var chunker = new WordBoundaryChunker();
+
+        for (var chunkSize : new int[] { 10, 20, 100, 300 }) {
+            var chunks = chunker.chunk(TEST_TEXT, chunkSize, 0);
+            int expectedNumChunks = (numWhiteSpaceSeparatedWords + chunkSize - 1) / chunkSize;
+            assertThat("chunk size= " + chunkSize, chunks, hasSize(expectedNumChunks));
+        }
+    }
+
+    public void testMultilingual() {
+        var chunker = new WordBoundaryChunker();
+        for (var input : MULTI_LINGUAL) {
+            var chunks = chunker.chunk(input, 10, 0);
+            assertTrue(chunks.size() > 1);
+        }
+    }
+
+    public void testNumberOfChunks() {
+        for (int numWords : new int[] { 10, 22, 50, 73, 100 }) {
+            var sb = new StringBuilder();
+            for (int i = 0; i < numWords; i++) {
+                sb.append(i).append(' ');
+            }
+            var whiteSpacedText = sb.toString();
+            assertExpectedNumberOfChunks(whiteSpacedText, numWords, 10, 4);
+            assertExpectedNumberOfChunks(whiteSpacedText, numWords, 10, 2);
+            assertExpectedNumberOfChunks(whiteSpacedText, numWords, 20, 4);
+            assertExpectedNumberOfChunks(whiteSpacedText, numWords, 20, 10);
+        }
+    }
+
+    public void testWindowSpanningWithOverlapNumWordsInOverlapSection() {
+        int chunkSize = 10;
+        int windowSize = 3;
+        for (int numWords : new int[] { 7, 8, 9, 10 }) {
+            var sb = new StringBuilder();
+            for (int i = 0; i < numWords; i++) {
+                sb.append(i).append(' ');
+            }
+            var chunks = new WordBoundaryChunker().chunk(sb.toString(), chunkSize, windowSize);
+            assertEquals("numWords= " + numWords, 1, chunks.size());
+        }
+
+        var sb = new StringBuilder();
+        for (int i = 0; i < 11; i++) {
+            sb.append(i).append(' ');
+        }
+        var chunks = new WordBoundaryChunker().chunk(sb.toString(), chunkSize, windowSize);
+        assertEquals(2, chunks.size());
+    }
+
+    public void testWindowSpanningWords() {
+        int numWords = randomIntBetween(4, 120);
+        var input = new StringBuilder();
+        for (int i = 0; i < numWords; i++) {
+            input.append(i).append(' ');
+        }
+        var whiteSpacedText = input.toString().stripTrailing();
+
+        var chunks = new WordBoundaryChunker().chunk(whiteSpacedText, 20, 10);
+        assertChunkContents(chunks, numWords, 20, 10);
+        chunks = new WordBoundaryChunker().chunk(whiteSpacedText, 10, 4);
+        assertChunkContents(chunks, numWords, 10, 4);
+        chunks = new WordBoundaryChunker().chunk(whiteSpacedText, 15, 3);
+        assertChunkContents(chunks, numWords, 15, 3);
+    }
+
+    private void assertChunkContents(List<String> chunks, int numWords, int windowSize, int overlap) {
+        int start = 0;
+        int chunkIndex = 0;
+        int newWordsPerWindow = windowSize - overlap;
+        boolean reachedEnd = false;
+        while (reachedEnd == false) {
+            var sb = new StringBuilder();
+            // the trailing whitespace from the previous chunk is
+            // included in this chunk
+            if (chunkIndex > 0) {
+                sb.append(" ");
+            }
+            int end = Math.min(start + windowSize, numWords);
+            for (int i = start; i < end; i++) {
+                sb.append(i).append(' ');
+            }
+            // delete the trailing whitespace
+            sb.deleteCharAt(sb.length() - 1);
+
+            assertEquals("numWords= " + numWords, sb.toString(), chunks.get(chunkIndex));
+
+            reachedEnd = end == numWords;
+            start += newWordsPerWindow;
+            chunkIndex++;
+        }
+
+        assertEquals("numWords= " + numWords, chunks.size(), chunkIndex);
+    }
+
+    public void testWindowSpanning_TextShorterThanWindow() {
+        var sb = new StringBuilder();
+        for (int i = 0; i < 8; i++) {
+            sb.append(i).append(' ');
+        }
+
+        // window size is > num words
+        var chunks = new WordBoundaryChunker().chunk(sb.toString(), 10, 5);
+        assertThat(chunks, hasSize(1));
+    }
+
+    public void testEmptyString() {
+        var chunks = new WordBoundaryChunker().chunk("", 10, 5);
+        assertThat(chunks, contains(""));
+    }
+
+    public void testWhitespace() {
+        var chunks = new WordBoundaryChunker().chunk(" ", 10, 5);
+        assertThat(chunks, contains(" "));
+    }
+
+    public void testPunctuation() {
+        int chunkSize = 1;
+        var chunks = new WordBoundaryChunker().chunk("Comma, separated", chunkSize, 0);
+        assertThat(chunks, contains("Comma", ", separated"));
+
+        chunks = new WordBoundaryChunker().chunk("Mme. Thénardier", chunkSize, 0);
+        assertThat(chunks, contains("Mme", ". Thénardier"));
+
+        chunks = new WordBoundaryChunker().chunk("Won't you chunk", chunkSize, 0);
+        assertThat(chunks, contains("Won't", " you", " chunk"));
+
+        chunkSize = 10;
+        chunks = new WordBoundaryChunker().chunk("Won't you chunk", chunkSize, 0);
+        assertThat(chunks, contains("Won't you chunk"));
+    }
+
+    private void assertExpectedNumberOfChunks(String input, int numWords, int windowSize, int overlap) {
+        var chunks = new WordBoundaryChunker().chunk(input, windowSize, overlap);
+        int expected = expectedNumberOfChunks(numWords, windowSize, overlap);
+        assertEquals(expected, chunks.size());
+    }
+
+    private int expectedNumberOfChunks(int numWords, int windowSize, int overlap) {
+        if (numWords < windowSize) {
+            return 1;
+        }
+
+        // the first chunk has windowSize words, because of overlap
+        // the subsequent will consume fewer new words
+        int wordsRemainingAfterFirstChunk = numWords - windowSize;
+        int newWordsPerWindow = windowSize - overlap;
+        int numberOfFollowingChunks = (wordsRemainingAfterFirstChunk + newWordsPerWindow - 1) / newWordsPerWindow;
+        // the +1 accounts for the first chunk
+        return 1 + numberOfFollowingChunks;
+    }
+
+    public void testInvalidParams() {
+        var chunker = new WordBoundaryChunker();
+        var e = expectThrows(IllegalArgumentException.class, () -> chunker.chunk("not evaluated", 4, 10));
+        assertThat(e.getMessage(), containsString("Invalid chunking parameters, overlap [10] must be < chunk size / 2 [4 / 2 = 2]"));
+
+        e = expectThrows(IllegalArgumentException.class, () -> chunker.chunk("not evaluated", 10, 6));
+        assertThat(e.getMessage(), containsString("Invalid chunking parameters, overlap [6] must be < chunk size / 2 [10 / 2 = 5]"));
+    }
+}

+ 56 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/results/ChunkedTextEmbeddingFloatResultsTests.java

@@ -0,0 +1,56 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0; you may not use this file except in compliance with the Elastic License
+ * 2.0.
+ */
+
+package org.elasticsearch.xpack.inference.results;
+
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.test.AbstractWireSerializingTestCase;
+import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingFloatResults;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+public class ChunkedTextEmbeddingFloatResultsTests extends AbstractWireSerializingTestCase<ChunkedTextEmbeddingFloatResults> {
+
+    public static ChunkedTextEmbeddingFloatResults createRandomResults() {
+        int numChunks = randomIntBetween(1, 5);
+        var chunks = new ArrayList<ChunkedTextEmbeddingFloatResults.EmbeddingChunk>(numChunks);
+
+        for (int i = 0; i < numChunks; i++) {
+            chunks.add(createRandomChunk());
+        }
+
+        return new ChunkedTextEmbeddingFloatResults(chunks);
+    }
+
+    private static ChunkedTextEmbeddingFloatResults.EmbeddingChunk createRandomChunk() {
+        int columns = randomIntBetween(1, 10);
+        List<Float> floats = new ArrayList<>(columns);
+
+        for (int i = 0; i < columns; i++) {
+            floats.add(randomFloat());
+        }
+
+        return new ChunkedTextEmbeddingFloatResults.EmbeddingChunk(randomAlphaOfLength(6), floats);
+    }
+
+    @Override
+    protected Writeable.Reader<ChunkedTextEmbeddingFloatResults> instanceReader() {
+        return ChunkedTextEmbeddingFloatResults::new;
+    }
+
+    @Override
+    protected ChunkedTextEmbeddingFloatResults createTestInstance() {
+        return createRandomResults();
+    }
+
+    @Override
+    protected ChunkedTextEmbeddingFloatResults mutateInstance(ChunkedTextEmbeddingFloatResults instance) throws IOException {
+        return null;
+    }
+}

+ 27 - 105
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java

@@ -32,9 +32,7 @@ import org.elasticsearch.test.http.MockWebServer;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xpack.core.inference.action.InferenceAction;
-import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingByteResults;
-import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults;
-import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults;
+import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingFloatResults;
 import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
@@ -63,7 +61,6 @@ import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
 import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
 import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
 import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
-import static org.elasticsearch.xpack.inference.results.ChunkedTextEmbeddingResultsTests.asMapWithListsInsteadOfArrays;
 import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation;
 import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
 import static org.elasticsearch.xpack.inference.services.Utils.getInvalidModel;
@@ -1162,11 +1159,12 @@ public class CohereServiceTests extends ESTestCase {
         }
     }
 
-    public void testChunkedInfer_CallsInfer_ConvertsFloatResponse() throws IOException {
+    public void testChunkedInfer_BatchesCalls() throws IOException {
         var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
 
         try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) {
 
+            // Batching will call the service with 2 inputs
             String responseJson = """
                 {
                     "id": "de37399c-5df6-47cb-bc57-e3c5680c977b",
@@ -1178,6 +1176,10 @@ public class CohereServiceTests extends ESTestCase {
                             [
                                 0.123,
                                 -0.123
+                            ],
+                            [
+                                0.223,
+                                -0.223
                             ]
                         ]
                     },
@@ -1204,9 +1206,10 @@ public class CohereServiceTests extends ESTestCase {
                 null
             );
             PlainActionFuture<List<ChunkedInferenceServiceResults>> listener = new PlainActionFuture<>();
+            // 2 inputs
             service.chunkedInfer(
                 model,
-                List.of("abc"),
+                List.of("foo", "bar"),
                 new HashMap<>(),
                 InputType.UNSPECIFIED,
                 new ChunkingOptions(null, null),
@@ -1214,25 +1217,23 @@ public class CohereServiceTests extends ESTestCase {
                 listener
             );
 
-            var result = listener.actionGet(TIMEOUT).get(0);
-            assertThat(result, CoreMatchers.instanceOf(ChunkedTextEmbeddingResults.class));
+            var results = listener.actionGet(TIMEOUT);
+            assertThat(results, hasSize(2));
+            {
+                assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedTextEmbeddingFloatResults.class));
+                var floatResult = (ChunkedTextEmbeddingFloatResults) results.get(0);
+                assertThat(floatResult.chunks(), hasSize(1));
+                assertEquals("foo", floatResult.chunks().get(0).matchedText());
+                assertEquals(List.of(0.123f, -0.123f), floatResult.chunks().get(0).embedding());
+            }
+            {
+                assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedTextEmbeddingFloatResults.class));
+                var floatResult = (ChunkedTextEmbeddingFloatResults) results.get(1);
+                assertThat(floatResult.chunks(), hasSize(1));
+                assertEquals("bar", floatResult.chunks().get(0).matchedText());
+                assertEquals(List.of(0.223f, -0.223f), floatResult.chunks().get(0).embedding());
+            }
 
-            MatcherAssert.assertThat(
-                asMapWithListsInsteadOfArrays((ChunkedTextEmbeddingResults) result),
-                Matchers.is(
-                    Map.of(
-                        ChunkedTextEmbeddingResults.FIELD_NAME,
-                        List.of(
-                            Map.of(
-                                ChunkedNlpInferenceResults.TEXT,
-                                "abc",
-                                ChunkedNlpInferenceResults.INFERENCE,
-                                List.of((double) 0.123f, (double) -0.123f)
-                            )
-                        )
-                    )
-                )
-            );
             MatcherAssert.assertThat(webServer.requests(), hasSize(1));
             assertNull(webServer.requests().get(0).getUri().getQuery());
             MatcherAssert.assertThat(
@@ -1244,92 +1245,13 @@ public class CohereServiceTests extends ESTestCase {
             var requestMap = entityAsMap(webServer.requests().get(0).getBody());
             MatcherAssert.assertThat(
                 requestMap,
-                is(Map.of("texts", List.of("abc"), "model", "model", "embedding_types", List.of("float")))
+                is(Map.of("texts", List.of("foo", "bar"), "model", "model", "embedding_types", List.of("float")))
             );
         }
     }
 
     public void testChunkedInfer_CallsInfer_ConvertsByteResponse() throws IOException {
-        var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
-
-        try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) {
-
-            String responseJson = """
-                {
-                    "id": "de37399c-5df6-47cb-bc57-e3c5680c977b",
-                    "texts": [
-                        "hello"
-                    ],
-                    "embeddings": {
-                        "int8": [
-                            [
-                                12,
-                                -12
-                            ]
-                        ]
-                    },
-                    "meta": {
-                        "api_version": {
-                            "version": "1"
-                        },
-                        "billed_units": {
-                            "input_tokens": 1
-                        }
-                    },
-                    "response_type": "embeddings_by_type"
-                }
-                """;
-            webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
-
-            var model = CohereEmbeddingsModelTests.createModel(
-                getUrl(webServer),
-                "secret",
-                new CohereEmbeddingsTaskSettings(null, null),
-                1024,
-                1024,
-                "model",
-                CohereEmbeddingType.INT8
-            );
-            PlainActionFuture<List<ChunkedInferenceServiceResults>> listener = new PlainActionFuture<>();
-            service.chunkedInfer(
-                model,
-                List.of("abc"),
-                new HashMap<>(),
-                InputType.UNSPECIFIED,
-                new ChunkingOptions(null, null),
-                InferenceAction.Request.DEFAULT_TIMEOUT,
-                listener
-            );
-
-            var result = listener.actionGet(TIMEOUT).get(0);
-
-            MatcherAssert.assertThat(
-                result.asMap(),
-                Matchers.is(
-                    Map.of(
-                        ChunkedTextEmbeddingByteResults.FIELD_NAME,
-                        List.of(
-                            Map.of(
-                                ChunkedNlpInferenceResults.TEXT,
-                                "abc",
-                                ChunkedNlpInferenceResults.INFERENCE,
-                                List.of((byte) 12, (byte) -12)
-                            )
-                        )
-                    )
-                )
-            );
-            MatcherAssert.assertThat(webServer.requests(), hasSize(1));
-            assertNull(webServer.requests().get(0).getUri().getQuery());
-            MatcherAssert.assertThat(
-                webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE),
-                equalTo(XContentType.JSON.mediaType())
-            );
-            MatcherAssert.assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret"));
-
-            var requestMap = entityAsMap(webServer.requests().get(0).getBody());
-            MatcherAssert.assertThat(requestMap, is(Map.of("texts", List.of("abc"), "model", "model", "embedding_types", List.of("int8"))));
-        }
+        // TODO byte response not implemented yet
     }
 
     private Map<String, Object> getRequestConfigMap(

+ 29 - 24
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java

@@ -31,8 +31,7 @@ import org.elasticsearch.test.http.MockWebServer;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.xcontent.XContentType;
 import org.elasticsearch.xpack.core.inference.action.InferenceAction;
-import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults;
-import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults;
+import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingFloatResults;
 import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
@@ -60,7 +59,6 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
 import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
 import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
 import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER;
-import static org.elasticsearch.xpack.inference.results.ChunkedTextEmbeddingResultsTests.asMapWithListsInsteadOfArrays;
 import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation;
 import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
 import static org.elasticsearch.xpack.inference.services.Utils.getInvalidModel;
@@ -1213,11 +1211,12 @@ public class OpenAiServiceTests extends ESTestCase {
         assertEquals("model", serviceSettings.get(ServiceFields.MODEL_ID));
     }
 
-    public void testChunkedInfer_CallsInfer_ConvertsFloatResponse() throws IOException {
+    public void testChunkedInfer_Batches() throws IOException {
         var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
 
         try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) {
 
+            // response with 2 embeddings
             String responseJson = """
                 {
                   "object": "list",
@@ -1229,6 +1228,14 @@ public class OpenAiServiceTests extends ESTestCase {
                               0.123,
                               -0.123
                           ]
+                      },
+                      {
+                          "object": "embedding",
+                          "index": 1,
+                          "embedding": [
+                              0.223,
+                              -0.223
+                          ]
                       }
                   ],
                   "model": "text-embedding-ada-002-v2",
@@ -1244,7 +1251,7 @@ public class OpenAiServiceTests extends ESTestCase {
             PlainActionFuture<List<ChunkedInferenceServiceResults>> listener = new PlainActionFuture<>();
             service.chunkedInfer(
                 model,
-                List.of("abc"),
+                List.of("foo", "bar"),
                 new HashMap<>(),
                 InputType.INGEST,
                 new ChunkingOptions(null, null),
@@ -1252,25 +1259,23 @@ public class OpenAiServiceTests extends ESTestCase {
                 listener
             );
 
-            var result = listener.actionGet(TIMEOUT).get(0);
-            assertThat(result, CoreMatchers.instanceOf(ChunkedTextEmbeddingResults.class));
+            var results = listener.actionGet(TIMEOUT);
+            assertThat(results, hasSize(2));
+            {
+                assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedTextEmbeddingFloatResults.class));
+                var floatResult = (ChunkedTextEmbeddingFloatResults) results.get(0);
+                assertThat(floatResult.chunks(), hasSize(1));
+                assertEquals("foo", floatResult.chunks().get(0).matchedText());
+                assertEquals(List.of(0.123f, -0.123f), floatResult.chunks().get(0).embedding());
+            }
+            {
+                assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedTextEmbeddingFloatResults.class));
+                var floatResult = (ChunkedTextEmbeddingFloatResults) results.get(1);
+                assertThat(floatResult.chunks(), hasSize(1));
+                assertEquals("bar", floatResult.chunks().get(0).matchedText());
+                assertEquals(List.of(0.223f, -0.223f), floatResult.chunks().get(0).embedding());
+            }
 
-            assertThat(
-                asMapWithListsInsteadOfArrays((ChunkedTextEmbeddingResults) result),
-                Matchers.is(
-                    Map.of(
-                        ChunkedTextEmbeddingResults.FIELD_NAME,
-                        List.of(
-                            Map.of(
-                                ChunkedNlpInferenceResults.TEXT,
-                                "abc",
-                                ChunkedNlpInferenceResults.INFERENCE,
-                                List.of((double) 0.123f, (double) -0.123f)
-                            )
-                        )
-                    )
-                )
-            );
             assertThat(webServer.requests(), hasSize(1));
             assertNull(webServer.requests().get(0).getUri().getQuery());
             assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
@@ -1279,7 +1284,7 @@ public class OpenAiServiceTests extends ESTestCase {
 
             var requestMap = entityAsMap(webServer.requests().get(0).getBody());
             assertThat(requestMap.size(), Matchers.is(3));
-            assertThat(requestMap.get("input"), Matchers.is(List.of("abc")));
+            assertThat(requestMap.get("input"), Matchers.is(List.of("foo", "bar")));
             assertThat(requestMap.get("model"), Matchers.is("model"));
             assertThat(requestMap.get("user"), Matchers.is("user"));
         }