|
@@ -76,13 +76,18 @@ public class PyTorchStateStreamer {
|
|
|
restorer.setSearchSize(1);
|
|
|
restorer.restoreModelDefinition(doc -> writeChunk(doc, restoreStream), success -> {
|
|
|
logger.debug("model [{}] state restored in [{}] documents from index [{}]", modelId, restorer.getNumDocsWritten(), index);
|
|
|
- if (modelBytesWritten.get() != modelSize) {
|
|
|
- logger.error(
|
|
|
- "model [{}] restored state size [{}] does not equal the expected model size [{}]",
|
|
|
- modelId,
|
|
|
- modelBytesWritten,
|
|
|
- modelSize
|
|
|
- );
|
|
|
+
|
|
|
+ if (success) {
|
|
|
+ if (modelBytesWritten.get() != modelSize) {
|
|
|
+ logger.error(
|
|
|
+ "model [{}] restored state size [{}] does not equal the expected model size [{}]",
|
|
|
+ modelId,
|
|
|
+ modelBytesWritten,
|
|
|
+ modelSize
|
|
|
+ );
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ logger.info("[{}] loading model state cancelled", modelId);
|
|
|
}
|
|
|
listener.onResponse(success);
|
|
|
}, listener::onFailure);
|