|
@@ -18,6 +18,7 @@ import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingByte
|
|
|
import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults;
|
|
|
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
|
|
|
import org.elasticsearch.xpack.core.ml.search.WeightedToken;
|
|
|
+import org.hamcrest.Matchers;
|
|
|
|
|
|
import java.util.ArrayList;
|
|
|
import java.util.List;
|
|
@@ -31,16 +32,62 @@ import static org.hamcrest.Matchers.startsWith;
|
|
|
|
|
|
public class EmbeddingRequestChunkerTests extends ESTestCase {
|
|
|
|
|
|
- public void testEmptyInput() {
|
|
|
+ public void testEmptyInput_WordChunker() {
|
|
|
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
|
|
|
var batches = new EmbeddingRequestChunker(List.of(), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener());
|
|
|
assertThat(batches, empty());
|
|
|
}
|
|
|
|
|
|
- public void testBlankInput() {
|
|
|
+ public void testEmptyInput_SentenceChunker() {
|
|
|
+ var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
|
|
|
+ var batches = new EmbeddingRequestChunker(List.of(), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1))
|
|
|
+ .batchRequestsWithListeners(testListener());
|
|
|
+ assertThat(batches, empty());
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testWhitespaceInput_SentenceChunker() {
|
|
|
+ var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
|
|
|
+ var batches = new EmbeddingRequestChunker(List.of(" "), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1))
|
|
|
+ .batchRequestsWithListeners(testListener());
|
|
|
+ assertThat(batches, hasSize(1));
|
|
|
+ assertThat(batches.get(0).batch().inputs(), hasSize(1));
|
|
|
+ assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(" "));
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testBlankInput_WordChunker() {
|
|
|
var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
|
|
|
var batches = new EmbeddingRequestChunker(List.of(""), 100, 100, 10, embeddingType).batchRequestsWithListeners(testListener());
|
|
|
assertThat(batches, hasSize(1));
|
|
|
+ assertThat(batches.get(0).batch().inputs(), hasSize(1));
|
|
|
+ assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(""));
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testBlankInput_SentenceChunker() {
|
|
|
+ var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
|
|
|
+ var batches = new EmbeddingRequestChunker(List.of(""), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1))
|
|
|
+ .batchRequestsWithListeners(testListener());
|
|
|
+ assertThat(batches, hasSize(1));
|
|
|
+ assertThat(batches.get(0).batch().inputs(), hasSize(1));
|
|
|
+ assertThat(batches.get(0).batch().inputs().get(0), Matchers.is(""));
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testInputThatDoesNotChunk_WordChunker() {
|
|
|
+ var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
|
|
|
+ var batches = new EmbeddingRequestChunker(List.of("ABBAABBA"), 100, 100, 10, embeddingType).batchRequestsWithListeners(
|
|
|
+ testListener()
|
|
|
+ );
|
|
|
+ assertThat(batches, hasSize(1));
|
|
|
+ assertThat(batches.get(0).batch().inputs(), hasSize(1));
|
|
|
+ assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA"));
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testInputThatDoesNotChunk_SentenceChunker() {
|
|
|
+ var embeddingType = randomFrom(EmbeddingRequestChunker.EmbeddingType.values());
|
|
|
+ var batches = new EmbeddingRequestChunker(List.of("ABBAABBA"), 10, embeddingType, new SentenceBoundaryChunkingSettings(250, 1))
|
|
|
+ .batchRequestsWithListeners(testListener());
|
|
|
+ assertThat(batches, hasSize(1));
|
|
|
+ assertThat(batches.get(0).batch().inputs(), hasSize(1));
|
|
|
+ assertThat(batches.get(0).batch().inputs().get(0), Matchers.is("ABBAABBA"));
|
|
|
}
|
|
|
|
|
|
public void testShortInputsAreSingleBatch() {
|