Browse Source

Adding chunking settings parser fix and tests (#135726)

Jonathan Buttner 3 weeks ago
parent
commit
7f9ba0f641

+ 16 - 4
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java

@@ -96,7 +96,12 @@ public class CustomService extends SenderService {
             Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
             Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
 
-            var chunkingSettings = extractChunkingSettings(config, taskType);
+            ChunkingSettings chunkingSettings = null;
+            if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
+                chunkingSettings = ChunkingSettingsBuilder.fromMap(
+                    removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)
+                );
+            }
 
             CustomModel model = createModel(
                 inferenceEntityId,
@@ -147,7 +152,14 @@ public class CustomService extends SenderService {
         };
     }
 
-    private static ChunkingSettings extractChunkingSettings(Map<String, Object> config, TaskType taskType) {
+    private static ChunkingSettings extractPersistentChunkingSettings(Map<String, Object> config, TaskType taskType) {
+        /*
+         * There's a sutle difference between how the chunking settings are parsed for the request context vs the persistent context.
+         * For persistent context, to support backwards compatibility, if the chunking settings are not present, removeFromMap will
+         * return null which results in the older word boundary chunking settings being used as the default.
+         * For request context, removeFromMapOrDefaultEmpty returns an empty map which results in the newer sentence boundary chunking
+         * settings being used as the default.
+         */
         if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
             return ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
         }
@@ -220,7 +232,7 @@ public class CustomService extends SenderService {
         Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
         Map<String, Object> secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);
 
-        var chunkingSettings = extractChunkingSettings(config, taskType);
+        var chunkingSettings = extractPersistentChunkingSettings(config, taskType);
 
         return createModelWithoutLoggingDeprecations(
             inferenceEntityId,
@@ -237,7 +249,7 @@ public class CustomService extends SenderService {
         Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
         Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
 
-        var chunkingSettings = extractChunkingSettings(config, taskType);
+        var chunkingSettings = extractPersistentChunkingSettings(config, taskType);
 
         return createModelWithoutLoggingDeprecations(
             inferenceEntityId,

+ 90 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java

@@ -31,6 +31,7 @@ import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
 import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
 import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
 import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
+import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
 import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
 import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
 import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
@@ -53,7 +54,9 @@ import java.util.List;
 import java.util.Map;
 
 import static org.elasticsearch.xpack.inference.Utils.TIMEOUT;
+import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
 import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap;
+import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap;
 import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
 import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
 import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
@@ -312,6 +315,93 @@ public class CustomServiceTests extends AbstractInferenceServiceTests {
             : CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS;
     }
 
+    public void testParseRequestConfig_CreatesAnEmbeddingsModel_WhenChunkingSettingsProvided() throws Exception {
+        var chunkingSettingsMap = createRandomChunkingSettingsMap();
+
+        try (var service = createService(threadPool, clientManager)) {
+            var config = getRequestConfigMap(
+                createServiceSettingsMap(TaskType.TEXT_EMBEDDING),
+                createTaskSettingsMap(),
+                chunkingSettingsMap,
+                createSecretSettingsMap()
+            );
+
+            var listener = new PlainActionFuture<Model>();
+            service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, listener);
+            var model = listener.actionGet(TIMEOUT);
+
+            assertModel(model, TaskType.TEXT_EMBEDDING);
+
+            var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(chunkingSettingsMap);
+            assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings));
+        }
+    }
+
+    public void testParseRequestConfig_CreatesAnEmbeddingsModel_WhenChunkingSettingsNotProvided() throws Exception {
+        try (var service = createService(threadPool, clientManager)) {
+            var config = getRequestConfigMap(
+                createServiceSettingsMap(TaskType.TEXT_EMBEDDING),
+                createTaskSettingsMap(),
+                createSecretSettingsMap()
+            );
+
+            var listener = new PlainActionFuture<Model>();
+            service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, listener);
+            var model = listener.actionGet(TIMEOUT);
+
+            assertModel(model, TaskType.TEXT_EMBEDDING);
+
+            var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(Map.of());
+            assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings));
+        }
+    }
+
+    public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModel_WhenChunkingSettingsProvided() throws Exception {
+        var chunkingSettingsMap = createRandomChunkingSettingsMap();
+
+        try (var service = createService(threadPool, clientManager)) {
+            var persistedConfigMap = getPersistedConfigMap(
+                createServiceSettingsMap(TaskType.TEXT_EMBEDDING),
+                createTaskSettingsMap(),
+                chunkingSettingsMap,
+                createSecretSettingsMap()
+            );
+
+            var model = service.parsePersistedConfigWithSecrets(
+                "id",
+                TaskType.TEXT_EMBEDDING,
+                persistedConfigMap.config(),
+                persistedConfigMap.secrets()
+            );
+
+            assertModel(model, TaskType.TEXT_EMBEDDING);
+
+            var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(chunkingSettingsMap);
+            assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings));
+        }
+    }
+
+    public void testParsePersistedConfigWithSecrets_CreatesAnEmbeddingsModel_WhenChunkingSettingsNotProvided() throws Exception {
+        try (var service = createService(threadPool, clientManager)) {
+            var persistedConfigMap = getPersistedConfigMap(
+                createServiceSettingsMap(TaskType.TEXT_EMBEDDING),
+                createTaskSettingsMap(),
+                createSecretSettingsMap()
+            );
+
+            var model = service.parsePersistedConfigWithSecrets(
+                "id",
+                TaskType.TEXT_EMBEDDING,
+                persistedConfigMap.config(),
+                persistedConfigMap.secrets()
+            );
+            assertModel(model, TaskType.TEXT_EMBEDDING);
+
+            var expectedChunkingSettings = ChunkingSettingsBuilder.fromMap(null);
+            assertThat(model.getConfigurations().getChunkingSettings(), is(expectedChunkingSettings));
+        }
+    }
+
     public void testInfer_ReturnsAnError_WithoutParsingTheResponseBody() throws IOException {
         try (var service = createService(threadPool, clientManager)) {
             String responseJson = "error";