Skip to content

RFC: Rollout sources — a unified seam for environment & agent RL #5974

@adithya-s-k

Description

@adithya-s-k

RFC: A rollout-source seam for environment & agent RL in TRL

Scope: GRPOTrainer, AsyncGRPOTrainer (the pattern generalizes to any online trainer).
Relates to: #5912 (environment-owned reward), #5903 (environment-owned dataset).

This RFC proposes a single, narrow contract that decouples the trainers from the systems that
produce rollouts (environments, tool harnesses, prebuilt agents). It lets TRL integrate the growing
environment ecosystem through thin adapters instead of trainer-resident special cases, and it makes
multi-turn / agentic RL correct by construction.


1. Motivation

Online RL trainers in TRL keep accreting one bespoke hook per integration: reward_funcs for
reward models, environment_factory for OpenEnv-style envs, rollout_func for custom backends.
Each new system (OpenReward, Harbor, verifiers, in-house agents) adds conditional surface inside the
trainer, and each multi-turn variation re-implements the dangerous parts — tokenization across
turns, the loss mask, logprob capture.

                    today                                  proposed
   ┌───────────────────────────────┐         ┌───────────────────────────────┐
   │           GRPOTrainer          │         │           GRPOTrainer          │
   │  reward_funcs                  │         │  consumes ONE contract         │
   │  environment_factory           │  ──▶    │  (RolloutSample); never calls  │
   │  rollout_func                  │         │  an environment itself         │
   │  …trainer knows each backend   │         └───────────────────────────────┘
   └───────────────────────────────┘                       ▲ groups()
        N backends, N code paths                 ┌──────────┴──────────┐
                                                 │   RolloutSource     │  ← integrations live here
                                                 └─────────────────────┘

The proposal: invert the dependency. The trainer consumes a small per-rollout bundle and never
calls into environments. Environments and agents produce that bundle; the trainer pulls it.
"How does the trainer learn about system X" becomes the much smaller, stable question "how does X
produce the contract."

And once inverted, every integration is one of two kinds, both producing the same contract —
the only difference is who runs the loop and where tokens are captured:

                                   ┌──────────────────────────────┐
                                   │          GRPOTrainer          │
                                   │   consumes RolloutSample      │
                                   └──────────────────────────────┘
                                                  ▲ groups() → RolloutGroup
                            ┌─────────────────────┴─────────────────────┐
                            │                                           │
              WHITE-BOX (TRL drives the loop)          BLACK-BOX (agent owns the loop)
         ┌───────────────────────────────────┐   ┌───────────────────────────────────┐
         │        EnvRolloutSource           │   │       AgentRolloutSource          │
         │  reset → policy.generate → step/  │   │  launch agent → it calls the model│
         │  call_tool → append obs → repeat  │   │  capture ids+logprobs at the      │
         │  (captures ids+logprobs locally)  │   │  inference boundary (client/seam) │
         └───────────────────────────────────┘   └───────────────────────────────────┘
                     ▲ Env / ToolEnv / StepEnv          ▲ AgentRunner + Verifier
              OpenEnv-style, function-calling,     coding agents, computer-use, SWE
              gym, verifier envs                   agents, any prebuilt opaque agent

2. The contract: RolloutSample

This is the spine. The trainer's loss code touches only this type. It is the public, versioned,
stable surface; everything else is replaceable.

Status = Literal["ok", "timeout", "context_overflow", "error"]

@dataclass
class RolloutSample:
    input_ids: list[int]            # the exact token sequence: prompt + every generated turn +
                                    # every observation / tool output, in order
    completion_mask: list[int]      # same length as input_ids; 1 where the MODEL generated a token,
                                    # 0 for prompt, tool output, observations, scaffolding
    old_log_probs: list[float]      # the sampling-policy logprobs, one per generated token, aligned
                                    # 1:1 with the positions where completion_mask == 1
    reward: float | list[float] | dict[str, float]
                                    # scalar (outcome) | per-step list (process) | named components
    group_id: str                   # all G attempts at one task share this id
    model_version: int              # which policy-weights version produced this rollout
    status: Status                  # terminal status; failed rollouts must never look like
                                    # zero-reward successes
    advantage: float | None = None  # normally computed trainer-side; a producer MAY pre-fill it
    metadata: dict = field(default_factory=dict)   # turns, finish reason, tool calls — logging only

@dataclass
class RolloutGroup:
    samples: list[RolloutSample]    # the G attempts at one task (GRPO compares within the group)

Field-by-field rationale

  • input_ids / completion_mask are kept as one interleaved sequence rather than separate
    prompt/completion tensors, because multi-turn rollouts interleave model turns and tool/observation
    spans. The mask is what tells the loss which positions are trainable.
  • old_log_probs is the importance-sampling denominator for PPO/GRPO. It must be the sampling
    policy's logprob for the token actually emitted, captured at the inference engine — not recomputed
    from text.
  • reward is deliberately a union. Scalar covers outcome reward; a per-step list covers process
    reward / per-turn credit; a dict covers multi-component reward the trainer weights (preserving the
    existing reward_weights capability). The trainer normalizes whatever shape it receives.
  • group_id is what makes GRPO work: the advantage is the reward's deviation from the group mean,
    so the source must emit ≥2 attempts that share an id, or there is no variance to learn from.
  • model_version enables async / off-policy correctness: a rollout produced under stale weights can
    be importance-corrected rather than discarded.
  • status exists because real rollouts fail (timeouts, context overflow, sandbox errors). A failed
    trajectory must be masked or dropped, never fed as a zero-reward success.

Invariants enforced at construction (so a broken producer fails loudly, not silently):

assert len(completion_mask) == len(input_ids)
assert sum(completion_mask) == len(old_log_probs)
# RolloutGroup: all samples share one group_id; group is non-empty

The three correctness rules the contract protects. Every silent multi-turn-RL bug is one of:

  1. Logprobs must line up with the exact emitted tokens. Capture token ids + logprobs at the engine.
  2. The mask must be exact across turns — only sampled tokens are 1.
  3. There must be several attempts per task under a shared group_id.

3. The source interface: RolloutSource

One interface, drained two ways. Sync and async are not forked into separate APIs (that is how these
abstractions rot); they differ only in how the trainer consumes the iterator.

class RolloutSource(Protocol):
    def groups(
        self, *, policy: "Policy", group_size: int, model_version: int
    ) -> AsyncIterator[RolloutGroup]: ...
  • Sync (GRPOTrainer) drains the iterator into a list for the step.
  • Async (AsyncGRPOTrainer) reads it as a stream/queue while training proceeds, using
    model_version for staleness handling.

The Policy handle

policy is not a single object — white-box and black-box need different things from it, so it is
made explicit:

class Policy(Protocol):
    url: str        # for black-box: the served endpoint the external agent/seam calls
    version: int    # current weights version, stamped onto emitted samples
    async def generate(self, token_ids: list[int], sampling_params
                       ) -> tuple[list[int], list[float]]:   # for white-box: (token_ids, logprobs)
        ...

A white-box source calls policy.generate(...) and captures locally. A black-box source hands
policy.url to the agent and captures at the inference boundary. Both read policy.version. The
sampling params a producer uses must match what the trainer assumes, or old_log_probs are wrong.


4. White-box integrations (TRL drives generation)

A white-box harness is one whose interaction loop TRL can drive: the model produces each turn,
and the environment supplies the tools, the reward, and the tasks. Because TRL is the one calling
the inference engine, it captures token ids + logprobs locally — no proxy, no cooperation needed
from the environment. This is the default and the simpler case.

The environment implements one small protocol, in two flavors that share a base:

class Env(Protocol):                               # the base every white-box env implements
    def tasks(self, split: str = "train") -> list[Task]
    async def reset(self, task: Task) -> Prompt    # returns the initial prompt (messages or token_ids)

class ToolEnv(Env, Protocol):                      # tool-calling flavor (MCP / function-calling envs)
    async def tools(self) -> list[ToolSpec]
    async def call_tool(self, name: str, arguments: dict) -> ToolOutput   # .text, .reward, .done

