|
|
@@ -14,14 +14,18 @@ import org.elasticsearch.common.io.Streams;
|
|
|
import org.elasticsearch.common.settings.SecureString;
|
|
|
import org.elasticsearch.common.xcontent.XContentHelper;
|
|
|
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
|
|
|
+import org.elasticsearch.inference.InputType;
|
|
|
import org.elasticsearch.inference.SimilarityMeasure;
|
|
|
import org.elasticsearch.inference.TaskType;
|
|
|
import org.elasticsearch.test.ESTestCase;
|
|
|
import org.elasticsearch.xcontent.XContentType;
|
|
|
+import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
|
|
|
+import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs;
|
|
|
import org.elasticsearch.xpack.inference.services.custom.CustomModelTests;
|
|
|
import org.elasticsearch.xpack.inference.services.custom.CustomSecretSettings;
|
|
|
import org.elasticsearch.xpack.inference.services.custom.CustomServiceSettings;
|
|
|
import org.elasticsearch.xpack.inference.services.custom.CustomTaskSettings;
|
|
|
+import org.elasticsearch.xpack.inference.services.custom.InputTypeTranslator;
|
|
|
import org.elasticsearch.xpack.inference.services.custom.QueryParameters;
|
|
|
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
|
|
|
import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
|
|
|
@@ -46,7 +50,8 @@ public class CustomRequestTests extends ESTestCase {
|
|
|
Map<String, String> headers = Map.of(HttpHeaders.AUTHORIZATION, Strings.format("${api_key}"));
|
|
|
var requestContentString = """
|
|
|
{
|
|
|
- "input": ${input}
|
|
|
+ "input": ${input},
|
|
|
+ "input_type": ${input_type}
|
|
|
}
|
|
|
""";
|
|
|
|
|
|
@@ -62,7 +67,9 @@ public class CustomRequestTests extends ESTestCase {
|
|
|
new QueryParameters(List.of(new QueryParameters.Parameter("key", "value"), new QueryParameters.Parameter("key", "value2"))),
|
|
|
requestContentString,
|
|
|
new TextEmbeddingResponseParser("$.result.embeddings"),
|
|
|
- new RateLimitSettings(10_000)
|
|
|
+ new RateLimitSettings(10_000),
|
|
|
+ null,
|
|
|
+ new InputTypeTranslator(Map.of(InputType.INGEST, "value"), "default")
|
|
|
);
|
|
|
|
|
|
var model = CustomModelTests.createModel(
|
|
|
@@ -73,7 +80,13 @@ public class CustomRequestTests extends ESTestCase {
|
|
|
new CustomSecretSettings(Map.of("api_key", new SecureString("my-secret-key".toCharArray())))
|
|
|
);
|
|
|
|
|
|
- var request = new CustomRequest(null, List.of("abc", "123"), model);
|
|
|
+ var request = new CustomRequest(
|
|
|
+ EmbeddingParameters.of(
|
|
|
+ new EmbeddingsInput(List.of("abc", "123"), null, null),
|
|
|
+ model.getServiceSettings().getInputTypeTranslator()
|
|
|
+ ),
|
|
|
+ model
|
|
|
+ );
|
|
|
var httpRequest = request.createHttpRequest();
|
|
|
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
|
|
|
|
|
@@ -84,18 +97,20 @@ public class CustomRequestTests extends ESTestCase {
|
|
|
|
|
|
var expectedBody = XContentHelper.stripWhitespace("""
|
|
|
{
|
|
|
- "input": ["abc", "123"]
|
|
|
+ "input": ["abc", "123"],
|
|
|
+ "input_type": "default"
|
|
|
}
|
|
|
""");
|
|
|
|
|
|
assertThat(convertToString(httpPost.getEntity().getContent()), is(expectedBody));
|
|
|
}
|
|
|
|
|
|
- public void testCreateRequest_QueryParametersAreEscaped_AndEncoded() {
|
|
|
+ public void testCreateRequest_QueryParametersAreEscaped_AndEncoded() throws IOException {
|
|
|
var inferenceId = "inferenceId";
|
|
|
var requestContentString = """
|
|
|
{
|
|
|
- "input": ${input}
|
|
|
+ "input": ${input},
|
|
|
+ "input_type": ${input_type}
|
|
|
}
|
|
|
""";
|
|
|
|
|
|
@@ -115,7 +130,9 @@ public class CustomRequestTests extends ESTestCase {
|
|
|
),
|
|
|
requestContentString,
|
|
|
new TextEmbeddingResponseParser("$.result.embeddings"),
|
|
|
- new RateLimitSettings(10_000)
|
|
|
+ new RateLimitSettings(10_000),
|
|
|
+ null,
|
|
|
+ new InputTypeTranslator(Map.of(InputType.INGEST, "value"), "default")
|
|
|
);
|
|
|
|
|
|
var model = CustomModelTests.createModel(
|
|
|
@@ -126,7 +143,13 @@ public class CustomRequestTests extends ESTestCase {
|
|
|
new CustomSecretSettings(Map.of("api_key", new SecureString("my-secret-key".toCharArray())))
|
|
|
);
|
|
|
|
|
|
- var request = new CustomRequest(null, List.of("abc", "123"), model);
|
|
|
+ var request = new CustomRequest(
|
|
|
+ EmbeddingParameters.of(
|
|
|
+ new EmbeddingsInput(List.of("abc", "123"), null, InputType.INGEST),
|
|
|
+ model.getServiceSettings().getInputTypeTranslator()
|
|
|
+ ),
|
|
|
+ model
|
|
|
+ );
|
|
|
var httpRequest = request.createHttpRequest();
|
|
|
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
|
|
|
|
|
@@ -136,6 +159,14 @@ public class CustomRequestTests extends ESTestCase {
|
|
|
// To visually verify that this is correct, input the query parameters into here: https://www.urldecoder.org/
|
|
|
is("http://www.elastic.co?key=+%3C%3E%23%25%2B%7B%7D%7C%5C%5E%7E%5B%5D%60%3B%2F%3F%3A%40%3D%26%24&key=%CE%A3+%F0%9F%98%80")
|
|
|
);
|
|
|
+
|
|
|
+ var expectedBody = XContentHelper.stripWhitespace("""
|
|
|
+ {
|
|
|
+ "input": ["abc", "123"],
|
|
|
+ "input_type": "value"
|
|
|
+ }
|
|
|
+ """);
|
|
|
+ assertThat(convertToString(httpPost.getEntity().getContent()), is(expectedBody));
|
|
|
}
|
|
|
|
|
|
public void testCreateRequest_SecretsInTheJsonBody_AreEncodedCorrectly() throws IOException {
|
|
|
@@ -173,7 +204,13 @@ public class CustomRequestTests extends ESTestCase {
|
|
|
new CustomSecretSettings(Map.of("api_key", new SecureString("my-secret-key".toCharArray())))
|
|
|
);
|
|
|
|
|
|
- var request = new CustomRequest(null, List.of("abc", "123"), model);
|
|
|
+ var request = new CustomRequest(
|
|
|
+ EmbeddingParameters.of(
|
|
|
+ new EmbeddingsInput(List.of("abc", "123"), null, InputType.SEARCH),
|
|
|
+ model.getServiceSettings().getInputTypeTranslator()
|
|
|
+ ),
|
|
|
+ model
|
|
|
+ );
|
|
|
var httpRequest = request.createHttpRequest();
|
|
|
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
|
|
|
|
|
@@ -220,7 +257,7 @@ public class CustomRequestTests extends ESTestCase {
|
|
|
new CustomSecretSettings(Map.of("api_key", new SecureString("my-secret-key".toCharArray())))
|
|
|
);
|
|
|
|
|
|
- var request = new CustomRequest("query string", List.of("abc", "123"), model);
|
|
|
+ var request = new CustomRequest(RerankParameters.of(new QueryAndDocsInputs("query string", List.of("abc", "123"))), model);
|
|
|
var httpRequest = request.createHttpRequest();
|
|
|
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
|
|
|
|
|
@@ -236,6 +273,56 @@ public class CustomRequestTests extends ESTestCase {
|
|
|
assertThat(convertToString(httpPost.getEntity().getContent()), is(expectedBody));
|
|
|
}
|
|
|
|
|
|
+ public void testCreateRequest_HandlesQuery_WithReturnDocsAndTopN() throws IOException {
|
|
|
+ var inferenceId = "inference_id";
|
|
|
+ var requestContentString = """
|
|
|
+ {
|
|
|
+ "input": ${input},
|
|
|
+ "query": ${query},
|
|
|
+ "return_documents": ${return_documents},
|
|
|
+ "top_n": ${top_n}
|
|
|
+ }
|
|
|
+ """;
|
|
|
+
|
|
|
+ var serviceSettings = new CustomServiceSettings(
|
|
|
+ CustomServiceSettings.TextEmbeddingSettings.NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS,
|
|
|
+ "http://www.elastic.co",
|
|
|
+ null,
|
|
|
+ null,
|
|
|
+ requestContentString,
|
|
|
+ new RerankResponseParser("$.result.score"),
|
|
|
+ new RateLimitSettings(10_000)
|
|
|
+ );
|
|
|
+
|
|
|
+ var model = CustomModelTests.createModel(
|
|
|
+ inferenceId,
|
|
|
+ TaskType.RERANK,
|
|
|
+ serviceSettings,
|
|
|
+ new CustomTaskSettings(Map.of()),
|
|
|
+ new CustomSecretSettings(Map.of("api_key", new SecureString("my-secret-key".toCharArray())))
|
|
|
+ );
|
|
|
+
|
|
|
+ var request = new CustomRequest(
|
|
|
+ RerankParameters.of(new QueryAndDocsInputs("query string", List.of("abc", "123"), false, 2, false)),
|
|
|
+ model
|
|
|
+ );
|
|
|
+ var httpRequest = request.createHttpRequest();
|
|
|
+ assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
|
|
+
|
|
|
+ var httpPost = (HttpPost) httpRequest.httpRequestBase();
|
|
|
+
|
|
|
+ var expectedBody = XContentHelper.stripWhitespace("""
|
|
|
+ {
|
|
|
+ "input": ["abc", "123"],
|
|
|
+ "query": "query string",
|
|
|
+ "return_documents": false,
|
|
|
+ "top_n": 2
|
|
|
+ }
|
|
|
+ """);
|
|
|
+
|
|
|
+ assertThat(convertToString(httpPost.getEntity().getContent()), is(expectedBody));
|
|
|
+ }
|
|
|
+
|
|
|
public void testCreateRequest_IgnoresNonStringFields_ForStringParams() throws IOException {
|
|
|
var inferenceId = "inference_id";
|
|
|
var requestContentString = """
|
|
|
@@ -262,7 +349,7 @@ public class CustomRequestTests extends ESTestCase {
|
|
|
new CustomSecretSettings(Map.of("api_key", new SecureString("my-secret-key".toCharArray())))
|
|
|
);
|
|
|
|
|
|
- var request = new CustomRequest(null, List.of("abc", "123"), model);
|
|
|
+ var request = new CustomRequest(RerankParameters.of(new QueryAndDocsInputs("query string", List.of("abc", "123"))), model);
|
|
|
var exception = expectThrows(IllegalStateException.class, request::createHttpRequest);
|
|
|
assertThat(
|
|
|
exception.getMessage(),
|
|
|
@@ -299,7 +386,10 @@ public class CustomRequestTests extends ESTestCase {
|
|
|
new CustomSecretSettings(Map.of("api_key", new SecureString("my-secret-key".toCharArray())))
|
|
|
);
|
|
|
|
|
|
- var exception = expectThrows(IllegalStateException.class, () -> new CustomRequest(null, List.of("abc", "123"), model));
|
|
|
+ var exception = expectThrows(
|
|
|
+ IllegalStateException.class,
|
|
|
+ () -> new CustomRequest(RerankParameters.of(new QueryAndDocsInputs("query string", List.of("abc", "123"))), model)
|
|
|
+ );
|
|
|
assertThat(exception.getMessage(), is("Failed to build URI, error: Illegal character in path at index 0: ^"));
|
|
|
}
|
|
|
|