|
@@ -17,6 +17,7 @@ import org.elasticsearch.common.bytes.BytesReference;
|
|
|
import org.elasticsearch.common.settings.Settings;
|
|
|
import org.elasticsearch.common.xcontent.XContentHelper;
|
|
|
import org.elasticsearch.core.TimeValue;
|
|
|
+import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
|
|
|
import org.elasticsearch.inference.ChunkInferenceInput;
|
|
|
import org.elasticsearch.inference.ChunkedInference;
|
|
|
import org.elasticsearch.inference.EmptySecretSettings;
|
|
@@ -29,7 +30,6 @@ import org.elasticsearch.inference.MinimalServiceSettings;
|
|
|
import org.elasticsearch.inference.Model;
|
|
|
import org.elasticsearch.inference.TaskType;
|
|
|
import org.elasticsearch.inference.UnifiedCompletionRequest;
|
|
|
-import org.elasticsearch.inference.WeightedToken;
|
|
|
import org.elasticsearch.plugins.Plugin;
|
|
|
import org.elasticsearch.test.ESSingleNodeTestCase;
|
|
|
import org.elasticsearch.test.http.MockResponse;
|
|
@@ -40,9 +40,8 @@ import org.elasticsearch.xcontent.XContentFactory;
|
|
|
import org.elasticsearch.xcontent.XContentType;
|
|
|
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
|
|
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
|
|
|
-import org.elasticsearch.xpack.core.inference.results.EmbeddingResults;
|
|
|
-import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
|
|
|
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResultsTests;
|
|
|
+import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
|
|
|
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
|
|
|
import org.elasticsearch.xpack.inference.InferencePlugin;
|
|
|
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
|
|
@@ -59,6 +58,7 @@ import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticI
|
|
|
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler;
|
|
|
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
|
|
|
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
|
|
|
+import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModelTests;
|
|
|
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
|
|
|
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModelTests;
|
|
|
import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity;
|
|
@@ -421,47 +421,6 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid() throws IOException {
|
|
|
- var sender = mock(Sender.class);
|
|
|
-
|
|
|
- var factory = mock(HttpRequestSender.Factory.class);
|
|
|
- when(factory.createSender()).thenReturn(sender);
|
|
|
-
|
|
|
- var mockModel = getInvalidModel("model_id", "service_name", TaskType.TEXT_EMBEDDING);
|
|
|
-
|
|
|
- try (var service = createService(factory)) {
|
|
|
- PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
|
|
- service.infer(
|
|
|
- mockModel,
|
|
|
- null,
|
|
|
- null,
|
|
|
- null,
|
|
|
- List.of(""),
|
|
|
- false,
|
|
|
- new HashMap<>(),
|
|
|
- InputType.INGEST,
|
|
|
- InferenceAction.Request.DEFAULT_TIMEOUT,
|
|
|
- listener
|
|
|
- );
|
|
|
-
|
|
|
- var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
|
|
|
- MatcherAssert.assertThat(
|
|
|
- thrownException.getMessage(),
|
|
|
- is(
|
|
|
- "Inference entity [model_id] does not support task type [text_embedding] "
|
|
|
- + "for inference, the task type must be one of [sparse_embedding, rerank]."
|
|
|
- )
|
|
|
- );
|
|
|
-
|
|
|
- verify(factory, times(1)).createSender();
|
|
|
- verify(sender, times(1)).start();
|
|
|
- }
|
|
|
-
|
|
|
- verify(sender, times(1)).close();
|
|
|
- verifyNoMoreInteractions(factory);
|
|
|
- verifyNoMoreInteractions(sender);
|
|
|
- }
|
|
|
-
|
|
|
public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid_ChatCompletion() throws IOException {
|
|
|
var sender = mock(Sender.class);
|
|
|
|
|
@@ -490,7 +449,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
thrownException.getMessage(),
|
|
|
is(
|
|
|
"Inference entity [model_id] does not support task type [chat_completion] "
|
|
|
- + "for inference, the task type must be one of [sparse_embedding, rerank]. "
|
|
|
+ + "for inference, the task type must be one of [text_embedding, sparse_embedding, rerank]. "
|
|
|
+ "The task type for the inference entity is chat_completion, "
|
|
|
+ "please use the _inference/chat_completion/model_id/_stream URL."
|
|
|
)
|
|
@@ -701,82 +660,6 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- public void testChunkedInfer_PropagatesProductUseCaseHeader() throws IOException {
|
|
|
- var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
|
|
- var elasticInferenceServiceURL = getUrl(webServer);
|
|
|
-
|
|
|
- try (var service = createService(senderFactory, elasticInferenceServiceURL)) {
|
|
|
- String responseJson = """
|
|
|
- {
|
|
|
- "data": [
|
|
|
- {
|
|
|
- "hello": 2.1259406,
|
|
|
- "greet": 1.7073475
|
|
|
- }
|
|
|
- ]
|
|
|
- }
|
|
|
- """;
|
|
|
-
|
|
|
- webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
|
|
-
|
|
|
- // Set up the product use case in the thread context
|
|
|
- String productUseCase = "test-product-use-case";
|
|
|
- threadPool.getThreadContext().putHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, productUseCase);
|
|
|
-
|
|
|
- var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(elasticInferenceServiceURL, "my-model-id");
|
|
|
- PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
|
|
|
-
|
|
|
- try {
|
|
|
- service.chunkedInfer(
|
|
|
- model,
|
|
|
- null,
|
|
|
- List.of(new ChunkInferenceInput("input text")),
|
|
|
- new HashMap<>(),
|
|
|
- InputType.INGEST,
|
|
|
- InferenceAction.Request.DEFAULT_TIMEOUT,
|
|
|
- listener
|
|
|
- );
|
|
|
-
|
|
|
- var results = listener.actionGet(TIMEOUT);
|
|
|
-
|
|
|
- // Verify the response was processed correctly
|
|
|
- ChunkedInference inferenceResult = results.getFirst();
|
|
|
- assertThat(inferenceResult, instanceOf(ChunkedInferenceEmbedding.class));
|
|
|
- var sparseResult = (ChunkedInferenceEmbedding) inferenceResult;
|
|
|
- assertThat(
|
|
|
- sparseResult.chunks(),
|
|
|
- is(
|
|
|
- List.of(
|
|
|
- new EmbeddingResults.Chunk(
|
|
|
- new SparseEmbeddingResults.Embedding(
|
|
|
- List.of(new WeightedToken("hello", 2.1259406f), new WeightedToken("greet", 1.7073475f)),
|
|
|
- false
|
|
|
- ),
|
|
|
- new ChunkedInference.TextOffset(0, "input text".length())
|
|
|
- )
|
|
|
- )
|
|
|
- )
|
|
|
- );
|
|
|
-
|
|
|
- // Verify the request was sent and contains expected headers
|
|
|
- MatcherAssert.assertThat(webServer.requests(), hasSize(1));
|
|
|
- var request = webServer.requests().getFirst();
|
|
|
- assertNull(request.getUri().getQuery());
|
|
|
- MatcherAssert.assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
|
|
|
-
|
|
|
- // Check that the product use case header was set correctly
|
|
|
- assertThat(request.getHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER), is(productUseCase));
|
|
|
-
|
|
|
- // Verify request body
|
|
|
- var requestMap = entityAsMap(request.getBody());
|
|
|
- assertThat(requestMap, is(Map.of("input", List.of("input text"), "model", "my-model-id", "usage_context", "ingest")));
|
|
|
- } finally {
|
|
|
- // Clean up the thread context
|
|
|
- threadPool.getThreadContext().stashContext();
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
public void testUnifiedCompletionInfer_PropagatesProductUseCaseHeader() throws IOException {
|
|
|
var elasticInferenceServiceURL = getUrl(webServer);
|
|
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
|
@@ -835,30 +718,45 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- public void testChunkedInfer() throws IOException {
|
|
|
+ public void testChunkedInfer_PropagatesProductUseCaseHeader() throws IOException {
|
|
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
|
|
- var elasticInferenceServiceURL = getUrl(webServer);
|
|
|
|
|
|
- try (var service = createService(senderFactory, elasticInferenceServiceURL)) {
|
|
|
+ try (var service = createService(senderFactory, getUrl(webServer))) {
|
|
|
+
|
|
|
+ // Batching will call the service with 2 inputs
|
|
|
String responseJson = """
|
|
|
{
|
|
|
"data": [
|
|
|
- {
|
|
|
- "hello": 2.1259406,
|
|
|
- "greet": 1.7073475
|
|
|
+ [
|
|
|
+ 0.123,
|
|
|
+ -0.456,
|
|
|
+ 0.789
|
|
|
+ ],
|
|
|
+ [
|
|
|
+ 0.987,
|
|
|
+ -0.654,
|
|
|
+ 0.321
|
|
|
+ ]
|
|
|
+ ],
|
|
|
+ "meta": {
|
|
|
+ "usage": {
|
|
|
+ "total_tokens": 10
|
|
|
}
|
|
|
- ]
|
|
|
+ }
|
|
|
}
|
|
|
""";
|
|
|
-
|
|
|
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
|
|
+ var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id");
|
|
|
+
|
|
|
+ String productUseCase = "test-product-use-case";
|
|
|
+ threadPool.getThreadContext().putHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER, productUseCase);
|
|
|
|
|
|
- var model = ElasticInferenceServiceSparseEmbeddingsModelTests.createModel(elasticInferenceServiceURL, "my-model-id");
|
|
|
PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
|
|
|
+ // 2 inputs
|
|
|
service.chunkedInfer(
|
|
|
model,
|
|
|
null,
|
|
|
- List.of(new ChunkInferenceInput("input text")),
|
|
|
+ List.of(new ChunkInferenceInput("hello world"), new ChunkInferenceInput("dense embedding")),
|
|
|
new HashMap<>(),
|
|
|
InputType.INGEST,
|
|
|
InferenceAction.Request.DEFAULT_TIMEOUT,
|
|
@@ -866,58 +764,111 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
);
|
|
|
|
|
|
var results = listener.actionGet(TIMEOUT);
|
|
|
- assertThat(results.get(0), instanceOf(ChunkedInferenceEmbedding.class));
|
|
|
- var sparseResult = (ChunkedInferenceEmbedding) results.get(0);
|
|
|
- assertThat(
|
|
|
- sparseResult.chunks(),
|
|
|
- is(
|
|
|
- List.of(
|
|
|
- new EmbeddingResults.Chunk(
|
|
|
- new SparseEmbeddingResults.Embedding(
|
|
|
- List.of(new WeightedToken("hello", 2.1259406f), new WeightedToken("greet", 1.7073475f)),
|
|
|
- false
|
|
|
- ),
|
|
|
- new ChunkedInference.TextOffset(0, "input text".length())
|
|
|
- )
|
|
|
- )
|
|
|
- )
|
|
|
- );
|
|
|
+ assertThat(results, hasSize(2));
|
|
|
|
|
|
- MatcherAssert.assertThat(webServer.requests(), hasSize(1));
|
|
|
- assertNull(webServer.requests().get(0).getUri().getQuery());
|
|
|
- MatcherAssert.assertThat(
|
|
|
- webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE),
|
|
|
- equalTo(XContentType.JSON.mediaType())
|
|
|
- );
|
|
|
+ // Verify the response was processed correctly
|
|
|
+ ChunkedInference inferenceResult = results.getFirst();
|
|
|
+ assertThat(inferenceResult, instanceOf(ChunkedInferenceEmbedding.class));
|
|
|
+
|
|
|
+ // Verify the request was sent and contains expected headers
|
|
|
+ assertThat(webServer.requests(), hasSize(1));
|
|
|
+ var request = webServer.requests().getFirst();
|
|
|
+ assertNull(request.getUri().getQuery());
|
|
|
+ assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
|
|
|
|
|
|
- var requestMap = entityAsMap(webServer.requests().get(0).getBody());
|
|
|
- assertThat(requestMap, is(Map.of("input", List.of("input text"), "model", "my-model-id", "usage_context", "ingest")));
|
|
|
+ // Check that the product use case header was set correctly
|
|
|
+ assertThat(request.getHeader(InferencePlugin.X_ELASTIC_PRODUCT_USE_CASE_HTTP_HEADER), is(productUseCase));
|
|
|
+
|
|
|
+ } finally {
|
|
|
+ // Clean up the thread context
|
|
|
+ threadPool.getThreadContext().stashContext();
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- public void testHideFromConfigurationApi_ReturnsTrue_WithNoAvailableModels() throws Exception {
|
|
|
- try (var service = createServiceWithMockSender(ElasticInferenceServiceAuthorizationModel.newDisabledService())) {
|
|
|
- ensureAuthorizationCallFinished(service);
|
|
|
+ public void testChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOException {
|
|
|
+ var model = ElasticInferenceServiceDenseTextEmbeddingsModelTests.createModel(getUrl(webServer), "my-dense-model-id");
|
|
|
|
|
|
- assertTrue(service.hideFromConfigurationApi());
|
|
|
+ var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
|
|
+
|
|
|
+ try (var service = createService(senderFactory, getUrl(webServer))) {
|
|
|
+
|
|
|
+ // Batching will call the service with 2 inputs
|
|
|
+ String responseJson = """
|
|
|
+ {
|
|
|
+ "data": [
|
|
|
+ [
|
|
|
+ 0.123,
|
|
|
+ -0.456,
|
|
|
+ 0.789
|
|
|
+ ],
|
|
|
+ [
|
|
|
+ 0.987,
|
|
|
+ -0.654,
|
|
|
+ 0.321
|
|
|
+ ]
|
|
|
+ ],
|
|
|
+ "meta": {
|
|
|
+ "usage": {
|
|
|
+ "total_tokens": 10
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ """;
|
|
|
+ webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
|
|
+
|
|
|
+ PlainActionFuture<List<ChunkedInference>> listener = new PlainActionFuture<>();
|
|
|
+ // 2 inputs
|
|
|
+ service.chunkedInfer(
|
|
|
+ model,
|
|
|
+ null,
|
|
|
+ List.of(new ChunkInferenceInput("hello world"), new ChunkInferenceInput("dense embedding")),
|
|
|
+ new HashMap<>(),
|
|
|
+ InputType.INGEST,
|
|
|
+ InferenceAction.Request.DEFAULT_TIMEOUT,
|
|
|
+ listener
|
|
|
+ );
|
|
|
+
|
|
|
+ var results = listener.actionGet(TIMEOUT);
|
|
|
+ assertThat(results, hasSize(2));
|
|
|
+
|
|
|
+ // First result
|
|
|
+ {
|
|
|
+ assertThat(results.getFirst(), instanceOf(ChunkedInferenceEmbedding.class));
|
|
|
+ var denseResult = (ChunkedInferenceEmbedding) results.getFirst();
|
|
|
+ assertThat(denseResult.chunks(), hasSize(1));
|
|
|
+ assertEquals(new ChunkedInference.TextOffset(0, "hello world".length()), denseResult.chunks().getFirst().offset());
|
|
|
+ assertThat(denseResult.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
|
|
|
+
|
|
|
+ var embedding = (TextEmbeddingFloatResults.Embedding) denseResult.chunks().get(0).embedding();
|
|
|
+ assertArrayEquals(new float[] { 0.123f, -0.456f, 0.789f }, embedding.values(), 0.0f);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Second result
|
|
|
+ {
|
|
|
+ assertThat(results.get(1), instanceOf(ChunkedInferenceEmbedding.class));
|
|
|
+ var denseResult = (ChunkedInferenceEmbedding) results.get(1);
|
|
|
+ assertThat(denseResult.chunks(), hasSize(1));
|
|
|
+ assertEquals(new ChunkedInference.TextOffset(0, "dense embedding".length()), denseResult.chunks().getFirst().offset());
|
|
|
+ assertThat(denseResult.chunks().getFirst().embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class));
|
|
|
+
|
|
|
+ var embedding = (TextEmbeddingFloatResults.Embedding) denseResult.chunks().get(0).embedding();
|
|
|
+ assertArrayEquals(new float[] { 0.987f, -0.654f, 0.321f }, embedding.values(), 0.0f);
|
|
|
+ }
|
|
|
+
|
|
|
+ assertThat(webServer.requests(), hasSize(1));
|
|
|
+ assertNull(webServer.requests().getFirst().getUri().getQuery());
|
|
|
+ assertThat(webServer.requests().getFirst().getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType()));
|
|
|
+
|
|
|
+ var requestMap = entityAsMap(webServer.requests().getFirst().getBody());
|
|
|
+ MatcherAssert.assertThat(
|
|
|
+ requestMap,
|
|
|
+ is(Map.of("input", List.of("hello world", "dense embedding"), "model", "my-dense-model-id", "usage_context", "ingest"))
|
|
|
+ );
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- public void testHideFromConfigurationApi_ReturnsTrue_WithModelTaskTypesThatAreNotImplemented() throws Exception {
|
|
|
- try (
|
|
|
- var service = createServiceWithMockSender(
|
|
|
- ElasticInferenceServiceAuthorizationModel.of(
|
|
|
- new ElasticInferenceServiceAuthorizationResponseEntity(
|
|
|
- List.of(
|
|
|
- new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel(
|
|
|
- "model-1",
|
|
|
- EnumSet.of(TaskType.TEXT_EMBEDDING)
|
|
|
- )
|
|
|
- )
|
|
|
- )
|
|
|
- )
|
|
|
- )
|
|
|
- ) {
|
|
|
+ public void testHideFromConfigurationApi_ReturnsTrue_WithNoAvailableModels() throws Exception {
|
|
|
+ try (var service = createServiceWithMockSender(ElasticInferenceServiceAuthorizationModel.newDisabledService())) {
|
|
|
ensureAuthorizationCallFinished(service);
|
|
|
|
|
|
assertTrue(service.hideFromConfigurationApi());
|
|
@@ -953,7 +904,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
List.of(
|
|
|
new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel(
|
|
|
"model-1",
|
|
|
- EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION)
|
|
|
+ EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.TEXT_EMBEDDING)
|
|
|
)
|
|
|
)
|
|
|
)
|
|
@@ -966,7 +917,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
{
|
|
|
"service": "elastic",
|
|
|
"name": "Elastic",
|
|
|
- "task_types": ["sparse_embedding", "chat_completion"],
|
|
|
+ "task_types": ["sparse_embedding", "chat_completion", "text_embedding"],
|
|
|
"configurations": {
|
|
|
"rate_limit.requests_per_minute": {
|
|
|
"description": "Minimize the number of rate limit errors.",
|
|
@@ -975,7 +926,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
"sensitive": false,
|
|
|
"updatable": false,
|
|
|
"type": "int",
|
|
|
- "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"]
|
|
|
+ "supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"]
|
|
|
},
|
|
|
"model_id": {
|
|
|
"description": "The name of the model to use for the inference task.",
|
|
@@ -984,7 +935,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
"sensitive": false,
|
|
|
"updatable": false,
|
|
|
"type": "str",
|
|
|
- "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"]
|
|
|
+ "supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"]
|
|
|
},
|
|
|
"max_input_tokens": {
|
|
|
"description": "Allows you to specify the maximum number of tokens per input.",
|
|
@@ -993,7 +944,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
"sensitive": false,
|
|
|
"updatable": false,
|
|
|
"type": "int",
|
|
|
- "supported_task_types": ["sparse_embedding"]
|
|
|
+ "supported_task_types": ["text_embedding", "sparse_embedding"]
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -1030,7 +981,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
"sensitive": false,
|
|
|
"updatable": false,
|
|
|
"type": "int",
|
|
|
- "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"]
|
|
|
+ "supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"]
|
|
|
},
|
|
|
"model_id": {
|
|
|
"description": "The name of the model to use for the inference task.",
|
|
@@ -1039,7 +990,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
"sensitive": false,
|
|
|
"updatable": false,
|
|
|
"type": "str",
|
|
|
- "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"]
|
|
|
+ "supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"]
|
|
|
},
|
|
|
"max_input_tokens": {
|
|
|
"description": "Allows you to specify the maximum number of tokens per input.",
|
|
@@ -1048,7 +999,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
"sensitive": false,
|
|
|
"updatable": false,
|
|
|
"type": "int",
|
|
|
- "supported_task_types": ["sparse_embedding"]
|
|
|
+ "supported_task_types": ["text_embedding", "sparse_embedding"]
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -1090,7 +1041,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
{
|
|
|
"service": "elastic",
|
|
|
"name": "Elastic",
|
|
|
- "task_types": [],
|
|
|
+ "task_types": ["text_embedding"],
|
|
|
"configurations": {
|
|
|
"rate_limit.requests_per_minute": {
|
|
|
"description": "Minimize the number of rate limit errors.",
|
|
@@ -1099,7 +1050,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
"sensitive": false,
|
|
|
"updatable": false,
|
|
|
"type": "int",
|
|
|
- "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"]
|
|
|
+ "supported_task_types": ["text_embedding" , "sparse_embedding", "rerank", "chat_completion"]
|
|
|
},
|
|
|
"model_id": {
|
|
|
"description": "The name of the model to use for the inference task.",
|
|
@@ -1108,7 +1059,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
"sensitive": false,
|
|
|
"updatable": false,
|
|
|
"type": "str",
|
|
|
- "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"]
|
|
|
+ "supported_task_types": ["text_embedding" , "sparse_embedding", "rerank", "chat_completion"]
|
|
|
},
|
|
|
"max_input_tokens": {
|
|
|
"description": "Allows you to specify the maximum number of tokens per input.",
|
|
@@ -1117,7 +1068,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
"sensitive": false,
|
|
|
"updatable": false,
|
|
|
"type": "int",
|
|
|
- "supported_task_types": ["sparse_embedding"]
|
|
|
+ "supported_task_types": ["text_embedding", "sparse_embedding"]
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -1296,6 +1247,10 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
"task_types": ["embed/text/sparse"]
|
|
|
},
|
|
|
{
|
|
|
+ "model_name": "multilingual-embed-v1",
|
|
|
+ "task_types": ["embed/text/dense"]
|
|
|
+ },
|
|
|
+ {
|
|
|
"model_name": "rerank-v1",
|
|
|
"task_types": ["rerank/text/text-similarity"]
|
|
|
}
|
|
@@ -1319,6 +1274,16 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
|
|
|
service
|
|
|
),
|
|
|
+ new InferenceService.DefaultConfigId(
|
|
|
+ ".multilingual-embed-v1-elastic",
|
|
|
+ MinimalServiceSettings.textEmbedding(
|
|
|
+ ElasticInferenceService.NAME,
|
|
|
+ ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS,
|
|
|
+ ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(),
|
|
|
+ DenseVectorFieldMapper.ElementType.FLOAT
|
|
|
+ ),
|
|
|
+ service
|
|
|
+ ),
|
|
|
new InferenceService.DefaultConfigId(
|
|
|
".rainbow-sprinkles-elastic",
|
|
|
MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME),
|
|
@@ -1332,16 +1297,19 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
)
|
|
|
)
|
|
|
);
|
|
|
- assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK)));
|
|
|
+ assertThat(
|
|
|
+ service.supportedTaskTypes(),
|
|
|
+ is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK, TaskType.TEXT_EMBEDDING))
|
|
|
+ );
|
|
|
|
|
|
PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
|
|
|
service.defaultConfigs(listener);
|
|
|
var models = listener.actionGet(TIMEOUT);
|
|
|
- assertThat(models.size(), is(3));
|
|
|
+ assertThat(models.size(), is(4));
|
|
|
assertThat(models.get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic"));
|
|
|
- assertThat(models.get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
|
|
|
- assertThat(models.get(2).getConfigurations().getInferenceEntityId(), is(".rerank-v1-elastic"));
|
|
|
-
|
|
|
+ assertThat(models.get(1).getConfigurations().getInferenceEntityId(), is(".multilingual-embed-v1-elastic"));
|
|
|
+ assertThat(models.get(2).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
|
|
|
+ assertThat(models.get(3).getConfigurations().getInferenceEntityId(), is(".rerank-v1-elastic"));
|
|
|
}
|
|
|
}
|
|
|
|