PyTorch Lightning example#3189
Conversation
|
Addresses #3171. |
|
|
||
| def main(args): | ||
| # Create a model, synthetic data, a guide, and a lightning module. | ||
| pyro.set_rng_seed(args.seed) |
There was a problem hiding this comment.
This option added in #3149 ensures that parameters of PyroModules will not be implicitly shared across model instances via the Pyro parameter store:
| pyro.set_rng_seed(args.seed) | |
| pyro.set_rng_seed(args.seed) | |
| pyro.settings.set(module_local_params=True) |
It's not really exercised in this simple example since there's only one model and guide but I think it's good practice to enable it whenever models and guides can be written as PyroModules and trained using generic PyTorch infrastructure like torch.optim and PyTorch Lightning.
| guide = AutoNormal(model) | ||
| training_plan = PyroLightningModule(model, guide, args.learning_rate) |
There was a problem hiding this comment.
This change uses the new __call__ method added to the base pyro.infer.elbo.ELBO in #3149 that takes a model and guide returns a torch.nn.Module wrapper around the loss:
| guide = AutoNormal(model) | |
| training_plan = PyroLightningModule(model, guide, args.learning_rate) | |
| guide = AutoNormal(model) | |
| loss_fn = Trace_ELBO()(model, guide) | |
| training_plan = PyroLightningModule(loss_fn, args.learning_rate) |
It saves you from having to pass around a model and guide everywhere or deal with the Pyro parameter store, which makes SVI a little easier to use with other PyTorch tools like Lightning and the PyTorch JIT.
There was a problem hiding this comment.
I didn't know about ELBOModule. This is much neater!
| # All relevant parameters need to be initialized before ``configure_optimizer`` is called. | ||
| # Since we used AutoNormal guide our parameters have not be initialized yet. | ||
| # Therefore we warm up the guide by running one mini-batch through it. | ||
| mini_batch = dataset[: args.batch_size] | ||
| guide(*mini_batch) |
There was a problem hiding this comment.
| # All relevant parameters need to be initialized before ``configure_optimizer`` is called. | |
| # Since we used AutoNormal guide our parameters have not be initialized yet. | |
| # Therefore we warm up the guide by running one mini-batch through it. | |
| mini_batch = dataset[: args.batch_size] | |
| guide(*mini_batch) | |
| # All relevant parameters need to be initialized before ``configure_optimizer`` is called. | |
| # Since we used AutoNormal guide our parameters have not be initialized yet. | |
| # Therefore we initialize the model and guide by running one mini-batch through the loss. | |
| mini_batch = dataset[: args.batch_size] | |
| loss_fn(*mini_batch) |
| # Distributed training via Pytorch Lightning. | ||
| # | ||
| # This tutorial demonstrates how to distribute SVI training across multiple | ||
| # machines (or multiple GPUs on one or more machines) using the PyTorch Lightning |
There was a problem hiding this comment.
Where is the distributed training in this example? Is it hidden in the default configuration of the DataLoader and TrainingPlan in main below?
There was a problem hiding this comment.
Argparse arguments are passed to the pl.Trainer:
trainer = pl.Trainer.from_argparse_args(args)So you can run the script as follows:
$ python examples/svi_lightning.py --accelerator gpu --devices 2 --max_epochs 100 --strategy ddp
When there are multiple devices DataLoader will use DistributedSampler automatically.
| def __init__(self, model, guide, lr): | ||
| super().__init__() | ||
| self.pyro_model = model | ||
| self.pyro_guide = guide | ||
| self.loss_fn = Trace_ELBO().differentiable_loss |
There was a problem hiding this comment.
| def __init__(self, model, guide, lr): | |
| super().__init__() | |
| self.pyro_model = model | |
| self.pyro_guide = guide | |
| self.loss_fn = Trace_ELBO().differentiable_loss | |
| def __init__(self, loss_fn: pyro.infer.elbo.ELBOModule, lr: float): | |
| super().__init__() | |
| self.loss_fn = loss_fn | |
| self.model = loss_fn.model | |
| self.guide = loss_fn.guide |
|
|
||
| def training_step(self, batch, batch_idx): | ||
| """Training step for Pyro training.""" | ||
| loss = self.loss_fn(self.pyro_model, self.pyro_guide, *batch) |
There was a problem hiding this comment.
| loss = self.loss_fn(self.pyro_model, self.pyro_guide, *batch) | |
| loss = self.loss_fn(*batch) |
|
|
||
| def configure_optimizers(self): | ||
| """Configure an optimizer.""" | ||
| return torch.optim.Adam(self.pyro_guide.parameters(), lr=self.lr) |
There was a problem hiding this comment.
| return torch.optim.Adam(self.pyro_guide.parameters(), lr=self.lr) | |
| return torch.optim.Adam(self.loss_fn.parameters(), lr=self.lr) |
| self.lr = lr | ||
|
|
There was a problem hiding this comment.
Adding a forward method that calls Predictive is sometimes helpful:
| self.lr = lr | |
| self.lr = lr | |
| self.predictive = pyro.infer.Predictive(self.model, guide=self.guide) | |
| def forward(self, *args): | |
| return self.predictive(*args) | |
ordabayevy
left a comment
There was a problem hiding this comment.
Thanks for reviewing @eb8680. I think it is much neater now using ELBOModule!
| # Distributed training via Pytorch Lightning. | ||
| # | ||
| # This tutorial demonstrates how to distribute SVI training across multiple | ||
| # machines (or multiple GPUs on one or more machines) using the PyTorch Lightning |
There was a problem hiding this comment.
Argparse arguments are passed to the pl.Trainer:
trainer = pl.Trainer.from_argparse_args(args)So you can run the script as follows:
$ python examples/svi_lightning.py --accelerator gpu --devices 2 --max_epochs 100 --strategy ddp
When there are multiple devices DataLoader will use DistributedSampler automatically.
| guide = AutoNormal(model) | ||
| training_plan = PyroLightningModule(model, guide, args.learning_rate) |
There was a problem hiding this comment.
I didn't know about ELBOModule. This is much neater!
|
@fritzo There is something wrong with building tutorials when I run Trying to figure out what is wrong ... (if you know a quick fix would appreciate it) |
|
@ordabayevy not sure what's causing the build issue... Unrelated, I see
Could you add svi_lightning to tutorial/source/index.rst so it shows up on the website? |
|
Still no luck with Details |
I was able to build the tutorial by ignoring warnings and can confirm that the generated doc is readable and the title in the left hand side TOC is not too long. |
PyTorch Lightning example (pyro-ppl#3189)
PyTorch Lightning example (pyro-ppl#3189)
This example shows how to train Pyro models using PyTorch Lightning and is adapted from Horovod example.