|
@@ -22,13 +22,13 @@ import static org.hamcrest.Matchers.hasSize;
|
|
|
import static org.hamcrest.Matchers.is;
|
|
|
import static org.hamcrest.Matchers.oneOf;
|
|
|
|
|
|
-public class DefaultElserIT extends InferenceBaseRestTest {
|
|
|
+public class DefaultEndPointsIT extends InferenceBaseRestTest {
|
|
|
|
|
|
private TestThreadPool threadPool;
|
|
|
|
|
|
@Before
|
|
|
public void createThreadPool() {
|
|
|
- threadPool = new TestThreadPool(DefaultElserIT.class.getSimpleName());
|
|
|
+ threadPool = new TestThreadPool(DefaultEndPointsIT.class.getSimpleName());
|
|
|
}
|
|
|
|
|
|
@After
|
|
@@ -38,7 +38,7 @@ public class DefaultElserIT extends InferenceBaseRestTest {
|
|
|
}
|
|
|
|
|
|
@SuppressWarnings("unchecked")
|
|
|
- public void testInferCreatesDefaultElser() throws IOException {
|
|
|
+ public void testInferDeploysDefaultElser() throws IOException {
|
|
|
assumeTrue("Default config requires a feature flag", DefaultElserFeatureFlag.isEnabled());
|
|
|
var model = getModel(ElasticsearchInternalService.DEFAULT_ELSER_ID);
|
|
|
assertDefaultElserConfig(model);
|
|
@@ -67,4 +67,39 @@ public class DefaultElserIT extends InferenceBaseRestTest {
|
|
|
Matchers.is(Map.of("enabled", true, "min_number_of_allocations", 1, "max_number_of_allocations", 8))
|
|
|
);
|
|
|
}
|
|
|
+
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
+ public void testInferDeploysDefaultE5() throws IOException {
|
|
|
+ assumeTrue("Default config requires a feature flag", DefaultElserFeatureFlag.isEnabled());
|
|
|
+ var model = getModel(ElasticsearchInternalService.DEFAULT_E5_ID);
|
|
|
+ assertDefaultE5Config(model);
|
|
|
+
|
|
|
+ var inputs = List.of("Hello World", "Goodnight moon");
|
|
|
+ var queryParams = Map.of("timeout", "120s");
|
|
|
+ var results = infer(ElasticsearchInternalService.DEFAULT_E5_ID, TaskType.TEXT_EMBEDDING, inputs, queryParams);
|
|
|
+ var embeddings = (List<Map<String, Object>>) results.get("text_embedding");
|
|
|
+ assertThat(results.toString(), embeddings, hasSize(2));
|
|
|
+ }
|
|
|
+
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
+ private static void assertDefaultE5Config(Map<String, Object> modelConfig) {
|
|
|
+ assertEquals(modelConfig.toString(), ElasticsearchInternalService.DEFAULT_E5_ID, modelConfig.get("inference_id"));
|
|
|
+ assertEquals(modelConfig.toString(), ElasticsearchInternalService.NAME, modelConfig.get("service"));
|
|
|
+ assertEquals(modelConfig.toString(), TaskType.TEXT_EMBEDDING.toString(), modelConfig.get("task_type"));
|
|
|
+
|
|
|
+ var serviceSettings = (Map<String, Object>) modelConfig.get("service_settings");
|
|
|
+ assertThat(
|
|
|
+ modelConfig.toString(),
|
|
|
+ serviceSettings.get("model_id"),
|
|
|
+ is(oneOf(".multilingual-e5-small", ".multilingual-e5-small_linux-x86_64"))
|
|
|
+ );
|
|
|
+ assertEquals(modelConfig.toString(), 1, serviceSettings.get("num_threads"));
|
|
|
+
|
|
|
+ var adaptiveAllocations = (Map<String, Object>) serviceSettings.get("adaptive_allocations");
|
|
|
+ assertThat(
|
|
|
+ modelConfig.toString(),
|
|
|
+ adaptiveAllocations,
|
|
|
+ Matchers.is(Map.of("enabled", true, "min_number_of_allocations", 1, "max_number_of_allocations", 8))
|
|
|
+ );
|
|
|
+ }
|
|
|
}
|