|
@@ -33,7 +33,6 @@ import org.elasticsearch.xcontent.XContentType;
|
|
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
|
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
|
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
|
|
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
|
|
import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults;
|
|
import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults;
|
|
-import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults;
|
|
|
|
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
|
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
|
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
|
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
|
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
|
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
|
@@ -55,14 +54,12 @@ import org.junit.After;
|
|
import org.junit.Before;
|
|
import org.junit.Before;
|
|
|
|
|
|
import java.io.IOException;
|
|
import java.io.IOException;
|
|
-import java.net.URISyntaxException;
|
|
|
|
import java.util.HashMap;
|
|
import java.util.HashMap;
|
|
import java.util.List;
|
|
import java.util.List;
|
|
import java.util.Map;
|
|
import java.util.Map;
|
|
import java.util.Set;
|
|
import java.util.Set;
|
|
import java.util.concurrent.TimeUnit;
|
|
import java.util.concurrent.TimeUnit;
|
|
|
|
|
|
-import static org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResultsTests.asMapWithListsInsteadOfArrays;
|
|
|
|
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
|
|
import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
|
|
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
|
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
|
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
|
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
|
@@ -849,7 +846,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|
verifyNoMoreInteractions(sender);
|
|
verifyNoMoreInteractions(sender);
|
|
}
|
|
}
|
|
|
|
|
|
- public void testChunkedInfer_Embeddings_CallsInfer_ConvertsFloatResponse() throws IOException, URISyntaxException {
|
|
|
|
|
|
+ public void testChunkedInfer() throws IOException {
|
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
|
|
|
|
|
try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
|
|
try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
|
|
@@ -865,6 +862,14 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|
0.0123,
|
|
0.0123,
|
|
-0.0123
|
|
-0.0123
|
|
]
|
|
]
|
|
|
|
+ },
|
|
|
|
+ {
|
|
|
|
+ "object": "embedding",
|
|
|
|
+ "index": 1,
|
|
|
|
+ "embedding": [
|
|
|
|
+ 1.0123,
|
|
|
|
+ -1.0123
|
|
|
|
+ ]
|
|
}
|
|
}
|
|
],
|
|
],
|
|
"model": "text-embedding-ada-002-v2",
|
|
"model": "text-embedding-ada-002-v2",
|
|
@@ -892,7 +897,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|
PlainActionFuture<List<ChunkedInferenceServiceResults>> listener = new PlainActionFuture<>();
|
|
PlainActionFuture<List<ChunkedInferenceServiceResults>> listener = new PlainActionFuture<>();
|
|
service.chunkedInfer(
|
|
service.chunkedInfer(
|
|
model,
|
|
model,
|
|
- List.of("abc"),
|
|
|
|
|
|
+ List.of("foo", "bar"),
|
|
new HashMap<>(),
|
|
new HashMap<>(),
|
|
InputType.INGEST,
|
|
InputType.INGEST,
|
|
new ChunkingOptions(null, null),
|
|
new ChunkingOptions(null, null),
|
|
@@ -900,20 +905,23 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|
listener
|
|
listener
|
|
);
|
|
);
|
|
|
|
|
|
- var result = listener.actionGet(TIMEOUT).get(0);
|
|
|
|
- assertThat(result, CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class));
|
|
|
|
|
|
+ var results = listener.actionGet(TIMEOUT);
|
|
|
|
+ assertThat(results, hasSize(2));
|
|
|
|
+ {
|
|
|
|
+ assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class));
|
|
|
|
+ var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0);
|
|
|
|
+ assertThat(floatResult.chunks(), hasSize(1));
|
|
|
|
+ assertEquals("foo", floatResult.chunks().get(0).matchedText());
|
|
|
|
+ assertArrayEquals(new float[] { 0.0123f, -0.0123f }, floatResult.chunks().get(0).embedding(), 0.0f);
|
|
|
|
+ }
|
|
|
|
+ {
|
|
|
|
+ assertThat(results.get(1), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class));
|
|
|
|
+ var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1);
|
|
|
|
+ assertThat(floatResult.chunks(), hasSize(1));
|
|
|
|
+ assertEquals("bar", floatResult.chunks().get(0).matchedText());
|
|
|
|
+ assertArrayEquals(new float[] { 1.0123f, -1.0123f }, floatResult.chunks().get(0).embedding(), 0.0f);
|
|
|
|
+ }
|
|
|
|
|
|
- assertThat(
|
|
|
|
- asMapWithListsInsteadOfArrays((InferenceChunkedTextEmbeddingFloatResults) result),
|
|
|
|
- Matchers.is(
|
|
|
|
- Map.of(
|
|
|
|
- InferenceChunkedTextEmbeddingFloatResults.FIELD_NAME,
|
|
|
|
- List.of(
|
|
|
|
- Map.of(ChunkedNlpInferenceResults.TEXT, "abc", ChunkedNlpInferenceResults.INFERENCE, List.of(0.0123f, -0.0123f))
|
|
|
|
- )
|
|
|
|
- )
|
|
|
|
- )
|
|
|
|
- );
|
|
|
|
assertThat(webServer.requests(), hasSize(1));
|
|
assertThat(webServer.requests(), hasSize(1));
|
|
assertNull(webServer.requests().get(0).getUri().getQuery());
|
|
assertNull(webServer.requests().get(0).getUri().getQuery());
|
|
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
|
|
assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
|
|
@@ -921,7 +929,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|
|
|
|
|
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
|
|
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
|
|
assertThat(requestMap.size(), Matchers.is(2));
|
|
assertThat(requestMap.size(), Matchers.is(2));
|
|
- assertThat(requestMap.get("input"), Matchers.is(List.of("abc")));
|
|
|
|
|
|
+ assertThat(requestMap.get("input"), Matchers.is(List.of("foo", "bar")));
|
|
assertThat(requestMap.get("user"), Matchers.is("user"));
|
|
assertThat(requestMap.get("user"), Matchers.is("user"));
|
|
}
|
|
}
|
|
}
|
|
}
|