Skip to content

Commit

Permalink
Fix module of DistributedDataParallel when not using single gpu model…
Browse files Browse the repository at this point in the history
… when saving
  • Loading branch information
antoine311200 committed Aug 20, 2024
1 parent d8a235a commit 7e217b4
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions torchrun_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def parse_args(args):
parser.add_argument("--weight_decay", type=float, default=0.0)
parser.add_argument("--warmup_steps", type=int, default=1_000)
parser.add_argument("--eval_every", type=int, default=5_000)
parser.add_argument("--target_eval_tokens", type=int, default=1_000_000)
parser.add_argument("--target_eval_tokens", type=int, default=10_000_000)
parser.add_argument("--num_training_steps", type=int, default=10_000,
help="Number of **update steps** to train for. "
"Notice that gradient accumulation is taken into account.")
Expand Down Expand Up @@ -442,7 +442,12 @@ def optimizer_hook(p):
current_model_directory = f"{args.save_dir}/model_{update_step}"
logger.info(f"Saving model and optimizer to {current_model_directory}, update step {update_step}")
os.makedirs(args.save_dir, exist_ok=True)
model.save_pretrained(current_model_directory, max_shard_size='100GB')

# Check if DistributedDataParallel is used
if not args.single_gpu:
model.save_pretrained(current_model_directory, max_shard_size='100GB')
else:
model.module.save_pretrained(current_model_directory, max_shard_size='100GB')

optimizer_checkpoint = {
"optimizer": optimizer.state_dict(),
Expand Down Expand Up @@ -519,7 +524,11 @@ def optimizer_hook(p):
if global_rank == 0 and not os.path.exists(current_model_directory):
logger.info(f"Saving model and optimizer to {current_model_directory}, update step {update_step}")
os.makedirs(args.save_dir, exist_ok=True)
model.save_pretrained(current_model_directory)

if not args.single_gpu:
model.save_pretrained(current_model_directory)
else:
model.module.save_pretrained(current_model_directory)

optimizer_checkpoint = {
"optimizer": optimizer.state_dict(),
Expand Down

0 comments on commit 7e217b4

Please sign in to comment.