Issue
I am using PytorchLightning and beside others a ModelCheckpoint
which saves models with a formated filename like filename="model_{epoch}-{val_acc:.2f}"
In a process I want to load these checkpoints again, for simplicity lets say I want only the best via save_top_k=N
.
As the filename is dynamic I wonder how can I retrieve the checkpoint easily.
Is there a built in attribute or via the trainer that gives the saved checkpoints?
For example like
checkpoint_callback.get_top_k_paths()
I know I can do it with glob
and model_dir but wondering if there is a one line solution built in somehwere.
Solution
you can retrieve the best model path after training from the checkpoint
# retrieve the best checkpoint after training
checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
trainer = Trainer(callbacks=[checkpoint_callback])
model = ...
trainer.fit(model)
checkpoint_callback.best_model_path
To find all the checkpoints you can get the list of files in the dirpath
where the checkpoints are saved.
Answered By - Aniket Maurya
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.