浏览代码

[ML] Fix incorrect logging of unexpected model size error (#81089)

David Kyle 3 年之前
父节点
当前提交
3101118440

+ 10 - 8
x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchStateStreamer.java

@@ -37,12 +37,16 @@ public class PyTorchStateStreamer {
 
     private static final Logger logger = LogManager.getLogger(PyTorchStateStreamer.class);
 
+    /** The size of the data written before the model definition */
+    private static final int NUM_BYTES_IN_PRELUDE = 4;
+
     private final OriginSettingClient client;
     private final ExecutorService executorService;
     private final NamedXContentRegistry xContentRegistry;
     private volatile boolean isCancelled;
     private volatile int modelSize = -1;
-    private final AtomicInteger bytesWritten = new AtomicInteger();
+    // model bytes only, does not include the prelude
+    private final AtomicInteger modelBytesWritten = new AtomicInteger();
 
     public PyTorchStateStreamer(Client client, ExecutorService executorService, NamedXContentRegistry xContentRegistry) {
         this.client = new OriginSettingClient(Objects.requireNonNull(client), ML_ORIGIN);
@@ -59,7 +63,7 @@ public class PyTorchStateStreamer {
 
     /**
      * First writes the size of the model so the native process can
-     * allocated memory then writes the chunks of binary state.
+     * allocate memory then writes the chunks of binary state.
      *
      * @param modelId  The model to write
      * @param index    The index to search for the model
@@ -72,11 +76,11 @@ public class PyTorchStateStreamer {
         restorer.setSearchSize(1);
         restorer.restoreModelDefinition(doc -> writeChunk(doc, restoreStream), success -> {
             logger.debug("model [{}] state restored in [{}] documents from index [{}]", modelId, restorer.getNumDocsWritten(), index);
-            if (bytesWritten.get() != modelSize) {
+            if (modelBytesWritten.get() != modelSize) {
                 logger.error(
                     "model [{}] restored state size [{}] does not equal the expected model size [{}]",
                     modelId,
-                    bytesWritten,
+                    modelBytesWritten,
                     modelSize
                 );
             }
@@ -96,7 +100,7 @@ public class PyTorchStateStreamer {
         // The array backing the BytesReference may be bigger than what is
         // referred to so write only what is after the offset
         outputStream.write(doc.getBinaryData().array(), doc.getBinaryData().arrayOffset(), doc.getBinaryData().length());
-        bytesWritten.addAndGet(doc.getBinaryData().length());
+        modelBytesWritten.addAndGet(doc.getBinaryData().length());
         return true;
     }
 
@@ -139,12 +143,10 @@ public class PyTorchStateStreamer {
             throw new IllegalStateException(message);
         }
 
-        final int NUM_BYTES = 4;
-        ByteBuffer lengthBuffer = ByteBuffer.allocate(NUM_BYTES);
+        ByteBuffer lengthBuffer = ByteBuffer.allocate(NUM_BYTES_IN_PRELUDE);
         lengthBuffer.putInt(modelSizeBytes.intValue());
         outputStream.write(lengthBuffer.array());
 
-        bytesWritten.addAndGet(NUM_BYTES);
         return modelSizeBytes.intValue();
     }
 }