-
Notifications
You must be signed in to change notification settings - Fork 7k
Description
Description
Summary
RLlib's new API stack (enable_rl_module_and_learner=True) currently lacks native support for gymnasium.spaces.Dict observation spaces with automatic encoder routing. When using a Dict obs space combining image and vector data, RLlib either flattens everything into a single MLP input or raises errors about missing encoder configs.
Current Behavior (Ray 2.52.1)
With the new API stack enabled and a Dict observation space:
self.observation_space = spaces.Dict({
"pixels": spaces.Box(low=0.0, high=1.0, shape=(84, 84, 4), dtype=np.float32),
"features": spaces.Box(low=-1.0, high=1.0, shape=(9,), dtype=np.float32),
})RLlib does not automatically:
- Route
pixels(3D Box) through a CNN encoder - Route
features(1D Box) through an MLP encoder - Concatenate the encoded representations before policy/value heads
Instead, users must implement a custom RLModule class to handle this manually, which requires significant boilerplate and deep knowledge of RLlib internals.
Desired Behavior
Automatic encoder routing based on observation space structure, similar to Stable-Baselines3's CombinedExtractor:
"You can use environments with dictionary observation spaces. This is useful in the case where one can't directly concatenate observations such as an image from a camera combined with a vector of servo sensor data."
— SB3 Documentation
Expected behavior with Dict obs:
- Detect observation space type (Dict with mixed sub-spaces)
- Automatically select appropriate encoder per key:
- 3D Box (image-like) → CNN encoder
- 1D Box (vector) → MLP encoder
- Nested Dict → recursive handling
- Concatenate encoded features
- Feed combined representation to policy/value heads
Affected Algorithms
This feature would benefit all algorithms using the new API stack:
- PPO (my primary use case)
- DreamerV3 (especially relevant for vision+proprioception)
- SAC, IMPALA, APPO, etc.
Current Workaround
Users must implement a custom RLModule (see action_masking_rlm.py and tiny_atari_cnn_rlm.py for patterns), manually extracting Dict keys and routing through separate encoders.
Environment
- Ray version: 2.52.1
- PyTorch: 2.9.1+cu130
- Python: 3.11
- OS: Windows 11
Use case
Primary Use Case: Game AI with Vision + State
Training an RL agent for a 2D platformer game (SuperTux) where optimal control requires:
- Visual input (
pixels): 84x84x4 stacked grayscale frames for spatial/temporal awareness - Numeric state (
features): 9D vector with player position, velocity, coin count, etc.
Neither input alone is sufficient:
- Pixels alone: Agent lacks precise velocity/position data
- Features alone: Agent can't see enemies, platforms, or level geometry
This is a common pattern in robotics (camera + joint encoders), autonomous driving (LIDAR + vehicle state), and games (screen + game memory).
Why Not Flatten?
Flattening 84x84x4 pixels to a 28,224-dim vector and concatenating with 9 features:
- Loses spatial structure that CNNs exploit
- Massively inefficient (huge MLP input layer)
- Poor sample efficiency compared to proper CNN encoding
Why Not Use Old API?
The old ModelV2 API with model_config={"use_lstm": False, ...} is deprecated. The new RLModule API is the future, but currently lacks this ergonomic feature.
Proposed API (suggestion)
config = (
PPOConfig()
.api_stack(enable_rl_module_and_learner=True, ...)
.environment(env="MyDictObsEnv")
# Automatic encoder routing based on obs space structure
.training(
model_config={
"encoder_config": {
"pixels": {"type": "cnn", "conv_filters": [[32, 8, 4], [64, 4, 2], [64, 3, 1]]},
"features": {"type": "mlp", "hiddens": [64, 64]},
},
"post_concat_hiddens": [256], # after concatenation
}
)
)Or even simpler with auto-detection:
.training(model_config={"auto_encoder_for_dict_obs": True})