|
@@ -17,6 +17,7 @@ import org.elasticsearch.common.bytes.BytesReference;
|
|
|
import org.elasticsearch.common.settings.Settings;
|
|
|
import org.elasticsearch.common.xcontent.XContentHelper;
|
|
|
import org.elasticsearch.core.TimeValue;
|
|
|
+import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
|
|
|
import org.elasticsearch.inference.ChunkInferenceInput;
|
|
|
import org.elasticsearch.inference.ChunkedInference;
|
|
|
import org.elasticsearch.inference.EmptySecretSettings;
|
|
@@ -43,6 +44,7 @@ import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
|
|
|
import org.elasticsearch.xpack.core.inference.results.EmbeddingResults;
|
|
|
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
|
|
|
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
|
|
|
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
|
|
|
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
|
|
|
import org.elasticsearch.xpack.inference.InferencePlugin;
|
|
|
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
|
|
@@ -59,6 +61,7 @@ import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticI
|
|
|
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler;
|
|
|
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
|
|
|
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
|
|
|
+import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModelTests;
|
|
|
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
|
|
|
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModelTests;
|
|
|
import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity;
|
|
@@ -420,47 +423,6 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid() throws IOException {
|
|
|
- var sender = mock(Sender.class);
|
|
|
-
|
|
|
- var factory = mock(HttpRequestSender.Factory.class);
|
|
|
- when(factory.createSender()).thenReturn(sender);
|
|
|
-
|
|
|
- var mockModel = getInvalidModel("model_id", "service_name", TaskType.TEXT_EMBEDDING);
|
|
|
-
|
|
|
- try (var service = createService(factory)) {
|
|
|
- PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
|
|
- service.infer(
|
|
|
- mockModel,
|
|
|
- null,
|
|
|
- null,
|
|
|
- null,
|
|
|
- List.of(""),
|
|
|
- false,
|
|
|
- new HashMap<>(),
|
|
|
- InputType.INGEST,
|
|
|
- InferenceAction.Request.DEFAULT_TIMEOUT,
|
|
|
- listener
|
|
|
- );
|
|
|
-
|
|
|
- var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
|
|
|
- MatcherAssert.assertThat(
|
|
|
- thrownException.getMessage(),
|
|
|
- is(
|
|
|
- "Inference entity [model_id] does not support task type [text_embedding] "
|
|
|
- + "for inference, the task type must be one of [sparse_embedding, rerank]."
|
|
|
- )
|
|
|
- );
|
|
|
-
|
|
|
- verify(factory, times(1)).createSender();
|
|
|
- verify(sender, times(1)).start();
|
|
|
- }
|
|
|
-
|
|
|
- verify(sender, times(1)).close();
|
|
|
- verifyNoMoreInteractions(factory);
|
|
|
- verifyNoMoreInteractions(sender);
|
|
|
- }
|
|
|
-
|
|
|
public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws IOException {
|
|
|
var sender = mock(Sender.class);
|
|
|
|
|
@@ -489,7 +451,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
thrownException.getMessage(),
|
|
|
is(
|
|
|
"Inference entity [model_id] does not support task type [chat_completion] "
|
|
|
- + "for inference, the task type must be one of [sparse_embedding, rerank]. "
|
|
|
+ + "for inference, the task type must be one of [text_embedding, sparse_embedding, rerank]. "
|
|
|
+ "The task type for the inference entity is chat_completion, "
|
|
|
+ "please use the _inference/chat_completion/model_id/_stream URL."
|
|
|
)
|
|
@@ -834,29 +796,43 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- public void testChunkedInfer() throws IOException {
|
|
|
+ public void testChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOException {
|
|
|
+ var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id");
|
|
|
+
|
|
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
|
|
- var elasticInferenceServiceURL = getUrl(webServer);
|
|
|
|
|
|
- try (var service = createService(senderFactory, elasticInferenceServiceURL)) {
|
|
|
+ try (var service = createService(senderFactory, getUrl(webServer))) {
|
|
|
+
|
|
|
+ // Batching will call the service with 2 inputs
|
|
|
String responseJson = """
|
|
|
{
|
|
|
"data": [
|
|
|
- {
|
|
|
- "hello": 2.1259406,
|
|
|
- "greet": 1.7073475
|
|
|
+ [
|
|
|
+ 0.123,
|
|
|
+ -0.456,
|
|
|
+ 0.789
|
|
|
+ ],
|
|
|
+ [
|
|
|
+ 0.987,
|
|
|
+ -0.654,
|
|
|
+ 0.321
|
|
|
+ ]
|
|
|
+ ],
|
|
|
+ "meta": {
|
|
|
+ "usage": {
|
|
|
+ "total_tokens": 10
|
|
|
}
|
|
|
- ]
|
|
|
+ }
|
|
|
}
|
|
|
""";
|
|
|
-
|
|
|
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
|
|
|
|
|
- var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(elasticInferenceServiceURL, "my-model-id");
|
|
|
PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
|
|
|
+ // 2 inputs
|
|
|
service.chunkedInfer(
|
|
|
model,
|
|
|
- List.of(new ChunkInferenceInput("input text")),
|
|
|
+ null,
|
|
|
+ List.of(new ChunkInferenceInput("hello world"), new ChunkInferenceInput("dense embedding")),
|
|
|
new HashMap<>(),
|
|
|
InputType.INGEST,
|
|
|
InferenceAction.Request.DEFAULT_TIMEOUT,
|
|
@@ -864,32 +840,41 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
);
|
|
|
|
|
|
var results = listener.actionGet(TIMEOUT);
|
|
|
- assertThat(results.get(0), instanceOf(ChunkedInferenceEmbedding.class));
|
|
|
- var sparseResult = (ChunkedInferenceEmbedding) results.get(0);
|
|
|
- assertThat(
|
|
|
- sparseResult.chunks(),
|
|
|
- is(
|
|
|
- List.of(
|
|
|
- new EmbeddingResults.Chunk(
|
|
|
- new SparseEmbeddingResults.Embedding(
|
|
|
- List.of(new WeightedToken("hello", 2.1259406f), new WeightedToken("greet", 1.7073475f)),
|
|
|
- false
|
|
|
- ),
|
|
|
- new ChunkedInference.TextOffset(0, "input text".length())
|
|
|
- )
|
|
|
- )
|
|
|
- )
|
|
|
- );
|
|
|
+ assertThat(results, hasSize(2));
|
|
|
+
|
|
|
+ // First result
|
|
|
+ {
|
|
|
+ assertThat(results.get(0), instanceOf(ChunkedInferenceEmbedding.class));
|
|
|
+ var denseResult = (ChunkedInferenceEmbedding) results.get(0);
|
|
|
+ assertThat(denseResult.chunks(), hasSize(1));
|
|
|
+ assertEquals(new ChunkedInference.TextOffset(0, "hello world".length()), denseResult.chunks().get(0).offset());
|
|
|
+ assertThat(denseResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
|
|
|
+
|
|
|
+ var embedding = (TextEmbeddingFloatResults.Embedding) denseResult.chunks().get(0).embedding();
|
|
|
+ assertArrayEquals(new float[] { 0.123f, -0.456f, 0.789f }, embedding.values(), 0.0f);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Second result
|
|
|
+ {
|
|
|
+ assertThat(results.get(1), instanceOf(ChunkedInferenceEmbedding.class));
|
|
|
+ var denseResult = (ChunkedInferenceEmbedding) results.get(1);
|
|
|
+ assertThat(denseResult.chunks(), hasSize(1));
|
|
|
+ assertEquals(new ChunkedInference.TextOffset(0, "dense embedding".length()), denseResult.chunks().get(0).offset());
|
|
|
+ assertThat(denseResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
|
|
|
+
|
|
|
+ var embedding = (TextEmbeddingFloatResults.Embedding) denseResult.chunks().get(0).embedding();
|
|
|
+ assertArrayEquals(new float[] { 0.987f, -0.654f, 0.321f }, embedding.values(), 0.0f);
|
|
|
+ }
|
|
|
|
|
|
- MatcherAssert.assertThat(webServer.requests(), hasSize(1));
|
|
|
+ assertThat(webServer.requests(), hasSize(1));
|
|
|
assertNull(webServer.requests().get(0).getUri().getQuery());
|
|
|
- MatcherAssert.assertThat(
|
|
|
- webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE),
|
|
|
- equalTo(XContentType.JSON.mediaType())
|
|
|
- );
|
|
|
+ assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
|
|
|
|
|
|
var requestMap = entityAsMap(webServer.requests().get(0).getBody());
|
|
|
- assertThat(requestMap, is(Map.of("input", List.of("input text"), "model", "my-model-id", "usage_context", "ingest")));
|
|
|
+ MatcherAssert.assertThat(
|
|
|
+ requestMap,
|
|
|
+ is(Map.of("input", List.of("hello world", "dense embedding"), "model", "my-dense-model-id", "usage_context", "ingest"))
|
|
|
+ );
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -901,27 +886,6 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- public void testHideFromConfigurationApi_ReturnsTrue_WithModelTaskTypesThatAreNotImplemented() throws Exception {
|
|
|
- try (
|
|
|
- var service = createServiceWithMockSender(
|
|
|
- ElasticInferenceServiceAuthorizationModel.of(
|
|
|
- new ElasticInferenceServiceAuthorizationResponseEntity(
|
|
|
- List.of(
|
|
|
- new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel(
|
|
|
- "model-1",
|
|
|
- EnumSet.of(TaskType.TEXT_EMBEDDING)
|
|
|
- )
|
|
|
- )
|
|
|
- )
|
|
|
- )
|
|
|
- )
|
|
|
- ) {
|
|
|
- ensureAuthorizationCallFinished(service);
|
|
|
-
|
|
|
- assertTrue(service.hideFromConfigurationApi());
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
public void testHideFromConfigurationApi_ReturnsFalse_WithAvailableModels() throws Exception {
|
|
|
try (
|
|
|
var service = createServiceWithMockSender(
|
|
@@ -951,7 +915,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
List.of(
|
|
|
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel(
|
|
|
"model-1",
|
|
|
- EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION)
|
|
|
+ EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.TEXT_EMBEDDING)
|
|
|
)
|
|
|
)
|
|
|
)
|
|
@@ -964,7 +928,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
{
|
|
|
"service": "elastic",
|
|
|
"name": "Elastic",
|
|
|
- "task_types": ["sparse_embedding", "chat_completion"],
|
|
|
+ "task_types": ["sparse_embedding", "chat_completion", "text_embedding"],
|
|
|
"configurations": {
|
|
|
"rate_limit.requests_per_minute": {
|
|
|
"description": "Minimize the number of rate limit errors.",
|
|
@@ -973,7 +937,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
"sensitive": false,
|
|
|
"updatable": false,
|
|
|
"type": "int",
|
|
|
- "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"]
|
|
|
+ "supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"]
|
|
|
},
|
|
|
"model_id": {
|
|
|
"description": "The name of the model to use for the inference task.",
|
|
@@ -982,7 +946,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
"sensitive": false,
|
|
|
"updatable": false,
|
|
|
"type": "str",
|
|
|
- "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"]
|
|
|
+ "supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"]
|
|
|
},
|
|
|
"max_input_tokens": {
|
|
|
"description": "Allows you to specify the maximum number of tokens per input.",
|
|
@@ -991,7 +955,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
"sensitive": false,
|
|
|
"updatable": false,
|
|
|
"type": "int",
|
|
|
- "supported_task_types": ["sparse_embedding"]
|
|
|
+ "supported_task_types": ["text_embedding", "sparse_embedding"]
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -1028,7 +992,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
"sensitive": false,
|
|
|
"updatable": false,
|
|
|
"type": "int",
|
|
|
- "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"]
|
|
|
+ "supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"]
|
|
|
},
|
|
|
"model_id": {
|
|
|
"description": "The name of the model to use for the inference task.",
|
|
@@ -1037,7 +1001,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
"sensitive": false,
|
|
|
"updatable": false,
|
|
|
"type": "str",
|
|
|
- "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"]
|
|
|
+ "supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"]
|
|
|
},
|
|
|
"max_input_tokens": {
|
|
|
"description": "Allows you to specify the maximum number of tokens per input.",
|
|
@@ -1046,7 +1010,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
"sensitive": false,
|
|
|
"updatable": false,
|
|
|
"type": "int",
|
|
|
- "supported_task_types": ["sparse_embedding"]
|
|
|
+ "supported_task_types": ["text_embedding", "sparse_embedding"]
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -1088,7 +1052,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
{
|
|
|
"service": "elastic",
|
|
|
"name": "Elastic",
|
|
|
- "task_types": [],
|
|
|
+ "task_types": ["text_embedding"],
|
|
|
"configurations": {
|
|
|
"rate_limit.requests_per_minute": {
|
|
|
"description": "Minimize the number of rate limit errors.",
|
|
@@ -1097,7 +1061,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
"sensitive": false,
|
|
|
"updatable": false,
|
|
|
"type": "int",
|
|
|
- "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"]
|
|
|
+ "supported_task_types": ["text_embedding" , "sparse_embedding", "rerank", "chat_completion"]
|
|
|
},
|
|
|
"model_id": {
|
|
|
"description": "The name of the model to use for the inference task.",
|
|
@@ -1106,7 +1070,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
"sensitive": false,
|
|
|
"updatable": false,
|
|
|
"type": "str",
|
|
|
- "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"]
|
|
|
+ "supported_task_types": ["text_embedding" , "sparse_embedding", "rerank", "chat_completion"]
|
|
|
},
|
|
|
"max_input_tokens": {
|
|
|
"description": "Allows you to specify the maximum number of tokens per input.",
|
|
@@ -1115,7 +1079,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
"sensitive": false,
|
|
|
"updatable": false,
|
|
|
"type": "int",
|
|
|
- "supported_task_types": ["sparse_embedding"]
|
|
|
+ "supported_task_types": ["text_embedding", "sparse_embedding"]
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -1291,6 +1255,14 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
{
|
|
|
"model_name": "elser-v2",
|
|
|
"task_types": ["embed/text/sparse"]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "model_name": "multilingual-embed-v1",
|
|
|
+ "task_types": ["embed/text/dense"]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "model_name": "rerank-v1",
|
|
|
+ "task_types": ["rerank/text/text-similarity"]
|
|
|
}
|
|
|
]
|
|
|
}
|
|
@@ -1312,22 +1284,42 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
|
|
|
service
|
|
|
),
|
|
|
+ new InferenceService.DefaultConfigId(
|
|
|
+ ".multilingual-embed-v1-elastic",
|
|
|
+ MinimalServiceSettings.textEmbedding(
|
|
|
+ ElasticInferenceService.NAME,
|
|
|
+ ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS,
|
|
|
+ ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(),
|
|
|
+ DenseVectorFieldMapper.ElementType.FLOAT
|
|
|
+ ),
|
|
|
+ service
|
|
|
+ ),
|
|
|
new InferenceService.DefaultConfigId(
|
|
|
".rainbow-sprinkles-elastic",
|
|
|
MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME),
|
|
|
service
|
|
|
+ ),
|
|
|
+ new InferenceService.DefaultConfigId(
|
|
|
+ ".rerank-v1-elastic",
|
|
|
+ MinimalServiceSettings.rerank(ElasticInferenceService.NAME),
|
|
|
+ service
|
|
|
)
|
|
|
)
|
|
|
)
|
|
|
);
|
|
|
- assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING)));
|
|
|
+ assertThat(
|
|
|
+ service.supportedTaskTypes(),
|
|
|
+ is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK, TaskType.TEXT_EMBEDDING))
|
|
|
+ );
|
|
|
|
|
|
PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
|
|
|
service.defaultConfigs(listener);
|
|
|
var models = listener.actionGet(TIMEOUT);
|
|
|
- assertThat(models.size(), is(2));
|
|
|
+ assertThat(models.size(), is(4));
|
|
|
assertThat(models.get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic"));
|
|
|
- assertThat(models.get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
|
|
|
+ assertThat(models.get(1).getConfigurations().getInferenceEntityId(), is(".multilingual-embed-v1-elastic"));
|
|
|
+ assertThat(models.get(2).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
|
|
|
+ assertThat(models.get(3).getConfigurations().getInferenceEntityId(), is(".rerank-v1-elastic"));
|
|
|
}
|
|
|
}
|
|
|
|