Tags: pytorch/pytorch
Tags
[TESTING] Append rules to sample skips and xfails (#185905) Since not long ago, an attempt of running `pytest test/functorch/test_vmap.py -v` ends up in an error: ``` File "/opt/pytorch/pytorch/test/functorch/test_vmap.py", line 6579, in <module> instantiate_device_type_tests(TestVmapOperatorsOpInfo, globals(), only_for=only_for) File "/usr/local/lib/python3.12/dist-packages/torch/testing/_internal/common_device_type.py", line 1074, in instantiate_device_type_tests device_type_test_class.instantiate_test( File "/usr/local/lib/python3.12/dist-packages/torch/testing/_internal/common_device_type.py", line 634, in instantiate_test instantiate_test_helper( File "/usr/local/lib/python3.12/dist-packages/torch/testing/_internal/common_device_type.py", line 540, in instantiate_test_helper test = decorator(test) ^^^^^^^^^^^^^^^ File "/usr/local/lib/python3.12/dist-packages/torch/testing/_internal/opinfo/core.py", line 1713, in __call__ raise RuntimeError("Multiple sets of sample_skips_and_xfails defined") RuntimeError: Multiple sets of sample_skips_and_xfails defined ================================================================================== short test summary info =================================================================================== ERROR ../opt/pytorch/pytorch/test/functorch/test_vmap.py - RuntimeError: Multiple sets of sample_skips_and_xfails defined !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Interrupted: 1 error during collection !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ====================================================================================== 1 error in 3.91s ====================================================================================== ``` due to `test_vmap_exhaustive` having multiple decorators `skipIf` and `xfailIf` from `test/functorch/common_utils.py`. Both of these rely on `sample_skips_and_xfails`: https://github.com/pytorch/pytorch/blob/d3ce23d75ee5e488787aafb12c281b5142d91e75/test/functorch/common_utils.py#L481-L497 https://github.com/pytorch/pytorch/blob/d3ce23d75ee5e488787aafb12c281b5142d91e75/test/functorch/common_utils.py#L510-L526 Though, using a mix of these on a certain test end up in the `RuntimeError`: https://github.com/pytorch/pytorch/blob/d3ce23d75ee5e488787aafb12c281b5142d91e75/torch/testing/_internal/opinfo/core.py#L1713 This PR proposes appending the rules, as this simplifies the usage of `sample_skips_and_xfails`. Fixes #184894 Pull Request resolved: #185905 Approved by: https://github.com/benjaminglass1, https://github.com/eqy
[pipelining] Add guards for non-float tensors when building pipeline (#… …183582) Fixes #183024. Tested working with code below by manually applying the patch over 2.11. Analysis and initial patches found using AI. Manually applied and tested patches. <details> <summary> Test code, run with torchrun on 2 gpus </summary> ```python import torch import torch.distributed as dist from torch.distributed.pipelining import pipeline, SplitPoint, PipelineStage, Schedule1F1B from transformers import AutoModelForCausalLM import os # Initialize torchrun's distributed environment pp_group = dist.init_process_group(backend="nccl") rank = dist.get_rank() world_size = dist.get_world_size() local_rank = int(os.environ["LOCAL_RANK"]) # Assign this specific process to its designated GPU device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(device) print(f"[Rank {rank}] Starting model init on {device}") # GPT2 model_id = "openai-community/gpt2" device_count = torch.cuda.device_count() model = AutoModelForCausalLM.from_pretrained(model_id) print(model) # kv cache is enabled by default and causes other issues with dynamo tracing/pipelining, disabling model.config.use_cache = False # run on 2 or more gpus split_spec = { f"transformer.h.5": SplitPoint.END, } model = model.to(device) pipe_module = pipeline( module=model, mb_args=(), mb_kwargs={ 'input_ids': torch.zeros((1, 512), dtype=torch.long).to(device), 'attention_mask': torch.ones((1, 512), dtype=torch.bool).to(device) }, split_spec=split_spec, ) stage = pipe_module.build_stage(rank, device, pp_group) # placeholder loss function def loss_fn(outputs, other): if isinstance(outputs, tuple): return outputs[0] elif isinstance(outputs, torch.Tensor): logits = outputs else: logits = outputs['logits'] loss = logits.sum() - other.sum() return loss schedule = Schedule1F1B(stage, n_microbatches=4, loss_fn=loss_fn) if rank == 0: inputs = { 'input_ids': torch.zeros((4, 512), dtype=torch.long).to(device), 'attention_mask': torch.ones((4, 512), dtype=torch.bool).to(device), } schedule.step( **inputs ) elif rank == world_size - 1: losses = [] output = schedule.step(target=torch.randn(4, 512).to(device), losses=losses) print(f"losses: {losses}") else: schedule.step() ``` </details> <details> <summary> Prompt and Initial AI output </summary> Prompt: Analyze, triage, and provide me with a fix for this github issue `https://github.com/pytorch/pytorch/issues/183024` Z.ai agent response: I'll analyze this GitHub issue for you. Let me start by fetching the issue details and understanding the problem. Let me extract the actual issue content more cleanly. Now let me look at the PyTorch source code for the specific failing function to understand the bug deeply. Now I have a complete picture. Let me also verify the related backward code path to ensure the fix is comprehensive. I now have a comprehensive understanding of the issue. Here's my full analysis, triage, and proposed fix: --- ## 🔍 Issue Analysis: PyTorch #183024 ### **Title:** [pipelining] Unable to use PP with HF transformers models due to non-float inputs ### **Bug Summary** When using PyTorch's Pipeline Parallelism (`torch.distributed.pipelining`) with HuggingFace transformer models (GPT2, LLaMA, etc.), the pipeline crashes with: ``` RuntimeError: only Tensors of floating point dtype can require gradients ``` --- ## 🧬 Root Cause The root cause is a **dtype-agnostic `requires_grad` assignment** in the pipeline stage infrastructure. When `has_backward=True` (training mode), the code **unconditionally** sets `requires_grad=True` on **all** activation receive/send buffers — including non-float tensors like `input_ids` (`torch.long`) and `attention_mask` (`torch.bool`). PyTorch only allows floating-point tensors to require gradients (a fundamental constraint since [Issue #37680, open since 2020](#37680)). The crash occurs at **three specific locations**: ### **Location 1: `torch/distributed/pipelining/stage.py` — `create_recv_tensor` inside `_create_act_recv_info`** (PRIMARY CRASH) ```python # Lines ~1361-1377 (broken) tensor_meta = _TensorMeta( shape=example_value.shape, stride=example_value.stride(), dtype=example_value.dtype, requires_grad=self.has_backward, # ← BUG: True for ALL dtypes ) buffer = _make_tensor_from_meta(tensor_meta, self.device) if self.has_backward: buffer.requires_grad_(True) # ← CRASH: non-float tensors can't require grad ``` This is the **exact crash point** from the traceback. When the pipeline creates receive buffers for intermediate activations between stages, it encounters `input_ids` (dtype=`torch.long`) or `attention_mask` (dtype=`torch.bool`) and tries to set `requires_grad_(True)`, which is illegal. ### **Location 2: `torch/distributed/pipelining/stage.py` — `_create_act_send_info`** (SECONDARY) ```python # Lines ~1462-1468 (broken) output_metas.append( _TensorMeta( shape=val.shape, stride=val.stride(), dtype=val.dtype, requires_grad=self.has_backward, # ← Same bug for outputs ) ) ``` ### **Location 3: `torch/distributed/pipelining/_utils.py` — `_TensorMeta.to_tensor`** (LATENT) ```python # Line ~96 (broken) t = _make_tensor_from_meta(self, device) t.requires_grad_(self.requires_grad) # ← Crashes if requires_grad=True and dtype is non-float return t ``` ### **Additional Related Bug: `torch/distributed/pipelining/_backward.py` — `_get_grad_fn_or_grad_acc`** Even after fixing the above, the backward pass has a **second bug** waiting to surface. The function `_get_grad_fn_or_grad_acc` assumes its argument is always a `torch.Tensor`: ```python def _get_grad_fn_or_grad_acc(t: torch.Tensor) -> Node | None: if t.requires_grad and t.grad_fn is None: # ← AttributeError if t is not a Tensor ... ``` This is called via `map(_get_grad_fn_or_grad_acc, input_values)` in `stage_backward_input`, where `input_values` can contain non-tensor objects. While the immediate crash is masked by Location 1, once that's fixed, this will be the **next crash site**. --- ## 🔧 Proposed Fix The fix is to **guard `requires_grad=True` with a `tensor.is_floating_point()` check** at all affected locations. This is semantically correct because: - Non-float tensors (integer, boolean) **never** produce meaningful gradients - Their gradient in the backward pass should be `None` - This is consistent with PyTorch's autograd semantics ### **Patch 1: `torch/distributed/pipelining/stage.py` — `create_recv_tensor`** ```python def create_recv_tensor(placeholder, arg_node): example_value = placeholder.meta["val"] # ... (DTensor check unchanged) ... if arg_node.op == "placeholder": return _RecvInfo( input_name=f"root_input_{placeholder.name}", source=None, buffer=None, tensor_meta=_TensorMeta.from_tensor(example_value), is_root_arg=True, ) # Received activation from a previous stage. while arg_node.target is operator.getitem: arg_node = arg_node.args[0] if arg_node.op != "call_module": raise PipeliningMetadataError( f"Expecting call_module, got {arg_node.op}" ) src_stage = self.get_stage_index_of_submod(arg_node.name) # FIX: Only set requires_grad for floating-point tensors needs_grad = self.has_backward and example_value.is_floating_point() tensor_meta = _TensorMeta( shape=example_value.shape, stride=example_value.stride(), dtype=example_value.dtype, requires_grad=needs_grad, # ← was: self.has_backward ) buffer = _make_tensor_from_meta(tensor_meta, self.device) if needs_grad: # ← was: if self.has_backward buffer.requires_grad_(True) return _RecvInfo( arg_node.name, src_stage, buffer, tensor_meta, ) ``` ### **Patch 2: `torch/distributed/pipelining/stage.py` — `_create_act_send_info`** ```python # Inside _create_act_send_info, where output_metas are built: output_metas.append( _TensorMeta( shape=val.shape, stride=val.stride(), dtype=val.dtype, requires_grad=self.has_backward and val.is_floating_point(), # ← was: self.has_backward ) ) ``` ### **Patch 3: `torch/distributed/pipelining/_utils.py` — `_TensorMeta.to_tensor`** ```python def to_tensor(self, device: torch.device | str) -> torch.Tensor: t = _make_tensor_from_meta(self, device) if self.requires_grad and t.is_floating_point(): # ← was: unconditionally t.requires_grad_(True) return t ``` ### **Patch 4: `torch/distributed/pipelining/_backward.py` — `_get_grad_fn_or_grad_acc`** ```python def _get_grad_fn_or_grad_acc(t: torch.Tensor) -> Node | None: # FIX: Guard against non-tensor inputs if not isinstance(t, torch.Tensor): return None if t.requires_grad and t.grad_fn is None: viewed_t = t.view_as(t) grad_fn = viewed_t.grad_fn if grad_fn is not None: return grad_fn.next_functions[0][0] else: raise RuntimeError( "Attempted to get grad_fn, but got None." "Is this being created in a no-grad context?" ) else: return t.grad_fn ``` --- ## 📊 Triage Assessment | Aspect | Rating | |--------|--------| | **Severity** | 🔴 **High** — Completely blocks pipeline parallelism for any model with non-float inputs (all HF transformer models) | | **Impact** | 🔴 **High** — Affects GPT2, LLaMA, BERT, T5 — essentially all popular transformer architectures | | **Workaround?** | 🟡 **Difficult** — Would require rewriting models to avoid passing non-float tensors between stages, or monkey-patching the pipelining internals | | **Fix Complexity** | 🟢 **Low** — Four surgical `is_floating_point()` / `isinstance` guards, no architectural changes | | **Regression Risk** | 🟢 **Low** — The fix only restricts `requires_grad` for cases that already crash; float tensors are unaffected | ### Why PRs #182182 and #182644 Don't Fix This - **PR #182182** fixes None gradient handling for float tensors that are unused — it doesn't address the dtype mismatch - **PR #182644** fixes FX graph node ordering — it doesn't touch the `requires_grad` logic at all - This is a **distinct, orthogonal bug** that needs its own fix --- ## 🚀 Suggested Next Steps 1. **File a PR** against `pytorch/pytorch` with the four patches above 2. **Add regression tests** that pipeline an HF model (even a tiny one like GPT2-small) with `input_ids` and `attention_mask` 3. **Coordinate with PR #182182** — both fixes are needed for full pipeline parallelism support, and they should be tested together since they both touch the backward path Prompt: Check your work Z.ai agent response: Good call — let me verify my analysis against the actual PyTorch source code. I'll cross-check each claim. Now I have verified all the source code against my original analysis. Let me also check one more thing — whether the `to_tensor` method is actually called in the crash path or not. Now let me verify the secondary bug claim about `_get_grad_fn_or_grad_acc` — whether it would actually be hit in the backward path after fixing the primary bug. Good — my verification is complete. Here's the **corrected and refined** analysis: --- ## ✅ Verification Results ### Claims Confirmed ✅ | Claim | Status | Evidence | |-------|--------|----------| | **Primary crash: `create_recv_tensor` sets `requires_grad=self.has_backward` unconditionally** | ✅ **Confirmed** | Exact match with source at `stage.py` lines ~1361-1377 | | **Primary crash: `buffer.requires_grad_(True)` when `self.has_backward`** | ✅ **Confirmed** | Line ~1377, crashes for `torch.long`/`torch.bool` tensors | | **Secondary bug: `_create_act_send_info` uses same `requires_grad=self.has_backward` pattern** | ✅ **Confirmed** | Line ~1467, same unconditional pattern for outputs | | **Latent bug: `_TensorMeta.to_tensor()` calls `t.requires_grad_(self.requires_grad)` unconditionally** | ✅ **Confirmed** | `_utils.py` line ~96, same crash pattern | | **PR #182182 does NOT fix this** | ✅ **Confirmed** | It fixes None gradient handling for float tensors, not the dtype mismatch | | **PR #182644 does NOT fix this** | ✅ **Confirmed** | It fixes FX graph node ordering, doesn't touch `requires_grad` logic | | **`is_floating_point()` guard is the correct fix** | ✅ **Confirmed** | Semantically correct: non-float tensors can never produce meaningful gradients in PyTorch | ### Claim Clarified⚠️ | Claim | Status | Correction | |-------|--------|------------| | **`_get_grad_fn_or_grad_acc` will crash next after the primary fix** |⚠️ **Partially correct** | This IS a real bug, but it's for a **different scenario** — when non-Tensor Python objects (int, bool, None) appear in `input_values`. In the HF GPT2 pipeline case specifically, `input_values` contains `torch.Tensor` objects (just non-float ones), which after the primary fix will have `requires_grad=False` and work fine. The `isinstance` guard is still a valid defense-in-depth fix but is **not directly triggered** by this issue. | ### Revised Fix Assessment **Patch 1 (`create_recv_tensor`)** — **Essential, directly fixes the crash** ✅ **Patch 2 (`_create_act_send_info`)** — **Important, prevents same crash for non-float outputs crossing stage boundaries** ✅ (Not triggered in GPT2 specifically since embedding outputs are float, but needed for models where non-float tensors cross stages) **Patch 3 (`_TensorMeta.to_tensor`)** — **Defensive, prevents latent crash if this method is called with `requires_grad=True` on non-float metadata** ✅ **Patch 4 (`_get_grad_fn_or_grad_acc`)** — **Separate bug, defense-in-depth**⚠️ Should be framed as a separate fix for the torchtitan #3112 scenario (non-Tensor Python objects in `input_values`), not as part of the immediate fix for issue #183024. --- ### Bottom Line My original analysis was **substantially correct**. The three core patches (1, 2, 3) are accurate and necessary. Patch 4 is a valid bug fix but should be presented as an **orthogonal improvement** rather than a direct consequence of fixing issue #183024. The root cause, severity assessment, and fix approach all hold up under verification. </details> Pull Request resolved: #183582 Approved by: https://github.com/sanketpurandare
(dynamo) Implement tp_str slot via generic_str mirroring PyObject_Str (… …#180494) Introduces str_impl as a proper slot on VariableTracker (mirroring CPython's tp_str) and generic_str in object_protocol.py as the single dispatch point for all str() calls. Previously, str() was handled via an ad-hoc isinstance cascade in BuiltinVariable.call_str. Pull Request resolved: #180494 Approved by: https://github.com/guilhermeleobas, https://github.com/trichmo
[FSDP2] Add set_separate_reduce_scatter_group (opt-in AG/RS overlap) (#… …186335) By default FSDP2 runs all-gather and reduce-scatter on separate CUDA streams but through the same process group -- one NCCL communicator, which processes one collective at a time and so serializes them on the wire. This adds an opt-in FSDPModule API to give reduce-scatter its own communicator: FSDPModule.set_separate_reduce_scatter_group(enable=True, *, recurse=True) When enabled, FSDP creates a dedicated process group over the shard ranks (dist.new_group), one per distinct set of shard ranks (typically one communicator), so reduce-scatter and all-gather can progress concurrently when the network can sustain it; enable=False resets to the shared group. The default is unchanged -- reduce-scatter shares the shard process group -- so this is purely opt-in and creates no extra communicators unless requested. This redesigns the approach explored in (closed) PR #177015, which made the separate communicator the unconditional default (and created one even for post-forward meshes); here it is an opt-in toggle. Test Plan: ``` python test/distributed/_composable/fsdp/test_fully_shard_overlap.py \ TestFullyShardOverlap.test_set_separate_reduce_scatter_group \ TestFullyShardOverlap.test_fully_shard_backward_comm_overlap ``` Both pass on 4xH100: - test_set_separate_reduce_scatter_group: default shares the shard PG; enabling creates one dedicated PG shared across same-rank-set meshes; disabling resets to the shared PG. - test_fully_shard_backward_comm_overlap: real backward AG/RS overlap (large matmuls + collectives) is no slower than a serialized single-communicator reference. Authored with Claude. Co-authored-by: Lei Tian <2119521+leitian@users.noreply.github.com> Pull Request resolved: #186335 Approved by: https://github.com/anshul-si ghstack dependencies: #186000 Co-authored-by: Lei Tian <2119521+leitian@users.noreply.github.com>
[DTensor] Fix group_norm scalar adjuster crash when weight=None (#184819 ) Fixes #184816 The _adjust_group_norm_scalars function counts DTensorSpec args to find where the N/C/HxW scalar args start. When weight or bias is None (e.g. GroupNorm with affine=False), the None slots are not DTensorSpec, so the count is too low and the adjuster overwrites the None weight slot with the local N value. This causes a RuntimeError: expected Optional[Tensor] but got int. Fix by counting None slots alongside DTensorSpec when computing the offset. Adds a regression test for group_norm with weight=None. Authored by Claude. Pull Request resolved: #184819 Approved by: https://github.com/aditvenk
[FakeTensor] Add hinted symbolic storage size metadata (#183839) Trace tooling serializes FakeTensor storage metadata to JSON. When the storage size is a hinted SymInt, the trace can have both useful pieces of information: - the symbolic storage expression, which preserves provenance for downstream trace consumers; - the optimization hint, which gives diagnostic/policy tooling a concrete expected extent when one was explicitly provided. Keep the existing `size` field symbolic for symbolic storage sizes and add a separate `size_hint` field only when every free symbol in the storage-size expression has an explicit optimization hint override. This preserves the existing symbolic trace contract while exposing concrete policy metadata without specializing tensor shapes or changing runtime semantics. Fixes #183835 Test Plan: python test/test_fake_tensor.py FakeTensorTest.test_meta_storage_trace_uses_hint_for_symbolic_size -q lintrunner --config=.lintrunner.toml torch/_subclasses/meta_utils.py test/test_fake_tensor.py Pull Request resolved: #183839 Approved by: https://github.com/ezyang, https://github.com/laithsakka
[inductor] Replace topological sort with direct node moves in overlap… … scheduling (#184711) `ManualOverlapScheduler._manual_reorder_graph` used a full-graph topological sort to enforce a handful of all-gather prefetch and reduce-scatter defer dependencies. On a Llama 70B graph this moved ~97% of nodes (18,147 / 18,624) just to reposition a few AG/RS chains — making the pass unpredictable and hard to compose with downstream passes. Replace with `_move_overlap_nodes` which surgically repositions only the AG start chains (earlier) and RS wait chains (later), leaving the remaining ~95% of the graph untouched. ### Complexity comparison | | Stable topological sort | `_move_overlap_nodes` | |---|---|---| | **Time** | O(N + E), all nodes and edges | O(N) position dict + O(K * V) chain moves | | **Nodes visited** | All N nodes re-sorted | Only K chains of V nodes each | | **Nodes displaced** | ~97% (18,147 / 18,624) | ~5% (AG/RS chains only) | | **Graph disruption** | Full reshuffle | Surgical — untouched nodes stay in place | N ~ 18K (graph nodes), K ~ 30 (transformer layers), V ~ 3 (chain depth). ### New helpers - **`_collect_nodes_must_be_after(node)`** — BFS forward collecting transitive users whose inputs are all satisfied within the set. Used for RS wait+unpack chains. - **`_collect_nodes_must_be_before(node, node_positions)`** — BFS backward collecting non-placeholder dependencies, topo-sorted. O(V) per chain. Used for AG start chains. - **`_move_overlap_nodes(graph, overlap_deps, bucketed_node_types)`** — Classifies overlap dependencies into RS defer / AG prefetch, then repositions each chain via `Node.prepend` / `Node.append`. ### How it works **RS defer**: For each RS wait, find the latest RS start it must follow, collect the wait + unpack chain, move it right after that RS start. **AG prefetch**: For each AG wait, collect the AG start chains that should be prefetched before it, move each chain right before the AG wait. Pull Request resolved: #184711 Approved by: https://github.com/SherlockNoMad
Route ProcessGroup Python constructors through the PyProcessGroup tra… …mpoline (#186853) The 3-arg ProcessGroup constructor exposed to Python used a single-callback py::init factory that unconditionally returned c10::make_intrusive<ProcessGroup>(...). When a Python subclass of ProcessGroup was constructed via this path, pybind11 needs the most-derived (alias) instance, so it runs an is_alias dynamic_cast check on the returned holder. Because the factory returned a base ProcessGroup rather than the PyProcessGroup trampoline, that check failed and construction raised: TypeError: pybind11::init(): construction failed: returned holder-wrapped instance is not an alias instance So any Python ProcessGroup subclass that called super().__init__(store, rank, size) was unusable; even had it succeeded, the object would have been a raw C++ ProcessGroup with no dispatch back into the Python overrides. The fix switches both Python-facing constructors to pybind11's two-callback factory form: the first callback builds a plain ProcessGroup (used when ProcessGroup itself is constructed), the second builds the PyProcessGroup trampoline (used when a Python subclass is constructed) so overridden virtual methods dispatch back into Python. This is factored into an init_nogil helper. Both callbacks release the GIL via a local gil_scoped_release rather than a call_guard, which is unsafe in init (pybind/pybind11#5473); the 2-arg constructor previously used py::init<int, int>() and now releases the GIL the same way for consistency. A two-callback factory is required (rather than reverting to py::init<...>()) because the 3-arg constructor must release the GIL during construction, which is only expressible through a factory lambda. Test Plan: Added test_store_constructor to test_c10d_pypg. It constructs a ProcessGroup subclass via the 3-arg (store, rank, size) constructor, verifies the store is accessible through get_group_store(), and verifies that a getBackendName() override dispatches back into Python (proving the trampoline is in use). Before the fix this test failed at construction with the TypeError above. pip install -e . -v --no-build-isolation python test/distributed/test_c10d_pypg.py All 51 tests pass. Pull Request resolved: #186853 Approved by: https://github.com/dolpm, https://github.com/kapilsh
[FSDP2] Add set_separate_reduce_scatter_group (opt-in AG/RS overlap) (#… …186335) By default FSDP2 runs all-gather and reduce-scatter on separate CUDA streams but through the same process group -- one NCCL communicator, which processes one collective at a time and so serializes them on the wire. This adds an opt-in FSDPModule API to give reduce-scatter its own communicator: FSDPModule.set_separate_reduce_scatter_group(enable=True, *, recurse=True) When enabled, FSDP creates a dedicated process group over the shard ranks (dist.new_group), one per distinct set of shard ranks (typically one communicator), so reduce-scatter and all-gather can progress concurrently when the network can sustain it; enable=False resets to the shared group. The default is unchanged -- reduce-scatter shares the shard process group -- so this is purely opt-in and creates no extra communicators unless requested. This redesigns the approach explored in (closed) PR #177015, which made the separate communicator the unconditional default (and created one even for post-forward meshes); here it is an opt-in toggle. Test Plan: ``` python test/distributed/_composable/fsdp/test_fully_shard_overlap.py \ TestFullyShardOverlap.test_set_separate_reduce_scatter_group \ TestFullyShardOverlap.test_fully_shard_backward_comm_overlap ``` Both pass on 4xH100: - test_set_separate_reduce_scatter_group: default shares the shard PG; enabling creates one dedicated PG shared across same-rank-set meshes; disabling resets to the shared PG. - test_fully_shard_backward_comm_overlap: real backward AG/RS overlap (large matmuls + collectives) is no slower than a serialized single-communicator reference. Authored with Claude. Co-authored-by: Lei Tian <2119521+leitian@users.noreply.github.com> Pull Request resolved: #186335 Approved by: https://github.com/anshul-si ghstack dependencies: #186000 Co-authored-by: Lei Tian <2119521+leitian@users.noreply.github.com>
[pipelining] Add guards for non-float tensors when building pipeline (#… …183582) Fixes #183024. Tested working with code below by manually applying the patch over 2.11. Analysis and initial patches found using AI. Manually applied and tested patches. <details> <summary> Test code, run with torchrun on 2 gpus </summary> ```python import torch import torch.distributed as dist from torch.distributed.pipelining import pipeline, SplitPoint, PipelineStage, Schedule1F1B from transformers import AutoModelForCausalLM import os # Initialize torchrun's distributed environment pp_group = dist.init_process_group(backend="nccl") rank = dist.get_rank() world_size = dist.get_world_size() local_rank = int(os.environ["LOCAL_RANK"]) # Assign this specific process to its designated GPU device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(device) print(f"[Rank {rank}] Starting model init on {device}") # GPT2 model_id = "openai-community/gpt2" device_count = torch.cuda.device_count() model = AutoModelForCausalLM.from_pretrained(model_id) print(model) # kv cache is enabled by default and causes other issues with dynamo tracing/pipelining, disabling model.config.use_cache = False # run on 2 or more gpus split_spec = { f"transformer.h.5": SplitPoint.END, } model = model.to(device) pipe_module = pipeline( module=model, mb_args=(), mb_kwargs={ 'input_ids': torch.zeros((1, 512), dtype=torch.long).to(device), 'attention_mask': torch.ones((1, 512), dtype=torch.bool).to(device) }, split_spec=split_spec, ) stage = pipe_module.build_stage(rank, device, pp_group) # placeholder loss function def loss_fn(outputs, other): if isinstance(outputs, tuple): return outputs[0] elif isinstance(outputs, torch.Tensor): logits = outputs else: logits = outputs['logits'] loss = logits.sum() - other.sum() return loss schedule = Schedule1F1B(stage, n_microbatches=4, loss_fn=loss_fn) if rank == 0: inputs = { 'input_ids': torch.zeros((4, 512), dtype=torch.long).to(device), 'attention_mask': torch.ones((4, 512), dtype=torch.bool).to(device), } schedule.step( **inputs ) elif rank == world_size - 1: losses = [] output = schedule.step(target=torch.randn(4, 512).to(device), losses=losses) print(f"losses: {losses}") else: schedule.step() ``` </details> <details> <summary> Prompt and Initial AI output </summary> Prompt: Analyze, triage, and provide me with a fix for this github issue `https://github.com/pytorch/pytorch/issues/183024` Z.ai agent response: I'll analyze this GitHub issue for you. Let me start by fetching the issue details and understanding the problem. Let me extract the actual issue content more cleanly. Now let me look at the PyTorch source code for the specific failing function to understand the bug deeply. Now I have a complete picture. Let me also verify the related backward code path to ensure the fix is comprehensive. I now have a comprehensive understanding of the issue. Here's my full analysis, triage, and proposed fix: --- ## 🔍 Issue Analysis: PyTorch #183024 ### **Title:** [pipelining] Unable to use PP with HF transformers models due to non-float inputs ### **Bug Summary** When using PyTorch's Pipeline Parallelism (`torch.distributed.pipelining`) with HuggingFace transformer models (GPT2, LLaMA, etc.), the pipeline crashes with: ``` RuntimeError: only Tensors of floating point dtype can require gradients ``` --- ## 🧬 Root Cause The root cause is a **dtype-agnostic `requires_grad` assignment** in the pipeline stage infrastructure. When `has_backward=True` (training mode), the code **unconditionally** sets `requires_grad=True` on **all** activation receive/send buffers — including non-float tensors like `input_ids` (`torch.long`) and `attention_mask` (`torch.bool`). PyTorch only allows floating-point tensors to require gradients (a fundamental constraint since [Issue #37680, open since 2020](#37680)). The crash occurs at **three specific locations**: ### **Location 1: `torch/distributed/pipelining/stage.py` — `create_recv_tensor` inside `_create_act_recv_info`** (PRIMARY CRASH) ```python # Lines ~1361-1377 (broken) tensor_meta = _TensorMeta( shape=example_value.shape, stride=example_value.stride(), dtype=example_value.dtype, requires_grad=self.has_backward, # ← BUG: True for ALL dtypes ) buffer = _make_tensor_from_meta(tensor_meta, self.device) if self.has_backward: buffer.requires_grad_(True) # ← CRASH: non-float tensors can't require grad ``` This is the **exact crash point** from the traceback. When the pipeline creates receive buffers for intermediate activations between stages, it encounters `input_ids` (dtype=`torch.long`) or `attention_mask` (dtype=`torch.bool`) and tries to set `requires_grad_(True)`, which is illegal. ### **Location 2: `torch/distributed/pipelining/stage.py` — `_create_act_send_info`** (SECONDARY) ```python # Lines ~1462-1468 (broken) output_metas.append( _TensorMeta( shape=val.shape, stride=val.stride(), dtype=val.dtype, requires_grad=self.has_backward, # ← Same bug for outputs ) ) ``` ### **Location 3: `torch/distributed/pipelining/_utils.py` — `_TensorMeta.to_tensor`** (LATENT) ```python # Line ~96 (broken) t = _make_tensor_from_meta(self, device) t.requires_grad_(self.requires_grad) # ← Crashes if requires_grad=True and dtype is non-float return t ``` ### **Additional Related Bug: `torch/distributed/pipelining/_backward.py` — `_get_grad_fn_or_grad_acc`** Even after fixing the above, the backward pass has a **second bug** waiting to surface. The function `_get_grad_fn_or_grad_acc` assumes its argument is always a `torch.Tensor`: ```python def _get_grad_fn_or_grad_acc(t: torch.Tensor) -> Node | None: if t.requires_grad and t.grad_fn is None: # ← AttributeError if t is not a Tensor ... ``` This is called via `map(_get_grad_fn_or_grad_acc, input_values)` in `stage_backward_input`, where `input_values` can contain non-tensor objects. While the immediate crash is masked by Location 1, once that's fixed, this will be the **next crash site**. --- ## 🔧 Proposed Fix The fix is to **guard `requires_grad=True` with a `tensor.is_floating_point()` check** at all affected locations. This is semantically correct because: - Non-float tensors (integer, boolean) **never** produce meaningful gradients - Their gradient in the backward pass should be `None` - This is consistent with PyTorch's autograd semantics ### **Patch 1: `torch/distributed/pipelining/stage.py` — `create_recv_tensor`** ```python def create_recv_tensor(placeholder, arg_node): example_value = placeholder.meta["val"] # ... (DTensor check unchanged) ... if arg_node.op == "placeholder": return _RecvInfo( input_name=f"root_input_{placeholder.name}", source=None, buffer=None, tensor_meta=_TensorMeta.from_tensor(example_value), is_root_arg=True, ) # Received activation from a previous stage. while arg_node.target is operator.getitem: arg_node = arg_node.args[0] if arg_node.op != "call_module": raise PipeliningMetadataError( f"Expecting call_module, got {arg_node.op}" ) src_stage = self.get_stage_index_of_submod(arg_node.name) # FIX: Only set requires_grad for floating-point tensors needs_grad = self.has_backward and example_value.is_floating_point() tensor_meta = _TensorMeta( shape=example_value.shape, stride=example_value.stride(), dtype=example_value.dtype, requires_grad=needs_grad, # ← was: self.has_backward ) buffer = _make_tensor_from_meta(tensor_meta, self.device) if needs_grad: # ← was: if self.has_backward buffer.requires_grad_(True) return _RecvInfo( arg_node.name, src_stage, buffer, tensor_meta, ) ``` ### **Patch 2: `torch/distributed/pipelining/stage.py` — `_create_act_send_info`** ```python # Inside _create_act_send_info, where output_metas are built: output_metas.append( _TensorMeta( shape=val.shape, stride=val.stride(), dtype=val.dtype, requires_grad=self.has_backward and val.is_floating_point(), # ← was: self.has_backward ) ) ``` ### **Patch 3: `torch/distributed/pipelining/_utils.py` — `_TensorMeta.to_tensor`** ```python def to_tensor(self, device: torch.device | str) -> torch.Tensor: t = _make_tensor_from_meta(self, device) if self.requires_grad and t.is_floating_point(): # ← was: unconditionally t.requires_grad_(True) return t ``` ### **Patch 4: `torch/distributed/pipelining/_backward.py` — `_get_grad_fn_or_grad_acc`** ```python def _get_grad_fn_or_grad_acc(t: torch.Tensor) -> Node | None: # FIX: Guard against non-tensor inputs if not isinstance(t, torch.Tensor): return None if t.requires_grad and t.grad_fn is None: viewed_t = t.view_as(t) grad_fn = viewed_t.grad_fn if grad_fn is not None: return grad_fn.next_functions[0][0] else: raise RuntimeError( "Attempted to get grad_fn, but got None." "Is this being created in a no-grad context?" ) else: return t.grad_fn ``` --- ## 📊 Triage Assessment | Aspect | Rating | |--------|--------| | **Severity** | 🔴 **High** — Completely blocks pipeline parallelism for any model with non-float inputs (all HF transformer models) | | **Impact** | 🔴 **High** — Affects GPT2, LLaMA, BERT, T5 — essentially all popular transformer architectures | | **Workaround?** | 🟡 **Difficult** — Would require rewriting models to avoid passing non-float tensors between stages, or monkey-patching the pipelining internals | | **Fix Complexity** | 🟢 **Low** — Four surgical `is_floating_point()` / `isinstance` guards, no architectural changes | | **Regression Risk** | 🟢 **Low** — The fix only restricts `requires_grad` for cases that already crash; float tensors are unaffected | ### Why PRs #182182 and #182644 Don't Fix This - **PR #182182** fixes None gradient handling for float tensors that are unused — it doesn't address the dtype mismatch - **PR #182644** fixes FX graph node ordering — it doesn't touch the `requires_grad` logic at all - This is a **distinct, orthogonal bug** that needs its own fix --- ## 🚀 Suggested Next Steps 1. **File a PR** against `pytorch/pytorch` with the four patches above 2. **Add regression tests** that pipeline an HF model (even a tiny one like GPT2-small) with `input_ids` and `attention_mask` 3. **Coordinate with PR #182182** — both fixes are needed for full pipeline parallelism support, and they should be tested together since they both touch the backward path Prompt: Check your work Z.ai agent response: Good call — let me verify my analysis against the actual PyTorch source code. I'll cross-check each claim. Now I have verified all the source code against my original analysis. Let me also check one more thing — whether the `to_tensor` method is actually called in the crash path or not. Now let me verify the secondary bug claim about `_get_grad_fn_or_grad_acc` — whether it would actually be hit in the backward path after fixing the primary bug. Good — my verification is complete. Here's the **corrected and refined** analysis: --- ## ✅ Verification Results ### Claims Confirmed ✅ | Claim | Status | Evidence | |-------|--------|----------| | **Primary crash: `create_recv_tensor` sets `requires_grad=self.has_backward` unconditionally** | ✅ **Confirmed** | Exact match with source at `stage.py` lines ~1361-1377 | | **Primary crash: `buffer.requires_grad_(True)` when `self.has_backward`** | ✅ **Confirmed** | Line ~1377, crashes for `torch.long`/`torch.bool` tensors | | **Secondary bug: `_create_act_send_info` uses same `requires_grad=self.has_backward` pattern** | ✅ **Confirmed** | Line ~1467, same unconditional pattern for outputs | | **Latent bug: `_TensorMeta.to_tensor()` calls `t.requires_grad_(self.requires_grad)` unconditionally** | ✅ **Confirmed** | `_utils.py` line ~96, same crash pattern | | **PR #182182 does NOT fix this** | ✅ **Confirmed** | It fixes None gradient handling for float tensors, not the dtype mismatch | | **PR #182644 does NOT fix this** | ✅ **Confirmed** | It fixes FX graph node ordering, doesn't touch `requires_grad` logic | | **`is_floating_point()` guard is the correct fix** | ✅ **Confirmed** | Semantically correct: non-float tensors can never produce meaningful gradients in PyTorch | ### Claim Clarified⚠️ | Claim | Status | Correction | |-------|--------|------------| | **`_get_grad_fn_or_grad_acc` will crash next after the primary fix** |⚠️ **Partially correct** | This IS a real bug, but it's for a **different scenario** — when non-Tensor Python objects (int, bool, None) appear in `input_values`. In the HF GPT2 pipeline case specifically, `input_values` contains `torch.Tensor` objects (just non-float ones), which after the primary fix will have `requires_grad=False` and work fine. The `isinstance` guard is still a valid defense-in-depth fix but is **not directly triggered** by this issue. | ### Revised Fix Assessment **Patch 1 (`create_recv_tensor`)** — **Essential, directly fixes the crash** ✅ **Patch 2 (`_create_act_send_info`)** — **Important, prevents same crash for non-float outputs crossing stage boundaries** ✅ (Not triggered in GPT2 specifically since embedding outputs are float, but needed for models where non-float tensors cross stages) **Patch 3 (`_TensorMeta.to_tensor`)** — **Defensive, prevents latent crash if this method is called with `requires_grad=True` on non-float metadata** ✅ **Patch 4 (`_get_grad_fn_or_grad_acc`)** — **Separate bug, defense-in-depth**⚠️ Should be framed as a separate fix for the torchtitan #3112 scenario (non-Tensor Python objects in `input_values`), not as part of the immediate fix for issue #183024. --- ### Bottom Line My original analysis was **substantially correct**. The three core patches (1, 2, 3) are accurate and necessary. Patch 4 is a valid bug fix but should be presented as an **orthogonal improvement** rather than a direct consequence of fixing issue #183024. The root cause, severity assessment, and fix approach all hold up under verification. </details> Pull Request resolved: #183582 Approved by: https://github.com/sanketpurandare
PreviousNext