|
@@ -11,6 +11,7 @@ import org.apache.http.HttpHeaders;
|
|
|
import org.elasticsearch.ElasticsearchStatusException;
|
|
|
import org.elasticsearch.action.ActionListener;
|
|
|
import org.elasticsearch.action.support.PlainActionFuture;
|
|
|
+import org.elasticsearch.common.ValidationException;
|
|
|
import org.elasticsearch.common.bytes.BytesArray;
|
|
|
import org.elasticsearch.common.bytes.BytesReference;
|
|
|
import org.elasticsearch.common.settings.Settings;
|
|
@@ -58,6 +59,8 @@ import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticI
|
|
|
import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler;
|
|
|
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
|
|
|
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
|
|
|
+import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
|
|
|
+import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModelTests;
|
|
|
import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity;
|
|
|
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels;
|
|
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
|
@@ -149,6 +152,23 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ public void testParseRequestConfig_CreatesARerankModel() throws IOException {
|
|
|
+ try (var service = createServiceWithMockSender()) {
|
|
|
+ ActionListener<Model> modelListener = ActionListener.wrap(model -> {
|
|
|
+ assertThat(model, instanceOf(ElasticInferenceServiceRerankModel.class));
|
|
|
+ ElasticInferenceServiceRerankModel rerankModel = (ElasticInferenceServiceRerankModel) model;
|
|
|
+ assertThat(rerankModel.getServiceSettings().modelId(), is("my-rerank-model-id"));
|
|
|
+ }, e -> fail("Model parsing should have succeeded, but failed: " + e.getMessage()));
|
|
|
+
|
|
|
+ service.parseRequestConfig(
|
|
|
+ "id",
|
|
|
+ TaskType.RERANK,
|
|
|
+ getRequestConfigMap(Map.of(ServiceFields.MODEL_ID, "my-rerank-model-id"), Map.of(), Map.of()),
|
|
|
+ modelListener
|
|
|
+ );
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException {
|
|
|
try (var service = createServiceWithMockSender()) {
|
|
|
var config = getRequestConfigMap(Map.of(ServiceFields.MODEL_ID, ElserModels.ELSER_V2_MODEL), Map.of(), Map.of());
|
|
@@ -367,6 +387,39 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
verifyNoMoreInteractions(sender);
|
|
|
}
|
|
|
|
|
|
+ public void testInfer_ThrowsValidationErrorForInvalidRerankParams() throws IOException {
|
|
|
+ var sender = mock(Sender.class);
|
|
|
+
|
|
|
+ var factory = mock(HttpRequestSender.Factory.class);
|
|
|
+ when(factory.createSender()).thenReturn(sender);
|
|
|
+
|
|
|
+ try (var service = createServiceWithMockSender()) {
|
|
|
+ var model = ElasticInferenceServiceRerankModelTests.createModel(getUrl(webServer), "my-rerank-model-id");
|
|
|
+ PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
|
|
+
|
|
|
+ var thrownException = expectThrows(
|
|
|
+ ValidationException.class,
|
|
|
+ () -> service.infer(
|
|
|
+ model,
|
|
|
+ "search query",
|
|
|
+ Boolean.TRUE,
|
|
|
+ 10,
|
|
|
+ List.of("doc1", "doc2", "doc3"),
|
|
|
+ false,
|
|
|
+ new HashMap<>(),
|
|
|
+ InputType.SEARCH,
|
|
|
+ InferenceAction.Request.DEFAULT_TIMEOUT,
|
|
|
+ listener
|
|
|
+ )
|
|
|
+ );
|
|
|
+
|
|
|
+ assertThat(
|
|
|
+ thrownException.getMessage(),
|
|
|
+ is("Validation Failed: 1: Invalid return_documents [true]. The return_documents option is not supported by this service;")
|
|
|
+ );
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid() throws IOException {
|
|
|
var sender = mock(Sender.class);
|
|
|
|
|
@@ -395,7 +448,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
thrownException.getMessage(),
|
|
|
is(
|
|
|
"Inference entity [model_id] does not support task type [text_embedding] "
|
|
|
- + "for inference, the task type must be one of [sparse_embedding]."
|
|
|
+ + "for inference, the task type must be one of [sparse_embedding, rerank]."
|
|
|
)
|
|
|
);
|
|
|
|
|
@@ -436,7 +489,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
thrownException.getMessage(),
|
|
|
is(
|
|
|
"Inference entity [model_id] does not support task type [chat_completion] "
|
|
|
- + "for inference, the task type must be one of [sparse_embedding]. "
|
|
|
+ + "for inference, the task type must be one of [sparse_embedding, rerank]. "
|
|
|
+ "The task type for the inference entity is chat_completion, "
|
|
|
+ "please use the _inference/chat_completion/model_id/_stream URL."
|
|
|
)
|
|
@@ -504,6 +557,76 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
+ public void testRerank_SendsRerankRequest() throws IOException {
|
|
|
+ var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
|
|
+ var elasticInferenceServiceURL = getUrl(webServer);
|
|
|
+
|
|
|
+ try (var service = createService(senderFactory, elasticInferenceServiceURL)) {
|
|
|
+ var modelId = "my-model-id";
|
|
|
+ var topN = 2;
|
|
|
+ String responseJson = """
|
|
|
+ {
|
|
|
+ "results": [
|
|
|
+ {"index": 0, "relevance_score": 0.95},
|
|
|
+ {"index": 1, "relevance_score": 0.85},
|
|
|
+ {"index": 2, "relevance_score": 0.75}
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ """;
|
|
|
+
|
|
|
+ webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
|
|
+
|
|
|
+ var model = ElasticInferenceServiceRerankModelTests.createModel(elasticInferenceServiceURL, modelId);
|
|
|
+ PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
|
|
+
|
|
|
+ service.infer(
|
|
|
+ model,
|
|
|
+ "search query",
|
|
|
+ null,
|
|
|
+ topN,
|
|
|
+ List.of("doc1", "doc2", "doc3"),
|
|
|
+ false,
|
|
|
+ new HashMap<>(),
|
|
|
+ InputType.SEARCH,
|
|
|
+ InferenceAction.Request.DEFAULT_TIMEOUT,
|
|
|
+ listener
|
|
|
+ );
|
|
|
+ var result = listener.actionGet(TIMEOUT);
|
|
|
+
|
|
|
+ var resultMap = result.asMap();
|
|
|
+ var rerankResults = (List<Map<String, Object>>) resultMap.get("rerank");
|
|
|
+ assertThat(rerankResults.size(), Matchers.is(3));
|
|
|
+
|
|
|
+ Map<String, Object> rankedDocOne = (Map<String, Object>) rerankResults.get(0).get("ranked_doc");
|
|
|
+ Map<String, Object> rankedDocTwo = (Map<String, Object>) rerankResults.get(1).get("ranked_doc");
|
|
|
+ Map<String, Object> rankedDocThree = (Map<String, Object>) rerankResults.get(2).get("ranked_doc");
|
|
|
+
|
|
|
+ assertThat(rankedDocOne.get("index"), equalTo(0));
|
|
|
+ assertThat(rankedDocTwo.get("index"), equalTo(1));
|
|
|
+ assertThat(rankedDocThree.get("index"), equalTo(2));
|
|
|
+
|
|
|
+ // Verify the outgoing HTTP request
|
|
|
+ var request = webServer.requests().get(0);
|
|
|
+ assertNull(request.getUri().getQuery());
|
|
|
+ assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), Matchers.equalTo(XContentType.JSON.mediaType()));
|
|
|
+
|
|
|
+ // Verify the outgoing request body
|
|
|
+ Map<String, Object> requestMap = entityAsMap(request.getBody());
|
|
|
+ Map<String, Object> expectedRequestMap = Map.of(
|
|
|
+ "query",
|
|
|
+ "search query",
|
|
|
+ "model",
|
|
|
+ modelId,
|
|
|
+ "top_n",
|
|
|
+ topN,
|
|
|
+ "documents",
|
|
|
+ List.of("doc1", "doc2", "doc3")
|
|
|
+ );
|
|
|
+ assertThat(requestMap, is(expectedRequestMap));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
public void testInfer_PropagatesProductUseCaseHeader() throws IOException {
|
|
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
|
|
var elasticInferenceServiceURL = getUrl(webServer);
|
|
@@ -850,7 +973,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
"sensitive": false,
|
|
|
"updatable": false,
|
|
|
"type": "int",
|
|
|
- "supported_task_types": ["sparse_embedding" , "chat_completion"]
|
|
|
+ "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"]
|
|
|
},
|
|
|
"model_id": {
|
|
|
"description": "The name of the model to use for the inference task.",
|
|
@@ -859,7 +982,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
"sensitive": false,
|
|
|
"updatable": false,
|
|
|
"type": "str",
|
|
|
- "supported_task_types": ["sparse_embedding" , "chat_completion"]
|
|
|
+ "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"]
|
|
|
},
|
|
|
"max_input_tokens": {
|
|
|
"description": "Allows you to specify the maximum number of tokens per input.",
|
|
@@ -905,7 +1028,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
"sensitive": false,
|
|
|
"updatable": false,
|
|
|
"type": "int",
|
|
|
- "supported_task_types": ["sparse_embedding" , "chat_completion"]
|
|
|
+ "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"]
|
|
|
},
|
|
|
"model_id": {
|
|
|
"description": "The name of the model to use for the inference task.",
|
|
@@ -914,7 +1037,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
"sensitive": false,
|
|
|
"updatable": false,
|
|
|
"type": "str",
|
|
|
- "supported_task_types": ["sparse_embedding" , "chat_completion"]
|
|
|
+ "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"]
|
|
|
},
|
|
|
"max_input_tokens": {
|
|
|
"description": "Allows you to specify the maximum number of tokens per input.",
|
|
@@ -974,7 +1097,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
"sensitive": false,
|
|
|
"updatable": false,
|
|
|
"type": "int",
|
|
|
- "supported_task_types": ["sparse_embedding" , "chat_completion"]
|
|
|
+ "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"]
|
|
|
},
|
|
|
"model_id": {
|
|
|
"description": "The name of the model to use for the inference task.",
|
|
@@ -983,7 +1106,7 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase {
|
|
|
"sensitive": false,
|
|
|
"updatable": false,
|
|
|
"type": "str",
|
|
|
- "supported_task_types": ["sparse_embedding" , "chat_completion"]
|
|
|
+ "supported_task_types": ["sparse_embedding" , "rerank", "chat_completion"]
|
|
|
},
|
|
|
"max_input_tokens": {
|
|
|
"description": "Allows you to specify the maximum number of tokens per input.",
|