|
@@ -38,6 +38,7 @@ import org.elasticsearch.xcontent.XContentType;
|
|
|
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
|
|
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
|
|
|
import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding;
|
|
|
+import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
|
|
|
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
|
|
|
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
|
|
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
|
@@ -54,6 +55,10 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.Azure
|
|
|
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModelTests;
|
|
|
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettingsTests;
|
|
|
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsTaskSettingsTests;
|
|
|
+import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModel;
|
|
|
+import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankModelTests;
|
|
|
+import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankServiceSettingsTests;
|
|
|
+import org.elasticsearch.xpack.inference.services.azureaistudio.rerank.AzureAiStudioRerankTaskSettingsTests;
|
|
|
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
|
|
|
import org.hamcrest.CoreMatchers;
|
|
|
import org.hamcrest.Matchers;
|
|
@@ -219,6 +224,33 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ public void testParseRequestConfig_CreatesAnAzureAiStudioRerankModel() throws IOException {
|
|
|
+ try (var service = createService()) {
|
|
|
+ ActionListener<Model> modelVerificationListener = ActionListener.wrap(model -> {
|
|
|
+ assertThat(model, instanceOf(AzureAiStudioRerankModel.class));
|
|
|
+
|
|
|
+ var rerankModel = (AzureAiStudioRerankModel) model;
|
|
|
+ assertThat(rerankModel.getServiceSettings().target(), is("http://target.local"));
|
|
|
+ assertThat(rerankModel.getServiceSettings().provider(), is(AzureAiStudioProvider.COHERE));
|
|
|
+ assertThat(rerankModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN));
|
|
|
+ assertThat(rerankModel.getSecretSettings().apiKey().toString(), is("secret"));
|
|
|
+ assertNull(rerankModel.getTaskSettings().returnDocuments());
|
|
|
+ assertNull(rerankModel.getTaskSettings().topN());
|
|
|
+ }, exception -> fail("Unexpected exception: " + exception));
|
|
|
+
|
|
|
+ service.parseRequestConfig(
|
|
|
+ "id",
|
|
|
+ TaskType.RERANK,
|
|
|
+ getRequestConfigMap(
|
|
|
+ getRerankServiceSettingsMap("http://target.local", "cohere", "token"),
|
|
|
+ getRerankTaskSettingsMap(null, null),
|
|
|
+ getSecretSettingsMap("secret")
|
|
|
+ ),
|
|
|
+ modelVerificationListener
|
|
|
+ );
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException {
|
|
|
try (var service = createService()) {
|
|
|
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
|
|
@@ -441,6 +473,80 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInRerankServiceSettingsMap() throws IOException {
|
|
|
+ try (var service = createService()) {
|
|
|
+ var serviceSettings = getRerankServiceSettingsMap("http://target.local", "cohere", "token");
|
|
|
+ serviceSettings.put("extra_key", "value");
|
|
|
+
|
|
|
+ var config = getRequestConfigMap(serviceSettings, getRerankTaskSettingsMap(null, null), getSecretSettingsMap("secret"));
|
|
|
+
|
|
|
+ ActionListener<Model> modelVerificationListener = ActionListener.wrap(
|
|
|
+ model -> fail("Expected exception, but got model: " + model),
|
|
|
+ exception -> {
|
|
|
+ assertThat(exception, instanceOf(ElasticsearchStatusException.class));
|
|
|
+ assertThat(
|
|
|
+ exception.getMessage(),
|
|
|
+ is("Configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service")
|
|
|
+ );
|
|
|
+ }
|
|
|
+ );
|
|
|
+
|
|
|
+ service.parseRequestConfig("id", TaskType.RERANK, config, modelVerificationListener);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInRerankTaskSettingsMap() throws IOException {
|
|
|
+ try (var service = createService()) {
|
|
|
+ var taskSettings = getRerankTaskSettingsMap(null, null);
|
|
|
+ taskSettings.put("extra_key", "value");
|
|
|
+
|
|
|
+ var config = getRequestConfigMap(
|
|
|
+ getRerankServiceSettingsMap("http://target.local", "cohere", "token"),
|
|
|
+ taskSettings,
|
|
|
+ getSecretSettingsMap("secret")
|
|
|
+ );
|
|
|
+
|
|
|
+ ActionListener<Model> modelVerificationListener = ActionListener.wrap(
|
|
|
+ model -> fail("Expected exception, but got model: " + model),
|
|
|
+ exception -> {
|
|
|
+ assertThat(exception, instanceOf(ElasticsearchStatusException.class));
|
|
|
+ assertThat(
|
|
|
+ exception.getMessage(),
|
|
|
+ is("Configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service")
|
|
|
+ );
|
|
|
+ }
|
|
|
+ );
|
|
|
+
|
|
|
+ service.parseRequestConfig("id", TaskType.RERANK, config, modelVerificationListener);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInRerankSecretSettingsMap() throws IOException {
|
|
|
+ try (var service = createService()) {
|
|
|
+ var secretSettings = getSecretSettingsMap("secret");
|
|
|
+ secretSettings.put("extra_key", "value");
|
|
|
+
|
|
|
+ var config = getRequestConfigMap(
|
|
|
+ getRerankServiceSettingsMap("http://target.local", "cohere", "token"),
|
|
|
+ getRerankTaskSettingsMap(null, null),
|
|
|
+ secretSettings
|
|
|
+ );
|
|
|
+
|
|
|
+ ActionListener<Model> modelVerificationListener = ActionListener.wrap(
|
|
|
+ model -> fail("Expected exception, but got model: " + model),
|
|
|
+ exception -> {
|
|
|
+ assertThat(exception, instanceOf(ElasticsearchStatusException.class));
|
|
|
+ assertThat(
|
|
|
+ exception.getMessage(),
|
|
|
+ is("Configuration contains settings [{extra_key=value}] unknown to the [azureaistudio] service")
|
|
|
+ );
|
|
|
+ }
|
|
|
+ );
|
|
|
+
|
|
|
+ service.parseRequestConfig("id", TaskType.RERANK, config, modelVerificationListener);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
public void testParseRequestConfig_ThrowsWhenProviderIsNotValidForEmbeddings() throws IOException {
|
|
|
try (var service = createService()) {
|
|
|
var serviceSettings = getEmbeddingsServiceSettingsMap("http://target.local", "databricks", "token", null, null, null, null);
|
|
@@ -505,6 +611,45 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ public void testParseRequestConfig_ThrowsWhenProviderIsNotValidForRerank() throws IOException {
|
|
|
+ try (var service = createService()) {
|
|
|
+ var serviceSettings = getRerankServiceSettingsMap("http://target.local", "databricks", "token");
|
|
|
+
|
|
|
+ var config = getRequestConfigMap(serviceSettings, getRerankTaskSettingsMap(null, null), getSecretSettingsMap("secret"));
|
|
|
+
|
|
|
+ ActionListener<Model> modelVerificationListener = ActionListener.wrap(
|
|
|
+ model -> fail("Expected exception, but got model: " + model),
|
|
|
+ exception -> {
|
|
|
+ assertThat(exception, instanceOf(ElasticsearchStatusException.class));
|
|
|
+ assertThat(exception.getMessage(), is("The [rerank] task type for provider [databricks] is not available"));
|
|
|
+ }
|
|
|
+ );
|
|
|
+
|
|
|
+ service.parseRequestConfig("id", TaskType.RERANK, config, modelVerificationListener);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testParseRequestConfig_ThrowsWhenEndpointTypeIsNotValidForRerankProvider() throws IOException {
|
|
|
+ try (var service = createService()) {
|
|
|
+ var serviceSettings = getRerankServiceSettingsMap("http://target.local", "cohere", "realtime");
|
|
|
+
|
|
|
+ var config = getRequestConfigMap(serviceSettings, getRerankTaskSettingsMap(null, null), getSecretSettingsMap("secret"));
|
|
|
+
|
|
|
+ ActionListener<Model> modelVerificationListener = ActionListener.wrap(
|
|
|
+ model -> fail("Expected exception, but got model: " + model),
|
|
|
+ exception -> {
|
|
|
+ assertThat(exception, instanceOf(ElasticsearchStatusException.class));
|
|
|
+ assertThat(
|
|
|
+ exception.getMessage(),
|
|
|
+ is("The [realtime] endpoint type with [rerank] task type for provider [cohere] is not available")
|
|
|
+ );
|
|
|
+ }
|
|
|
+ );
|
|
|
+
|
|
|
+ service.parseRequestConfig("id", TaskType.RERANK, config, modelVerificationListener);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
public void testParsePersistedConfig_CreatesAnAzureAiStudioEmbeddingsModel() throws IOException {
|
|
|
try (var service = createService()) {
|
|
|
var config = getPersistedConfigMap(
|
|
@@ -603,6 +748,27 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ public void testParsePersistedConfig_CreatesAnAzureAiStudioRerankModel() throws IOException {
|
|
|
+ try (var service = createService()) {
|
|
|
+ var config = getPersistedConfigMap(
|
|
|
+ getRerankServiceSettingsMap("http://target.local", "cohere", "token"),
|
|
|
+ getRerankTaskSettingsMap(true, 2),
|
|
|
+ getSecretSettingsMap("secret")
|
|
|
+ );
|
|
|
+
|
|
|
+ var model = service.parsePersistedConfigWithSecrets("id", TaskType.RERANK, config.config(), config.secrets());
|
|
|
+
|
|
|
+ assertThat(model, instanceOf(AzureAiStudioRerankModel.class));
|
|
|
+
|
|
|
+ var chatCompletionModel = (AzureAiStudioRerankModel) model;
|
|
|
+ assertThat(chatCompletionModel.getServiceSettings().target(), is("http://target.local"));
|
|
|
+ assertThat(chatCompletionModel.getServiceSettings().provider(), is(AzureAiStudioProvider.COHERE));
|
|
|
+ assertThat(chatCompletionModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN));
|
|
|
+ assertThat(chatCompletionModel.getTaskSettings().returnDocuments(), is(true));
|
|
|
+ assertThat(chatCompletionModel.getTaskSettings().topN(), is(2));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
public void testParsePersistedConfig_ThrowsUnsupportedModelType() throws IOException {
|
|
|
try (var service = createService()) {
|
|
|
ActionListener<Model> modelVerificationListener = ActionListener.wrap(
|
|
@@ -747,6 +913,48 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInRerankServiceSettingsMap() throws IOException {
|
|
|
+ try (var service = createService()) {
|
|
|
+ var serviceSettings = getRerankServiceSettingsMap("http://target.local", "cohere", "token");
|
|
|
+ serviceSettings.put("extra_key", "value");
|
|
|
+ var taskSettings = getRerankTaskSettingsMap(true, 2);
|
|
|
+ var secretSettings = getSecretSettingsMap("secret");
|
|
|
+ var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings);
|
|
|
+
|
|
|
+ var model = service.parsePersistedConfigWithSecrets("id", TaskType.RERANK, config.config(), config.secrets());
|
|
|
+
|
|
|
+ assertThat(model, instanceOf(AzureAiStudioRerankModel.class));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInRerankTaskSettingsMap() throws IOException {
|
|
|
+ try (var service = createService()) {
|
|
|
+ var serviceSettings = getRerankServiceSettingsMap("http://target.local", "cohere", "token");
|
|
|
+ var taskSettings = getRerankTaskSettingsMap(true, 2);
|
|
|
+ taskSettings.put("extra_key", "value");
|
|
|
+ var secretSettings = getSecretSettingsMap("secret");
|
|
|
+ var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings);
|
|
|
+
|
|
|
+ var model = service.parsePersistedConfigWithSecrets("id", TaskType.RERANK, config.config(), config.secrets());
|
|
|
+
|
|
|
+ assertThat(model, instanceOf(AzureAiStudioRerankModel.class));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInRerankSecretSettingsMap() throws IOException {
|
|
|
+ try (var service = createService()) {
|
|
|
+ var serviceSettings = getRerankServiceSettingsMap("http://target.local", "cohere", "token");
|
|
|
+ var taskSettings = getRerankTaskSettingsMap(true, 2);
|
|
|
+ var secretSettings = getSecretSettingsMap("secret");
|
|
|
+ secretSettings.put("extra_key", "value");
|
|
|
+ var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings);
|
|
|
+
|
|
|
+ var model = service.parsePersistedConfigWithSecrets("id", TaskType.RERANK, config.config(), config.secrets());
|
|
|
+
|
|
|
+ assertThat(model, instanceOf(AzureAiStudioRerankModel.class));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
public void testParsePersistedConfig_WithoutSecretsCreatesEmbeddingsModel() throws IOException {
|
|
|
try (var service = createService()) {
|
|
|
var config = getPersistedConfigMap(
|
|
@@ -842,6 +1050,27 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ public void testParsePersistedConfig_WithoutSecretsCreatesRerankModel() throws IOException {
|
|
|
+ try (var service = createService()) {
|
|
|
+ var config = getPersistedConfigMap(
|
|
|
+ getRerankServiceSettingsMap("http://target.local", "cohere", "token"),
|
|
|
+ getRerankTaskSettingsMap(true, 2),
|
|
|
+ Map.of()
|
|
|
+ );
|
|
|
+
|
|
|
+ var model = service.parsePersistedConfig("id", TaskType.RERANK, config.config());
|
|
|
+
|
|
|
+ assertThat(model, instanceOf(AzureAiStudioRerankModel.class));
|
|
|
+
|
|
|
+ var rerankModel = (AzureAiStudioRerankModel) model;
|
|
|
+ assertThat(rerankModel.getServiceSettings().target(), is("http://target.local"));
|
|
|
+ assertThat(rerankModel.getServiceSettings().provider(), is(AzureAiStudioProvider.COHERE));
|
|
|
+ assertThat(rerankModel.getServiceSettings().endpointType(), is(AzureAiStudioEndpointType.TOKEN));
|
|
|
+ assertThat(rerankModel.getTaskSettings().returnDocuments(), is(true));
|
|
|
+ assertThat(rerankModel.getTaskSettings().topN(), is(2));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException {
|
|
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
|
|
try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
|
|
@@ -1184,6 +1413,47 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ public void testInfer_WithRerankModel() throws IOException {
|
|
|
+ var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
|
|
+
|
|
|
+ try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
|
|
|
+ webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testRerankTokenResponseJson));
|
|
|
+
|
|
|
+ var model = AzureAiStudioRerankModelTests.createModel(
|
|
|
+ "id",
|
|
|
+ getUrl(webServer),
|
|
|
+ AzureAiStudioProvider.COHERE,
|
|
|
+ AzureAiStudioEndpointType.TOKEN,
|
|
|
+ "apikey"
|
|
|
+ );
|
|
|
+
|
|
|
+ PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
|
|
+ service.infer(
|
|
|
+ model,
|
|
|
+ "query",
|
|
|
+ false,
|
|
|
+ 2,
|
|
|
+ List.of("abc"),
|
|
|
+ false,
|
|
|
+ new HashMap<>(),
|
|
|
+ InputType.INGEST,
|
|
|
+ InferenceAction.Request.DEFAULT_TIMEOUT,
|
|
|
+ listener
|
|
|
+ );
|
|
|
+
|
|
|
+ var result = listener.actionGet(TIMEOUT);
|
|
|
+ assertThat(result, CoreMatchers.instanceOf(RankedDocsResults.class));
|
|
|
+
|
|
|
+ var rankedDocsResults = (RankedDocsResults) result;
|
|
|
+ var rankedDocs = rankedDocsResults.getRankedDocs();
|
|
|
+ assertThat(rankedDocs.size(), is(2));
|
|
|
+ assertThat(rankedDocs.get(0).relevanceScore(), is(0.1111111F));
|
|
|
+ assertThat(rankedDocs.get(0).index(), is(0));
|
|
|
+ assertThat(rankedDocs.get(1).relevanceScore(), is(0.2222222F));
|
|
|
+ assertThat(rankedDocs.get(1).index(), is(1));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
public void testInfer_UnauthorisedResponse() throws IOException {
|
|
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
|
|
|
|
@@ -1320,7 +1590,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|
|
{
|
|
|
"service": "azureaistudio",
|
|
|
"name": "Azure AI Studio",
|
|
|
- "task_types": ["text_embedding", "completion"],
|
|
|
+ "task_types": ["text_embedding", "rerank", "completion"],
|
|
|
"configurations": {
|
|
|
"dimensions": {
|
|
|
"description": "The number of dimensions the resulting embeddings should have. For more information refer to https://learn.microsoft.com/en-us/azure/ai-studio/reference/reference-model-inference-embeddings.",
|
|
@@ -1338,7 +1608,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|
|
"sensitive": false,
|
|
|
"updatable": false,
|
|
|
"type": "str",
|
|
|
- "supported_task_types": ["text_embedding", "completion"]
|
|
|
+ "supported_task_types": ["text_embedding", "rerank", "completion"]
|
|
|
},
|
|
|
"provider": {
|
|
|
"description": "The model provider for your deployment.",
|
|
@@ -1347,7 +1617,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|
|
"sensitive": false,
|
|
|
"updatable": false,
|
|
|
"type": "str",
|
|
|
- "supported_task_types": ["text_embedding", "completion"]
|
|
|
+ "supported_task_types": ["text_embedding", "rerank", "completion"]
|
|
|
},
|
|
|
"api_key": {
|
|
|
"description": "API Key for the provider you're connecting to.",
|
|
@@ -1356,7 +1626,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|
|
"sensitive": true,
|
|
|
"updatable": true,
|
|
|
"type": "str",
|
|
|
- "supported_task_types": ["text_embedding", "completion"]
|
|
|
+ "supported_task_types": ["text_embedding", "rerank", "completion"]
|
|
|
},
|
|
|
"rate_limit.requests_per_minute": {
|
|
|
"description": "Minimize the number of rate limit errors.",
|
|
@@ -1365,7 +1635,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|
|
"sensitive": false,
|
|
|
"updatable": false,
|
|
|
"type": "int",
|
|
|
- "supported_task_types": ["text_embedding", "completion"]
|
|
|
+ "supported_task_types": ["text_embedding", "rerank", "completion"]
|
|
|
},
|
|
|
"target": {
|
|
|
"description": "The target URL of your Azure AI Studio model deployment.",
|
|
@@ -1374,7 +1644,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|
|
"sensitive": false,
|
|
|
"updatable": false,
|
|
|
"type": "str",
|
|
|
- "supported_task_types": ["text_embedding", "completion"]
|
|
|
+ "supported_task_types": ["text_embedding", "rerank", "completion"]
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -1462,6 +1732,10 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|
|
return AzureAiStudioChatCompletionServiceSettingsTests.createRequestSettingsMap(target, provider, endpointType);
|
|
|
}
|
|
|
|
|
|
+ private static HashMap<String, Object> getRerankServiceSettingsMap(String target, String provider, String endpointType) {
|
|
|
+ return AzureAiStudioRerankServiceSettingsTests.createRequestSettingsMap(target, provider, endpointType);
|
|
|
+ }
|
|
|
+
|
|
|
public static Map<String, Object> getChatCompletionTaskSettingsMap(
|
|
|
@Nullable Double temperature,
|
|
|
@Nullable Double topP,
|
|
@@ -1471,6 +1745,10 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|
|
return AzureAiStudioChatCompletionTaskSettingsTests.getTaskSettingsMap(temperature, topP, doSample, maxNewTokens);
|
|
|
}
|
|
|
|
|
|
+ public static Map<String, Object> getRerankTaskSettingsMap(@Nullable Boolean returnDocuments, @Nullable Integer topN) {
|
|
|
+ return AzureAiStudioRerankTaskSettingsTests.getTaskSettingsMap(returnDocuments, topN);
|
|
|
+ }
|
|
|
+
|
|
|
private static Map<String, Object> getSecretSettingsMap(String apiKey) {
|
|
|
return new HashMap<>(Map.of(API_KEY_FIELD, apiKey));
|
|
|
}
|
|
@@ -1520,4 +1798,28 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|
|
}
|
|
|
}
|
|
|
""";
|
|
|
+
|
|
|
+ private static final String testRerankTokenResponseJson = """
|
|
|
+ {
|
|
|
+ "id": "ff2feb42-5d3a-45d7-ba29-c3dabf59988b",
|
|
|
+ "results": [
|
|
|
+ {
|
|
|
+ "index": 0,
|
|
|
+ "relevance_score": 0.1111111
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "index": 1,
|
|
|
+ "relevance_score": 0.2222222
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "meta": {
|
|
|
+ "api_version": {
|
|
|
+ "version": "1"
|
|
|
+ },
|
|
|
+ "billed_units": {
|
|
|
+ "search_units": 1
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ """;
|
|
|
}
|