Skip to content

RFC: Agent RL training with environment: environment-owned dataset, if any #5903

@qgallouedec

Description

@qgallouedec

Scope: GRPOTrainer + AsyncGRPOTrainer. Same pattern applies to any online trainers with environment support.
Out of scope: reward de-correlation (separate RFC).

Summary

Let the environment own the dataset, if any.
Today the trainer owns the data. The trainer samples a prompt from the dataset, resets the environment with that prompt as argument.
This RFC shifts that ownership to the env, which can hold its own corpus, self-sample, or generate procedurally.

For some environments, a dataset does not make sense (e.g., procedurally generated environments). This new design should allow this, instead of the current dummy dataset workaround.

Other data points

  • verifiers: the environment owns the dataset. But an env must have a dataset (it raises if none is provided), which I don't think is a good design choice, see below.
  • OpenEnv: most environments don't need any external dataset to sample from.

Design: before / after

Before

Image

Before: the Trainer samples from the dataset it owns and pushes the data into the env via
reset(prompt).

After

Image

After: the Trainer only calls env.reset() and the env returns the prompt (sampled
from its own dataset, if any).

Usage

Example 1, wordle-like, i.e. when a dataset would fit

For some environments, like Wordle, there is a natural notion of dataset (the list of target words).

Before

Two options, both awkward. Either the trainer owns the words and iterates over them, with each row
passed to reset as a kwarg:

dataset = load_dataset("my_words")  # rows like {"prompt": "Guess the 5-letter word.", "target": "crane"}

class WordleEnvironment:
    def reset(self, target: str) -> None:    # receives the row's columns
        self._target = target

    def guess(self, word: str) -> str:
        """Submit a guess, get per-letter feedback."""
        ...

trainer = GRPOTrainer(
    model=my_model,
    reward_funcs=my_reward_func,
    train_dataset=dataset,                   # trainer owns and iterates the words
    environment_factory=WordleEnvironment,
)

…or the environment owns the words, but you still have to fabricate a dummy dataset whose only job is
to drive the loop length:

dummy = Dataset.from_dict({"prompt": ["Guess the 5-letter word."] * 10_000})  # placeholder rows

class WordleEnvironment:
    def __init__(self):
        self.dataset = load_dataset("my_words")

    def reset(self, **kwargs) -> None:       # ignores the dummy row
        self._target = sample(self.dataset)

    def guess(self, word: str) -> str: ...

trainer = GRPOTrainer(
    model=my_model,
    reward_funcs=my_reward_func,
    train_dataset=dummy,                     # exists only to set the number of steps
    environment_factory=WordleEnvironment,
)

After

Everything task-related is owned by the environment; no dummy dataset:

class WordleEnvironment:
    def __init__(self):
        self.dataset = load_dataset("my_words")

    def reset(self) -> str:
        self._target = sample(self.dataset)
        return "Guess the 5-letter word."

    def guess(self, word: str) -> str:
        """Submit a guess, get per-letter feedback."""
        ...

trainer = GRPOTrainer(
    model=my_model,
    reward_funcs=my_reward_func,
    environment_factory=WordleEnvironment,   # owns the words
    args=GRPOConfig(max_steps=1000),         # required when no dataset
)

Example 2, chess-like, i.e. when a dataset does not fit

In chess the initial state is always the same starting position (or sampled by self-play). There is
no corpus to iterate — forcing a dataset is purely artificial.

Before

You must fabricate a dummy dataset just to drive the training loop:

dummy = Dataset.from_dict({"prompt": [""] * 10_000})  # placeholder, content ignored

class ChessEnvironment:
    def reset(self, **kwargs) -> str:
        self._board = chess.Board()          # always the standard starting position
        return str(self._board)

    def move(self, uci: str) -> str:
        """Play a move in UCI notation, get the opponent's reply."""
        ...

trainer = GRPOTrainer(
    model=my_model,
    reward_funcs=my_reward_func,
    train_dataset=dummy,                     # only there to set the number of steps
    environment_factory=ChessEnvironment,
)

After

No dataset at all; max_steps sets the training length:

class ChessEnvironment:
    def reset(self) -> str:
        self._board = chess.Board()
        return str(self._board)

    def move(self, uci: str) -> str:
        """Play a move in UCI notation, get the opponent's reply."""
        ...

trainer = GRPOTrainer(
    model=my_model,
    reward_funcs=my_reward_func,
    environment_factory=ChessEnvironment,
    args=GRPOConfig(max_steps=1000),         # no dataset → max_steps required
)

What this would imply

transformers.Trainer assumes that the trainer has a dataset to determine the number of training steps, and falls back to max_steps when the dataset doesn't have a len. If we make the dataset optional, we should add a check to ensure that max_steps is provided when a dataset isn't provided at all.

Key decisions and questions (so far, will be edited)

Group-state sharing is conventional, not required

Important

This is not inherent to the new design!

GRPO assumes that we reset the environment to the exact same state for all G members of the group. But with the current and new environment implementation, nothing enforces this.

The only way to do it right now is to have, e.g., a "seed" column in the dataset, and have the env reset to the seed it receives.

Note that technically, we could run GRPO with heterogeneous states across the G members, the baseline is only noisier, and it is not the spirit of GRPO, but it would probably work.

Recommendation: always pass a group_id to reset(), and let the env decide what to do with it. This allows group-state sharing as a convention without enforcing it. Its use would be a good practice to be documented.

Alternative (not recommended): instantiate 1 env, and try to deep-copy it G times at each reset. This is brittle and may not even be possible for some envs.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions