|
|
@@ -24,29 +24,18 @@ import static org.hamcrest.Matchers.containsString;
|
|
|
public class TextEmbeddingCrudIT extends InferenceBaseRestTest {
|
|
|
|
|
|
public void testPutE5Small_withNoModelVariant() throws IOException {
|
|
|
- // Model downloaded automatically & test infer with no model variant
|
|
|
{
|
|
|
String inferenceEntityId = randomAlphaOfLength(10).toLowerCase();
|
|
|
- putTextEmbeddingModel(inferenceEntityId, TaskType.TEXT_EMBEDDING, noModelIdVariantJsonEntity());
|
|
|
- var models = getTrainedModel("_all");
|
|
|
- assertThat(models.toString(), containsString("deployment_id=" + inferenceEntityId));
|
|
|
-
|
|
|
- Map<String, Object> results = inferOnMockService(
|
|
|
- inferenceEntityId,
|
|
|
- TaskType.TEXT_EMBEDDING,
|
|
|
- List.of("hello world", "this is the second document")
|
|
|
+ expectThrows(
|
|
|
+ org.elasticsearch.client.ResponseException.class,
|
|
|
+ () -> putTextEmbeddingModel(inferenceEntityId, noModelIdVariantJsonEntity())
|
|
|
);
|
|
|
- assertTrue(((List) ((Map) ((List) results.get("text_embedding")).get(0)).get("embedding")).size() > 1);
|
|
|
- // there exists embeddings
|
|
|
- assertTrue(((List) results.get("text_embedding")).size() == 2);
|
|
|
- // there are two sets of embeddings
|
|
|
- deleteTextEmbeddingModel(inferenceEntityId);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
public void testPutE5Small_withPlatformAgnosticVariant() throws IOException {
|
|
|
String inferenceEntityId = randomAlphaOfLength(10).toLowerCase();
|
|
|
- putTextEmbeddingModel(inferenceEntityId, TaskType.TEXT_EMBEDDING, platformAgnosticModelVariantJsonEntity());
|
|
|
+ putTextEmbeddingModel(inferenceEntityId, platformAgnosticModelVariantJsonEntity());
|
|
|
var models = getTrainedModel("_all");
|
|
|
assertThat(models.toString(), containsString("deployment_id=" + inferenceEntityId));
|
|
|
|
|
|
@@ -65,7 +54,7 @@ public class TextEmbeddingCrudIT extends InferenceBaseRestTest {
|
|
|
public void testPutE5Small_withPlatformSpecificVariant() throws IOException {
|
|
|
String inferenceEntityId = randomAlphaOfLength(10).toLowerCase();
|
|
|
if ("linux-x86_64".equals(Platforms.PLATFORM_NAME)) {
|
|
|
- putTextEmbeddingModel(inferenceEntityId, TaskType.TEXT_EMBEDDING, platformSpecificModelVariantJsonEntity());
|
|
|
+ putTextEmbeddingModel(inferenceEntityId, platformSpecificModelVariantJsonEntity());
|
|
|
var models = getTrainedModel("_all");
|
|
|
assertThat(models.toString(), containsString("deployment_id=" + inferenceEntityId));
|
|
|
|
|
|
@@ -82,7 +71,7 @@ public class TextEmbeddingCrudIT extends InferenceBaseRestTest {
|
|
|
} else {
|
|
|
expectThrows(
|
|
|
org.elasticsearch.client.ResponseException.class,
|
|
|
- () -> putTextEmbeddingModel(inferenceEntityId, TaskType.TEXT_EMBEDDING, platformSpecificModelVariantJsonEntity())
|
|
|
+ () -> putTextEmbeddingModel(inferenceEntityId, platformSpecificModelVariantJsonEntity())
|
|
|
);
|
|
|
}
|
|
|
}
|
|
|
@@ -91,9 +80,15 @@ public class TextEmbeddingCrudIT extends InferenceBaseRestTest {
|
|
|
String inferenceEntityId = randomAlphaOfLength(10).toLowerCase();
|
|
|
expectThrows(
|
|
|
org.elasticsearch.client.ResponseException.class,
|
|
|
- () -> putTextEmbeddingModel(inferenceEntityId, TaskType.TEXT_EMBEDDING, fakeModelVariantJsonEntity())
|
|
|
+ () -> putTextEmbeddingModel(inferenceEntityId, fakeModelVariantJsonEntity())
|
|
|
);
|
|
|
+ }
|
|
|
|
|
|
+ public void testPutE5WithTrainedModelAndInference() throws IOException {
|
|
|
+ putE5TrainedModels();
|
|
|
+ deployE5TrainedModels();
|
|
|
+ putTextEmbeddingModel("an-e5-deployment", platformAgnosticModelVariantJsonEntity());
|
|
|
+ getTrainedModel("an-e5-deployment");
|
|
|
}
|
|
|
|
|
|
private Map<String, Object> deleteTextEmbeddingModel(String inferenceEntityId) throws IOException {
|
|
|
@@ -104,8 +99,8 @@ public class TextEmbeddingCrudIT extends InferenceBaseRestTest {
|
|
|
return entityAsMap(response);
|
|
|
}
|
|
|
|
|
|
- private Map<String, Object> putTextEmbeddingModel(String inferenceEntityId, TaskType taskType, String jsonEntity) throws IOException {
|
|
|
- var endpoint = Strings.format("_inference/%s/%s", taskType, inferenceEntityId);
|
|
|
+ private Map<String, Object> putTextEmbeddingModel(String inferenceEntityId, String jsonEntity) throws IOException {
|
|
|
+ var endpoint = Strings.format("_inference/%s/%s", TaskType.TEXT_EMBEDDING, inferenceEntityId);
|
|
|
var request = new Request("PUT", endpoint);
|
|
|
|
|
|
request.setJsonEntity(jsonEntity);
|