Skip to content

PyTorch Lightning example#3189

Merged
eb8680 merged 16 commits into
devfrom
svi-lightning
Mar 16, 2023
Merged

PyTorch Lightning example#3189
eb8680 merged 16 commits into
devfrom
svi-lightning

Conversation

@ordabayevy
Copy link
Copy Markdown
Member

This example shows how to train Pyro models using PyTorch Lightning and is adapted from Horovod example.

Yerdos Ordabayev added 2 commits March 13, 2023 02:49
@ordabayevy
Copy link
Copy Markdown
Member Author

Addresses #3171.

Copy link
Copy Markdown
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I've been using Lightning recently as well, so I left some (optional) suggestions aimed at making the example slightly more PyTorch-idiomatic using new features from #3149

Comment thread examples/svi_lightning.py

def main(args):
# Create a model, synthetic data, a guide, and a lightning module.
pyro.set_rng_seed(args.seed)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This option added in #3149 ensures that parameters of PyroModules will not be implicitly shared across model instances via the Pyro parameter store:

Suggested change
pyro.set_rng_seed(args.seed)
pyro.set_rng_seed(args.seed)
pyro.settings.set(module_local_params=True)

It's not really exercised in this simple example since there's only one model and guide but I think it's good practice to enable it whenever models and guides can be written as PyroModules and trained using generic PyTorch infrastructure like torch.optim and PyTorch Lightning.

Comment thread examples/svi_lightning.py Outdated
Comment on lines +79 to +80
guide = AutoNormal(model)
training_plan = PyroLightningModule(model, guide, args.learning_rate)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change uses the new __call__ method added to the base pyro.infer.elbo.ELBO in #3149 that takes a model and guide returns a torch.nn.Module wrapper around the loss:

Suggested change
guide = AutoNormal(model)
training_plan = PyroLightningModule(model, guide, args.learning_rate)
guide = AutoNormal(model)
loss_fn = Trace_ELBO()(model, guide)
training_plan = PyroLightningModule(loss_fn, args.learning_rate)

It saves you from having to pass around a model and guide everywhere or deal with the Pyro parameter store, which makes SVI a little easier to use with other PyTorch tools like Lightning and the PyTorch JIT.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't know about ELBOModule. This is much neater!

Comment thread examples/svi_lightning.py Outdated
Comment on lines +86 to +90
# All relevant parameters need to be initialized before ``configure_optimizer`` is called.
# Since we used AutoNormal guide our parameters have not be initialized yet.
# Therefore we warm up the guide by running one mini-batch through it.
mini_batch = dataset[: args.batch_size]
guide(*mini_batch)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# All relevant parameters need to be initialized before ``configure_optimizer`` is called.
# Since we used AutoNormal guide our parameters have not be initialized yet.
# Therefore we warm up the guide by running one mini-batch through it.
mini_batch = dataset[: args.batch_size]
guide(*mini_batch)
# All relevant parameters need to be initialized before ``configure_optimizer`` is called.
# Since we used AutoNormal guide our parameters have not be initialized yet.
# Therefore we initialize the model and guide by running one mini-batch through the loss.
mini_batch = dataset[: args.batch_size]
loss_fn(*mini_batch)

Comment thread examples/svi_lightning.py
Comment on lines +4 to +7
# Distributed training via Pytorch Lightning.
#
# This tutorial demonstrates how to distribute SVI training across multiple
# machines (or multiple GPUs on one or more machines) using the PyTorch Lightning
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the distributed training in this example? Is it hidden in the default configuration of the DataLoader and TrainingPlan in main below?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Argparse arguments are passed to the pl.Trainer:

trainer = pl.Trainer.from_argparse_args(args)

So you can run the script as follows:

$ python examples/svi_lightning.py --accelerator gpu --devices 2 --max_epochs 100 --strategy ddp

When there are multiple devices DataLoader will use DistributedSampler automatically.

Comment thread examples/svi_lightning.py Outdated
Comment on lines +54 to +58
def __init__(self, model, guide, lr):
super().__init__()
self.pyro_model = model
self.pyro_guide = guide
self.loss_fn = Trace_ELBO().differentiable_loss
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(self, model, guide, lr):
super().__init__()
self.pyro_model = model
self.pyro_guide = guide
self.loss_fn = Trace_ELBO().differentiable_loss
def __init__(self, loss_fn: pyro.infer.elbo.ELBOModule, lr: float):
super().__init__()
self.loss_fn = loss_fn
self.model = loss_fn.model
self.guide = loss_fn.guide

Comment thread examples/svi_lightning.py Outdated

def training_step(self, batch, batch_idx):
"""Training step for Pyro training."""
loss = self.loss_fn(self.pyro_model, self.pyro_guide, *batch)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
loss = self.loss_fn(self.pyro_model, self.pyro_guide, *batch)
loss = self.loss_fn(*batch)

Comment thread examples/svi_lightning.py Outdated

