Skip to content

Loading last checkpoint and versioned last checkpoint #12

@agitter

Description

@agitter

The Lightning ModelCheckpoint has an argument save_top_k

save_top_k (int) – if save_top_k == k, the best k models according to the quantity monitored will be saved. If save_top_k == 0, no models are saved. If save_top_k == -1, all models are saved. Please note that the monitors are checked every every_n_epochs epochs. If save_top_k >= 2 and the callback is called multiple times inside an epoch, and the filename remains unchanged, the name of the saved file will be appended with a version count starting with v1 to avoid collisions unless enable_version_counter is set to False. The version counter is unrelated to the top-k ranking of the checkpoint, and we recommend formatting the filename to include the monitored metric to avoid collisions.

METL training sets save_top_k=-1 here

# checkpoints at regular intervals (every 10 epochs)
checkpoint_callback_2 = ModelCheckpoint(
dirpath=join(log_dir, "checkpoints", "interval_checkpoints"),
every_n_epochs=10,
save_top_k=-1
)

We do not expect to have versioned "last" checkpoints, but they were observed under certain conditions while training in AWS. That caused a problem when loading checkpoints with filenames like last-v3.ckpt using the code here

def get_checkpoint_path(log_dir):
if isfile(join(log_dir, "checkpoints", "last.ckpt")):
ckpt_path = join(log_dir, "checkpoints", "last.ckpt")
print("Found checkpoint, resuming training from: {}".format(ckpt_path))
else:
ckpt_path = None
print("No checkpoint found, training from scratch")
return ckpt_path

that looks for last.ckpt.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions