1
0
Эх сурвалжийг харах

[ML] Remove SageMaker Elastic updates (#131301)

Rather than silently drop the payload,
throw a validation error when Users try to send task settings in the
update payload for SageMaker inference with the Elastic API.
Pat Whelan 3 сар өмнө
parent
commit
037ddaa5c8

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerModel.java

@@ -116,7 +116,7 @@ public class SageMakerModel extends Model {
             getConfigurations(),
             getSecrets(),
             serviceSettings,
-            taskSettings.updatedTaskSettings(taskSettingsOverride),
+            taskSettings.override(taskSettingsOverride),
             awsSecretSettings
         );
     }

+ 11 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/model/SageMakerTaskSettings.java

@@ -71,11 +71,21 @@ record SageMakerTaskSettings(
     @Override
     public SageMakerTaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
         var validationException = new ValidationException();
-
         var updateTaskSettings = fromMap(newSettings, apiTaskSettings.updatedTaskSettings(newSettings), validationException);
+        validationException.throwIfValidationErrorsExist();
+
+        return override(updateTaskSettings);
+    }
 
+    public SageMakerTaskSettings override(Map<String, Object> newSettings) {
+        var validationException = new ValidationException();
+        var updateTaskSettings = fromMap(newSettings, apiTaskSettings.override(newSettings), validationException);
         validationException.throwIfValidationErrorsExist();
 
+        return override(updateTaskSettings);
+    }
+
+    private SageMakerTaskSettings override(SageMakerTaskSettings updateTaskSettings) {
         var updatedExtraTaskSettings = updateTaskSettings.apiTaskSettings().equals(SageMakerStoredTaskSchema.NO_OP)
             ? apiTaskSettings
             : updateTaskSettings.apiTaskSettings();

+ 4 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerStoredTaskSchema.java

@@ -68,4 +68,8 @@ public interface SageMakerStoredTaskSchema extends TaskSettings {
 
     @Override
     SageMakerStoredTaskSchema updatedTaskSettings(Map<String, Object> newSettings);
+
+    default SageMakerStoredTaskSchema override(Map<String, Object> newSettings) {
+        return updatedTaskSettings(newSettings);
+    }
 }

+ 0 - 6
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticPayload.java

@@ -88,12 +88,6 @@ interface ElasticPayload extends SageMakerSchemaPayload {
 
     @Override
     default SageMakerElasticTaskSettings apiTaskSettings(Map<String, Object> taskSettings, ValidationException validationException) {
-        if (taskSettings != null && (taskSettings.isEmpty() == false)) {
-            validationException.addValidationError(
-                InferenceAction.Request.TASK_SETTINGS.getPreferredName()
-                    + " is only supported during the inference request and cannot be stored in the inference endpoint."
-            );
-        }
         return SageMakerElasticTaskSettings.empty();
     }
 

+ 12 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/SageMakerElasticTaskSettings.java

@@ -9,10 +9,12 @@ package org.elasticsearch.xpack.inference.services.sagemaker.schema.elastic;
 
 import org.elasticsearch.TransportVersion;
 import org.elasticsearch.TransportVersions;
+import org.elasticsearch.common.ValidationException;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.xcontent.XContentBuilder;
+import org.elasticsearch.xpack.core.inference.action.InferenceAction;
 import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStoredTaskSchema;
 
 import java.io.IOException;
@@ -40,6 +42,16 @@ record SageMakerElasticTaskSettings(@Nullable Map<String, Object> passthroughSet
 
     @Override
     public SageMakerStoredTaskSchema updatedTaskSettings(Map<String, Object> newSettings) {
+        var validationException = new ValidationException();
+        validationException.addValidationError(
+            InferenceAction.Request.TASK_SETTINGS.getPreferredName()
+                + " is only supported during the inference request and cannot be stored in the inference endpoint."
+        );
+        throw validationException;
+    }
+
+    @Override
+    public SageMakerStoredTaskSchema override(Map<String, Object> newSettings) {
         return new SageMakerElasticTaskSettings(newSettings);
     }
 

+ 1 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/SageMakerSchemaPayloadTestCase.java

@@ -119,7 +119,7 @@ public abstract class SageMakerSchemaPayloadTestCase<T extends SageMakerSchemaPa
         }
     }
 
-    public final void testUpdate() throws IOException {
+    public void testUpdate() throws IOException {
         var taskSettings = randomApiTaskSettings();
         if (taskSettings != SageMakerStoredTaskSchema.NO_OP) {
             var otherTaskSettings = randomValueOtherThan(taskSettings, this::randomApiTaskSettings);

+ 15 - 7
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/sagemaker/schema/elastic/ElasticPayloadTestCase.java

@@ -18,8 +18,8 @@ import org.elasticsearch.xpack.inference.services.sagemaker.schema.SageMakerStor
 import java.util.List;
 import java.util.Map;
 
-import static org.hamcrest.Matchers.equalTo;
-import static org.hamcrest.Matchers.is;
+import static org.elasticsearch.xpack.inference.services.InferenceSettingsTestCase.toMap;
+import static org.hamcrest.Matchers.containsString;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
@@ -50,6 +50,7 @@ public abstract class ElasticPayloadTestCase<T extends ElasticPayload> extends S
         return model;
     }
 
+    @Override
     public void testApiTaskSettings() {
         {
             var validationException = new ValidationException();
@@ -67,14 +68,21 @@ public abstract class ElasticPayloadTestCase<T extends ElasticPayload> extends S
             var validationException = new ValidationException();
             var actualApiTaskSettings = payload.apiTaskSettings(Map.of("hello", "world"), validationException);
             assertTrue(actualApiTaskSettings.isEmpty());
-            assertFalse(validationException.validationErrors().isEmpty());
-            assertThat(
-                validationException.validationErrors().get(0),
-                is(equalTo("task_settings is only supported during the inference request and cannot be stored in the inference endpoint."))
-            );
+            assertTrue(validationException.validationErrors().isEmpty());
         }
     }
 
+    @Override
+    public void testUpdate() {
+        var taskSettings = randomApiTaskSettings();
+        var otherTaskSettings = randomValueOtherThan(taskSettings, this::randomApiTaskSettings);
+        var e = assertThrows(ValidationException.class, () -> taskSettings.updatedTaskSettings(toMap(otherTaskSettings)));
+        assertThat(
+            e.getMessage(),
+            containsString("task_settings is only supported during the inference request and cannot be stored in the inference endpoint")
+        );
+    }
+
     public void testRequestWithRequiredFields() throws Exception {
         var request = new SageMakerInferenceRequest(null, null, null, List.of("hello"), false, InputType.UNSPECIFIED);
         var sdkByes = payload.requestBytes(mockModel(), request);