TensorDict is a batched, nested dict[str, Tensor] that behaves like a tensor.
Move it, slice it, reshape it, stack it, save it, compile it, or do arithmetic on
it: every tensor leaf follows the same operation, and one shared batch_size
keeps the structure honest.
TensorDict(batch_size=[32])
|-- obs: Tensor[32, 128]
|-- action: Tensor[32]
|-- reward: Tensor[32]
`-- next:
`-- obs: Tensor[32, 128]
30-second demo | Why TensorDict | What is new in 0.13 | Patterns | Installation | Ecosystem | Citation
import torch
from tensordict import TensorDict
batch = TensorDict(
{
"obs": torch.randn(32, 128),
"action": torch.randint(0, 4, (32,)),
"reward": torch.randn(32),
"next": {"obs": torch.randn(32, 128)},
},
batch_size=[32],
)
mini = batch[:8] # slices every leaf
device = "cuda" if torch.cuda.is_available() else "cpu"
on_device = batch.to(device) # moves every leaf; non-blocking internally
scaled = batch * 0.5 # arithmetic on the whole structure
merged = batch + batch # leaf-wise TensorDict arithmetic
stacked = torch.stack([batch, batch], 0)
print(mini.shape) # torch.Size([8])
print(stacked.shape) # torch.Size([2, 32])The object remains a mapping, but the batch acts like a tensor. That is the point: write the operation once, apply it to every tensor that belongs to the same example, rollout, batch, parameter set, or dataset shard.
Plain dictionaries are flexible. TensorDict keeps that flexibility and adds the parts tensor programs need once the code gets serious.
With a plain dict |
With TensorDict |
|---|---|
| Manually keep leading dimensions aligned | One batch_size validates the structure |
Repeat .to(device) for every tensor |
td.to(device) moves the full batch |
| Hand-roll slicing, stacking, reshaping | td[:32], torch.stack, td.reshape |
| Manually recurse through nested state | Nested keys are first-class |
| Duplicate arithmetic over leaves | td + td, td * scalar, td.abs() |
| Invent checkpoint formats | td.save, td.memmap, load_memmap |
| Hope generic code keeps working | PyTorch-native APIs, torch.compile coverage |
Use TensorDict when the unit of data is not one tensor anymore, but it should still move through your program like one tensor.
TensorDict is not just syntax for recursive Python loops. Core paths are built for high-throughput PyTorch workloads:
- Arithmetic dispatch: operations such as
td + td,td * 0.5,td.abs()and in-place variants apply directly to leaves and use PyTorch foreach kernels where available. - Device and host transfers: D2H and H2D copies are dispatched across the
full structure. TensorDict uses non-blocking leaf transfers internally when
possible, so the common path is just
td.to(device); passnon_blocking=Falseonly when you need an explicitly synchronous transfer. - Shape operations without boilerplate: indexing,
view,reshape,permute,unsqueeze,squeeze,flatten,unflatten,stackandcatoperate on the batch structure rather than on hand-maintained lists of leaves. - Low-allocation workflows: lazy stacks, preallocation, memory mapping and
inplace=Trueshape-changing operations help reduce peak memory in data-heavy pipelines. - Compile-aware internals: TensorDict is used in compiled training and RL
loops, and the codebase carries dedicated
torch.compilecoverage for hot paths.
For deeper numbers, see the benchmark notes.
TensorDict 0.13 focuses on making structured tensor programs more practical in large training systems:
- Tabular import/export for pandas, CSV, Parquet and JSON workflows.
- More
inplace=Trueshape operations, includinggather,repeat,repeat_interleave,roll,reshape,flatten,unflattenandcontiguous. - Improved
torch.compilebehavior for TensorClass initialization, dynamic-shape export, locking paths and shallow clones. - Safer memmap filenames by default through robust key encoding.
- A migration path for module state preservation with
to_module(..., preserve_module_state=...). - CPU-only release wheels for TensorDict, avoiding duplicate GPU wheel artifacts for a package whose compiled extension is device-independent.
TensorDict lets datasets, models and losses agree on one container instead of a long argument list.
for batch in dataloader:
batch = batch.to(device)
batch = model(batch)
loss = loss_module(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()That loop can stay stable while the schema changes from classification to segmentation, RL rollouts, model-based prediction or LLM post-training batches.
td = TensorDict(
{
"agents": {
"policy": torch.randn(64, 8),
"value": torch.randn(64, 1),
},
"env": {
"reward": torch.randn(64),
"done": torch.zeros(64, dtype=torch.bool),
},
},
batch_size=[64],
)
policy = td["agents", "policy"]
td["env", "reward"] = td["env", "reward"].clip(-1, 1)Nested keys are part of the API, not an afterthought.
TensorDict can hold module parameters, swap them into modules, vectorize over ensembles and make model state explicit.
from tensordict import TensorDict
params = TensorDict.from_module(module)
with params.to_module(module, preserve_module_state=True):
out = module(inputs)This is the same foundation used by TorchRL modules and functional training utilities.
td = TensorDict({"tokens": tokens, "scores": scores}, batch_size=[n])
td.memmap("/tmp/batch") # memory-map every leaf
reloaded = TensorDict.load_memmap("/tmp/batch")Memory-mapped TensorDicts are useful for large offline datasets, replay buffers, inter-process handoff and checkpointed intermediate state.
- Tensor-like collection ops: indexing, slicing, device casting, dtype casting, reshaping, stacking and concatenation. [tutorial]
- Nested structures with tuple keys and predictable batch semantics. [tutorial]
- Fast memory workflows: asynchronous transfers, memmap, consolidated tensors, lazy stacks and preallocation. [tutorial]
- Functional programming with parameter TensorDicts,
to_moduleand compatibility withtorch.vmap. [tutorial] @tensorclass: a tensor-aware dataclass for structured tensor objects. [tutorial]- Distributed and multiprocessed pipelines across workers, devices and machines. [doc]
- Serialization and memory mapping for efficient checkpointing and dataset storage. [doc]
For a longer tour, start with GETTING_STARTED.md or the online documentation.
With pip:
pip install tensordictWith conda:
conda install -c conda-forge tensordictNightly builds:
pip install tensordict-nightlyFrom source with an existing PyTorch install:
pip install -e . --no-depsIf you use uv with PyTorch nightlies, keep torch pinned to the PyTorch wheel
index or install TensorDict with --no-deps so the resolver does not replace
your existing PyTorch build:
uv pip install -e . --no-deps
uv pip install -e . --prerelease=allow -f "https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html"TensorDict started in reinforcement learning, where batches quickly become nested trajectories. It is now used anywhere tensor batches are structured data: RL rollouts, LLM post-training samples, robotics trajectories, simulation state, model parameters, checkpointed datasets and scientific pipelines.
| Domain | Projects |
|---|---|
| Reinforcement Learning | TorchRL (PyTorch), DreamerV3-torch, Dreamer4, SkyRL |
| LLM Post-Training | verl, ROLL (Alibaba), LMFlow, LoongFlow (Baidu) |
| Robotics and Simulation | MuJoCo Playground (Google DeepMind), ProtoMotions (NVIDIA), holosoma (Amazon) |
| Physics and Scientific ML | PhysicsNeMo (NVIDIA) |
| Genomics | Medaka (Oxford Nanopore) |
If you use TensorDict, please cite the TorchRL paper:
@misc{bou2023torchrl,
title={TorchRL: A data-driven decision-making library for PyTorch},
author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
year={2023},
eprint={2306.00577},
archivePrefix={arXiv},
primaryClass={cs.LG}
}TensorDict is licensed under the MIT License. See LICENSE for details.