feat: introduce rank-local and synced source of randomness#1160
feat: introduce rank-local and synced source of randomness#1160ssmmnn11 wants to merge 8 commits into
Conversation
# Conflicts: # models/src/anemoi/models/models/transport_encoder_processor_decoder.py # models/src/anemoi/models/samplers/transport_samplers.py # models/tests/models/test_diffusion_sampling_pipeline.py # training/src/anemoi/training/train/methods/diffusion.py # training/src/anemoi/training/train/train.py # training/tests/unit/train/test_methods.py
There was a problem hiding this comment.
Thanks for adding this! I think the separate saved RNG state acting as synced stream generally works well. A couple of comments below mostly regarding maintainability, like making sure future contributions don't desyncronise the synced steam.
One small comment along those lines regarding RandomFourierEmbeddigns in diffusion.py (which wasn't modified so I can't comment there directly): I would prefer the torch.randn call in the init to be explicitly wrapped in use_synced_torch_rng and not just implicitly being called in the model construction in train.py (which is wrapped in the context) -- just to be more explicit, it works as it is.
More generally, I am wondering if we can make the two states of the stream more explicit for future contributors, e.g. having with_synced_torch_rng and with_local_torch_rng. (This doesn't enforce the correct context but it might help against forgetting about the context.) Maybe even wrap torch.randn calls in synced_randn and local_randn which assert they are in the correct context?
Similarly, perhaps it's worth thinking about adding guards against rank-specific calls to toch.randn within synced contexts? E.g. an optional check for shape consistency (broadcast shape from rank 0 and assert all ranks agree) inside synced_randn controlled by a debug flag? Then we could add an integration test that runs in debug mode. (Maybe this is too cautious but I imagine these type of issues would be hard to debug.)
| noise_shape = ( | ||
| batch_size, | ||
| ensemble_size, | ||
| grid_size if self.noise_graph_provider is None else self.noise_graph_provider.projection_matrix.shape[1], | ||
| self.noise_channels, | ||
| ) |
There was a problem hiding this comment.
Can we make sure that this is the same across all ranks, specifically for grid_size if self.noise_graph_provider is None? I.e. prevent callers from passing shard-local sizes which would desyncronize the synced stream. E.g. take the full grid size and shard inside, or add a check?
| noise_shard_sizes = get_shard_sizes(noise, 0, model_comm_group) | ||
| noise = shard_tensor(noise, 0, noise_shard_sizes, model_comm_group) # sharded grid dim, full channels | ||
|
|
||
| noise = checkpoint(self.noise_mlp, noise, use_reentrant=False) |
There was a problem hiding this comment.
Do we need maybe_checkpoint here?
| training_method_cfg = self.config.training.method | ||
| training_method_cls = get_class(training_method_cfg._target_) | ||
| model = instantiate_with_runtime_kwargs(training_method_cfg, **kwargs) # Task -> pl.LightningModule | ||
| seed_torch_rng_sources(self.initial_seed, self.strategy.global_rank, reset_synced=True) |
There was a problem hiding this comment.
This resets the groups to the stream derived from sync_group_id=0. Elsewhere, we set sync_group_id=model_comm_group_id. Do we need to restore this after model init?
| "CUDA was initialized inside use_synced_torch_rng(). " | ||
| "Initialize CUDA before entering the synced random context." | ||
| ) | ||
| raise RuntimeError(msg) |
There was a problem hiding this comment.
This raises after modifying the _synced_state. This seems fine for training runs because they crash here, but not sure how it interacts with tests.
| assert output.shape == x.shape | ||
|
|
||
|
|
||
| def test_pointwisemlp_processor_checkpointed_dropout_preserves_rng(pointwisemlp_processor_init): |
There was a problem hiding this comment.
do we need a seed_torch_rng_sources here?
Introduce separate model RNG streams for distributed training. The default pytorch RNG is now seeded independently per rank, so dropout can use the rank-local random source. A synced pytorch RNG stream is available through a context manager for operations that must stay identical across ranks, such as model initialization and selected noise sampling operations.