Releases: pytorch/xla
PyTorch/XLA 2.8 release
Highlights
-
Broader Platform Support: Build support has been added for Python 3.12 and 3.13.
-
New Quantization Features: Introduced support for weight-8-bit, activation-8-bit (w8a8) quantized matrix multiplication, including both the kernel and the Torch/XLA wrapper.
-
Torchax Enhancements: The torchax library has been significantly improved with new features like torchax.amp.autocast, support for bicubic and bilinear resampling, and better interoperability with Flax.
-
Distributed Computing: Added support for the torch.distributed.scatter collective operation and fixed logic for all_gather_into_tensor.
Bug Fixes
- Fixed an issue in EmbeddingDenseBackward by removing an unnecessary cast of padding_idx to double.
- Corrected all_gather_into_tensor logic.
- Resolved an issue where Allgather would incorrectly check tuple shapes.
- Fixed unspecified behavior in custom calls.
- Addressed a bug in the ragged paged attention kernel by applying a KV mask to filter out NaNs.
- Corrected the reloading of sharded optimizer parameter groups and master weights.
- Fixed an issue where NoneRemover was not returning the modified list/tuple.
Deprecations
- A warning will now be surfaced upon initialization if the deprecated XLA:CUDA device is used. Nightly CUDA builds have been removed.
- The devkind field from xla_model.xla_device is deprecated.
- ShapeOfXlaOp is deprecated in favor of GetShape.
What's Changed
- [torchax] Fix functional.max_pool by @dvhg in #8814
- Support unused arguments in xb.call_jax by @tengyifei in #8830
- Remove xla_mark_sharding_dynamo_custom_op by @tengyifei in #8823
- Adapt Neuron data type tests by @rpsilva-aws in #8835
- Revamp SPMD guide by @tengyifei in #8807
- Add build triggers for 2.7-rc1 by @zpcore in #8828
- Fix for triton pin update. by @ysiraichi in #8825
- Expose mark_sharding_with_gradients as a public API. by @iwknow in #8826
- Add info related to codegen by @pgmoka in #8834
- Add information about codegen in codegen/xla_native_functions.yaml by @pgmoka in #8817
- install jax dependency by default to unblock pytorch ci by @zpcore in #8845
- Make cxx_abi enable by default by @zpcore in #8844
- fix existing bazel link removal by @zpcore in #8848
- fix dependency by @zpcore in #8852
- [ragged-paged-attn] Use hidden states in kv cache and support any num_kv_head by @bythew3i in #8851
- Increase TPU CI node count from 2 to 32 by @tengyifei in #8857
- Fix missing min_tpu_nodes argument by @tengyifei in #8863
- Avoid destorying the cluster by @tengyifei in #8864
- fix CPP test build failure in upstream by @zpcore in #8867
- Support collective matmul optimization in mp by @yaochengji in #8855
- Update README.md with torch_xla-2.8.0 wheels by @bhavya01 in #8870
- Add lowering for
bitwise_left_shift
. by @ysiraichi in #8865 - 8727 create a site map or centralize links in readme by @pgmoka in #8843
- Correct type conversions in OpBuilder and add corresponding test. by @iwknow in #8873
- Add lowering for
bitwise_right_shift
. by @ysiraichi in #8866 - Cache HLO in xb.call_jax and support non-tensor args by @tengyifei in #8878
- Increase CPU node count in tpu-ci by @tengyifei in #8885
- update Dockerfile config for CI by @zpcore in #8886
- Avoid unnecessary copy in TensorSource by @lsy323 in #8849
- Encapsulate Mesh invariants by @rpsilva-aws in #8882
- [ragged-paged-attn] Combine k_pages and v_pages on num_kv_head by @bythew3i in #8892
- 8740 add single processing to getting started instructions by @pgmoka in #8875
- Update cudnn dependency for newer cuda versions by @zpcore in #8894
- Remove dynamic grid by @bythew3i in #8896
- [Build] Pin cmake to build PyTorch by @lsy323 in #8902
- Use shard_as in scan to ensure that inputs and their gradients have the same sharding by @tengyifei in #8879
- point build to rc3 by @zpcore in #8907
- Add Pallas tutorials to the Pallas guide at torch_xla. by @vanbasten23 in #8904
- Fix fp8 test by @lsy323 in #8909
- Lower
isneginf()
. by @ysiraichi in #8912 - Fix cuda dependency not found by @zpcore in #8903
- [torchax] Support linalg.det and linalg.lu_solve by @dvhg in #8872
- Use f32 scratch for output so we only need to transfer output with desired dtype back to HBM. by @vanbasten23 in #8924
- Clean up deprecated APIs by @zpcore in #8927
- Add heuristic default block sizes for different cases in ragged attention kernel by @yaochengji in #8922
- [ragged-paged-attn] Unify kv strided load to one. by @bythew3i in #8929
- pin update by @lsy323 in #8908
- Include torchax in torch_xla by @tengyifei in #8895
- Fix up pin update by @tengyifei in #8935
- Rewrite scan-based GRU based on nn.GRU by @iwknow in #8914
- Update and rename torch_xla2.yml to torchax.yml by @tengyifei in #8936
- Trigger build for 2.6.1 patch by @zpcore in #8938
- Add alternative dynamo backend by @qihqi in #8893
- Add block size table for ragged_paged_attention by @yaochengji in #8942
- add rc4 trigger by @zpcore in #8941
- Fix ragged_paged_attention op signature by @yaochengji in #8943
- Use pages_per_seq * page_size instead of directly passing max_model_len by @yaochengji in #8950
- Pin update to 20250406 by @tengyifei in #8945
- Update to pytorch 2.6 by @qihqi in #8944
- Add test to ensure scan-based and the standard GRU are interchangeable. by @iwknow in #8949
- Remove deprecated typing module in ansible flow by @tengyifei in #8955
- [call_jax] support returning PyTree from the JAX function by @tengyifei in #8957
- Make scan-based GRU support
batch_first
parameter. by @iwknow in #8964 - Adapt Splash Attention from TorchPrime by @zpcore in #8911
- Add tuned parameters for Qwen/Qwen2.5-32B by @yarongmu-google in #8966
- Disable one splash attention test by @zpcore in #8970
- Avoid re-computing computation hashes by @rpsilva-aws in #8976
- @assume_pure by @tengyifei in #8962
- Update to ubuntu latest by @qihqi in #8979
- Fix call_jax hashing by @zpcore in #8981
- Fix splash attention test by @zpcore in #8978
- Add a helper class to handle mesh and sharding by @qihqi in #8967
- Add an option for JittableModule to dedup parameters. by @qihqi in #8965
- Fix merge conflict by @zpcore in #8991
- Set scoped vmem for paged attention by @zpcore in #8988
- Fix the source param sharding for GradAcc API by @rpsilva-aws in #8999
- scan-based GRU falls back to nn.GRU when
bidirectional
is true. by @iwknow in #8984 - typo fix by @haifeng-jin in #8990
- Default explicit donation for step barriers by @rpsilva-aws in #8982
- update to r2.7 rc5 by @zpcore in #9001
- Replace upstream GRU implementation with scan-based GRU by @iwknow in #9010
- Guide to debugging in PyTorch by @yaoshiang in #8998
- Remove config warning log from GradAcc API by @rpsilva-aws in #9006
- fix libtpu path by @zpcore in #9008
- composibility of assume_pure and call_jax by @qihqi in #8989
- handle requires_grad in torchax by @qihqi in #8992
- Typo ...
PyTorch/XLA 2.7 release
Highlights
- Easier training on Cloud TPUs with TorchPrime
- A new Pallas-based kernel for ragged paged attention, enabling further optimizations on vLLM TPU (#8791)
- Usability improvements
- Experimental JAX interoperability with JAX operations (#8781, #8789, #8830, #8878)
- re-enabled GPU CI build [#8593]
Stable Features
- Operator Lowering
- Support splitting physical axis in SPMD mesh (#8698)
- Support of placeholder tensor (#8785).
- Dynamo/AOTAutograd traceable flash attention(#8654)
- C++11 ABI build is the default
Experimental Features
- Gated Recurrent Unit (GRU) implemented with scan (#8777)
- Introduce
apply_xla_patch_to_nn_linear
to improveeinsum
performance (#8793) - Enable default buffer donation for step barriers (#8721, #8982)
Usability
- Better profiling control: the start and the end of the profiling session can be controlled by the new profiler API (#8743)
- API to query number of cached compilation graphs (#8822)
- Enhancement on host-to-device transfer (#8849)
Bug fixes
- fix a bug in tensor.flatten (#8680)
- cummax: fix 0-sized dimension reduction. (#8653)
- Fix dk/dv autograd error on TPU flash attention (#8685)
- Fix a bug in flash attention where kv_seq_len should divide block_k_major. (#8671)
- [scan] Make sure inputs into fn are not device_data IR nodes(#8769)
Libtpu stable version
- Pin 2.7 release to stable libtpu version '0.0.11.1'
Deprecations
- Deprecate
torch.export
and instead, use torchax to export graph to StableHLO for full dynamism support - Remove
torch_xla.core.xla_model.xrt_world_size
, replace withtorch_xla.runtime.world_size
- Remove
torch_xla.core.xla_model.get_ordinal
, replace withtorch_xla.runtime.global_ordinal
- Remove
torch_xla.core.xla_model.parse_xla_device
, replace with_utils.parse_xla_device
- Remove
torch_xla.experimental.compile
, replace withtorch_xla.compile
PyTorch/XLA 2.6 release
Highlights
Kernel improvements for vLLM: Multi-Queries Paged Attention Pallas Kernel
- Added the multi-queries paged attention pallas kernel (#8328). Unlocks opportunities in vLLM such as prefix caching.
- Perf improvement: only write to HBM at the last iteration (#8393)
Experimental scan operator (#7901)
Previously when you loop over many nn.Module
s of the same structure in PyTorch/XLA, the loop will be unrolled during graph tracing, leading to giant computation graphs. This unrolling results in long compilation times, up to an hour for large language modules with many decoder layers. In this release we offer an experimental API to reduce compilation times called "scan", which mirrors the jax.lax.scan
transform in JAX. When you replace a Python for loop with scan, instead of compiling every iteration individually, only the first iteration will be compiled, and the compiled HLO is reused for all subsequent iterations. Building upon torch_xla.experimental.scan
, torch_xla.experimental.scan_layers
offers a convenient interface for looping over a sequence of nn.Module
s without unrolling.
Documentation: https://pytorch.org/xla/release/r2.6/features/scan.html
C++11 ABI builds
Starting from Pytorch/XLA 2.6, we'll provide wheels and docker images built with two C++ ABI flavors: C++11 and pre-C++11. Pre-C++11 is the default to align with PyTorch upstream, but C++11 ABI wheels and docker images have better lazy tensor tracing performance.
To install C++11 ABI flavored 2.6 wheels (Python 3.10 example):
pip install torch==2.6.0+cpu.cxx11.abi \
https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0%2Bcxx11-cp310-cp310-manylinux_2_28_x86_64.whl \
'torch_xla[tpu]' \
-f https://storage.googleapis.com/libtpu-releases/index.html \
-f https://storage.googleapis.com/libtpu-wheels/index.html \
-f https://download.pytorch.org/whl/torch
The above command works for Python 3.10. We additionally have Python 3.9 and 3.11 wheels:
- 3.9: https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0%2Bcxx11-cp39-cp39-manylinux_2_28_x86_64.whl
- 3.10: https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0%2Bcxx11-cp310-cp310-manylinux_2_28_x86_64.whl
- 3.11: https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0%2Bcxx11-cp311-cp311-manylinux_2_28_x86_64.whl
To access C++11 ABI flavored docker image:
us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11
If your model is tracing bound (e.g. you see that the host CPU is busy tracing the model while TPUs are idle), switching to the C++11 ABI wheels/docker images can improve performance. Mixtral 8x7B benchmarking results on v5p-256, global batch size 1024:
- Pre-C++11 ABI MFU: 33%
- C++ ABI MFU: 39%
GPU builds are temporarily skipped in 2.6
We do not offer a PyTorch/XLA:GPU wheel in the PyTorch/XLA 2.6 release. We understand this is important and plan to reinstate GPU support by the 2.7 release. PyTorch/XLA remains an open-source project and we welcome contributions from the community to help maintain and improve the project. To contribute, please start with the contributors guide.
The newest stable version where PyTorch/XLA:GPU wheel is available is torch_xla 2.5.
Stable Features
Stable libtpu releases
Starting from PyTorch/XLA 2.6, TPU backend support will be provided by a stable libtpu
Python package. That means we'll expect less TPU-specific bugs and improved test coverage overall. The libtpu-nightly
Python package will be pinned to a special empty version to avoid conflicts. As long as you use our PyTorch/XLA docker images or follow the latest installation instructions in the README.md, there are no actions needed on your part and the right dependencies will be installed.
GSPMD
- [LoweringContext] Support an optimized parameter mapping for SPMD (#8460)
- [LoweringContext] SPMD propagation #8471: this ensures that the computation has the respective sharding specs deduced from the inputs (scoped to the creation of the parameters), and to propagate the input shardings to the output.
AMP
- Add autocast support for einsum #8420
- Add autocast support for XlaPatchedLinear #8421
- Support S32/U32 indices for BWD embedding & Neuron implicit downcast #8462
Bug fixes
- Getting "undefined symbol: _ZN5torch4lazy13MetricFnValueB5cxx11E" with torch-xla nightly wheel for 2.6 #8406
Experimental Features
Support for host offloading (#8350, #8477)
When doing reverse-mode automatic differentiation, many tensors are saved during the forward pass to be used to compute the gradient during the backward pass. Previously you could use torch_xla.utils.checkpoint
to discard tensors that's easy to recompute later, called "checkpointing" or "rematerialization". Now PyTorch/XLA also supports a technique called "host offloading", i.e. moving the tensor to host and moving them back, adding another tool in the arsenal to save memory. Use torch_xla.experimental.stablehlo_custom_call.place_to_host
to move a tensor to host and torch_xla.experimental.stablehlo_custom_call.place_to_device
to move a tensor back to the device. For example, you can use this to move intermediate activations to host during a forward pass, and move those activations back to device during the corresponding backward pass.
Because the XLA graph compiler aggressively reorders operations, host offloading is best used in combination with scan
.
Updates to Flash Attention kernels
Support SegmentID in FlashAttention when doing data parallel SPMD #8425
Deprecations
See Backward Compatibility proposal.
APIs that will be removed in 2.7 release:
- Deprecate APIs (deprecated → new):
xla_model.xrt_world_size()
→runtime.world_size()
[#7679][#7743]xla_model.get_ordinal()
→runtime.global_ordinal()
[#7679]xla_model.get_local_ordinal()
→runtime.global_ordinal()
[#7679]
- Internalize APIs
xla_model.parse_xla_device()
[#7675]
- Improvement
- Automatic PJRT device detection when importing
torch_xla
[#7787]
- Automatic PJRT device detection when importing
- Add deprecated decorator [#7703]
APIs that will be removed in 2.8 release:
- The
XLA_USE_BF16
environment variable is deprecated. Please convert your model to bf16 directly: [#8474]
PyTorch/XLA 2.5.1: Readme update Release
PyTorch/XLA 2.5.1 fixes the torch_xla[tpu]
PyPi README instructions, aligns with the PyTorch 2.5.1 hot fix release; No new feature is added between PyTorch/XLA 2.5.0 and 2.5.1.
PyTorch/XLA 2.5 Release
Cloud TPUs now support the Pytorch 2.5 release, via PyTorch/XLA integration. On top of the underlying improvements and bug fixes in the PyTorch 2.5 release, this release introduces several features, and PyTorch/XLA specific bug fixes.
Highlights
We are excited to announce the release of PyTorch XLA 2.5! PyTorch 2.5 supports torch_xla.compile
function which improves the debugging experience for developers during the development process, and aligns distributed APIs with upstream PyTorch with the traceable collective support for both Dynamo and non-Dynamo cases. Start from PyTorch/XLA 2.5, proposed a clarified vision for deprecation of the older torch_xla API in favor of moving towards the existing PyTorch API, providing for a simplified developer experience.
If you’ve used vLLM for serving models on GPUs, you’ll now be able to seamlessly switch to its TPU backend. vLLM is a widely adopted inference framework that also serves as an excellent way to drive accelerator interoperability. With vLLM on TPU, users will retain the same vLLM interface we’ve grown to love, with direct integration with Hugging Face Models to make model experimentation easy.
STABLE FEATURES
Eager
- Increase max in flight operation to accommodate eager mode [#7263]
- Unify the logics to check eager mode [#7709]
- Update
eager.md
[#7710] - Optimize execution for ops that have multiple output in eager mode [#7680]
Quantization / Low Precision
- Asymmetric quantized
matmul
support [#7626] - Add blockwise quantized dot support [#7605]
- Support
int4
weight in quantized matmul / linear [#7235] - Support
fp8e5m2 dtype
[#7740] - Add
fp8e4m3fn
support [#7842] - Support dynamic activation quant for per-channel quantized matmul [#7867]
- Enable cross entropy loss for xla autocast with FP32 precision [#8094]
Pallas Kernels
- Support ab for
flash_attention
[#7840], actual kernel is implemented in JAX - Support
logits_soft_cap
parameter inpaged_attention
[#7704], actual kernel is implemented in JAX - Support
gmm
andtgmm trace_pallas
caching [#7921] - Cache flash attention tracing [#8026]
- Improve the user guide [#7625]
- Update pallas doc with
paged_attention
[#7591]
StableHLO
- Add user guide for stablehlo composite op [#7826]
gSPMD
- Handle the parameter wrapping for SPMD [#7604]
- Add helper function to get 1d mesh [#7577]
- Support manual
all-reduce
[#7576] - Expose
apply_backward_optimization_barrier
[#7477] - Support reduce-scatter in manual sharding [#7231]
- Allow
MpDeviceLoader
to shard dictionaries of tensor [#8202]
Dynamo
- Optimize dynamo dynamic shape caching [#7726]
- Add support for dynamic shape in dynamo [#7676]
- In dynamo optim_mode avoid unnecessary set_attr [#7915]
- Fix the crash with copy op in dynamo [#7902]
- Optimize
_split_xla_args_tensor_sym_constant
[#7900] - DYNAMO RNG seed update optimization [#7884]
- Support
mark_dynamic
[#7812] - Support gmm as a custom op for dynamo [#7672]
- Fix dynamo inplace copy [#7933]
- CPU time optimization for
GraphInputMatcher
[#7895]
PJRT
- Improve device auto-detection [#7787]
- Move _xla_register_custom_call_target implementation into PjRtComputationClient [#7801]
- Handle SPMD case inside of ComputationClient::WaitDeviceOps [#7796]
GKE
Functionalization
- Add 1-layer gradient accumulation test to check aliasing [#7692]
AMP
- Fix norm data-type when using AMP [#7878]
BETA FEATURES
Op Lowering
- Lower
aten::_linalg_eigh
[#7674] - Fallback
_embedding_bag_backward
and forcesparse=false
[#7584] - Support trilinear by using upstream decomp [#7586]
Higher order ops
- [Fori_loop] Update randint max range to Support bool dtype [#7632]
TorchBench Integration
- [benchmarks] API alignment with PyTorch profiler events [#7930]
- [benchmarks] Add IR dump option when run torchbench [#7927]
- [benchmarks] Use same
matmul
precision between PyTorch and PyTorch/XLA[#7748] - [benchmarks] Introduce verifier to verify the model output correctness against native pytorch [#7724, #7777]
- [benchmarks] Fix moco model issue on XLA [#7257, #7598]
- Type annotation for
benchmarks/
[#7289] - Default with
CUDAGraphs
on for inductor [#7749]
GPU
- Deprecate
XRT
forXLA:CUDA
[#8006]
EXPERIMENTAL FEATURES
Backward Compatibility & APIs that will be removed in 2.7 release:
- Deprecate APIs (deprecated → new):
Deprecated New PRs xla_model.xrt_world_size()
runtime.world_size()
[#7679][#7743] xla_model.get_ordinal()
runtime.global_ordinal()
[#7679] xla_model.get_local_ordinal()
runtime.global_ordinal()
[#7679] - Internalize APIs
xla_model.parse_xla_device()
[#7675]
- Improvement
- Automatic PJRT device detection when importing
torch_xla
[#7787]
- Automatic PJRT device detection when importing
- Add deprecated decorator [#7703]
Distributed
Distributed API
We have aligned our distributed APIs with upstream PyTorch. Previously, we implemented custom distributed APIs, such as torch_xla.xla_model.all_reduce. With the traceable collective support, we now enable torch.distributed.all_reduce
and similar functions for both Dynamo and non-Dynamo cases in torch_xla
.
- Support of upstream distributed APIs (torch.distributed.*) like
all_reduce
,all_gather
,reduce_scatter_tensor
,all_to_all
. Previously we used xla specific distributed APIs in xla_model [#7860, #7950, #8064]. - Introduce
torch_xla.launch()
to launch the multiprocess in order to unify torchrun andtorch_xla.distributed.xla_multiprocessing.spawn()
[#7764, #7648, #7695]. torch.distributed.reduce_scatter_tensor()
: [#7950]- Register sdp lower precision autocast [#7299]
- Add Python binding for xla::DotGeneral [#7863]
- Fix input output alias for custom inplace ops [#7822]
torch_xla.compile
- Support
full_graph
which will error out if there will be more than one ...
PyTorch/XLA 2.4 Release
Cloud TPUs now support the Pytorch 2.4 release, via PyTorch/XLA integration. On top of the underlying improvements and bug fixes in the PyTorch 2.4 release, this release introduces several features, and PyTorch/XLA specific bug fixes.
🚀 PyTorch/XLA 2.4 release delivers a 4% speedup boost (Geometric Mean) on torchbench evaluation benchmarks using openxla_eval
dynamo backend on TPUs, compared to the 2.3 release.
Highlights
We are excited to announce the release of PyTorch XLA 2.4! PyTorch 2.4 offers improved support for custom kernels using Pallas, including kernels like FlashAttention and Group Matrix Multiplication that can be used like any other torch operators and inference support for the PagedAttention kernel. We also add experimental support for eager mode that compiles and executes each operator for a better debugging and development experience.
Stable Features
PJRT
- Enable dynamic plugins by default #7270
GSPMD
- Support manual sharding and introduce high level manual sharding APIs #6915, #6931
- Support SPMDFullToShardShape, SPMDShardToFullShape #6922, #6925
Torch Compile
- Add a DynamoSyncInputExecuteTime counter #6813
- Fix runtime error when run dynamo with a profiler scope #6913
Export
- Add fx passes to support unbounded dynamism #6653
- Add dynamism support to conv1d, view, softmax #6653
- Add dynamism support to aten.embedding and aten.split_with_sizes #6781
- Inline all scalars by default in export path #6803
- Run shape propagation for inserted fx nodes #6805
- Add an option to not generate weights #6909
- Support export custom op to stablehlo custom call #7017
- Support array attribute in stablehlo composite #6840
- Add option to export FX Node metadata to StableHLO #7046
Beta Features
Pallas
- Support FlashAttention backward kernels #6870
- Make FlashAttention as torch.autograd.Function #6886
- Remove torch.empty in tracing to avoid allocating extra memory #6897
- Integrate FlashAttention with SPMD #6935
- Support scaling factor for attention weights in FlashAttention #7035
- Support segment ids in FlashAttention #6943
- Enable PagedAttention through Pallas #6912
- Properly support PagedAttention dynamo code path #7022
- Support megacore_mode in PagedAttention #7060
- Add Megablocks’ Group Matrix Multiplication kernel #6940, #7117, #7120, #7119, #7133, #7151
- Support histogram #7115, #7202
- Support tgmm #7137
- Make repeat_with_fixed_output_size not OOM on VMEM #7145
- Introduce GMM torch.autograd.function #7152
CoreAtenOpSet
- Lower embedding_bag_forward_only #6951
- Implement Repeat with fixed output shape #7114
- Add int8 per channel weight-only quantized matmul #7201
FSDP via SPMD
- Support multislice #7044
- Allow sharding on the maximal dimension of the weights #7134
- Apply optimization-barrier to all params and buffers during grad checkpointing #7206
Distributed Checkpoint
- Add optimizer priming for distributed checkpointing #6572
Usability
- Add xla.sync as a better name for mark_step. See #6399. #6914
- Add xla.step context manager to handle exceptions better. See #6751. #7068
- Implement ComputationClient::GetMemoryInfo for getting TPU memory allocation #7086
- Dump HLO HBM usage info #7085
- Add function for retrieving fallback operations #7116
- Deprecate XLA_USE_BF16 and XLA_USE_FP16 #7150
- Add PT_XLA_DEBUG_LEVEL to make it easier to distinguish between execution cause and compilation cause #7149
- Warn when using persistent cache with debug env vars #7175
- Add experimental MLIR debuginfo writer API #6799
GPU CUDA Fallback
- Add dlpack support #7025
- Make from_dlpack handle cuda synchronization implicitly for input tensors that have
__dlpack__
and__dlpack_device__
attributes. #7125
Distributed
- Switch all_reduce to use the new functional collective op #6887
- Allow user to configure distributed runtime service. #7204
- Use dest_offsets directly in LoadPlanner #7243
Experimental Features
Eager Mode
- Enable Eager mode for PyTorch/XLA #7611
- Support eager mode with torch.compile #7649
- Eagerly execute inplace ops in eager mode #7666
- Support eager mode for multi-process training #7668
- Handle random seed for eager mode #7669
- Enable SPMD with eager mode #7673
Triton
While Loop
- Prepare for torch while_loop signature change. #6872
- Implement fori_loop as a wrapper around while_loop #6850
- Complete fori_loop/while_loop and additional test case #7306
Bug Fixes and Improvements
- Fix type promotion for pow. (#6745)
- Fix vector norm lowering #6883
- Manually init absl log to avoid log spam #6890
- Fix pixel_shuffle return empty #6907
- Make nms fallback to CPU implementation by default #6933
- Fix torch.full scalar type #7010
- Handle multiple inplace update input output aliasing #7023
- Fix overflow for div arguments. #7081
- Add data_type promotion to gelu_backward, stack #7090, #7091
- Fix index of 0-element tensor by 0-element tensor #7113
- Fix output data-type for upsample_bilinear #7168
- Fix a data-type related problem for mul operation by converting inputs to result type #7130
- Make clip_grad_norm_ follow input’s dtype #7205
PyTorch/XLA 2.3 Release Notes
Highlights
We are excited to announce the release of PyTorch XLA 2.3! PyTorch 2.3 offers experimental support for SPMD Auto Sharding on single TPU host, this allows user to shard their models on TPU with a single config change. We also add the experimental support for Pallas custom kernel for inference, which enables users to make use of the popular custom kernel like flash attention and paged attention on TPU.
Stable Features
PJRT
- Experimental GPU PJRT Plugin (#6240)
- Define PJRT plugin interface in C++ (#6360)
- Add limit to max inflight TPU computations (#6533)
- Remove TPU_C_API device type (#6435)
GSPMD
Torch Compile
- Support activation sharding within torch.compile (#6524)
- Do not cache FX input args in dynamo bridge to avoid memory leak (#6553)
- Ignore non-XLA nodes and their direct dependents. (#6170)
Export
- Support of implicit broadcasting with unbounded dynamism (#6219)
- Support multiple StableHLO Composite outputs (#6295)
- Add support of dynamism for add (#6443)
- Enable unbounded dynamism on conv, softmax, addmm, slice (#6494)
- Handle constant variable (#6510)
Beta Features
CoreAtenOpSet
Support all Core Aten Ops used by torch.export
- Lower reflection_pad1d, reflection_pad1d_backward, reflection_pad3d and reflection_pad3d_backward (#6588)
- lower replication_pad3d and replication_pad3d_backward (#6566)
- Lower the embedding op (#6495)
- Lowering for _pdist_forward (#6507)
- Support mixed precision for torch.where (#6303)
Benchmark
- Unify PyTorch/XLA and Pytorch torchbench model configuration using the same torchbench.yaml (#6881)
- Align model data precision settings with pytorch HUD (#6447, #6518, #6555)
- Fix some torchbench models configuration to make it runnable using XLA (#6509, #6542, #6558, #6612).
FSDP via SPMD
Distributed Checkpoint
Usability
GPU
- Fix global_device_count(), local_device_count() for single process on CUDA(#6022)
- Automatically use XLA:GPU if on a GPU machine (#6605)
- Add SPMD on GPU instructions (#6684)
- Build XLA:GPU as a separate Plugin (#6825)
Distributed
Experimental Features
Pallas
- Introduce Flash Attention kernel using Pallas (#6827)
- Support Flash Attention kernel with casual mask (#6837)
- Support Flash Attention kernel with
torch.compile
(#6875) - Support Pallas kernel (#6340)
- Support programmatically extracting the payload from Pallas kernel (#6696)
- Support Pallas kernel with
torch.compile
(#6477) - Introduce helper to convert Pallas kernel to PyTorch/XLA callable (#6713)
GSPMD Auto-Sharding
Input Output Aliasing
- Support torch.compile for
dynamo_set_buffer_donor
- Use XLA’s new API to alias graph input and output (#6855)
While Loop
Bug Fixes and Improvements
- Propagates requires_grad over to AllReduce output (#6326)
- Avoid fallback for avg_pool (#6409)
- Fix output tensor shape for argmin and argmax where keepdim=True and dim=None (#6536)
- Fix preserve_rng_state for activation checkpointing (#4690)
- Allow int data-type for Embedding indices (#6718)
- Don't terminate the whole process when Compile fails (#6707)
- Fix a incorrect assert on frame count for PT_XLA_DEBUG=1 (#6466)
- Refactor nms into TorchVision variant.(#6814)
PyTorch/XLA 2.2 Release Notes
Cloud TPUs now support the PyTorch 2.2 release, via PyTorch/XLA integration. On top of the underlying improvements and bug fixes in the PyTorch 2.2 release, this release introduces several features, and PyTorch/XLA specific bug fixes.
Installing PyTorch and PyTorch/XLA 2.2.0 wheel:
pip install torch~=2.2.0 torch_xla[tpu]~=2.2.0 -f https://storage.googleapis.com/libtpu-releases/index.html
Please note that you might have to re-install the libtpu on your TPUVM depending on your previous installation:
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
- Note: If you meet the error
RuntimeError: operator torchvision::nms does not exist
when using torchvision in the 2.2.0 docker image, please try the following command to fix the issue:
pip uninstall torch -y; pip install torch==2.2.0
Stable Features
PJRT
PJRT_DEVICE=GPU
has been renamed toPJRT_DEVICE=CUDA
(#5754).PJRT_DEVICE=GPU
will be removed in the 2.3 release.
- Optimize Host to Device transfer (#5772) and device to host transfer (#5825).
- Miscellaneous low-level refactoring and performance improvements (#5799, #5737, #5794, #5793, #5546).
Beta Features
GSPMD
- Support DTensor API integration and move GSPMD out of experimental (#5776).
- Enable debug visualization func
visualize_tensor_sharding
(#5742), added doc. - Support
mark_shard
scalar tensors (#6158). - Add
apply_backward_optimization_barrier
(#6157).
Export
- Handled lifted constants in torch export (#6111).
- Run decomp before processing (#5713).
- Support export to
tf.saved_model
for models with unused params (#5694). - Add an option to not save the weights (#5964).
- Experimental support for dynamic dimension sizes in torch export to StableHLO (#5790, openxla/xla#6897).
CoreAtenOpSet
- PyTorch/XLA aims to support all PyTorch core ATen ops in the 2.3 release. We’re actively working on this, remaining issues to be closed can be found at issue list.
Benchmark
- Support of benchmark running automation and metric report analysis on both TPU and GPU (doc).
Experimental Features
FSDP via SPMD
- Introduce FSDP via SPMD, or FSDPv2 (#6187). The RFC can be found (#6379).
- Add FSDPv2 user guide (#6386).
Distributed Op
Persistent Compilation
- Enable persistent compilation caching (#6065).
- Document and introduce
xr.initialize_cache
python API (#6046).
Checkpointing
- Support auto checkpointing for TPU preemption (#5753).
- Support Async checkpointing through CheckpointManager (#5697).
Usability
Quantization
- Lower quant/dequant torch op to StableHLO (#5763).
GPU
Bug Fixes and Improvements
- Pow precision issue (#6103).
- Handle negative dim for Diagonal Scatter (#6123).
- Fix
as_strided
for inputs smaller than the arguments specification (#5914). - Fix squeeze op lowering issue when dim is not in sorted order (#5751).
- Optimize RNG seed dtype for better memory utilization (#5710).
Lowering
_prelu_kernel_backward
(#5724).
PyTorch/XLA 2.1 Release
Cloud TPUs now support the PyTorch 2.1 release, via PyTorch/XLA integration. On top of the underlying improvements and bug fixes in the PyTorch 2.1 release, this release introduces several features, and PyTorch/XLA specific bug fixes.
PJRT is now PyTorch/XLA's officially supported runtime! PJRT brings improved performance, superior usability, and broader device support. PyTorch/XLA r2.1 will be the last release with XRT available as a legacy runtime. Our main release build will not include XRT, but it will be available in a separate package. In most cases, we expect the migration to PJRT to require minimal changes. For more information, see our PJRT documentation.
GSPMD support has been added as an experimental feature to the PyTorch/XLA 2.1 release. GSPMD will transform the single device program into a partitioned one with proper collectives, based on the user provided sharding hints. This feature allows developers to write PyTorch programs as if they are on a single large device without any custom sharded computation ops and/or collective communications to scale. We published a blog post explaining the technical details and expected usage, you can also find more detail in this user guide.
PyTorch/XLA has transitioned from depending on TensorFlow to depending on the new OpenXLA repo. This allows us to reduce our binary size and simplify our build system. Starting from 2.1, PyTorch/XLA will release our TPU whl on the pypi.
To install PyTorch/XLA 2.1.0 wheels, please find the installation instructions below.
Installing PyTorch and PyTorch/XLA 2.1.0 wheel:
pip install torch~=2.1.0 torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html
Please note that you might have to re-install the libtpu on your TPUVM depending on your previous installation:
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
Stable Features
OpenXLA
- Migrate to pull XLA from TensorFlow to OpenXLA, TF pin dependency sunset (#5202)
- Instructions to build PyTorch/XLA with OpenXLA can be found in this doc.
PjRt Runtime
- Move PJRT APIs from experimental to
torch_xla.runtime
(#5011) - Enable PJRT C API Client and other changes for Neuron (#5428)
- Enable PJRT C API Client for Intel XPU (#4891)
- Change pjrt:// init method to xla:// (#5560)
- Make TPU detection more robust (#5271)
- Add runtime.host_index (#5283)
Functionalization
Improvements and additions
- Op Lowering
- Build System
- Migrate the build system to Bazel (#4528)
Beta Features
AMP (Automatic MIxed Precision)
TorchDynamo
- Support CPU egaer fallback in Dynamo bridge (#5000)
- Support
torch.compile
with SPMD for inference (#5002) - Update the dynamo backend name to
openxla
andopenxla_eval
(#5402) - Inference optimization for SPMD inference +
torch.compile
(#5447, #5446)
Traceable Collectives
Experimental Features
GSPMD
- Add SPMD user guide
- Enable Input-output aliasing (#5320)
- Introduce
global_runtime_device_count
to query the runtime device count (#5129) - Support partial replication (#5411 )
- Support tuple partition spec (#5488)
- Support mark_sharding on IRs (#5301)
- Make IR sharding custom sharding op (#5433)
- Introduce Hybrid Device mesh creation (#5147)
- Introduce SPMD-friendly patched nn.Linear (#5491)
- Allow dumping post optimizations HLO (#5302)
- Allow sharding n-d tensor on (n+1)-d Mesh (#5268)
- Support synchronous distributed checkpointing (#5130, #5170)
Serving Support
- SavedModel
- Added a script stablehlo-to-saved-model (#5493)
- docs:https://github.com/pytorch/xla/blob/r2.1/docs/stablehlo.md#convert-saved-stablehlo-for-serving
StableHLO
- Add StableHLO user guide (#5523)
- Add save_as_stablehlo and save_torch_model_as_stablehlo APIs (#5493)
- Make StableHLO executable (#5476)
Ongoing Development
TorchDynamo
- Enable single step graph for training
- Avoid inter-graph reshapes from aot_autograd
- Support GSPMD for activation checkpointing
GSPMD
- Support auto-sharding
- Benchmark and improving GSPMD for XLA:GPU
- Integrating to PyTorch’s Distributed Tensor API
GPU
- Support Multi-host GPU for PJRT runtime
- Improve performance on torchbench models
Quantization
- Support PyTorch PT2E quantization workflow
Bug Fixes and Improvements
PyTorch/XLA 2.0 release
Cloud TPUs now support the PyTorch 2.0 release, via PyTorch/XLA integration. On top of the underlying improvements and bug fixes in PyTorch's 2.0 release, this release introduces several features, and PyTorch/XLA specific bug fixes.
Beta Features
PJRT runtime
- Checkout our newest document; PjRt is the default runtime in 2.0.
- New Implementation of xm.rendezvous with XLA collective communication which scales better (#4181)
- New PJRT TPU backend through the C-API (#4077)
- Use PJRT to default if no runtime is configured (#4599)
- Experimental support for torch.distributed and DDP on TPU v2 and v3 (#4520)
FSDP
- Add auto_wrap_policy into XLA FSDP for automatic wrapping (#4318)
Stable Features
Lazy Tensor Core Migration
- Migration is completed, checkout this dev discussion for more detail.
- Naively inherits LazyTensor (#4271)
- Adopt even more LazyTensor interfaces (#4317)
- Introduce XLAGraphExecutor (#4270)
- Inherits LazyGraphExecutor (#4296)
- Adopt more LazyGraphExecutor virtual interfaces (#4314)
- Rollback to use xla::Shape instead of torch::lazy::Shape (#4111)
- Use TORCH_LAZY_COUNTER/METRIC (#4208)
Improvements & Additions
- Add an option to increase the worker thread efficiency for data loading (#4727)
- Improve numerical stability of torch.sigmoid (#4311)
- Add an api to clear counter and metrics (#4109)
- Add met.short_metrics_report to display more concise metrics report (#4148)
- Document environment variables (#4273)
- Op Lowering
Experimental Features
TorchDynamo (torch.compile) support
- Checkout our newest doc.
- Dynamo bridge python binding (#4119)
- Dynamo bridge backend implementation (#4523)
- Training optimization: make execution async (#4425)
- Training optimization: reduce graph execution per step (#4523)
PyTorch/XLA GSPMD on single host
- Preserve parameter sharding with sharded data placeholder (#4721)
- Transfer shards from server to host (#4508)
- Store the sharding annotation within XLATensor(#4390)
- Use d2d replication for more efficient input sharding (#4336)
- Mesh to support custom device order. (#4162)
- Introduce virtual SPMD device to avoid unpartitioned data transfer (#4091)
Ongoing development
Ongoing Dynamic Shape implementation
- Implement missing
XLASymNodeImpl::Sub
(#4551) - Make empty_symint support dynamism. (#4550)
- Add dynamic shape support to SigmoidBackward (#4322)
- Add a forward pass NN model with dynamism test (#4256)