|
@@ -53,6 +53,7 @@ import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.Azure
|
|
|
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModelTests;
|
|
|
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsServiceSettingsTests;
|
|
|
import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsTaskSettingsTests;
|
|
|
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
|
|
|
import org.hamcrest.CoreMatchers;
|
|
|
import org.hamcrest.MatcherAssert;
|
|
|
import org.hamcrest.Matchers;
|
|
@@ -973,6 +974,112 @@ public class AzureAiStudioServiceTests extends ESTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException {
|
|
|
+ var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
|
|
+ try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
|
|
|
+ var model = AzureAiStudioChatCompletionModelTests.createModel(
|
|
|
+ randomAlphaOfLength(10),
|
|
|
+ randomAlphaOfLength(10),
|
|
|
+ randomFrom(AzureAiStudioProvider.values()),
|
|
|
+ randomFrom(AzureAiStudioEndpointType.values()),
|
|
|
+ randomAlphaOfLength(10)
|
|
|
+ );
|
|
|
+ assertThrows(
|
|
|
+ ElasticsearchStatusException.class,
|
|
|
+ () -> { service.updateModelWithEmbeddingDetails(model, randomNonNegativeInt()); }
|
|
|
+ );
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException {
|
|
|
+ testUpdateModelWithEmbeddingDetails_Successful(null);
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel() throws IOException {
|
|
|
+ testUpdateModelWithEmbeddingDetails_Successful(randomFrom(SimilarityMeasure.values()));
|
|
|
+ }
|
|
|
+
|
|
|
+ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException {
|
|
|
+ var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
|
|
+ try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
|
|
|
+ var embeddingSize = randomNonNegativeInt();
|
|
|
+ var model = AzureAiStudioEmbeddingsModelTests.createModel(
|
|
|
+ randomAlphaOfLength(10),
|
|
|
+ randomAlphaOfLength(10),
|
|
|
+ randomFrom(AzureAiStudioProvider.values()),
|
|
|
+ randomFrom(AzureAiStudioEndpointType.values()),
|
|
|
+ randomAlphaOfLength(10),
|
|
|
+ randomNonNegativeInt(),
|
|
|
+ randomBoolean(),
|
|
|
+ randomNonNegativeInt(),
|
|
|
+ similarityMeasure,
|
|
|
+ randomAlphaOfLength(10),
|
|
|
+ RateLimitSettingsTests.createRandom()
|
|
|
+ );
|
|
|
+
|
|
|
+ Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
|
|
|
+
|
|
|
+ SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null ? SimilarityMeasure.DOT_PRODUCT : similarityMeasure;
|
|
|
+ assertEquals(expectedSimilarityMeasure, updatedModel.getServiceSettings().similarity());
|
|
|
+ assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue());
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testUpdateModelWithChatCompletionDetails_InvalidModelProvided() throws IOException {
|
|
|
+ var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
|
|
+ try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
|
|
|
+ var model = AzureAiStudioEmbeddingsModelTests.createModel(
|
|
|
+ randomAlphaOfLength(10),
|
|
|
+ randomAlphaOfLength(10),
|
|
|
+ randomFrom(AzureAiStudioProvider.values()),
|
|
|
+ randomFrom(AzureAiStudioEndpointType.values()),
|
|
|
+ randomAlphaOfLength(10),
|
|
|
+ randomNonNegativeInt(),
|
|
|
+ randomBoolean(),
|
|
|
+ randomNonNegativeInt(),
|
|
|
+ randomFrom(SimilarityMeasure.values()),
|
|
|
+ randomAlphaOfLength(10),
|
|
|
+ RateLimitSettingsTests.createRandom()
|
|
|
+ );
|
|
|
+ assertThrows(ElasticsearchStatusException.class, () -> { service.updateModelWithChatCompletionDetails(model); });
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testUpdateModelWithChatCompletionDetails_NullSimilarityInOriginalModel() throws IOException {
|
|
|
+ testUpdateModelWithChatCompletionDetails_Successful(null);
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testUpdateModelWithChatCompletionDetails_NonNullSimilarityInOriginalModel() throws IOException {
|
|
|
+ testUpdateModelWithChatCompletionDetails_Successful(randomNonNegativeInt());
|
|
|
+ }
|
|
|
+
|
|
|
+ private void testUpdateModelWithChatCompletionDetails_Successful(Integer maxNewTokens) throws IOException {
|
|
|
+ var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
|
|
+ try (var service = new AzureAiStudioService(senderFactory, createWithEmptySettings(threadPool))) {
|
|
|
+ var model = AzureAiStudioChatCompletionModelTests.createModel(
|
|
|
+ randomAlphaOfLength(10),
|
|
|
+ randomAlphaOfLength(10),
|
|
|
+ randomFrom(AzureAiStudioProvider.values()),
|
|
|
+ randomFrom(AzureAiStudioEndpointType.values()),
|
|
|
+ randomAlphaOfLength(10),
|
|
|
+ randomDouble(),
|
|
|
+ randomDouble(),
|
|
|
+ randomBoolean(),
|
|
|
+ maxNewTokens,
|
|
|
+ RateLimitSettingsTests.createRandom()
|
|
|
+ );
|
|
|
+
|
|
|
+ Model updatedModel = service.updateModelWithChatCompletionDetails(model);
|
|
|
+ assertThat(updatedModel, instanceOf(AzureAiStudioChatCompletionModel.class));
|
|
|
+ AzureAiStudioChatCompletionTaskSettings updatedTaskSettings = (AzureAiStudioChatCompletionTaskSettings) updatedModel
|
|
|
+ .getTaskSettings();
|
|
|
+ Integer expectedMaxNewTokens = maxNewTokens == null
|
|
|
+ ? AzureAiStudioChatCompletionTaskSettings.DEFAULT_MAX_NEW_TOKENS
|
|
|
+ : maxNewTokens;
|
|
|
+ assertEquals(expectedMaxNewTokens, updatedTaskSettings.maxNewTokens());
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
public void testInfer_ThrowsErrorWhenModelIsNotAzureAiStudioModel() throws IOException {
|
|
|
var sender = mock(Sender.class);
|
|
|
|