Explorar o código

Adding chunking settings to IbmWatsonxService (#114914) (#117278)

* Adding chunking settings to IbmWatsonxService

* Removing feature flag

* Update docs/changelog/114914.yaml

---------

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
Dan Rubinstein hai 10 meses
pai
achega
b92e3c7090

+ 5 - 0
docs/changelog/114914.yaml

@@ -0,0 +1,5 @@
+pr: 114914
+summary: Adding chunking settings to `IbmWatsonxService`
+area: Machine Learning
+type: enhancement
+issues: []

+ 28 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java

@@ -16,6 +16,7 @@ import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.inference.ChunkedInferenceServiceResults;
 import org.elasticsearch.inference.ChunkingOptions;
+import org.elasticsearch.inference.ChunkingSettings;
 import org.elasticsearch.inference.EmptySettingsConfiguration;
 import org.elasticsearch.inference.InferenceServiceConfiguration;
 import org.elasticsearch.inference.InferenceServiceResults;
@@ -30,6 +31,7 @@ import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType;
 import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
 import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
 import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
 import org.elasticsearch.xpack.inference.external.action.ibmwatsonx.IbmWatsonxActionCreator;
 import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
@@ -86,11 +88,19 @@ public class IbmWatsonxService extends SenderService {
             Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
             Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
 
+            ChunkingSettings chunkingSettings = null;
+            if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
+                chunkingSettings = ChunkingSettingsBuilder.fromMap(
+                    removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)
+                );
+            }
+
             IbmWatsonxModel model = createModel(
                 inferenceEntityId,
                 taskType,
                 serviceSettingsMap,
                 taskSettingsMap,
+                chunkingSettings,
                 serviceSettingsMap,
                 TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
                 ConfigurationParseContext.REQUEST
@@ -112,6 +122,7 @@ public class IbmWatsonxService extends SenderService {
         TaskType taskType,
         Map<String, Object> serviceSettings,
         Map<String, Object> taskSettings,
+        ChunkingSettings chunkingSettings,
         @Nullable Map<String, Object> secretSettings,
         String failureMessage,
         ConfigurationParseContext context
@@ -123,6 +134,7 @@ public class IbmWatsonxService extends SenderService {
                 NAME,
                 serviceSettings,
                 taskSettings,
+                chunkingSettings,
                 secretSettings,
                 context
             );
@@ -141,11 +153,17 @@ public class IbmWatsonxService extends SenderService {
         Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
         Map<String, Object> secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS);
 
+        ChunkingSettings chunkingSettings = null;
+        if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
+            chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
+        }
+
         return createModelFromPersistent(
             inferenceEntityId,
             taskType,
             serviceSettingsMap,
             taskSettingsMap,
+            chunkingSettings,
             secretSettingsMap,
             parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
         );
@@ -166,6 +184,7 @@ public class IbmWatsonxService extends SenderService {
         TaskType taskType,
         Map<String, Object> serviceSettings,
         Map<String, Object> taskSettings,
+        ChunkingSettings chunkingSettings,
         Map<String, Object> secretSettings,
         String failureMessage
     ) {
@@ -174,6 +193,7 @@ public class IbmWatsonxService extends SenderService {
             taskType,
             serviceSettings,
             taskSettings,
+            chunkingSettings,
             secretSettings,
             failureMessage,
             ConfigurationParseContext.PERSISTENT
@@ -185,11 +205,17 @@ public class IbmWatsonxService extends SenderService {
         Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
         Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
 
+        ChunkingSettings chunkingSettings = null;
+        if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
+            chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
+        }
+
         return createModelFromPersistent(
             inferenceEntityId,
             taskType,
             serviceSettingsMap,
             taskSettingsMap,
+            chunkingSettings,
             null,
             parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
         );
@@ -266,7 +292,8 @@ public class IbmWatsonxService extends SenderService {
         var batchedRequests = new EmbeddingRequestChunker(
             input.getInputs(),
             EMBEDDING_MAX_BATCH_SIZE,
-            EmbeddingRequestChunker.EmbeddingType.FLOAT
+            EmbeddingRequestChunker.EmbeddingType.FLOAT,
+            model.getConfigurations().getChunkingSettings()
         ).batchRequestsWithListeners(listener);
         for (var request : batchedRequests) {
             var action = ibmWatsonxModel.accept(getActionCreator(getSender(), getServiceComponents()), taskSettings, inputType);

+ 5 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/embeddings/IbmWatsonxEmbeddingsModel.java

@@ -9,6 +9,7 @@ package org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings;
 
 import org.apache.http.client.utils.URIBuilder;
 import org.elasticsearch.core.Nullable;
+import org.elasticsearch.inference.ChunkingSettings;
 import org.elasticsearch.inference.EmptyTaskSettings;
 import org.elasticsearch.inference.InputType;
 import org.elasticsearch.inference.ModelConfigurations;
@@ -40,6 +41,7 @@ public class IbmWatsonxEmbeddingsModel extends IbmWatsonxModel {
         String service,
         Map<String, Object> serviceSettings,
         Map<String, Object> taskSettings,
+        ChunkingSettings chunkingSettings,
         Map<String, Object> secrets,
         ConfigurationParseContext context
     ) {
@@ -49,6 +51,7 @@ public class IbmWatsonxEmbeddingsModel extends IbmWatsonxModel {
             service,
             IbmWatsonxEmbeddingsServiceSettings.fromMap(serviceSettings, context),
             EmptyTaskSettings.INSTANCE,
+            chunkingSettings,
             DefaultSecretSettings.fromMap(secrets)
         );
     }
@@ -64,10 +67,11 @@ public class IbmWatsonxEmbeddingsModel extends IbmWatsonxModel {
         String service,
         IbmWatsonxEmbeddingsServiceSettings serviceSettings,
         TaskSettings taskSettings,
+        ChunkingSettings chunkingsettings,
         @Nullable DefaultSecretSettings secrets
     ) {
         super(
-            new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings),
+            new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingsettings),
             new ModelSecrets(secrets),
             serviceSettings
         );

+ 172 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java

@@ -20,6 +20,7 @@ import org.elasticsearch.common.xcontent.XContentHelper;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.inference.ChunkedInferenceServiceResults;
 import org.elasticsearch.inference.ChunkingOptions;
+import org.elasticsearch.inference.ChunkingSettings;
 import org.elasticsearch.inference.EmptyTaskSettings;
 import org.elasticsearch.inference.InferenceServiceConfiguration;
 import org.elasticsearch.inference.InferenceServiceResults;
@@ -69,6 +70,8 @@ import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
 import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
 import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
 import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
+import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettings;
+import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap;
 import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
 import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
 import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
@@ -124,6 +127,7 @@ public class IbmWatsonxServiceTests extends ESTestCase {
                 assertThat(embeddingsModel.getServiceSettings().url(), is(URI.create(url)));
                 assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(apiVersion));
                 assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(apiKey));
+                assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
             }, e -> fail("Model parsing should have succeeded, but failed: " + e.getMessage()));
 
             service.parseRequestConfig(
@@ -150,6 +154,45 @@ public class IbmWatsonxServiceTests extends ESTestCase {
         }
     }
 
+    public void testParseRequestConfig_CreatesAIbmWatsonxEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
+        try (var service = createIbmWatsonxService()) {
+            ActionListener<Model> modelListener = ActionListener.wrap(model -> {
+                assertThat(model, instanceOf(IbmWatsonxEmbeddingsModel.class));
+
+                var embeddingsModel = (IbmWatsonxEmbeddingsModel) model;
+                assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId));
+                assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId));
+                assertThat(embeddingsModel.getServiceSettings().url(), is(URI.create(url)));
+                assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(apiVersion));
+                assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(apiKey));
+                assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
+            }, e -> fail("Model parsing should have succeeded, but failed: " + e.getMessage()));
+
+            service.parseRequestConfig(
+                "id",
+                TaskType.TEXT_EMBEDDING,
+                getRequestConfigMap(
+                    new HashMap<>(
+                        Map.of(
+                            ServiceFields.MODEL_ID,
+                            modelId,
+                            IbmWatsonxServiceFields.PROJECT_ID,
+                            projectId,
+                            ServiceFields.URL,
+                            url,
+                            IbmWatsonxServiceFields.API_VERSION,
+                            apiVersion
+                        )
+                    ),
+                    new HashMap<>(Map.of()),
+                    createRandomChunkingSettingsMap(),
+                    getSecretSettingsMap(apiKey)
+                ),
+                modelListener
+            );
+        }
+    }
+
     public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException {
         try (var service = createIbmWatsonxService()) {
             var failureListener = getModelListenerForException(
@@ -235,6 +278,47 @@ public class IbmWatsonxServiceTests extends ESTestCase {
             assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(apiVersion));
             assertThat(embeddingsModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE));
             assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(apiKey));
+            assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
+        }
+    }
+
+    public void testParsePersistedConfigWithSecrets_CreatesAIbmWatsonxEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
+        try (var service = createIbmWatsonxService()) {
+            var persistedConfig = getPersistedConfigMap(
+                new HashMap<>(
+                    Map.of(
+                        ServiceFields.MODEL_ID,
+                        modelId,
+                        IbmWatsonxServiceFields.PROJECT_ID,
+                        projectId,
+                        ServiceFields.URL,
+                        url,
+                        IbmWatsonxServiceFields.API_VERSION,
+                        apiVersion
+                    )
+                ),
+                getTaskSettingsMapEmpty(),
+                createRandomChunkingSettingsMap(),
+                getSecretSettingsMap(apiKey)
+            );
+
+            var model = service.parsePersistedConfigWithSecrets(
+                "id",
+                TaskType.TEXT_EMBEDDING,
+                persistedConfig.config(),
+                persistedConfig.secrets()
+            );
+
+            assertThat(model, instanceOf(IbmWatsonxEmbeddingsModel.class));
+
+            var embeddingsModel = (IbmWatsonxEmbeddingsModel) model;
+            assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId));
+            assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId));
+            assertThat(embeddingsModel.getServiceSettings().url(), is(URI.create(url)));
+            assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(apiVersion));
+            assertThat(embeddingsModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE));
+            assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(apiKey));
+            assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
         }
     }
 
@@ -399,6 +483,73 @@ public class IbmWatsonxServiceTests extends ESTestCase {
         }
     }
 
+    public void testParsePersistedConfig_CreatesAIbmWatsonxEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException {
+        try (var service = createIbmWatsonxService()) {
+            var persistedConfig = getPersistedConfigMap(
+                new HashMap<>(
+                    Map.of(
+                        ServiceFields.MODEL_ID,
+                        modelId,
+                        IbmWatsonxServiceFields.PROJECT_ID,
+                        projectId,
+                        ServiceFields.URL,
+                        url,
+                        IbmWatsonxServiceFields.API_VERSION,
+                        apiVersion
+                    )
+                ),
+                getTaskSettingsMapEmpty(),
+                null
+            );
+
+            var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
+
+            assertThat(model, instanceOf(IbmWatsonxEmbeddingsModel.class));
+
+            var embeddingsModel = (IbmWatsonxEmbeddingsModel) model;
+            assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId));
+            assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId));
+            assertThat(embeddingsModel.getServiceSettings().url(), is(URI.create(url)));
+            assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(apiVersion));
+            assertThat(embeddingsModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE));
+            assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
+        }
+    }
+
+    public void testParsePersistedConfig_CreatesAIbmWatsonxEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
+        try (var service = createIbmWatsonxService()) {
+            var persistedConfig = getPersistedConfigMap(
+                new HashMap<>(
+                    Map.of(
+                        ServiceFields.MODEL_ID,
+                        modelId,
+                        IbmWatsonxServiceFields.PROJECT_ID,
+                        projectId,
+                        ServiceFields.URL,
+                        url,
+                        IbmWatsonxServiceFields.API_VERSION,
+                        apiVersion
+                    )
+                ),
+                getTaskSettingsMapEmpty(),
+                createRandomChunkingSettingsMap(),
+                null
+            );
+
+            var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
+
+            assertThat(model, instanceOf(IbmWatsonxEmbeddingsModel.class));
+
+            var embeddingsModel = (IbmWatsonxEmbeddingsModel) model;
+            assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId));
+            assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId));
+            assertThat(embeddingsModel.getServiceSettings().url(), is(URI.create(url)));
+            assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(apiVersion));
+            assertThat(embeddingsModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE));
+            assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
+        }
+    }
+
     public void testInfer_ThrowsErrorWhenModelIsNotIbmWatsonxModel() throws IOException {
         var sender = mock(Sender.class);
 
@@ -488,7 +639,15 @@ public class IbmWatsonxServiceTests extends ESTestCase {
         }
     }
 
-    public void testChunkedInfer_Batches() throws IOException {
+    public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException {
+        testChunkedInfer_Batches(null);
+    }
+
+    public void testChunkedInfer_ChunkingSettingsSet() throws IOException {
+        testChunkedInfer_Batches(createRandomChunkingSettings());
+    }
+
+    private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws IOException {
         var input = List.of("foo", "bar");
 
         var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
@@ -878,6 +1037,18 @@ public class IbmWatsonxServiceTests extends ESTestCase {
         });
     }
 
+    private Map<String, Object> getRequestConfigMap(
+        Map<String, Object> serviceSettings,
+        Map<String, Object> taskSettings,
+        Map<String, Object> chunkingSettings,
+        Map<String, Object> secretSettings
+    ) {
+        var requestConfigMap = getRequestConfigMap(serviceSettings, taskSettings, secretSettings);
+        requestConfigMap.put(ModelConfigurations.CHUNKING_SETTINGS, chunkingSettings);
+
+        return requestConfigMap;
+    }
+
     private Map<String, Object> getRequestConfigMap(
         Map<String, Object> serviceSettings,
         Map<String, Object> taskSettings,

+ 1 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/embeddings/IbmWatsonxEmbeddingsModelTests.java

@@ -82,6 +82,7 @@ public class IbmWatsonxEmbeddingsModelTests extends ESTestCase {
                 null
             ),
             EmptyTaskSettings.INSTANCE,
+            null,
             new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
         );
     }