|
|
@@ -31,6 +31,7 @@ import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
|
|
|
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
|
|
|
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
|
|
|
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
|
|
|
+import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
|
|
|
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
|
|
|
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
|
|
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
|
|
@@ -53,7 +54,9 @@ import java.util.List;
|
|
|
import java.util.Map;
|
|
|
|
|
|
import static org.elasticsearch.xpack.inference.Utils.TIMEOUT;
|
|
|
+import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
|
|
|
import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap;
|
|
|
+import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap;
|
|
|
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
|
|
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
|
|
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
|
|
|
@@ -312,6 +315,93 @@ public class CustomServiceTests extends AbstractInferenceServiceTests {
|
|
|
: CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS;
|
|
|
}
|
|
|
|
|
|
+ public void testParseRequestConfig_CreatesAnEmbeddingsModel_WhenChunkingSettingsProvided() throws Exception {
|
|
|
+ var chunkingSettingsMap = createRandomChunkingSettingsMap();
|
|
|
+
|
|
|
+ try (var service = createService(threadPool, clientManager)) {
|
|
|
+ var config = getRequestConfigMap(
|
|
|
+ createServiceSettingsMap(TaskType.TEXT_EMBEDDING),
|
|
|
+ createTaskSettingsMap(),
|
|
|
+ chunkingSettingsMap,
|
|
|
+ createSecretSettingsMap()
|
|
|
+ );
|
|
|
+
|
|
|
+ var listener = new PlainActionFuture<Model>();
|
|
|
+ service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, listener);
|
|
|
+ var model = listener.actionGet(TIMEOUT);
|
|
|
+
|
|
|
+ assertModel(model, TaskType.TEXT_EMBEDDING);
|
|
|
+
|
|
|
+ var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(chunkingSettingsMap);
|
|
|
+ assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testParseRequestConfig_CreatesAnEmbeddingsModel_WhenChunkingSettingsNotProvided() throws Exception {
|
|
|
+ try (var service = createService(threadPool, clientManager)) {
|
|
|
+ var config = getRequestConfigMap(
|
|
|
+ createServiceSettingsMap(TaskType.TEXT_EMBEDDING),
|
|
|
+ createTaskSettingsMap(),
|
|
|
+ createSecretSettingsMap()
|
|
|
+ );
|
|
|
+
|
|
|
+ var listener = new PlainActionFuture<Model>();
|
|
|
+ service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, listener);
|
|
|
+ var model = listener.actionGet(TIMEOUT);
|
|
|
+
|
|
|
+ assertModel(model, TaskType.TEXT_EMBEDDING);
|
|
|
+
|
|
|
+ var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(Map.of());
|
|
|
+ assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModel_WhenChunkingSettingsProvided() throws Exception {
|
|
|
+ var chunkingSettingsMap = createRandomChunkingSettingsMap();
|
|
|
+
|
|
|
+ try (var service = createService(threadPool, clientManager)) {
|
|
|
+ var persistedConfigMap = getPersistedConfigMap(
|
|
|
+ createServiceSettingsMap(TaskType.TEXT_EMBEDDING),
|
|
|
+ createTaskSettingsMap(),
|
|
|
+ chunkingSettingsMap,
|
|
|
+ createSecretSettingsMap()
|
|
|
+ );
|
|
|
+
|
|
|
+ var model = service.parsePersistedConfigWithSecrets(
|
|
|
+ "id",
|
|
|
+ TaskType.TEXT_EMBEDDING,
|
|
|
+ persistedConfigMap.config(),
|
|
|
+ persistedConfigMap.secrets()
|
|
|
+ );
|
|
|
+
|
|
|
+ assertModel(model, TaskType.TEXT_EMBEDDING);
|
|
|
+
|
|
|
+ var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(chunkingSettingsMap);
|
|
|
+ assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModel_WhenChunkingSettingsNotProvided() throws Exception {
|
|
|
+ try (var service = createService(threadPool, clientManager)) {
|
|
|
+ var persistedConfigMap = getPersistedConfigMap(
|
|
|
+ createServiceSettingsMap(TaskType.TEXT_EMBEDDING),
|
|
|
+ createTaskSettingsMap(),
|
|
|
+ createSecretSettingsMap()
|
|
|
+ );
|
|
|
+
|
|
|
+ var model = service.parsePersistedConfigWithSecrets(
|
|
|
+ "id",
|
|
|
+ TaskType.TEXT_EMBEDDING,
|
|
|
+ persistedConfigMap.config(),
|
|
|
+ persistedConfigMap.secrets()
|
|
|
+ );
|
|
|
+ assertModel(model, TaskType.TEXT_EMBEDDING);
|
|
|
+
|
|
|
+ var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(null);
|
|
|
+ assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
public void testInfer_ReturnsAnError_WithoutParsingTheResponseBody() throws IOException {
|
|
|
try (var service = createService(threadPool, clientManager)) {
|
|
|
String responseJson = "error";
|