Fix GRPO use_liger_kernel under DeepSpeed ZeRO-3#5891
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: ff3f888d55
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes using default effort and found 1 potential issue.
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit ff3f888. Configure here.
|
@codex review |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: ff3f888d55
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
There was a problem hiding this comment.
Pull request overview
Fixes a GRPO training crash when use_liger_kernel=True is combined with DeepSpeed ZeRO-3 by explicitly all-gathering lm_head parameters for the Liger fused loss call (which accesses weights directly and bypasses ZeRO-3’s usual gather hooks). Adds a distributed regression test that runs GRPO with --use_liger_kernel across multiple distributed configs.
Changes:
- Wraps the Liger GRPO loss call in a ZeRO-3-aware
GatheredParameterscontext forlm_head.weight/bias. - Avoids gathering when parameters are already
AVAILABLE(e.g., tied embeddings). - Adds
test_grpo_ligerto the distributed suite and gates it behindrequire_liger_kernel.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
trl/trainer/grpo_trainer.py |
Adds ZeRO-3 parameter gathering around the Liger fused GRPO loss invocation. |
tests/distributed/test_distributed.py |
Adds a distributed integration test running GRPO with --use_liger_kernel across ddp/zero2/zero3/fsdp2. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
What this fixes
use_liger_kernel=True+ DeepSpeed ZeRO-3 in GRPO crashes with a size mismatch. The matrix below is what people have been hitting (see #3368):Why
In
compute_liger_losswe handunwrapped_model.lm_head.weightstraight to the Liger fused loss. Under ZeRO-3 that weight is sharded across ranks, and since Liger reads it by attribute access instead of going throughmodel.forward(), DeepSpeed's gather hook never fires. So on the partitioned ranks the fused matmul gets an empty shard and blows up.Fix
Wrap the Liger call in a
GatheredParameterscontext so the fulllm_headweight/bias are all-gathered for the matmul, then re-partitioned on exit. This is the same_maybe_gather_lm_head_ctxpattern we use for SFT's chunked-NLL path. It's a no-op when not on ZeRO-3, or when the weight is already gathered (tied embeddings keep itAVAILABLE).We only need the gathered weight during the forward:
LigerFusedLinearGRPOLosscomputes the gradient w.r.t. the weight eagerly inside the forward and saves it for backward, so re-partitioning beforebackward()is fine.Test
Added
test_grpo_ligerto the distributed suite, parametrized over ddp / zero2 / zero3 / fsdp2 with--use_liger_kernel(mirrors the existingtest_grpo).Closes #3368
Note
Medium Risk
Touches GRPO training loss under ZeRO-3 and tied-embedding edge cases; scope is narrow and covered by new distributed tests.
Overview
Fixes a crash when GRPO runs with
use_liger_kernelunder DeepSpeed ZeRO-3: Liger readslm_headweights directly, so ZeRO-3 never gathers shards and the fused matmul fails with a size mismatch.In
compute_liger_loss, the Liger call is wrapped indeepspeed.zero.GatheredParametersforlm_headweight/bias when ZeRO-3 is active and params are not alreadyAVAILABLE(e.g. tied embeddings). Otherwise behavior is unchanged.Adds distributed
test_grpo_liger(ddp / zero2 / zero3 / fsdp2) mirroringtest_grpowith--use_liger_kernel, gated byrequire_liger_kernel.Reviewed by Cursor Bugbot for commit b4d345e. Bugbot is set up for automated code reviews on this repo. Configure here.