Skip to content

Fix GRPO use_liger_kernel under DeepSpeed ZeRO-3#5891

Merged
kashif merged 3 commits into
mainfrom
fix-grpo-liger-zero3
Jun 2, 2026
Merged

Fix GRPO use_liger_kernel under DeepSpeed ZeRO-3#5891
kashif merged 3 commits into
mainfrom
fix-grpo-liger-zero3

Conversation

@kashif

@kashif kashif commented May 30, 2026

Copy link
Copy Markdown
Collaborator

What this fixes

use_liger_kernel=True + DeepSpeed ZeRO-3 in GRPO crashes with a size mismatch. The matrix below is what people have been hitting (see #3368):

  • ✅ liger + zero2
  • ✅ no-liger + zero3
  • ❌ liger + zero3 → size mismatch

Why

In compute_liger_loss we hand unwrapped_model.lm_head.weight straight to the Liger fused loss. Under ZeRO-3 that weight is sharded across ranks, and since Liger reads it by attribute access instead of going through model.forward(), DeepSpeed's gather hook never fires. So on the partitioned ranks the fused matmul gets an empty shard and blows up.

Fix

Wrap the Liger call in a GatheredParameters context so the full lm_head weight/bias are all-gathered for the matmul, then re-partitioned on exit. This is the same _maybe_gather_lm_head_ctx pattern we use for SFT's chunked-NLL path. It's a no-op when not on ZeRO-3, or when the weight is already gathered (tied embeddings keep it AVAILABLE).

We only need the gathered weight during the forward: LigerFusedLinearGRPOLoss computes the gradient w.r.t. the weight eagerly inside the forward and saves it for backward, so re-partitioning before backward() is fine.

Test

Added test_grpo_liger to the distributed suite, parametrized over ddp / zero2 / zero3 / fsdp2 with --use_liger_kernel (mirrors the existing test_grpo).

Closes #3368


Note

Medium Risk
Touches GRPO training loss under ZeRO-3 and tied-embedding edge cases; scope is narrow and covered by new distributed tests.

Overview
Fixes a crash when GRPO runs with use_liger_kernel under DeepSpeed ZeRO-3: Liger reads lm_head weights directly, so ZeRO-3 never gathers shards and the fused matmul fails with a size mismatch.

In compute_liger_loss, the Liger call is wrapped in deepspeed.zero.GatheredParameters for lm_head weight/bias when ZeRO-3 is active and params are not already AVAILABLE (e.g. tied embeddings). Otherwise behavior is unchanged.

Adds distributed test_grpo_liger (ddp / zero2 / zero3 / fsdp2) mirroring test_grpo with --use_liger_kernel, gated by require_liger_kernel.

Reviewed by Cursor Bugbot for commit b4d345e. Bugbot is set up for automated code reviews on this repo. Configure here.

@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: ff3f888d55

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread tests/distributed/test_distributed.py

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes using default effort and found 1 potential issue.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit ff3f888. Configure here.

Comment thread trl/trainer/grpo_trainer.py Outdated
@kashif

kashif commented May 30, 2026

Copy link
Copy Markdown
Collaborator Author

@codex review

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: ff3f888d55

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread tests/distributed/test_distributed.py

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes a GRPO training crash when use_liger_kernel=True is combined with DeepSpeed ZeRO-3 by explicitly all-gathering lm_head parameters for the Liger fused loss call (which accesses weights directly and bypasses ZeRO-3’s usual gather hooks). Adds a distributed regression test that runs GRPO with --use_liger_kernel across multiple distributed configs.

Changes:

  • Wraps the Liger GRPO loss call in a ZeRO-3-aware GatheredParameters context for lm_head.weight/bias.
  • Avoids gathering when parameters are already AVAILABLE (e.g., tied embeddings).
  • Adds test_grpo_liger to the distributed suite and gates it behind require_liger_kernel.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
trl/trainer/grpo_trainer.py Adds ZeRO-3 parameter gathering around the Liger fused GRPO loss invocation.
tests/distributed/test_distributed.py Adds a distributed integration test running GRPO with --use_liger_kernel across ddp/zero2/zero3/fsdp2.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread trl/trainer/grpo_trainer.py

@qgallouedec qgallouedec left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

@kashif kashif merged commit f298ae3 into main Jun 2, 2026
13 checks passed
@kashif kashif deleted the fix-grpo-liger-zero3 branch June 2, 2026 09:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

GPRO: use_liger_loss + zero3 error

4 participants