|
@@ -22,6 +22,7 @@ import org.elasticsearch.inference.InputType;
|
|
|
import org.elasticsearch.inference.Model;
|
|
|
import org.elasticsearch.inference.ModelConfigurations;
|
|
|
import org.elasticsearch.inference.ModelSecrets;
|
|
|
+import org.elasticsearch.inference.SimilarityMeasure;
|
|
|
import org.elasticsearch.inference.TaskType;
|
|
|
import org.elasticsearch.test.ESTestCase;
|
|
|
import org.elasticsearch.test.http.MockResponse;
|
|
@@ -99,8 +100,8 @@ public class CohereServiceTests extends ESTestCase {
|
|
|
MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
|
|
|
|
|
|
var embeddingsModel = (CohereEmbeddingsModel) model;
|
|
|
- MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
|
|
|
- MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModelId(), is("model"));
|
|
|
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url"));
|
|
|
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model"));
|
|
|
MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(CohereEmbeddingType.FLOAT));
|
|
|
MatcherAssert.assertThat(
|
|
|
embeddingsModel.getTaskSettings(),
|
|
@@ -131,8 +132,8 @@ public class CohereServiceTests extends ESTestCase {
|
|
|
MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
|
|
|
|
|
|
var embeddingsModel = (CohereEmbeddingsModel) model;
|
|
|
- MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
|
|
|
- MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModelId(), is("model"));
|
|
|
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url"));
|
|
|
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model"));
|
|
|
MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(CohereEmbeddingType.FLOAT));
|
|
|
MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), equalTo(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS));
|
|
|
MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
|
|
@@ -257,7 +258,7 @@ public class CohereServiceTests extends ESTestCase {
|
|
|
MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
|
|
|
|
|
|
var embeddingsModel = (CohereEmbeddingsModel) model;
|
|
|
- assertNull(embeddingsModel.getServiceSettings().getCommonSettings().getUri());
|
|
|
+ assertNull(embeddingsModel.getServiceSettings().getCommonSettings().uri());
|
|
|
MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS));
|
|
|
MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
|
|
|
}, (e) -> fail("Model parsing should have succeeded " + e.getMessage()));
|
|
@@ -295,8 +296,8 @@ public class CohereServiceTests extends ESTestCase {
|
|
|
MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
|
|
|
|
|
|
var embeddingsModel = (CohereEmbeddingsModel) model;
|
|
|
- MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
|
|
|
- MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModelId(), is("model"));
|
|
|
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url"));
|
|
|
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model"));
|
|
|
MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(null, null)));
|
|
|
MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
|
|
|
}
|
|
@@ -345,7 +346,7 @@ public class CohereServiceTests extends ESTestCase {
|
|
|
MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
|
|
|
|
|
|
var embeddingsModel = (CohereEmbeddingsModel) model;
|
|
|
- assertNull(embeddingsModel.getServiceSettings().getCommonSettings().getUri());
|
|
|
+ assertNull(embeddingsModel.getServiceSettings().getCommonSettings().uri());
|
|
|
MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(InputType.INGEST, null)));
|
|
|
MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
|
|
|
}
|
|
@@ -370,8 +371,8 @@ public class CohereServiceTests extends ESTestCase {
|
|
|
MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
|
|
|
|
|
|
var embeddingsModel = (CohereEmbeddingsModel) model;
|
|
|
- MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
|
|
|
- MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModelId(), is("model"));
|
|
|
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url"));
|
|
|
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model"));
|
|
|
MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(CohereEmbeddingType.BYTE));
|
|
|
MatcherAssert.assertThat(
|
|
|
embeddingsModel.getTaskSettings(),
|
|
@@ -402,7 +403,7 @@ public class CohereServiceTests extends ESTestCase {
|
|
|
MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
|
|
|
|
|
|
var embeddingsModel = (CohereEmbeddingsModel) model;
|
|
|
- MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
|
|
|
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url"));
|
|
|
MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS));
|
|
|
MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
|
|
|
}
|
|
@@ -427,8 +428,8 @@ public class CohereServiceTests extends ESTestCase {
|
|
|
MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
|
|
|
|
|
|
var embeddingsModel = (CohereEmbeddingsModel) model;
|
|
|
- MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
|
|
|
- MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModelId(), is("model"));
|
|
|
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url"));
|
|
|
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model"));
|
|
|
MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(null, null)));
|
|
|
MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
|
|
|
}
|
|
@@ -451,7 +452,7 @@ public class CohereServiceTests extends ESTestCase {
|
|
|
MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
|
|
|
|
|
|
var embeddingsModel = (CohereEmbeddingsModel) model;
|
|
|
- MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
|
|
|
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url"));
|
|
|
MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS));
|
|
|
MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
|
|
|
}
|
|
@@ -478,8 +479,8 @@ public class CohereServiceTests extends ESTestCase {
|
|
|
MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
|
|
|
|
|
|
var embeddingsModel = (CohereEmbeddingsModel) model;
|
|
|
- MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
|
|
|
- MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModelId(), is("model"));
|
|
|
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url"));
|
|
|
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model"));
|
|
|
MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(InputType.SEARCH, null)));
|
|
|
MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret"));
|
|
|
}
|
|
@@ -497,8 +498,8 @@ public class CohereServiceTests extends ESTestCase {
|
|
|
MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
|
|
|
|
|
|
var embeddingsModel = (CohereEmbeddingsModel) model;
|
|
|
- MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
|
|
|
- MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModelId(), is("model"));
|
|
|
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url"));
|
|
|
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model"));
|
|
|
MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(null, CohereTruncation.NONE)));
|
|
|
assertNull(embeddingsModel.getSecretSettings());
|
|
|
}
|
|
@@ -535,8 +536,8 @@ public class CohereServiceTests extends ESTestCase {
|
|
|
MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
|
|
|
|
|
|
var embeddingsModel = (CohereEmbeddingsModel) model;
|
|
|
- assertNull(embeddingsModel.getServiceSettings().getCommonSettings().getUri());
|
|
|
- MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModelId(), is("model"));
|
|
|
+ assertNull(embeddingsModel.getServiceSettings().getCommonSettings().uri());
|
|
|
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model"));
|
|
|
MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getEmbeddingType(), is(CohereEmbeddingType.FLOAT));
|
|
|
MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(null, null)));
|
|
|
assertNull(embeddingsModel.getSecretSettings());
|
|
@@ -556,7 +557,7 @@ public class CohereServiceTests extends ESTestCase {
|
|
|
MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
|
|
|
|
|
|
var embeddingsModel = (CohereEmbeddingsModel) model;
|
|
|
- MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
|
|
|
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url"));
|
|
|
MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(CohereEmbeddingsTaskSettings.EMPTY_SETTINGS));
|
|
|
assertNull(embeddingsModel.getSecretSettings());
|
|
|
}
|
|
@@ -574,7 +575,7 @@ public class CohereServiceTests extends ESTestCase {
|
|
|
MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
|
|
|
|
|
|
var embeddingsModel = (CohereEmbeddingsModel) model;
|
|
|
- MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
|
|
|
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url"));
|
|
|
MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(InputType.SEARCH, null)));
|
|
|
assertNull(embeddingsModel.getSecretSettings());
|
|
|
}
|
|
@@ -595,8 +596,8 @@ public class CohereServiceTests extends ESTestCase {
|
|
|
MatcherAssert.assertThat(model, instanceOf(CohereEmbeddingsModel.class));
|
|
|
|
|
|
var embeddingsModel = (CohereEmbeddingsModel) model;
|
|
|
- MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getUri().toString(), is("url"));
|
|
|
- MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().getModelId(), is("model"));
|
|
|
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().uri().toString(), is("url"));
|
|
|
+ MatcherAssert.assertThat(embeddingsModel.getServiceSettings().getCommonSettings().modelId(), is("model"));
|
|
|
MatcherAssert.assertThat(embeddingsModel.getTaskSettings(), is(new CohereEmbeddingsTaskSettings(InputType.INGEST, null)));
|
|
|
assertNull(embeddingsModel.getSecretSettings());
|
|
|
}
|
|
@@ -755,6 +756,135 @@ public class CohereServiceTests extends ESTestCase {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ public void testCheckModelConfig_UpdatesSimilarityToDotProduct_WhenItIsNull() throws IOException {
|
|
|
+ var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
|
|
+
|
|
|
+ try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) {
|
|
|
+
|
|
|
+ String responseJson = """
|
|
|
+ {
|
|
|
+ "id": "de37399c-5df6-47cb-bc57-e3c5680c977b",
|
|
|
+ "texts": [
|
|
|
+ "hello"
|
|
|
+ ],
|
|
|
+ "embeddings": {
|
|
|
+ "float": [
|
|
|
+ [
|
|
|
+ 0.123,
|
|
|
+ -0.123
|
|
|
+ ]
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ "meta": {
|
|
|
+ "api_version": {
|
|
|
+ "version": "1"
|
|
|
+ },
|
|
|
+ "billed_units": {
|
|
|
+ "input_tokens": 1
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "response_type": "embeddings_by_type"
|
|
|
+ }
|
|
|
+ """;
|
|
|
+ webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
|
|
+
|
|
|
+ var model = CohereEmbeddingsModelTests.createModel(
|
|
|
+ getUrl(webServer),
|
|
|
+ "secret",
|
|
|
+ CohereEmbeddingsTaskSettings.EMPTY_SETTINGS,
|
|
|
+ 10,
|
|
|
+ 1,
|
|
|
+ null,
|
|
|
+ null
|
|
|
+ );
|
|
|
+ PlainActionFuture<Model> listener = new PlainActionFuture<>();
|
|
|
+ service.checkModelConfig(model, listener);
|
|
|
+ var result = listener.actionGet(TIMEOUT);
|
|
|
+
|
|
|
+ MatcherAssert.assertThat(
|
|
|
+ result,
|
|
|
+ // the dimension is set to 2 because there are 2 embeddings returned from the mock server
|
|
|
+ is(
|
|
|
+ CohereEmbeddingsModelTests.createModel(
|
|
|
+ getUrl(webServer),
|
|
|
+ "secret",
|
|
|
+ CohereEmbeddingsTaskSettings.EMPTY_SETTINGS,
|
|
|
+ 10,
|
|
|
+ 2,
|
|
|
+ null,
|
|
|
+ null,
|
|
|
+ SimilarityMeasure.DOT_PRODUCT
|
|
|
+ )
|
|
|
+ )
|
|
|
+ );
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ public void testCheckModelConfig_DoesNotUpdateSimilarity_WhenItIsSpecifiedAsCosine() throws IOException {
|
|
|
+ var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
|
|
+
|
|
|
+ try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) {
|
|
|
+
|
|
|
+ String responseJson = """
|
|
|
+ {
|
|
|
+ "id": "de37399c-5df6-47cb-bc57-e3c5680c977b",
|
|
|
+ "texts": [
|
|
|
+ "hello"
|
|
|
+ ],
|
|
|
+ "embeddings": {
|
|
|
+ "float": [
|
|
|
+ [
|
|
|
+ 0.123,
|
|
|
+ -0.123
|
|
|
+ ]
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ "meta": {
|
|
|
+ "api_version": {
|
|
|
+ "version": "1"
|
|
|
+ },
|
|
|
+ "billed_units": {
|
|
|
+ "input_tokens": 1
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "response_type": "embeddings_by_type"
|
|
|
+ }
|
|
|
+ """;
|
|
|
+ webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
|
|
+
|
|
|
+ var model = CohereEmbeddingsModelTests.createModel(
|
|
|
+ getUrl(webServer),
|
|
|
+ "secret",
|
|
|
+ CohereEmbeddingsTaskSettings.EMPTY_SETTINGS,
|
|
|
+ 10,
|
|
|
+ 1,
|
|
|
+ null,
|
|
|
+ null,
|
|
|
+ SimilarityMeasure.COSINE
|
|
|
+ );
|
|
|
+ PlainActionFuture<Model> listener = new PlainActionFuture<>();
|
|
|
+ service.checkModelConfig(model, listener);
|
|
|
+ var result = listener.actionGet(TIMEOUT);
|
|
|
+
|
|
|
+ MatcherAssert.assertThat(
|
|
|
+ result,
|
|
|
+ // the dimension is set to 2 because there are 2 embeddings returned from the mock server
|
|
|
+ is(
|
|
|
+ CohereEmbeddingsModelTests.createModel(
|
|
|
+ getUrl(webServer),
|
|
|
+ "secret",
|
|
|
+ CohereEmbeddingsTaskSettings.EMPTY_SETTINGS,
|
|
|
+ 10,
|
|
|
+ 2,
|
|
|
+ null,
|
|
|
+ null,
|
|
|
+ SimilarityMeasure.COSINE
|
|
|
+ )
|
|
|
+ )
|
|
|
+ );
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
public void testInfer_UnauthorisedResponse() throws IOException {
|
|
|
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
|
|
|