Browse Source

[ML] validate model definition on start deployment (#80439)

When a deployment is started, we do not validate that the definition
documents are all present and not truncated. This commit adds a
validation on _start that prevents a bad state from occurring where the
deployment starts, but the model is incorrectly defined, or some unknown
error occurs to late in the deployment process.
Benjamin Trent 3 years ago
parent
commit
c3c3f88000

+ 1 - 1
docs/reference/ml/df-analytics/apis/put-trained-model-definition-part.asciidoc

@@ -45,7 +45,7 @@ The definition part for the model. Must be a base64 encoded string.
 
 `total_definition_length`::
 (Required, number)
-The total uncompressed definition length.
+The total uncompressed definition length in bytes. Not base64 encoded.
 
 `total_parts`::
 (Required, number)

+ 7 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelDefinitionPartAction.java

@@ -27,6 +27,7 @@ import java.util.Objects;
 import static org.elasticsearch.action.ValidateActions.addValidationError;
 
 public class PutTrainedModelDefinitionPartAction extends ActionType<AcknowledgedResponse> {
+    public static final int MAX_NUM_NATIVE_DEFINITION_PARTS = 10_000;
 
     public static final PutTrainedModelDefinitionPartAction INSTANCE = new PutTrainedModelDefinitionPartAction();
     public static final String NAME = "cluster:admin/xpack/ml/trained_models/part/put";
@@ -88,6 +89,12 @@ public class PutTrainedModelDefinitionPartAction extends ActionType<Acknowledged
             if (totalParts <= 0) {
                 validationException = addValidationError("[total_parts] must be greater than 0", validationException);
             }
+            if (totalParts > MAX_NUM_NATIVE_DEFINITION_PARTS) {
+                validationException = addValidationError(
+                    "[total_parts] must be less than or equal to " + MAX_NUM_NATIVE_DEFINITION_PARTS,
+                    validationException
+                );
+            }
             if (totalDefinitionLength <= 0) {
                 validationException = addValidationError("[total_definition_length] must be greater than 0", validationException);
             }

+ 15 - 0
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelDefinitionPartActionRequestTests.java

@@ -13,6 +13,7 @@ import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;
 import org.elasticsearch.xpack.core.ml.action.PutTrainedModelDefinitionPartAction.Request;
 
+import static org.elasticsearch.xpack.core.ml.action.PutTrainedModelDefinitionPartAction.MAX_NUM_NATIVE_DEFINITION_PARTS;
 import static org.hamcrest.Matchers.containsString;
 
 public class PutTrainedModelDefinitionPartActionRequestTests extends AbstractBWCWireSerializationTestCase<Request> {
@@ -40,6 +41,20 @@ public class PutTrainedModelDefinitionPartActionRequestTests extends AbstractBWC
 
         exception = badRequest.validate();
         assertThat(exception.getMessage(), containsString("[part] must be less than total_parts"));
+
+        badRequest = new Request(
+            randomAlphaOfLength(10),
+            new BytesArray(randomAlphaOfLength(10)),
+            5,
+            10L,
+            randomIntBetween(MAX_NUM_NATIVE_DEFINITION_PARTS + 1, Integer.MAX_VALUE)
+        );
+
+        exception = badRequest.validate();
+        assertThat(
+            exception.getMessage(),
+            containsString("[total_parts] must be less than or equal to " + MAX_NUM_NATIVE_DEFINITION_PARTS)
+        );
     }
 
     @Override

+ 9 - 1
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/AutoscalingIT.java

@@ -39,6 +39,7 @@ import org.junit.After;
 import org.junit.Before;
 
 import java.util.Arrays;
+import java.util.Base64;
 import java.util.Collections;
 import java.util.List;
 import java.util.SortedMap;
@@ -172,6 +173,7 @@ public class AutoscalingIT extends MlNativeAutodetectIntegTestCase {
         );
     }
 
+    @AwaitsFix(bugUrl = "Cannot be fixed until we move estimation to config and not rely on definition length only")
     public void testMLAutoscalingForLargeModelAllocation() {
         String modelId = "really_big_model";
         SortedMap<String, Settings> deciders = new TreeMap<>();
@@ -266,7 +268,13 @@ public class AutoscalingIT extends MlNativeAutodetectIntegTestCase {
         ).actionGet();
         client().execute(
             PutTrainedModelDefinitionPartAction.INSTANCE,
-            new PutTrainedModelDefinitionPartAction.Request(modelId, new BytesArray(BASE_64_ENCODED_MODEL), 0, memoryUse, 1)
+            new PutTrainedModelDefinitionPartAction.Request(
+                modelId,
+                new BytesArray(Base64.getDecoder().decode(BASE_64_ENCODED_MODEL)),
+                0,
+                memoryUse,
+                1
+            )
         ).actionGet();
         client().execute(
             PutTrainedModelVocabularyAction.INSTANCE,

+ 25 - 0
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java

@@ -429,6 +429,31 @@ public class PyTorchModelIT extends ESRestTestCase {
         assertThat(ex.getMessage(), containsString("[should-fail-get] is type [pytorch] and does not support retrieving the definition"));
     }
 
+    public void testStartDeploymentWithTruncatedDefinition() throws IOException {
+        String model = "should-fail-get";
+        createTrainedModel(model);
+        putVocabulary(List.of("once", "twice"), model);
+        Request request = new Request("PUT", "_ml/trained_models/" + model + "/definition/0");
+        request.setJsonEntity(
+            "{  "
+                + "\"total_definition_length\":"
+                + RAW_MODEL_SIZE
+                + 2L
+                + ","
+                + "\"definition\": \""
+                + BASE_64_ENCODED_MODEL
+                + "\","
+                + "\"total_parts\": 1"
+                + "}"
+        );
+        client().performRequest(request);
+        Exception ex = expectThrows(Exception.class, () -> startDeployment(model));
+        assertThat(
+            ex.getMessage(),
+            containsString("Model definition truncated. Unable to deserialize trained model definition [" + model + "]")
+        );
+    }
+
     public void testInferencePipelineAgainstUnallocatedModel() throws IOException {
         String model = "not-deployed";
         createTrainedModel(model);

+ 8 - 1
x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureResetIT.java

@@ -39,6 +39,7 @@ import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.DataCounts;
 import org.junit.After;
 
 import java.util.Arrays;
+import java.util.Base64;
 import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
@@ -199,7 +200,13 @@ public class TestFeatureResetIT extends MlNativeAutodetectIntegTestCase {
         ).actionGet();
         client().execute(
             PutTrainedModelDefinitionPartAction.INSTANCE,
-            new PutTrainedModelDefinitionPartAction.Request(TRAINED_MODEL_ID, new BytesArray(BASE_64_ENCODED_MODEL), 0, RAW_MODEL_SIZE, 1)
+            new PutTrainedModelDefinitionPartAction.Request(
+                TRAINED_MODEL_ID,
+                new BytesArray(Base64.getDecoder().decode(BASE_64_ENCODED_MODEL)),
+                0,
+                RAW_MODEL_SIZE,
+                1
+            )
         ).actionGet();
         client().execute(
             PutTrainedModelVocabularyAction.INSTANCE,

+ 7 - 1
x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelCRUDIT.java

@@ -85,7 +85,13 @@ public class TrainedModelCRUDIT extends MlSingleNodeTestCase {
         assertThat(((IndexLocation) config.getLocation()).getIndexName(), equalTo(InferenceIndexConstants.nativeDefinitionStore()));
         client().execute(
             PutTrainedModelDefinitionPartAction.INSTANCE,
-            new PutTrainedModelDefinitionPartAction.Request(modelId, new BytesArray(BASE_64_ENCODED_MODEL), 0, RAW_MODEL_SIZE, 1)
+            new PutTrainedModelDefinitionPartAction.Request(
+                modelId,
+                new BytesArray(Base64.getDecoder().decode(BASE_64_ENCODED_MODEL)),
+                0,
+                RAW_MODEL_SIZE,
+                1
+            )
         ).actionGet();
 
         assertThat(

+ 1 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelDefinitionPartAction.java

@@ -95,6 +95,7 @@ public class TransportPutTrainedModelDefinitionPartAction extends TransportMaste
                 new TrainedModelDefinitionDoc.Builder().setModelId(request.getModelId())
                     .setDocNum(request.getPart())
                     .setEos(isEos)
+                    // XContentParser::binaryValue pulls out the raw, base64 decoded bytes automatically. So, we only need the length here
                     .setDefinitionLength(request.getDefinition().length())
                     .setTotalDefinitionLength(request.getTotalDefinitionLength())
                     .setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION)

+ 113 - 20
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java

@@ -12,6 +12,7 @@ import org.apache.logging.log4j.Logger;
 import org.apache.logging.log4j.message.ParameterizedMessage;
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.ResourceAlreadyExistsException;
+import org.elasticsearch.ResourceNotFoundException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.ActionFilters;
 import org.elasticsearch.action.support.master.TransportMasterNodeAction;
@@ -26,11 +27,16 @@ import org.elasticsearch.cluster.node.DiscoveryNode;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.inject.Inject;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.common.util.set.Sets;
 import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.index.query.QueryBuilders;
 import org.elasticsearch.license.LicenseUtils;
 import org.elasticsearch.license.XPackLicenseState;
 import org.elasticsearch.persistent.PersistentTasksCustomMetadata;
 import org.elasticsearch.rest.RestStatus;
+import org.elasticsearch.search.SearchHit;
+import org.elasticsearch.search.sort.SortBuilders;
+import org.elasticsearch.search.sort.SortOrder;
 import org.elasticsearch.tasks.Task;
 import org.elasticsearch.threadpool.ThreadPool;
 import org.elasticsearch.transport.TransportService;
@@ -47,11 +53,15 @@ import org.elasticsearch.xpack.core.ml.inference.allocation.AllocationStatus;
 import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingState;
 import org.elasticsearch.xpack.core.ml.inference.allocation.RoutingStateAndReason;
 import org.elasticsearch.xpack.core.ml.inference.allocation.TrainedModelAllocation;
+import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
+import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocation;
+import org.elasticsearch.xpack.core.ml.job.messages.Messages;
 import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.ml.MachineLearning;
 import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
 import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationService;
 import org.elasticsearch.xpack.ml.inference.persistence.ChunkedTrainedModelRestorer;
+import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelDefinitionDoc;
 import org.elasticsearch.xpack.ml.process.MlMemoryTracker;
 
 import java.util.Collections;
@@ -65,6 +75,7 @@ import java.util.function.Predicate;
 import java.util.stream.Collectors;
 
 import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
+import static org.elasticsearch.xpack.core.ml.action.PutTrainedModelDefinitionPartAction.MAX_NUM_NATIVE_DEFINITION_PARTS;
 
 public class TransportStartTrainedModelDeploymentAction extends TransportMasterNodeAction<
     StartTrainedModelDeploymentAction.Request,
@@ -174,26 +185,28 @@ public class TransportStartTrainedModelDeploymentAction extends TransportMasterN
                 listener.onFailure(ExceptionsHelper.serverError("model [{}] does not have location", trainedModelConfig.getModelId()));
                 return;
             }
-
-            getModelBytes(trainedModelConfig, ActionListener.wrap(modelBytes -> {
-                TaskParams taskParams = new TaskParams(
-                    trainedModelConfig.getModelId(),
-                    modelBytes,
-                    request.getInferenceThreads(),
-                    request.getModelThreads(),
-                    request.getQueueCapacity()
-                );
-                PersistentTasksCustomMetadata persistentTasks = clusterService.state()
-                    .getMetadata()
-                    .custom(PersistentTasksCustomMetadata.TYPE);
-                memoryTracker.refresh(
-                    persistentTasks,
-                    ActionListener.wrap(
-                        aVoid -> trainedModelAllocationService.createNewModelAllocation(taskParams, waitForDeploymentToStart),
-                        listener::onFailure
-                    )
-                );
-            }, listener::onFailure));
+            validateModelDefinition(
+                trainedModelConfig,
+                ActionListener.wrap(validate -> getModelBytes(trainedModelConfig, ActionListener.wrap(modelBytes -> {
+                    TaskParams taskParams = new TaskParams(
+                        trainedModelConfig.getModelId(),
+                        modelBytes,
+                        request.getInferenceThreads(),
+                        request.getModelThreads(),
+                        request.getQueueCapacity()
+                    );
+                    PersistentTasksCustomMetadata persistentTasks = clusterService.state()
+                        .getMetadata()
+                        .custom(PersistentTasksCustomMetadata.TYPE);
+                    memoryTracker.refresh(
+                        persistentTasks,
+                        ActionListener.wrap(
+                            aVoid -> trainedModelAllocationService.createNewModelAllocation(taskParams, waitForDeploymentToStart),
+                            listener::onFailure
+                        )
+                    );
+                }, listener::onFailure)), listener::onFailure)
+            );
 
         }, listener::onFailure);
 
@@ -270,6 +283,86 @@ public class TransportStartTrainedModelDeploymentAction extends TransportMasterN
 
     }
 
+    private void validateModelDefinition(TrainedModelConfig config, ActionListener<Void> listener) {
+        if (config.getLocation() instanceof IndexLocation == false) {
+            listener.onResponse(null);
+            return;
+        }
+        final String modelId = config.getModelId();
+        final String[] requiredSourceFields = new String[] {
+            TrainedModelDefinitionDoc.DEFINITION_LENGTH.getPreferredName(),
+            TrainedModelDefinitionDoc.DOC_NUM.getPreferredName(),
+            TrainedModelDefinitionDoc.TOTAL_DEFINITION_LENGTH.getPreferredName(),
+            TrainedModelDefinitionDoc.EOS.getPreferredName() };
+        final Set<String> requiredSet = Set.of(requiredSourceFields);
+        String index = ((IndexLocation) config.getLocation()).getIndexName();
+        client.prepareSearch(index)
+            .setQuery(
+                QueryBuilders.constantScoreQuery(
+                    QueryBuilders.boolQuery()
+                        .filter(QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId))
+                        .filter(
+                            QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), TrainedModelDefinitionDoc.NAME)
+                        )
+                )
+            )
+            .setFetchSource(requiredSourceFields, new String[0])
+            .setSize(MAX_NUM_NATIVE_DEFINITION_PARTS)
+            .setTrackTotalHits(true)
+            .addSort(SortBuilders.fieldSort(TrainedModelDefinitionDoc.DOC_NUM.getPreferredName()).order(SortOrder.ASC).unmappedType("long"))
+            .execute(ActionListener.wrap(response -> {
+                SearchHit[] hits = response.getHits().getHits();
+                if (hits.length == 0) {
+                    listener.onFailure(new ResourceNotFoundException(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
+                    return;
+                }
+                long summedLengths = 0;
+                for (SearchHit hit : hits) {
+                    Map<String, Object> fields = hit.getSourceAsMap();
+                    if (fields == null) {
+                        listener.onFailure(
+                            ExceptionsHelper.badRequestException(
+                                "[{}] model definition [{}] is missing required fields {}, unable to be deployed",
+                                modelId,
+                                TrainedModelDefinitionDoc.docNum(modelId, Objects.requireNonNull(hit.getId())),
+                                List.of(requiredSourceFields)
+                            )
+                        );
+                        return;
+                    }
+                    Set<String> diff = Sets.difference(fields.keySet(), requiredSet);
+                    if (diff.isEmpty() == false) {
+                        listener.onFailure(
+                            ExceptionsHelper.badRequestException(
+                                "[{}] model definition [{}] is missing required fields {}, unable to be deployed",
+                                modelId,
+                                TrainedModelDefinitionDoc.docNum(modelId, Objects.requireNonNull(hit.getId())),
+                                diff
+                            )
+                        );
+                        return;
+                    }
+                    summedLengths += ((Number) fields.get(TrainedModelDefinitionDoc.DEFINITION_LENGTH.getPreferredName())).longValue();
+                }
+                long totalLength = ((Number) hits[hits.length - 1].getSourceAsMap()
+                    .get(TrainedModelDefinitionDoc.TOTAL_DEFINITION_LENGTH.getPreferredName())).longValue();
+                Boolean eos = (Boolean) hits[hits.length - 1].getSourceAsMap().get(TrainedModelDefinitionDoc.EOS.getPreferredName());
+                if (summedLengths != totalLength || eos == null || eos == false) {
+                    listener.onFailure(ExceptionsHelper.serverError(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));
+                    return;
+                }
+                listener.onResponse(null);
+            }, e -> {
+                if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
+                    Exception ex = new ResourceNotFoundException(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId));
+                    ex.addSuppressed(e);
+                    listener.onFailure(ex);
+                    return;
+                }
+                listener.onFailure(e);
+            }));
+    }
+
     @Override
     protected ClusterBlockException checkBlock(StartTrainedModelDeploymentAction.Request request, ClusterState state) {
         // We only delegate here to PersistentTasksService, but if there is a metadata writeblock,

+ 19 - 0
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDoc.java

@@ -77,6 +77,25 @@ public class TrainedModelDefinitionDoc implements ToXContentObject {
         return NAME + "-" + modelId + "-" + docNum;
     }
 
+    /**
+     * Return the document number as represented in the docId
+     * @param modelId The model Id
+     * @param docId the document ID
+     * @return the document number or -1 if not found (invalid)
+     */
+    public static int docNum(String modelId, String docId) {
+        String prefix = NAME + "-" + modelId + "-";
+        if (prefix.length() >= docId.length()) {
+            return -1;
+        }
+        String numString = docId.substring(prefix.length());
+        try {
+            return Integer.parseInt(numString);
+        } catch (NumberFormatException _ex) {
+            return -1;
+        }
+    }
+
     private final BytesReference binaryData;
     private final String modelId;
     private final int docNum;