def configure_optimizers(self):
"""Configure an optimizer."""
return torch.optim.Adam(self.pyro_guide.parameters(), lr=self.lr)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return torch.optim.Adam(self.pyro_guide.parameters(), lr=self.lr)
return torch.optim.Adam(self.loss_fn.parameters(), lr=self.lr)

Comment thread examples/svi_lightning.py
Comment on lines +59 to +60
self.lr = lr

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a forward method that calls Predictive is sometimes helpful:

Suggested change
self.lr = lr
self.lr = lr
self.predictive = pyro.infer.Predictive(self.model, guide=self.guide)
def forward(self, *args):
return self.predictive(*args)

Copy link
Copy Markdown
Member Author

@ordabayevy ordabayevy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reviewing @eb8680. I think it is much neater now using ELBOModule!

Comment thread examples/svi_lightning.py
Comment on lines +4 to +7
# Distributed training via Pytorch Lightning.
#
# This tutorial demonstrates how to distribute SVI training across multiple
# machines (or multiple GPUs on one or more machines) using the PyTorch Lightning
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Argparse arguments are passed to the pl.Trainer:

trainer = pl.Trainer.from_argparse_args(args)

So you can run the script as follows:

$ python examples/svi_lightning.py --accelerator gpu --devices 2 --max_epochs 100 --strategy ddp

When there are multiple devices DataLoader will use DistributedSampler automatically.

Comment thread examples/svi_lightning.py Outdated
Comment on lines +79 to +80
guide = AutoNormal(model)
training_plan = PyroLightningModule(model, guide, args.learning_rate)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't know about ELBOModule. This is much neater!

fritzo
fritzo previously approved these changes Mar 14, 2023
Copy link
Copy Markdown
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Can you just confirm the generated docs are readable, i.e. after running make tutorial? Also ensure the title isn't too long when it appears on the left hand side TOC.

@ordabayevy
Copy link
Copy Markdown
Member Author

@fritzo There is something wrong with building tutorials when I run make tutorial:

make tutorial
make -C tutorial html
make[1]: Entering directory '/mnt/disks/dev/repos/pyro/tutorial'
Running Sphinx v6.1.3
building [mo]: targets for 0 po files that are out of date
writing output... 
building [html]: targets for 80 source files that are out of date
updating environment: [new config] 80 added, 0 changed, 0 removed
reading sources... [100%] svi_part_ii .. working_memory                                                                                                             

Warning, treated as error:
/mnt/disks/dev/repos/pyro/tutorial/source/gp.ipynb:973:Duplicate substitution definition name: "image0".
make[1]: *** [Makefile:20: html] Error 2
make[1]: Leaving directory '/mnt/disks/dev/repos/pyro/tutorial'
make: *** [Makefile:18: tutorial] Error 2

Trying to figure out what is wrong ... (if you know a quick fix would appreciate it)

@fritzo
Copy link
Copy Markdown
Member

fritzo commented Mar 15, 2023

@ordabayevy not sure what's causing the build issue...

Unrelated, I see

.../pyro-ppl/pyro/tutorial/source/svi_lightning.rst: WARNING: document isn't included in any toctree

Could you add svi_lightning to tutorial/source/index.rst so it shows up on the website?

@ordabayevy
Copy link
Copy Markdown
Member Author

Still no luck with make tutorial. When I try to build tutorials on dev branch I get this:

Details
make -C tutorial html
make[1]: Entering directory '/mnt/disks/dev/repos/pyro/tutorial'
Running Sphinx v6.1.3
making output directory... done
building [mo]: targets for 0 po files that are out of date
writing output... 
building [html]: targets for 79 source files that are out of date
updating environment: [new config] 79 added, 0 changed, 0 removed
reading sources... [100%] tensor_shapes .. working_memory                                                                                                           

Warning, treated as error:
/mnt/disks/dev/repos/pyro/tutorial/source/logistic-growth.ipynb:1220:File not found: 'workflow.html'
make[1]: *** [Makefile:20: html] Error 2
make[1]: Leaving directory '/mnt/disks/dev/repos/pyro/tutorial'
make: *** [Makefile:18: tutorial] Error 2

@ordabayevy
Copy link
Copy Markdown
Member Author

Can you just confirm the generated docs are readable, i.e. after running make tutorial? Also ensure the title isn't too long when it appears on the left hand side TOC.

I was able to build the tutorial by ignoring warnings and can confirm that the generated doc is readable and the title in the left hand side TOC is not too long.

Copy link
Copy Markdown
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks for building tutorials. I'll look into fixing those warnings.

@eb8680 any further comments? I'll hold off merging, feel free to merge

Copy link
Copy Markdown
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@eb8680 eb8680 merged commit c6851b8 into dev Mar 16, 2023
@eb8680 eb8680 deleted the svi-lightning branch March 16, 2023 03:43
@ordabayevy
Copy link
Copy Markdown
Member Author

Thanks @eb8680 and @fritzo for reviewing!

luisdiaz1997 added a commit to luisdiaz1997/pyro that referenced this pull request Mar 16, 2023
luisdiaz1997 added a commit to luisdiaz1997/pyro that referenced this pull request Mar 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants