|
|
@@ -37,12 +37,16 @@ public class PyTorchStateStreamer {
|
|
|
|
|
|
private static final Logger logger = LogManager.getLogger(PyTorchStateStreamer.class);
|
|
|
|
|
|
+ /** The size of the data written before the model definition */
|
|
|
+ private static final int NUM_BYTES_IN_PRELUDE = 4;
|
|
|
+
|
|
|
private final OriginSettingClient client;
|
|
|
private final ExecutorService executorService;
|
|
|
private final NamedXContentRegistry xContentRegistry;
|
|
|
private volatile boolean isCancelled;
|
|
|
private volatile int modelSize = -1;
|
|
|
- private final AtomicInteger bytesWritten = new AtomicInteger();
|
|
|
+ // model bytes only, does not include the prelude
|
|
|
+ private final AtomicInteger modelBytesWritten = new AtomicInteger();
|
|
|
|
|
|
public PyTorchStateStreamer(Client client, ExecutorService executorService, NamedXContentRegistry xContentRegistry) {
|
|
|
this.client = new OriginSettingClient(Objects.requireNonNull(client), ML_ORIGIN);
|
|
|
@@ -59,7 +63,7 @@ public class PyTorchStateStreamer {
|
|
|
|
|
|
/**
|
|
|
* First writes the size of the model so the native process can
|
|
|
- * allocated memory then writes the chunks of binary state.
|
|
|
+ * allocate memory then writes the chunks of binary state.
|
|
|
*
|
|
|
* @param modelId The model to write
|
|
|
* @param index The index to search for the model
|
|
|
@@ -72,11 +76,11 @@ public class PyTorchStateStreamer {
|
|
|
restorer.setSearchSize(1);
|
|
|
restorer.restoreModelDefinition(doc -> writeChunk(doc, restoreStream), success -> {
|
|
|
logger.debug("model [{}] state restored in [{}] documents from index [{}]", modelId, restorer.getNumDocsWritten(), index);
|
|
|
- if (bytesWritten.get() != modelSize) {
|
|
|
+ if (modelBytesWritten.get() != modelSize) {
|
|
|
logger.error(
|
|
|
"model [{}] restored state size [{}] does not equal the expected model size [{}]",
|
|
|
modelId,
|
|
|
- bytesWritten,
|
|
|
+ modelBytesWritten,
|
|
|
modelSize
|
|
|
);
|
|
|
}
|
|
|
@@ -96,7 +100,7 @@ public class PyTorchStateStreamer {
|
|
|
// The array backing the BytesReference may be bigger than what is
|
|
|
// referred to so write only what is after the offset
|
|
|
outputStream.write(doc.getBinaryData().array(), doc.getBinaryData().arrayOffset(), doc.getBinaryData().length());
|
|
|
- bytesWritten.addAndGet(doc.getBinaryData().length());
|
|
|
+ modelBytesWritten.addAndGet(doc.getBinaryData().length());
|
|
|
return true;
|
|
|
}
|
|
|
|
|
|
@@ -139,12 +143,10 @@ public class PyTorchStateStreamer {
|
|
|
throw new IllegalStateException(message);
|
|
|
}
|
|
|
|
|
|
- final int NUM_BYTES = 4;
|
|
|
- ByteBuffer lengthBuffer = ByteBuffer.allocate(NUM_BYTES);
|
|
|
+ ByteBuffer lengthBuffer = ByteBuffer.allocate(NUM_BYTES_IN_PRELUDE);
|
|
|
lengthBuffer.putInt(modelSizeBytes.intValue());
|
|
|
outputStream.write(lengthBuffer.array());
|
|
|
|
|
|
- bytesWritten.addAndGet(NUM_BYTES);
|
|
|
return modelSizeBytes.intValue();
|
|
|
}
|
|
|
}
|