|
@@ -30,6 +30,7 @@ import org.elasticsearch.threadpool.ThreadPool;
|
|
import org.elasticsearch.xcontent.XContentType;
|
|
import org.elasticsearch.xcontent.XContentType;
|
|
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
|
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
|
import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults;
|
|
import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults;
|
|
|
|
+import org.elasticsearch.xpack.inference.ModelConfigurationsTests;
|
|
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
|
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
|
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
|
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
|
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
|
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
|
@@ -38,6 +39,7 @@ import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
|
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingModelTests;
|
|
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingModelTests;
|
|
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel;
|
|
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel;
|
|
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings;
|
|
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings;
|
|
|
|
+import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
|
|
import org.hamcrest.CoreMatchers;
|
|
import org.hamcrest.CoreMatchers;
|
|
import org.hamcrest.MatcherAssert;
|
|
import org.hamcrest.MatcherAssert;
|
|
import org.hamcrest.Matchers;
|
|
import org.hamcrest.Matchers;
|
|
@@ -388,6 +390,48 @@ public class MistralServiceTests extends ESTestCase {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ public void testUpdateModelWithEmbeddingDetails_InvalidModelProvided() throws IOException {
|
|
|
|
+ var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
|
|
|
+ try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) {
|
|
|
|
+ var model = new Model(ModelConfigurationsTests.createRandomInstance());
|
|
|
|
+
|
|
|
|
+ assertThrows(
|
|
|
|
+ ElasticsearchStatusException.class,
|
|
|
|
+ () -> { service.updateModelWithEmbeddingDetails(model, randomNonNegativeInt()); }
|
|
|
|
+ );
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ public void testUpdateModelWithEmbeddingDetails_NullSimilarityInOriginalModel() throws IOException {
|
|
|
|
+ testUpdateModelWithEmbeddingDetails_Successful(null);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ public void testUpdateModelWithEmbeddingDetails_NonNullSimilarityInOriginalModel() throws IOException {
|
|
|
|
+ testUpdateModelWithEmbeddingDetails_Successful(randomFrom(SimilarityMeasure.values()));
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure similarityMeasure) throws IOException {
|
|
|
|
+ var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
|
|
|
+ try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) {
|
|
|
|
+ var embeddingSize = randomNonNegativeInt();
|
|
|
|
+ var model = MistralEmbeddingModelTests.createModel(
|
|
|
|
+ randomAlphaOfLength(10),
|
|
|
|
+ randomAlphaOfLength(10),
|
|
|
|
+ randomAlphaOfLength(10),
|
|
|
|
+ randomNonNegativeInt(),
|
|
|
|
+ randomNonNegativeInt(),
|
|
|
|
+ similarityMeasure,
|
|
|
|
+ RateLimitSettingsTests.createRandom()
|
|
|
|
+ );
|
|
|
|
+
|
|
|
|
+ Model updatedModel = service.updateModelWithEmbeddingDetails(model, embeddingSize);
|
|
|
|
+
|
|
|
|
+ SimilarityMeasure expectedSimilarityMeasure = similarityMeasure == null ? SimilarityMeasure.DOT_PRODUCT : similarityMeasure;
|
|
|
|
+ assertEquals(expectedSimilarityMeasure, updatedModel.getServiceSettings().similarity());
|
|
|
|
+ assertEquals(embeddingSize, updatedModel.getServiceSettings().dimensions().intValue());
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
public void testInfer_ThrowsErrorWhenModelIsNotMistralEmbeddingsModel() throws IOException {
|
|
public void testInfer_ThrowsErrorWhenModelIsNotMistralEmbeddingsModel() throws IOException {
|
|
var sender = mock(Sender.class);
|
|
var sender = mock(Sender.class);
|
|
|
|
|