-
Notifications
You must be signed in to change notification settings - Fork 9
Description
The Lightning ModelCheckpoint has an argument save_top_k
save_top_k (int) – if save_top_k == k, the best k models according to the quantity monitored will be saved. If save_top_k == 0, no models are saved. If save_top_k == -1, all models are saved. Please note that the monitors are checked every every_n_epochs epochs. If save_top_k >= 2 and the callback is called multiple times inside an epoch, and the filename remains unchanged, the name of the saved file will be appended with a version count starting with v1 to avoid collisions unless enable_version_counter is set to False. The version counter is unrelated to the top-k ranking of the checkpoint, and we recommend formatting the filename to include the monitored metric to avoid collisions.
METL training sets save_top_k=-1 here
metl/code/train_source_model.py
Lines 267 to 272 in 0d8129e
| # checkpoints at regular intervals (every 10 epochs) | |
| checkpoint_callback_2 = ModelCheckpoint( | |
| dirpath=join(log_dir, "checkpoints", "interval_checkpoints"), | |
| every_n_epochs=10, | |
| save_top_k=-1 | |
| ) |
We do not expect to have versioned "last" checkpoints, but they were observed under certain conditions while training in AWS. That caused a problem when loading checkpoints with filenames like last-v3.ckpt using the code here
metl/code/train_source_model.py
Lines 94 to 101 in 0d8129e
| def get_checkpoint_path(log_dir): | |
| if isfile(join(log_dir, "checkpoints", "last.ckpt")): | |
| ckpt_path = join(log_dir, "checkpoints", "last.ckpt") | |
| print("Found checkpoint, resuming training from: {}".format(ckpt_path)) | |
| else: | |
| ckpt_path = None | |
| print("No checkpoint found, training from scratch") | |
| return ckpt_path |
that looks for last.ckpt.