class StepEnv(Env, Protocol):                      # gym flavor
    async def step(self, action: str) -> StepResult                       # .observation, .reward, .done

The built-in driver, EnvRolloutSource, owns the loop and emits the contract:

for each task, for each of G attempts (concurrently, bounded):
    prompt        = await env.reset(task)
    ids           = render(prompt)                    # chat template → token ids, once
    turns = []
    repeat until done / max_turns / context budget:
        completion_ids, logprobs = await policy.generate(ids, sampling_params)
        turns.append(Turn(prompt_ids=ids, completion_ids=completion_ids, logprobs=logprobs))
        obs, reward, done = await env.call_tool(...) | env.step(...)   # reward travels inline
        ids = encode(obs)                             # next turn's conditioning span (mask 0)
    input_ids, completion_mask, old_log_probs = build_sample_arrays(turns)
    yield RolloutSample(...)                          # step-wise; reward on the terminating turn

build_sample_arrays is the single shared mask builder — the one place that turns per-turn
(prompt_ids, completion_ids, logprobs) into the flat (input_ids, completion_mask, old_log_probs),
so masking is correct by construction everywhere.

Tasks belong to the source (aligns with #5903): the env owns its task list addressed by split;
the trainer does not thread dataset rows into reset. Reward travels with the rollout (aligns
with #5912): call_tool/step returns it inline.


5. Black-box integrations (the agent owns its loop)

A black-box harness is a prebuilt agent that owns its own interaction loop and is opaque to TRL:
it plans, calls its own tools inside its own sandbox, decides when to stop. TRL cannot drive it; the
only thing crossing the boundary is the model call going out and the completion coming back. This is
the harder, increasingly important case (coding agents, computer-use agents, SWE agents).

Because TRL is not the one generating, the white-box trick (capture locally) does not apply. The
token ids + logprobs must be captured at the inference boundary, by one of two mechanisms — both
filling the same RolloutCapture shape, so the downstream is identical:

@dataclass
class TurnCapture:
    prompt_token_ids: list[int]       # the engine's own prompt ids for this call (ground truth)
    completion_token_ids: list[int]   # the engine's own completion ids
    completion_logprobs: list[float]  # sampling logprobs, 1:1 with completion ids
    model_version: int                # stamped per call → mid-trajectory weight updates handled

@dataclass
class RolloutCapture:
    rollout_id: str
    turns: list[TurnCapture]

Mechanism A — self-report (cooperative agents). If you control the agent and it routes LLM
calls through one client, give it an OpenAI-shaped client wrapper that records token_ids +
logprobs from each response keyed by a rollout id. No proxy; the agent runs unchanged except for
which client object it was handed. Many agent frameworks already expose this (a "collect rollout
details" flag that accumulates per-turn token ids/logprobs) — the adapter just reads it back.

Mechanism B — the seam (any agent). Put an OpenAI-compatible proxy in front of the inference
server and point the agent's base URL at it. The agent runs completely unchanged, thinking it talks
to a normal OpenAI endpoint. The seam re-issues each call to the real engine with token-id + logprob
capture on, records them by rollout id, and returns the plain response. This is the general answer:
even a fully opaque binary can be trained as long as you can set its base URL.

                        ┌────────────────────┐
   cooperative agent ──▶│  CapturingClient   │──┐  (A: in-process client wrapper)
                        └────────────────────┘  │
                                                 ├──▶ RolloutCapture ──▶ AgentRolloutSource
   opaque agent      ──▶┌────────────────────┐  │                          │ Verifier.reward()
   (set base URL only)  │  Seam (HTTP proxy) │──┘                          ▼
                        └────────────────────┘                    RolloutSample (step-wise)

The agent and its scoring are described by two tiny protocols:

class AgentRunner(Protocol):
    async def run(self, task, policy_url_or_client, rollout_id) -> AgentResult   # launches the agent
class Verifier(Protocol):
    async def reward(self, task, result: AgentResult) -> float | dict[str, float]

AgentRolloutSource mints a rollout id per attempt, launches the agent (Mechanism A or B), reads
back the RolloutCapture after it exits, calls the Verifier for the reward, and emits step-wise
samples — one RolloutSample per captured turn, reward on the last.

Things specific to black-box that the design must handle:

  • On-policy is a hard prerequisite. The agent must call our served model. If it calls an
    outside provider, there is no gradient. The model is therefore served over HTTP (serve/async mode,
    not colocate), reached via a tunnel when the sandbox is remote.
  • Capture the engine's own ids. The seam/self-report records the engine's actual prompt and
    completion token ids, never a re-render of the agent's messages through our template. Re-rendering
    drifts; the engine's ids are ground truth (see §7).
  • Long rollouts go stale. A trial can run for minutes while weights advance. Stamp
    model_version per turn and importance-correct rather than discard; do not throw expensive agent
    rollouts away.
  • Reward = run the agent's own verifier. Read the harness's reward; do not re-implement the
    grader.
  • Concurrency / backpressure. A black-box source can spawn whole sandboxes; the source must honor
    an in-flight cap and apply backpressure, or the first real run melts the budget.
  • History rewriting breaks the objective, not just the implementation. Agents that compact /
    summarize / strip their own history mean "the previous turn" was never a sampled trajectory. The
    workaround: freeze everything up to the last rewrite point as prompt (loss-mask 0) and keep only
    the genuine sampled tail under gradient.

The guiding intuition: a black-box rollout is a white-box rollout where the loop and tools moved
across the network into the agent.
The contract is identical; only the capture point moved.


6. Building a new integration

The extension story is the point of the design. Two recipes, depending on whether TRL can drive the
loop.

6a. A new white-box environment

  1. Implement Env + one flavor (ToolEnv or StepEnv) over your backend. This is the only code
    you write — typically a thin wrapper:
class MyToolEnv:                                   # satisfies ToolEnv structurally
    def __init__(self, client): self._client = client
    def tasks(self, split="train"):
        return [Task(task_id=t.id, payload={"raw": t}) for t in self._client.list_tasks(split)]
    async def reset(self, task):
        return Prompt(messages=self._client.prompt_for(task.payload["raw"]))
    async def tools(self):
        return [ToolSpec(name=t.name, description=t.desc, parameters=t.schema)
                for t in self._client.tools()]
    async def call_tool(self, name, arguments):
        out = await self._client.call(name, arguments)
        return ToolOutput(text=out.text, reward=out.reward, done=out.finished)
  1. Hand it to the driver — no trainer code, no new hook:
trainer = GRPOTrainer(model=..., args=GRPOConfig(...),
                      environment=EnvRolloutSource(env=lambda: MyToolEnv(client), tokenizer=tok))

Use an env factory (lambda: MyToolEnv(...)) when each of the G attempts needs its own session /
sandbox; pass a single instance when the env is stateless. That is the entire integration. The
trainer never learns the environment's name.

A maintainer-facing convenience layer can wrap common backends so end users write one line
(environment=OpenEnv(...)), but that sugar is optional — it builds an EnvRolloutSource over an
Env underneath.

6b. A new black-box agent

  1. Implement AgentRunner (how to launch your agent against a task, given a capturing client or a
    base URL) and Verifier (how to read its reward). If your agent can be handed a client, use
    Mechanism A; if you can only set its base URL, use the seam (Mechanism B) and no agent change is
    needed.
  2. Hand them to AgentRolloutSource:
trainer = GRPOTrainer(model=..., args=GRPOConfig(...),
                      environment=AgentRolloutSource(runner=MyRunner(), verifier=MyVerifier(),
                                                     tasks=tasks))

Everything after the agent exits — capture → step-wise samples → advantage → loss — is shared.

Module structure

trl/experimental/rollout/          ← the reusable KIT (what integration authors build against)
    sample.py        RolloutSample, RolloutGroup, Status            # the contract
    source.py        RolloutSource, Policy, Env / ToolEnv / StepEnv # the protocols
    mask.py          Turn, build_sample_arrays                      # the one shared mask builder
    env_source.py    EnvRolloutSource                               # white-box driver
    agent_source.py  AgentRolloutSource, AgentRunner, Verifier      # black-box driver
    capture.py       CapturingClient, RolloutCapture, TurnCapture   # self-report capture
    seam.py          Seam (OpenAI-compatible proxy)                 # proxy capture
    vllm_policy.py   a Policy backed by a served engine
    trainer_glue.py  attach_environment(trainer, environment)       # the trainer-side bridge

The kit is integration-agnostic. Concrete integrations are thin and live outside it (or in a
convenience catalog), each one a small adapter to Env/ToolEnv/StepEnv (white-box) or
AgentRunner/Verifier (black-box).


7. Correctness: token-in, token-out

The contract exists to enforce one rule that quietly breaks most multi-turn RL loops: never
re-encode tokens you decoded.
The naive loop keeps a message list, re-renders it each turn, and
re-tokenizes the whole conversation at the end — but encode∘decode is not the identity (BPE
segmentation, JSON whitespace, special-token re-rendering all drift), so the gradient lands on a
token sequence the model never sampled.

The design avoids this structurally: the model's sampled token ids go straight into the buffer that
becomes input_ids and are never re-encoded. White-box keeps each turn's completion_ids verbatim
and only tokenizes the new observation/tool span to append it. Black-box captures the engine's own
ids at the boundary. The one property required of a chat template is that appending a tool/obs
message extends the render token-for-token (prefix-preserving); when it isn't, swap in a
training-safe template once at init. This keeps old_log_probs aligned with the tokens actually
under gradient.


8. Trainer integration

A single optional kwarg is added to the trainers; everything else is reused.

GRPOTrainer(model=..., args=GRPOConfig(...), environment=<RolloutSource>)

When environment is set, an init-time bridge converts the source into the trainer's existing
internal rollout path (the same path rollout_func already feeds) plus a synthetic reward source
that surfaces sample.reward. The existing generation→scoring→advantage→loss→logging machinery is
reused unchanged; the new trainer-side surface is the kwarg plus a small bridge call.

Backwards compatibility (additive, nothing breaks):

  • Unset environment ⇒ behavior is byte-for-byte unchanged.
  • reward_funcs becomes optional when the environment supplies reward (it is registered as a reward
    source like any other; if both are present, the env reward is authoritative and reward_funcs may
    only add shaping terms).
  • train_dataset becomes optional when the environment owns its tasks (synthesize a placeholder
    sized by max_steps) — the RFC: Agent RL training with environment: environment-owned dataset, if any #5903 direction.
  • environment_factory and rollout_func can be re-expressed as thin forms of a RolloutSource,
    so existing configs keep running. Deprecating them is a separate, later decision.

9. Open questions for discussion

  1. Advantage location. Trainer-side group normalization by default, with an optional pre-filled
    advantage for producers that already aggregate across machines? Where should the default live
    once async grouping is in play?
  2. Async drain + staleness. What is the contract for AsyncGRPOTrainer draining groups() as a
    stream, and the default staleness policy (importance-correct vs drop) keyed on model_version?
  3. Reward typing in the loss. How far do we push the dict[str, float] named-component reward
    into reward_weights / per-component logging vs. collapsing to a scalar at the source?
  4. Where integrations live. In-tree trl.experimental catalog, vs. each environment shipping its
    own TRL adapter, vs. a registry. The kit makes all three possible; which do we bless?
  5. Naming. environment= reads naturally for the 90% case (training "on an environment") even
    though black-box agents aren't environments in the strict sense; is that acceptable, or should
    the kwarg name the contract (rollout_source=)?
  6. Step-wise cost. Step-wise multi-turn re-encodes the growing prefix in the trainer's forward
    pass (≈ quadratic in turns). Do we want prefix-sharing / packing in the contract, or leave it to
    the trainer?

Feedback welcome on the contract fields, the two-protocol split, and the trainer kwarg before
implementation hardens.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions