|
@@ -20,6 +20,7 @@ import org.elasticsearch.common.xcontent.XContentHelper;
|
|
|
import org.elasticsearch.core.TimeValue;
|
|
|
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
|
|
|
import org.elasticsearch.inference.ChunkingOptions;
|
|
|
+import org.elasticsearch.inference.ChunkingSettings;
|
|
|
import org.elasticsearch.inference.EmptyTaskSettings;
|
|
|
import org.elasticsearch.inference.InferenceServiceConfiguration;
|
|
|
import org.elasticsearch.inference.InferenceServiceResults;
|
|
@@ -69,6 +70,8 @@ import static org.elasticsearch.xpack.inference.Utils.getInvalidModel;
|
|
|
import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap;
|
|
|
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
|
|
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
|
|
+import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettings;
|
|
|
+import static org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests.createRandomChunkingSettingsMap;
|
|
|
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
|
|
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
|
|
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
|
|
@@ -124,6 +127,7 @@ public class IbmWatsonxServiceTests extends ESTestCase {
|
|
|
assertThat(embeddingsModel.getServiceSettings().url(), is(URI.create(url)));
|
|
|
assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(apiVersion));
|
|
|
assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(apiKey));
|
|
|
+ assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
|
|
|
}, e -> fail("Model parsing should have succeeded, but failed: " + e.getMessage()));
|
|
|
|
|
|
service.parseRequestConfig(
|
|
@@ -150,6 +154,45 @@ public class IbmWatsonxServiceTests extends ESTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ public void testParseRequestConfig_CreatesAIbmWatsonxEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
|
|
|
+ try (var service = createIbmWatsonxService()) {
|
|
|
+ ActionListener<Model> modelListener = ActionListener.wrap(model -> {
|
|
|
+ assertThat(model, instanceOf(IbmWatsonxEmbeddingsModel.class));
|
|
|
+
|
|
|
+ var embeddingsModel = (IbmWatsonxEmbeddingsModel) model;
|
|
|
+ assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId));
|
|
|
+ assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId));
|
|
|
+ assertThat(embeddingsModel.getServiceSettings().url(), is(URI.create(url)));
|
|
|
+ assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(apiVersion));
|
|
|
+ assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(apiKey));
|
|
|
+ assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
|
|
|
+ }, e -> fail("Model parsing should have succeeded, but failed: " + e.getMessage()));
|
|
|
+
|
|
|
+ service.parseRequestConfig(
|
|
|
+ "id",
|
|
|
+ TaskType.TEXT_EMBEDDING,
|
|
|
+ getRequestConfigMap(
|
|
|
+ new HashMap<>(
|
|
|
+ Map.of(
|
|
|
+ ServiceFields.MODEL_ID,
|
|
|
+ modelId,
|
|
|
+ IbmWatsonxServiceFields.PROJECT_ID,
|
|
|
+ projectId,
|
|
|
+ ServiceFields.URL,
|
|
|
+ url,
|
|
|
+ IbmWatsonxServiceFields.API_VERSION,
|
|
|
+ apiVersion
|
|
|
+ )
|
|
|
+ ),
|
|
|
+ new HashMap<>(Map.of()),
|
|
|
+ createRandomChunkingSettingsMap(),
|
|
|
+ getSecretSettingsMap(apiKey)
|
|
|
+ ),
|
|
|
+ modelListener
|
|
|
+ );
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException {
|
|
|
try (var service = createIbmWatsonxService()) {
|
|
|
var failureListener = getModelListenerForException(
|
|
@@ -235,6 +278,47 @@ public class IbmWatsonxServiceTests extends ESTestCase {
|
|
|
assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(apiVersion));
|
|
|
assertThat(embeddingsModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE));
|
|
|
assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(apiKey));
|
|
|
+ assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testParsePersistedConfigWithSecrets_CreatesAIbmWatsonxEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
|
|
|
+ try (var service = createIbmWatsonxService()) {
|
|
|
+ var persistedConfig = getPersistedConfigMap(
|
|
|
+ new HashMap<>(
|
|
|
+ Map.of(
|
|
|
+ ServiceFields.MODEL_ID,
|
|
|
+ modelId,
|
|
|
+ IbmWatsonxServiceFields.PROJECT_ID,
|
|
|
+ projectId,
|
|
|
+ ServiceFields.URL,
|
|
|
+ url,
|
|
|
+ IbmWatsonxServiceFields.API_VERSION,
|
|
|
+ apiVersion
|
|
|
+ )
|
|
|
+ ),
|
|
|
+ getTaskSettingsMapEmpty(),
|
|
|
+ createRandomChunkingSettingsMap(),
|
|
|
+ getSecretSettingsMap(apiKey)
|
|
|
+ );
|
|
|
+
|
|
|
+ var model = service.parsePersistedConfigWithSecrets(
|
|
|
+ "id",
|
|
|
+ TaskType.TEXT_EMBEDDING,
|
|
|
+ persistedConfig.config(),
|
|
|
+ persistedConfig.secrets()
|
|
|
+ );
|
|
|
+
|
|
|
+ assertThat(model, instanceOf(IbmWatsonxEmbeddingsModel.class));
|
|
|
+
|
|
|
+ var embeddingsModel = (IbmWatsonxEmbeddingsModel) model;
|
|
|
+ assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId));
|
|
|
+ assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId));
|
|
|
+ assertThat(embeddingsModel.getServiceSettings().url(), is(URI.create(url)));
|
|
|
+ assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(apiVersion));
|
|
|
+ assertThat(embeddingsModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE));
|
|
|
+ assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(apiKey));
|
|
|
+ assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -399,6 +483,73 @@ public class IbmWatsonxServiceTests extends ESTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ public void testParsePersistedConfig_CreatesAIbmWatsonxEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException {
|
|
|
+ try (var service = createIbmWatsonxService()) {
|
|
|
+ var persistedConfig = getPersistedConfigMap(
|
|
|
+ new HashMap<>(
|
|
|
+ Map.of(
|
|
|
+ ServiceFields.MODEL_ID,
|
|
|
+ modelId,
|
|
|
+ IbmWatsonxServiceFields.PROJECT_ID,
|
|
|
+ projectId,
|
|
|
+ ServiceFields.URL,
|
|
|
+ url,
|
|
|
+ IbmWatsonxServiceFields.API_VERSION,
|
|
|
+ apiVersion
|
|
|
+ )
|
|
|
+ ),
|
|
|
+ getTaskSettingsMapEmpty(),
|
|
|
+ null
|
|
|
+ );
|
|
|
+
|
|
|
+ var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
|
|
|
+
|
|
|
+ assertThat(model, instanceOf(IbmWatsonxEmbeddingsModel.class));
|
|
|
+
|
|
|
+ var embeddingsModel = (IbmWatsonxEmbeddingsModel) model;
|
|
|
+ assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId));
|
|
|
+ assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId));
|
|
|
+ assertThat(embeddingsModel.getServiceSettings().url(), is(URI.create(url)));
|
|
|
+ assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(apiVersion));
|
|
|
+ assertThat(embeddingsModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE));
|
|
|
+ assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testParsePersistedConfig_CreatesAIbmWatsonxEmbeddingsModelWhenChunkingSettingsProvided() throws IOException {
|
|
|
+ try (var service = createIbmWatsonxService()) {
|
|
|
+ var persistedConfig = getPersistedConfigMap(
|
|
|
+ new HashMap<>(
|
|
|
+ Map.of(
|
|
|
+ ServiceFields.MODEL_ID,
|
|
|
+ modelId,
|
|
|
+ IbmWatsonxServiceFields.PROJECT_ID,
|
|
|
+ projectId,
|
|
|
+ ServiceFields.URL,
|
|
|
+ url,
|
|
|
+ IbmWatsonxServiceFields.API_VERSION,
|
|
|
+ apiVersion
|
|
|
+ )
|
|
|
+ ),
|
|
|
+ getTaskSettingsMapEmpty(),
|
|
|
+ createRandomChunkingSettingsMap(),
|
|
|
+ null
|
|
|
+ );
|
|
|
+
|
|
|
+ var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config());
|
|
|
+
|
|
|
+ assertThat(model, instanceOf(IbmWatsonxEmbeddingsModel.class));
|
|
|
+
|
|
|
+ var embeddingsModel = (IbmWatsonxEmbeddingsModel) model;
|
|
|
+ assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId));
|
|
|
+ assertThat(embeddingsModel.getServiceSettings().projectId(), is(projectId));
|
|
|
+ assertThat(embeddingsModel.getServiceSettings().url(), is(URI.create(url)));
|
|
|
+ assertThat(embeddingsModel.getServiceSettings().apiVersion(), is(apiVersion));
|
|
|
+ assertThat(embeddingsModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE));
|
|
|
+ assertThat(embeddingsModel.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
public void testInfer_ThrowsErrorWhenModelIsNotIbmWatsonxModel() throws IOException {
|
|
|
var sender = mock(Sender.class);
|
|
|
|
|
@@ -488,7 +639,15 @@ public class IbmWatsonxServiceTests extends ESTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- public void testChunkedInfer_Batches() throws IOException {
|
|
|
+ public void testChunkedInfer_ChunkingSettingsNotSet() throws IOException {
|
|
|
+ testChunkedInfer_Batches(null);
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testChunkedInfer_ChunkingSettingsSet() throws IOException {
|
|
|
+ testChunkedInfer_Batches(createRandomChunkingSettings());
|
|
|
+ }
|
|
|
+
|
|
|
+ private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws IOException {
|
|
|
var input = List.of("foo", "bar");
|
|
|
|
|
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
|
@@ -878,6 +1037,18 @@ public class IbmWatsonxServiceTests extends ESTestCase {
|
|
|
});
|
|
|
}
|
|
|
|
|
|
+ private Map<String, Object> getRequestConfigMap(
|
|
|
+ Map<String, Object> serviceSettings,
|
|
|
+ Map<String, Object> taskSettings,
|
|
|
+ Map<String, Object> chunkingSettings,
|
|
|
+ Map<String, Object> secretSettings
|
|
|
+ ) {
|
|
|
+ var requestConfigMap = getRequestConfigMap(serviceSettings, taskSettings, secretSettings);
|
|
|
+ requestConfigMap.put(ModelConfigurations.CHUNKING_SETTINGS, chunkingSettings);
|
|
|
+
|
|
|
+ return requestConfigMap;
|
|
|
+ }
|
|
|
+
|
|
|
private Map<String, Object> getRequestConfigMap(
|
|
|
Map<String, Object> serviceSettings,
|
|
|
Map<String, Object> taskSettings,
|