浏览代码

[Inference API] Use extractOptionalPositiveInteger in MistralEmbeddingsServiceSettings for dims and maxInputTokens (#110485)

Tim Grein 1 年之前
父节点
当前提交
930ff47c2f

+ 1 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java

@@ -33,7 +33,6 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARIT
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity;
 import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity;
-import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType;
 import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.MODEL_FIELD;
 import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.MODEL_FIELD;
 
 
 public class MistralEmbeddingsServiceSettings extends FilteredXContentObject implements ServiceSettings {
 public class MistralEmbeddingsServiceSettings extends FilteredXContentObject implements ServiceSettings {
@@ -67,7 +66,7 @@ public class MistralEmbeddingsServiceSettings extends FilteredXContentObject imp
             MistralService.NAME,
             MistralService.NAME,
             context
             context
         );
         );
-        Integer dims = removeAsType(map, DIMENSIONS, Integer.class);
+        Integer dims = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException);
 
 
         if (validationException.validationErrors().isEmpty() == false) {
         if (validationException.validationErrors().isEmpty() == false) {
             throw validationException;
             throw validationException;

+ 80 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettingsTests.java

@@ -8,6 +8,7 @@
 package org.elasticsearch.xpack.inference.services.mistral.embeddings;
 package org.elasticsearch.xpack.inference.services.mistral.embeddings;
 
 
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.Strings;
+import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.io.stream.ByteArrayStreamInput;
 import org.elasticsearch.common.io.stream.ByteArrayStreamInput;
 import org.elasticsearch.common.io.stream.BytesStreamOutput;
 import org.elasticsearch.common.io.stream.BytesStreamOutput;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.Nullable;
@@ -27,6 +28,7 @@ import java.util.HashMap;
 import java.util.Map;
 import java.util.Map;
 
 
 import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY;
 import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY;
+import static org.hamcrest.Matchers.containsString;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.is;
 
 
 public class MistralEmbeddingsServiceSettingsTests extends ESTestCase {
 public class MistralEmbeddingsServiceSettingsTests extends ESTestCase {
@@ -77,6 +79,84 @@ public class MistralEmbeddingsServiceSettingsTests extends ESTestCase {
         assertThat(serviceSettings, is(new MistralEmbeddingsServiceSettings(model, null, null, null, null)));
         assertThat(serviceSettings, is(new MistralEmbeddingsServiceSettings(model, null, null, null, null)));
     }
     }
 
 
+    public void testFromMap_ThrowsException_WhenDimensionsAreZero() {
+        var model = "mistral-embed";
+        var dimensions = 0;
+
+        var settingsMap = createRequestSettingsMap(model, dimensions, null, SimilarityMeasure.COSINE);
+
+        var thrownException = expectThrows(
+            ValidationException.class,
+            () -> MistralEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST)
+        );
+
+        assertThat(
+            thrownException.getMessage(),
+            containsString("Validation Failed: 1: [service_settings] Invalid value [0]. [dimensions] must be a positive integer;")
+        );
+    }
+
+    public void testFromMap_ThrowsException_WhenDimensionsAreNegative() {
+        var model = "mistral-embed";
+        var dimensions = randomNegativeInt();
+
+        var settingsMap = createRequestSettingsMap(model, dimensions, null, SimilarityMeasure.COSINE);
+
+        var thrownException = expectThrows(
+            ValidationException.class,
+            () -> MistralEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST)
+        );
+
+        assertThat(
+            thrownException.getMessage(),
+            containsString(
+                Strings.format(
+                    "Validation Failed: 1: [service_settings] Invalid value [%d]. [dimensions] must be a positive integer;",
+                    dimensions
+                )
+            )
+        );
+    }
+
+    public void testFromMap_ThrowsException_WhenMaxInputTokensAreZero() {
+        var model = "mistral-embed";
+        var maxInputTokens = 0;
+
+        var settingsMap = createRequestSettingsMap(model, null, maxInputTokens, SimilarityMeasure.COSINE);
+
+        var thrownException = expectThrows(
+            ValidationException.class,
+            () -> MistralEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST)
+        );
+
+        assertThat(
+            thrownException.getMessage(),
+            containsString("Validation Failed: 1: [service_settings] Invalid value [0]. [max_input_tokens] must be a positive integer;")
+        );
+    }
+
+    public void testFromMap_ThrowsException_WhenMaxInputTokensAreNegative() {
+        var model = "mistral-embed";
+        var maxInputTokens = randomNegativeInt();
+
+        var settingsMap = createRequestSettingsMap(model, null, maxInputTokens, SimilarityMeasure.COSINE);
+
+        var thrownException = expectThrows(
+            ValidationException.class,
+            () -> MistralEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST)
+        );
+
+        assertThat(
+            thrownException.getMessage(),
+            containsString(
+                Strings.format(
+                    "Validation Failed: 1: [service_settings] Invalid value [%d]. [max_input_tokens] must be a positive integer;",
+                    maxInputTokens
+                )
+            )
+        );
+    }
+
     public void testFromMap_PersistentContext_DoesNotThrowException_WhenSimilarityIsPresent() {
     public void testFromMap_PersistentContext_DoesNotThrowException_WhenSimilarityIsPresent() {
         var model = "mistral-embed";
         var model = "mistral-embed";