Skip to content

fegin/titan-demo

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

13 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

titan-demo

Scripts and APIs for demonstrating TorchTitan parallelism on CPU-only environments (e.g., Google Colab) via PyTorch's FakeTensorMode.

Install (from a notebook)

# Colab's runtime ships with a stable torch that pip would skip as
# "already satisfied", leaving you on a release that doesn't match the
# nightly TorchTitan. Uninstall first, then force-install a pinned
# nightly wheel. Update the wheel URL to a more recent nightly if the
# pinned one stops resolving; pick from
# https://download.pytorch.org/whl/nightly/cpu/torch/
_WHEEL = "https://download-r2.pytorch.org/whl/nightly/cpu/torch-2.13.0.dev20260518%2Bcpu-cp312-cp312-manylinux_2_28_x86_64.whl"
!pip uninstall torch -y
!pip install $_WHEEL --force-reinstall --quiet --progress-bar off && echo "Installed torch."
!pip install git+https://github.com/pytorch/torchtitan.git --force-reinstall --quiet --progress-bar off && echo "Installed torchtitan."
!pip install git+https://github.com/fegin/titan-demo.git --force-reinstall --quiet --progress-bar off && echo "Installed titan-demo."

Quick start

import torch
from titan_demo import (
    make_model_spec,
    build_fake_model,
    make_parallel_dims,
    parallelize_fake_model,
    estimate_memory,
    print_memory_estimate,
)

# 1. Get the spec (with default sharding declarations installed).
spec = make_model_spec(
    "debugmodel",         # or "1B", "3B", "8B", "70B", "405B"
    seq_len=128,
)

# 2. (Optional) inspect the default sharding plan, then tweak spec.model
#    if you want to override anything.
from titan_demo import print_sharding_config
print_sharding_config(spec)              # root + embedding/loss + layer 0
# print_sharding_config(spec, layer_id=1)  # any other layer if needed

# 3. Build under FakeTensorMode (no real memory).
model, fake_mode = build_fake_model(spec.model, dtype=torch.bfloat16)
print(f"{spec.flavor}: {sum(p.numel() for p in model.parameters()):,} params")

# 4. Pick parallelism degrees. world_size required; dp_shard=-1 fills the
#    leftover. full_dtensor is always True.
parallel_dims = make_parallel_dims(world_size=8, tp=2)  # -> dp_shard=4

# 5. Apply TP/CP sharding + fully_shard wrapping. Auto-inits a fake
#    process group of size world_size on first call.
parallelize_fake_model(model, parallel_dims=parallel_dims)

# 6. Estimate per-rank peak memory (params, grads, opt-state, acts,
#    all-gather / reduce-scatter buffers). Internally runs one full
#    fake training step (forward + backward + AdamW.step) under
#    torch.distributed._tools.fsdp2_mem_tracker.FSDPMemTracker.
snap = estimate_memory(model, fake_mode, parallel_dims, batch_size=2, seq_len=64)
print_memory_estimate(snap, units="MiB")

Parameters and activations are FakeTensors. No real memory is allocated, so 70B and 405B builds finish in seconds on a single CPU.

Notebook

notebooks/parallelism_explorer.ipynb walks through three scenarios (8B on 8 GPUs, 70B on 128 GPUs, 405B on 1024 GPUs) with interactive widgets for tp, cp, dp_replicate, batch_size, seq_len. Pick a config, click "Run Interact", and see whether it FITS, is NEAR OOM (>=95% of per-GPU budget), or OOMs. Install the notebook extra (pip install -e .[notebook]) to get ipywidgets. CP under fake mode relies on the patches described below.

CLI

scripts/parallelism_explorer.py is the same explorer as a standalone Python script. Defaults to 70B + FSDP + TP + CP on 128 GPUs x 80 GiB; every axis is overridable via flags (--flavor, --world-size, --tp, --cp, --dp-replicate, --batch-size, --seq-len, ...). Run python scripts/parallelism_explorer.py --help for the full list.

What the memory estimate represents

The estimate models eager training memory: TorchTitan's standard config compiles only the FlexAttention kernel; the rest of the model (MLP, RMSNorm, etc.) runs eager, and our estimator measures the activations those eager ops allocate.

Under FakeTensorMode the FlexAttention call goes through its registered fake impl, which allocates just the outputs (output + lse). That matches what the real Triton flex_attention kernel allocates -- no quadratic (B, H, S, S) intermediate in either case.

If you run real training with full-model torch.compile enabled, fused MLP / norm kernels can drop some of the intermediates we count, so the actual peak may be lower than this estimate. The demo does not model that.

The OptState category lands on cpu in the raw tracker snapshot (an artifact of fake AdamW). The notebook helper rewrites it to a deterministic local_params * 12 bytes (AdamW master + exp_avg + exp_avg_sq, fp32) and attributes it to the rank's compute device, which under FakeTensorMode is reported as cpu (it would be cuda in real training).

Monkey-patches

titan_demo/_patches.py applies five PyTorch patches and one TorchTitan patch on import. They are needed only when Context Parallel (cp > 1) is enabled under FakeTensorMode. The patches are idempotent and a no-op outside FakeTensorMode. See the module docstring for full rationale; in short:

  • _StridedShard.local_shard_size_and_offset -- wrap .tolist() in unset_fake_temporarily so the internal torch.arange does not trip the data-dependent guard.
  • FSDPMemTracker HOO support -- advertise supports_higher_order_operators and handle HOOs in __torch_dispatch__ so FlexAttention (a HOO) is tracked correctly.
  • FSDPMemTracker infra mode -- set is_infra_mode() = True so FlexAttention's internal torch.compile does not bail when the tracker is on the stack.
  • ModTracker -- skip torch.fx.GraphModule instances in the pre/post hooks so compile-generated modules do not corrupt the parent stack.
  • FlexAttention._compiled_flex_attn -- dispatch to eager flex_attention under fake mode (avoids the inductor lowering that has no CPU backend).
  • redistribute_cost -- short-circuit to 0.0 for _StridedShard placements under fake mode so strategy enumeration doesn't run the slow Dijkstra (avoids the hang on TP + CP combos). The execution-time transforms still go through the full path and produce correct shapes.

Layout

titan_demo/        Python package: public API
tests/             pytest smoke tests
pyproject.toml     pip-installable metadata

Tests

pip install -e .[test]
pytest

Lint and format

Mirrors TorchTitan's setup (subset): ufmt (black + usort) and flake8 with flake8-bugbear + pep8-naming.

pip install -e .[lint]
pre-commit install            # auto-run on commit
pre-commit run --all-files    # one-shot
# or directly:
ufmt format titan_demo tests
flake8 titan_demo tests --config=.flake8

About

Demo Scripts for TorchTitan

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors