|
@@ -18,8 +18,8 @@ import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStor
|
|
|
import java.util.List;
|
|
|
import java.util.Map;
|
|
|
|
|
|
-import static org.hamcrest.Matchers.equalTo;
|
|
|
-import static org.hamcrest.Matchers.is;
|
|
|
+import static org.elasticsearch.xpack.inference.services.InferenceSettingsTestCase.toMap;
|
|
|
+import static org.hamcrest.Matchers.containsString;
|
|
|
import static org.mockito.Mockito.mock;
|
|
|
import static org.mockito.Mockito.when;
|
|
|
|
|
@@ -50,6 +50,7 @@ public abstract class ElasticPayloadTestCase<T extends ElasticPayload> extends S
|
|
|
return model;
|
|
|
}
|
|
|
|
|
|
+ @Override
|
|
|
public void testApiTaskSettings() {
|
|
|
{
|
|
|
var validationException = new ValidationException();
|
|
@@ -67,14 +68,21 @@ public abstract class ElasticPayloadTestCase<T extends ElasticPayload> extends S
|
|
|
var validationException = new ValidationException();
|
|
|
var actualApiTaskSettings = payload.apiTaskSettings(Map.of("hello", "world"), validationException);
|
|
|
assertTrue(actualApiTaskSettings.isEmpty());
|
|
|
- assertFalse(validationException.validationErrors().isEmpty());
|
|
|
- assertThat(
|
|
|
- validationException.validationErrors().get(0),
|
|
|
- is(equalTo("task_settings is only supported during the inference request and cannot be stored in the inference endpoint."))
|
|
|
- );
|
|
|
+ assertTrue(validationException.validationErrors().isEmpty());
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ @Override
|
|
|
+ public void testUpdate() {
|
|
|
+ var taskSettings = randomApiTaskSettings();
|
|
|
+ var otherTaskSettings = randomValueOtherThan(taskSettings, this::randomApiTaskSettings);
|
|
|
+ var e = assertThrows(ValidationException.class, () -> taskSettings.updatedTaskSettings(toMap(otherTaskSettings)));
|
|
|
+ assertThat(
|
|
|
+ e.getMessage(),
|
|
|
+ containsString("task_settings is only supported during the inference request and cannot be stored in the inference endpoint")
|
|
|
+ );
|
|
|
+ }
|
|
|
+
|
|
|
public void testRequestWithRequiredFields() throws Exception {
|
|
|
var request = new SageMakerInferenceRequest(null, null, null, List.of("hello"), false, InputType.UNSPECIFIED);
|
|
|
var sdkByes = payload.requestBytes(mockModel(), request);
|