Bläddra i källkod

Merge pull request ESQL-1051 from elastic/main

🤖 ESQL: Merge upstream
elasticsearchmachine 2 år sedan
förälder
incheckning
adb8ab27a2

+ 1 - 1
x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportGetTrainedModelPackageConfigAction.java

@@ -98,7 +98,7 @@ public class TransportGetTrainedModelPackageConfigAction extends TransportMaster
                         return;
                     }
 
-                    if (Strings.isNullOrEmpty(packageConfig.getSha256())) {
+                    if (Strings.isNullOrEmpty(packageConfig.getSha256()) || packageConfig.getSha256().length() != 64) {
                         listener.onFailure(new ElasticsearchStatusException("Invalid package sha", RestStatus.INTERNAL_SERVER_ERROR));
                         return;
                     }

+ 27 - 2
x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java

@@ -114,8 +114,9 @@ public class TransportLoadTrainedModelPackage extends TransportMasterNodeAction<
                 // simple round up
                 int totalParts = (int) ((size + DEFAULT_CHUNK_SIZE - 1) / DEFAULT_CHUNK_SIZE);
 
-                for (int part = 0; part < totalParts; ++part) {
+                for (int part = 0; part < totalParts - 1; ++part) {
                     BytesArray definition = chunkIterator.next();
+
                     PutTrainedModelDefinitionPartAction.Request r = new PutTrainedModelDefinitionPartAction.Request(
                         modelId,
                         definition,
@@ -126,7 +127,31 @@ public class TransportLoadTrainedModelPackage extends TransportMasterNodeAction<
 
                     client.execute(PutTrainedModelDefinitionPartAction.INSTANCE, r).actionGet();
                 }
-                logger.debug(() -> format("finished uploading model using [%d] parts", totalParts));
+
+                // get the last part, this time verify the checksum
+                BytesArray definition = chunkIterator.next();
+
+                if (modelPackageConfig.getSha256().equals(chunkIterator.getSha256())) {
+                    PutTrainedModelDefinitionPartAction.Request r = new PutTrainedModelDefinitionPartAction.Request(
+                        modelId,
+                        definition,
+                        totalParts - 1,
+                        size,
+                        totalParts
+                    );
+
+                    client.execute(PutTrainedModelDefinitionPartAction.INSTANCE, r).actionGet();
+
+                    logger.debug(() -> format("finished uploading model using [%d] parts", totalParts));
+                } else {
+                    logger.error(
+                        format(
+                            "Model sha256 checksums do not match, expected [%s] but got [%s]",
+                            modelPackageConfig.getSha256(),
+                            chunkIterator.getSha256()
+                        )
+                    );
+                }
             } catch (MalformedURLException e) {
                 logger.error(format("Invalid URL [%s]", e));
             } catch (URISyntaxException e) {

+ 26 - 0
x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/ModelLoaderUtilsTests.java

@@ -7,8 +7,11 @@
 
 package org.elasticsearch.xpack.ml.packageloader.action;
 
+import org.elasticsearch.common.hash.MessageDigests;
 import org.elasticsearch.test.ESTestCase;
 
+import java.io.ByteArrayInputStream;
+import java.io.IOException;
 import java.net.URI;
 import java.net.URISyntaxException;
 
@@ -57,4 +60,27 @@ public class ModelLoaderUtilsTests extends ESTestCase {
         );
         assertEquals("Repository must contain a scheme", e.getMessage());
     }
+
+    public void testSha256() throws IOException {
+        byte[] bytes = randomByteArrayOfLength(randomIntBetween(10, 1_000_000));
+        String expectedDigest = MessageDigests.toHexString(MessageDigests.sha256().digest(bytes));
+        assertEquals(64, expectedDigest.length());
+
+        int chunkSize = randomIntBetween(100, 10_000);
+
+        ModelLoaderUtils.InputStreamChunker inputStreamChunker = new ModelLoaderUtils.InputStreamChunker(
+            new ByteArrayInputStream(bytes),
+            chunkSize
+        );
+
+        int totalParts = (bytes.length + chunkSize - 1) / chunkSize;
+
+        for (int part = 0; part < totalParts - 1; ++part) {
+            assertEquals(chunkSize, inputStreamChunker.next().length());
+        }
+
+        assertEquals(bytes.length - (chunkSize * (totalParts - 1)), inputStreamChunker.next().length());
+
+        assertEquals(expectedDigest, inputStreamChunker.getSha256());
+    }
 }