Skip to content

AsyncGRPOTrainer: vision-language (*ForConditionalGeneration) checkpoints can't be trained (weight-sync key mismatch) #6028

@adithya-s-k

Description

@adithya-s-k

AsyncGRPOTrainer: vision-language (*ForConditionalGeneration) checkpoints can't be trained (weight-sync key mismatch)

Environment

  • trl main (trl.experimental.async_grpo)
  • transformers 5.11, vllm 0.22.0 (native 4-phase RL weight transfer; needs
    #5892 for the WeightTransferClient 0.22 path)
  • Layout: vLLM server (--data-parallel-size 2) + trainer FSDP2, weights synced over NCCL

What happens

Training a vision-language checkpoint (e.g. Qwen/Qwen3.5-4B, google/gemma-4-E4B-it) — even for text-only
RL — crashes at the first weight sync:

ValueError: There is no module or parameter named 'model' in Qwen3_5ForConditionalGeneration.
The available parameters are: {'language_model.model.layers...', 'visual.blocks...'}

Root cause

async_grpo_trainer.py loads the policy with a hardcoded:

model = AutoModelForCausalLM.from_pretrained(model, ...)

For a VL checkpoint this yields the text tower (weight keys model.*). But vLLM serves the model under
its config architecture *ForConditionalGeneration (keys language_model.model.* + visual.*). The NCCL
weight transfer broadcasts the trainer's model.* keys; vLLM's loader can't map them → ValueError. The two
sides disagree on the model's parameter namespace.

Proposed fix

Load the model class that matches what vLLM serves. When the checkpoint's config.architectures is a
conditional-generation / image-text-to-text type, load via AutoModelForImageTextToText and freeze the
vision tower (text-only RL), so the trainer's keys (language_model.model.* + visual.*) match vLLM:

archs = getattr(AutoConfig.from_pretrained(model), "architectures", None) or []
if any("ConditionalGeneration" in a or "ImageTextToText" in a for a in archs):
    model = AutoModelForImageTextToText.from_pretrained(model, ...)
    for name, p in model.named_parameters():
        if "visual" in name or "vision" in name:
            p.requires_grad = False
else:
    model = AutoModelForCausalLM.from_pretrained(model, ...)

Text models are unaffected.

Two FSDP2 caveats for VL models (upstream of TRL, in accelerate — noted for repro)

  • TRANSFORMER_BASED_WRAP auto-detection fails (it looks for the vision block class, which isn't
    instantiated for causal-LM use), so fsdp_transformer_layer_cls_to_wrap must be pinned to the text decoder
    layer (e.g. Qwen3_5DecoderLayer, Gemma4TextDecoderLayer).
  • fsdp_cpu_ram_efficient_loading: true crashes in
    accelerate/utils/fsdp_utils.py::fsdp2_load_full_state_dict with
    AttributeError: 'Tensor' object has no attribute 'device_mesh' — that loader assumes every param is a
    sharded DTensor, but the VL model leaves a param (vision tower) as a plain Tensor. Workaround:
    fsdp_cpu_ram_efficient_loading: false. (Likely worth a separate accelerate issue.)

Validation

With this fix (plus #5892 and the env-reward fix filed separately), async GRPO + a tool-using (Harbor) task
suite trained end-to-end on Qwen/Qwen3.5-4B (weight sync passes, 2 steps, clean exit). google/gemma-4-E4B-it
gets past the weight-key issue and is being validated with cpu_ram_efficient_loading: false.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions