Skip to content

feat: introduce rank-local and synced source of randomness#1160

Draft
ssmmnn11 wants to merge 8 commits into
mainfrom
feat/random_seed_sharded
Draft

feat: introduce rank-local and synced source of randomness#1160
ssmmnn11 wants to merge 8 commits into
mainfrom
feat/random_seed_sharded

Conversation

@ssmmnn11

@ssmmnn11 ssmmnn11 commented Jun 2, 2026

Copy link
Copy Markdown
Member

Introduce separate model RNG streams for distributed training. The default pytorch RNG is now seeded independently per rank, so dropout can use the rank-local random source. A synced pytorch RNG stream is available through a context manager for operations that must stay identical across ranks, such as model initialization and selected noise sampling operations.

@ssmmnn11 ssmmnn11 requested a review from VeraChristina June 2, 2026 10:29
@ssmmnn11 ssmmnn11 self-assigned this Jun 2, 2026
@github-project-automation github-project-automation Bot moved this to To be triaged in Anemoi-dev Jun 2, 2026
@ssmmnn11 ssmmnn11 changed the title Feat/random seed sharded feat: introduce rank-local and synced source of randomness Jun 2, 2026
@ssmmnn11 ssmmnn11 marked this pull request as draft June 2, 2026 10:30
@ssmmnn11

ssmmnn11 commented Jun 2, 2026

Copy link
Copy Markdown
Member Author

#595

ssmmnn11 added 5 commits June 2, 2026 16:51
# Conflicts:
#	models/src/anemoi/models/models/transport_encoder_processor_decoder.py
#	models/src/anemoi/models/samplers/transport_samplers.py
#	models/tests/models/test_diffusion_sampling_pipeline.py
#	training/src/anemoi/training/train/methods/diffusion.py
#	training/src/anemoi/training/train/train.py
#	training/tests/unit/train/test_methods.py

@VeraChristina VeraChristina left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this! I think the separate saved RNG state acting as synced stream generally works well. A couple of comments below mostly regarding maintainability, like making sure future contributions don't desyncronise the synced steam.

One small comment along those lines regarding RandomFourierEmbeddigns in diffusion.py (which wasn't modified so I can't comment there directly): I would prefer the torch.randn call in the init to be explicitly wrapped in use_synced_torch_rng and not just implicitly being called in the model construction in train.py (which is wrapped in the context) -- just to be more explicit, it works as it is.

More generally, I am wondering if we can make the two states of the stream more explicit for future contributors, e.g. having with_synced_torch_rng and with_local_torch_rng. (This doesn't enforce the correct context but it might help against forgetting about the context.) Maybe even wrap torch.randn calls in synced_randn and local_randn which assert they are in the correct context?

Similarly, perhaps it's worth thinking about adding guards against rank-specific calls to toch.randn within synced contexts? E.g. an optional check for shape consistency (broadcast shape from rank 0 and assert all ranks agree) inside synced_randn controlled by a debug flag? Then we could add an integration test that runs in debug mode. (Maybe this is too cautious but I imagine these type of issues would be hard to debug.)

Comment on lines 178 to 183
noise_shape = (
batch_size,
ensemble_size,
grid_size if self.noise_graph_provider is None else self.noise_graph_provider.projection_matrix.shape[1],
self.noise_channels,
)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make sure that this is the same across all ranks, specifically for grid_size if self.noise_graph_provider is None? I.e. prevent callers from passing shard-local sizes which would desyncronize the synced stream. E.g. take the full grid size and shard inside, or add a check?

noise_shard_sizes = get_shard_sizes(noise, 0, model_comm_group)
noise = shard_tensor(noise, 0, noise_shard_sizes, model_comm_group) # sharded grid dim, full channels

noise = checkpoint(self.noise_mlp, noise, use_reentrant=False)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need maybe_checkpoint here?

training_method_cfg = self.config.training.method
training_method_cls = get_class(training_method_cfg._target_)
model = instantiate_with_runtime_kwargs(training_method_cfg, **kwargs) # Task -> pl.LightningModule
seed_torch_rng_sources(self.initial_seed, self.strategy.global_rank, reset_synced=True)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This resets the groups to the stream derived from sync_group_id=0. Elsewhere, we set sync_group_id=model_comm_group_id. Do we need to restore this after model init?

"CUDA was initialized inside use_synced_torch_rng(). "
"Initialize CUDA before entering the synced random context."
)
raise RuntimeError(msg)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This raises after modifying the _synced_state. This seems fine for training runs because they crash here, but not sure how it interacts with tests.

assert output.shape == x.shape


def test_pointwisemlp_processor_checkpointed_dropout_preserves_rng(pointwisemlp_processor_init):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need a seed_torch_rng_sources here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

Status: To be triaged

Development

Successfully merging this pull request may close these issues.

2 participants