Scripts and APIs for demonstrating TorchTitan parallelism on CPU-only environments (e.g., Google Colab) via PyTorch's FakeTensorMode.
# Colab's runtime ships with a stable torch that pip would skip as
# "already satisfied", leaving you on a release that doesn't match the
# nightly TorchTitan. Uninstall first, then force-install a pinned
# nightly wheel. Update the wheel URL to a more recent nightly if the
# pinned one stops resolving; pick from
# https://download.pytorch.org/whl/nightly/cpu/torch/
_WHEEL = "https://download-r2.pytorch.org/whl/nightly/cpu/torch-2.13.0.dev20260518%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl"
!pip uninstall torch -y
!pip install $_WHEEL --force-reinstall --quiet --progress-bar off && echo "Installed torch."
!pip install git+https://github.com/pytorch/torchtitan.git --force-reinstall --quiet --progress-bar off && echo "Installed torchtitan."
!pip install git+https://github.com/fegin/titan-demo.git --force-reinstall --quiet --progress-bar off && echo "Installed titan-demo."import torch
from titan_demo import (
make_model_spec,
build_fake_model,
make_parallel_dims,
parallelize_fake_model,
estimate_memory,
print_memory_estimate,
)
# 1. Get the spec (with default sharding declarations installed).
spec = make_model_spec(
"debugmodel", # or "1B", "3B", "8B", "70B", "405B"
seq_len=128,
)
# 2. (Optional) inspect the default sharding plan, then tweak spec.model
# if you want to override anything.
from titan_demo import print_sharding_config
print_sharding_config(spec) # root + embedding/loss + layer 0
# print_sharding_config(spec, layer_id=1) # any other layer if needed
# 3. Build under FakeTensorMode (no real memory).
model, fake_mode = build_fake_model(spec.model, dtype=torch.bfloat16)
print(f"{spec.flavor}: {sum(p.numel() for p in model.parameters()):,} params")
# 4. Pick parallelism degrees. world_size required; dp_shard=-1 fills the
# leftover. full_dtensor is always True.
parallel_dims = make_parallel_dims(world_size=8, tp=2) # -> dp_shard=4
# 5. Apply TP/CP sharding + fully_shard wrapping. Auto-inits a fake
# process group of size world_size on first call.
parallelize_fake_model(model, parallel_dims=parallel_dims)
# 6. Estimate per-rank peak memory (params, grads, opt-state, acts,
# all-gather / reduce-scatter buffers). Internally runs one full
# fake training step (forward + backward + AdamW.step) under
# torch.distributed._tools.fsdp2_mem_tracker.FSDPMemTracker.
snap = estimate_memory(model, fake_mode, parallel_dims, batch_size=2, seq_len=64)
print_memory_estimate(snap, units="MiB")Parameters and activations are FakeTensors. No real memory is allocated, so 70B and 405B builds finish in seconds on a single CPU.
notebooks/parallelism_explorer.ipynb walks through three scenarios (8B on 8 GPUs, 70B on 128 GPUs, 405B on 1024 GPUs) with interactive widgets for tp, cp, dp_replicate, batch_size, seq_len. Pick a config, click "Run Interact", and see whether it FITS, is NEAR OOM (>=95% of per-GPU budget), or OOMs. Install the notebook extra (pip install -e .[notebook]) to get ipywidgets. CP under fake mode relies on the patches described below.
scripts/parallelism_explorer.py is the same explorer as a standalone Python script. Defaults to 70B + FSDP + TP + CP on 128 GPUs x 80 GiB; every axis is overridable via flags (--flavor, --world-size, --tp, --cp, --dp-replicate, --batch-size, --seq-len, ...). Run python scripts/parallelism_explorer.py --help for the full list.
The estimate models eager training memory: TorchTitan's standard config compiles only the FlexAttention kernel; the rest of the model (MLP, RMSNorm, etc.) runs eager, and our estimator measures the activations those eager ops allocate.
Under FakeTensorMode the FlexAttention call goes through its registered fake impl, which allocates just the outputs (output + lse). That matches what the real Triton flex_attention kernel allocates -- no quadratic (B, H, S, S) intermediate in either case.
If you run real training with full-model torch.compile enabled, fused MLP / norm kernels can drop some of the intermediates we count, so the actual peak may be lower than this estimate. The demo does not model that.
The OptState category lands on cpu in the raw tracker snapshot (an artifact of fake AdamW). The notebook helper rewrites it to a deterministic local_params * 12 bytes (AdamW master + exp_avg + exp_avg_sq, fp32) and attributes it to the rank's compute device, which under FakeTensorMode is reported as cpu (it would be cuda in real training).
titan_demo/_patches.py applies five PyTorch patches and one TorchTitan patch on import. They are needed only when Context Parallel (cp > 1) is enabled under FakeTensorMode. The patches are idempotent and a no-op outside FakeTensorMode. See the module docstring for full rationale; in short:
_StridedShard.local_shard_size_and_offset-- wrap.tolist()inunset_fake_temporarilyso the internaltorch.arangedoes not trip the data-dependent guard.FSDPMemTrackerHOO support -- advertisesupports_higher_order_operatorsand handle HOOs in__torch_dispatch__so FlexAttention (a HOO) is tracked correctly.FSDPMemTrackerinfra mode -- setis_infra_mode() = Trueso FlexAttention's internaltorch.compiledoes not bail when the tracker is on the stack.ModTracker-- skiptorch.fx.GraphModuleinstances in the pre/post hooks so compile-generated modules do not corrupt the parent stack.FlexAttention._compiled_flex_attn-- dispatch to eagerflex_attentionunder fake mode (avoids the inductor lowering that has no CPU backend).redistribute_cost-- short-circuit to0.0for_StridedShardplacements under fake mode so strategy enumeration doesn't run the slow Dijkstra (avoids the hang on TP + CP combos). The execution-time transforms still go through the full path and produce correct shapes.
titan_demo/ Python package: public API
tests/ pytest smoke tests
pyproject.toml pip-installable metadata
pip install -e .[test]
pytestMirrors TorchTitan's setup (subset): ufmt (black + usort) and
flake8 with flake8-bugbear + pep8-naming.
pip install -e .[lint]
pre-commit install # auto-run on commit
pre-commit run --all-files # one-shot
# or directly:
ufmt format titan_demo tests
flake8 titan_demo tests --config=.flake8