Highlights
-
Broader Platform Support: Build support has been added for Python 3.12 and 3.13.
-
New Quantization Features: Introduced support for weight-8-bit, activation-8-bit (w8a8) quantized matrix multiplication, including both the kernel and the Torch/XLA wrapper.
-
Torchax Enhancements: The torchax library has been significantly improved with new features like torchax.amp.autocast, support for bicubic and bilinear resampling, and better interoperability with Flax.
-
Distributed Computing: Added support for the torch.distributed.scatter collective operation and fixed logic for all_gather_into_tensor.
Bug Fixes
- Fixed an issue in EmbeddingDenseBackward by removing an unnecessary cast of padding_idx to double.
- Corrected all_gather_into_tensor logic.
- Resolved an issue where Allgather would incorrectly check tuple shapes.
- Fixed unspecified behavior in custom calls.
- Addressed a bug in the ragged paged attention kernel by applying a KV mask to filter out NaNs.
- Corrected the reloading of sharded optimizer parameter groups and master weights.
- Fixed an issue where NoneRemover was not returning the modified list/tuple.
Deprecations
- A warning will now be surfaced upon initialization if the deprecated XLA:CUDA device is used. Nightly CUDA builds have been removed.
- The devkind field from xla_model.xla_device is deprecated.
- ShapeOfXlaOp is deprecated in favor of GetShape.
What's Changed
- [torchax] Fix functional.max_pool by @dvhg in #8814
- Support unused arguments in xb.call_jax by @tengyifei in #8830
- Remove xla_mark_sharding_dynamo_custom_op by @tengyifei in #8823
- Adapt Neuron data type tests by @rpsilva-aws in #8835
- Revamp SPMD guide by @tengyifei in #8807
- Add build triggers for 2.7-rc1 by @zpcore in #8828
- Fix for triton pin update. by @ysiraichi in #8825
- Expose mark_sharding_with_gradients as a public API. by @iwknow in #8826
- Add info related to codegen by @pgmoka in #8834
- Add information about codegen in codegen/xla_native_functions.yaml by @pgmoka in #8817
- install jax dependency by default to unblock pytorch ci by @zpcore in #8845
- Make cxx_abi enable by default by @zpcore in #8844
- fix existing bazel link removal by @zpcore in #8848
- fix dependency by @zpcore in #8852
- [ragged-paged-attn] Use hidden states in kv cache and support any num_kv_head by @bythew3i in #8851
- Increase TPU CI node count from 2 to 32 by @tengyifei in #8857
- Fix missing min_tpu_nodes argument by @tengyifei in #8863
- Avoid destorying the cluster by @tengyifei in #8864
- fix CPP test build failure in upstream by @zpcore in #8867
- Support collective matmul optimization in mp by @yaochengji in #8855
- Update README.md with torch_xla-2.8.0 wheels by @bhavya01 in #8870
- Add lowering for
bitwise_left_shift
. by @ysiraichi in #8865 - 8727 create a site map or centralize links in readme by @pgmoka in #8843
- Correct type conversions in OpBuilder and add corresponding test. by @iwknow in #8873
- Add lowering for
bitwise_right_shift
. by @ysiraichi in #8866 - Cache HLO in xb.call_jax and support non-tensor args by @tengyifei in #8878
- Increase CPU node count in tpu-ci by @tengyifei in #8885
- update Dockerfile config for CI by @zpcore in #8886
- Avoid unnecessary copy in TensorSource by @lsy323 in #8849
- Encapsulate Mesh invariants by @rpsilva-aws in #8882
- [ragged-paged-attn] Combine k_pages and v_pages on num_kv_head by @bythew3i in #8892
- 8740 add single processing to getting started instructions by @pgmoka in #8875
- Update cudnn dependency for newer cuda versions by @zpcore in #8894
- Remove dynamic grid by @bythew3i in #8896
- [Build] Pin cmake to build PyTorch by @lsy323 in #8902
- Use shard_as in scan to ensure that inputs and their gradients have the same sharding by @tengyifei in #8879
- point build to rc3 by @zpcore in #8907
- Add Pallas tutorials to the Pallas guide at torch_xla. by @vanbasten23 in #8904
- Fix fp8 test by @lsy323 in #8909
- Lower
isneginf()
. by @ysiraichi in #8912 - Fix cuda dependency not found by @zpcore in #8903
- [torchax] Support linalg.det and linalg.lu_solve by @dvhg in #8872
- Use f32 scratch for output so we only need to transfer output with desired dtype back to HBM. by @vanbasten23 in #8924
- Clean up deprecated APIs by @zpcore in #8927
- Add heuristic default block sizes for different cases in ragged attention kernel by @yaochengji in #8922
- [ragged-paged-attn] Unify kv strided load to one. by @bythew3i in #8929
- pin update by @lsy323 in #8908
- Include torchax in torch_xla by @tengyifei in #8895
- Fix up pin update by @tengyifei in #8935
- Rewrite scan-based GRU based on nn.GRU by @iwknow in #8914
- Update and rename torch_xla2.yml to torchax.yml by @tengyifei in #8936
- Trigger build for 2.6.1 patch by @zpcore in #8938
- Add alternative dynamo backend by @qihqi in #8893
- Add block size table for ragged_paged_attention by @yaochengji in #8942
- add rc4 trigger by @zpcore in #8941
- Fix ragged_paged_attention op signature by @yaochengji in #8943
- Use pages_per_seq * page_size instead of directly passing max_model_len by @yaochengji in #8950
- Pin update to 20250406 by @tengyifei in #8945
- Update to pytorch 2.6 by @qihqi in #8944
- Add test to ensure scan-based and the standard GRU are interchangeable. by @iwknow in #8949
- Remove deprecated typing module in ansible flow by @tengyifei in #8955
- [call_jax] support returning PyTree from the JAX function by @tengyifei in #8957
- Make scan-based GRU support
batch_first
parameter. by @iwknow in #8964 - Adapt Splash Attention from TorchPrime by @zpcore in #8911
- Add tuned parameters for Qwen/Qwen2.5-32B by @yarongmu-google in #8966
- Disable one splash attention test by @zpcore in #8970
- Avoid re-computing computation hashes by @rpsilva-aws in #8976
- @assume_pure by @tengyifei in #8962
- Update to ubuntu latest by @qihqi in #8979
- Fix call_jax hashing by @zpcore in #8981
- Fix splash attention test by @zpcore in #8978
- Add a helper class to handle mesh and sharding by @qihqi in #8967
- Add an option for JittableModule to dedup parameters. by @qihqi in #8965
- Fix merge conflict by @zpcore in #8991
- Set scoped vmem for paged attention by @zpcore in #8988
- Fix the source param sharding for GradAcc API by @rpsilva-aws in #8999
- scan-based GRU falls back to nn.GRU when
bidirectional
is true. by @iwknow in #8984 - typo fix by @haifeng-jin in #8990
- Default explicit donation for step barriers by @rpsilva-aws in #8982
- update to r2.7 rc5 by @zpcore in #9001
- Replace upstream GRU implementation with scan-based GRU by @iwknow in #9010
- Guide to debugging in PyTorch by @yaoshiang in #8998
- Remove config warning log from GradAcc API by @rpsilva-aws in #9006
- fix libtpu path by @zpcore in #9008
- composibility of assume_pure and call_jax by @qihqi in #8989
- handle requires_grad in torchax by @qihqi in #8992
- Typo fix for OP_LOWERING_GUIDE by @haifeng-jin in #9020
- Remove trailing spaces from CONTRIBUTING.md by @ghpvnist in #9024
- disable tests using mark_sharding + assume_pure on GPU by @qihqi in #9013
- [cleanup] Remove install_post_deps_pytorch_xla by @tengyifei in #9027
- Fix link to show rendered markdown instead of edit by @haifeng-jin in #9028
- Trigger final wheel by @zpcore in #9029
- Fix broken link to GPU docs by @ghpvnist in #9023
- Update CONTRIBUTING.md by @haifeng-jin in #9030
- Update README for 2.7 release by @zpcore in #9031
- Propagate mutations of Tensor slices to the source Tensor. by @wenxindongwork in #9025
- cpp_debugger.md updates by @mikegre-google in #9034
- Add runtime check when using non-kernel for ragged paged attn by @bythew3i in #8958
- Migration from deprecated (@openxla/head) headers and build-targets by @sdasgup3 in #9033
- refactor. by @yaoshiang in #9040
- Improve assume_pure docs and tests by @tengyifei in #9038
- fix runner cancel issue by @zpcore in #9011
- Use new tuned table by @bythew3i in #9041
- Update bazel.md to replace Tensorflow with Openxla by @bhavya01 in #9042
- Add support for building PyTorch/XLA using clang. by @zhanyong-wan in #9059
- Update pin 04/24 by @bhavya01 in #9045
- Extend device data node binding API to not clone specified input tensors by @rpsilva-aws in #9054
- update DTensor usage with upstream by @zpcore in #9079
- Try to re-enable previously disabled CPU test by @tengyifei in #9085
- Add tengyifei, bhavya01 and qihqi to infra owners by @bhavya01 in #9093
- Add
clang
link toPATH
. by @ysiraichi in #9053 - Create copy of nightly wheel without version number by @bfolie in #9091
- Add missing tensor data types (unsigned int 16, 32, 64) to PopulateTensorBuffer by @iwknow in #9090
- change all_to_all check to allow for split sizes > 1 by @bfolie in #9100
- Revert "Add
clang
link toPATH
." by @bhavya01 in #9099 - Enrich instructions for setting up dev environment. by @zhanyong-wan in #9104
- Add instructions on creating a PR and brining forks up-to-date. by @zhanyong-wan in #9105
- Fix nightly installation instruction by @tengyifei in #9106
- Fix typo
executation
->execution
by @ghpvnist in #9109 - Upgrade sccache to v0.10.0 in upstream docker image by @clee2000 in #9102
- Remove trailing whitespace from the repo by @ghpvnist in #9065
- [torchax] Support mark_sharding_with_gradients by @tengyifei in #9122
- Unify style across pytorch/xla by @tengyifei in #9124
- Add instructions on making VSCode discover the pytorch/xla repo. by @zhanyong-wan in #9134
- 9080 expose mat mul precision by @yaoshiang in #9081
- Add tooling and documentation for setting up clangd by @zhanyong-wan in #9137
- Disable running slow CI for doc-only changes by @ghpvnist in #9072
- Refine the gradient accumulation API by @rpsilva-aws in #9078
- [benchmarks] Fix run single config command error by @haifeng-jin in #9115
- Silence distributed warning by @tengyifei in #9140
- Test torchax on Python 3.10 - 3.12 by @tengyifei in #9139
- Ensure that all of PyTorch/XLA C++ is compiled with exceptions enabled. by @zhanyong-wan in #9146
- Re-land: Add
clang
link to PATH. by @ysiraichi in #9144 - Update instructions for updating forked pytorch by @zhanyong-wan in #9154
- Rename *.cc files to *.cpp for consistency. by @zhanyong-wan in #9149
- Add a linter check to prevent *.cc file names. by @zhanyong-wan in #9147
- Add instructions on how to rebase local commits by @zhanyong-wan in #9162
- Update XLA to a 2025/5/13 revision. by @zhanyong-wan in #9155
- Don't install wheels when building docs by @ghpvnist in #9153
- add aarch64 platform build support by @snadampal in #8663
- Fix a clang compilation error. by @zhanyong-wan in #9156
- Switch bazel default spawn strategy to
sandboxed
by @zhanyong-wan in #9168 - Replace implementation of
xla_torch.sync()
toxm.mark_step()
by @ghpvnist in #9086 - fix: reload the param groups in optimizer by @avizon-aws in #9164
- Add assume_pure_torch implementation for forward pass only by @bhavya01 in #9135
- Replace
xm.mark_step
withtorch_xla.sync()
wherever possible by @ghpvnist in #9070 - test: add test_xla_graph_execution to test flags (_set_allow_execution with PT_XLA_DEBUG_LEVEL) by @aws-yyjau in #9171
- get master ip address for neuron device by @aws-zhenguo in #9120
- [Kernel] Use call_jax to simplify the gmm pallas kernel wrapper by @yaochengji in #9180
- Allow run_tests to run a subset of the tests. by @zhanyong-wan in #9190
- Make 2 more run_tests.sh support test selection by @zhanyong-wan in #9192
- Split expensive TPU tests to run in parallel by @ghpvnist in #9198
- Clean up env var usages in run_tests.sh scripts. by @zhanyong-wan in #9207
- Update ci.md and fix typo by @tengyifei in #9150
- Reflow ci.md to 80 chars by @tengyifei in #9151
- Fix the examples in API_GUIDE by @zhanyong-wan in #9213
- Improve CI build speed by letting bazel decide how many worker threads to have by @zhanyong-wan in #9209
- Document test/neuron/run_tests.sh. by @zhanyong-wan in #9205
- [ragged-paged-attn] Apply kv mask to filter out NaNs by @bythew3i in #9219
- Update the OpenXLA pin by @haifeng-jin in #9216
- Fixed reloading of sharded master weights by @avizon-aws in #9224
- Subset of existing, approved PR by @yaoshiang in #9204
- 9082 educate users on mat mul precision by @yaoshiang in #9103
- Run all torchax tests with
find
by @tengyifei in #9227 - Yho doc cicd attempt 1 by @yaoshiang in #9233
- Make
run_tests.sh
less spammy. by @zhanyong-wan in #9237 - Add description of prerequisites for running benchmarks by @haifeng-jin in #9230
- Fix + Run
DynamicShapeDetector
tests on CI. by @ysiraichi in #9075 - Configure jobs to use better machine by @ghpvnist in #9229
- Separate pallas test into a separate test shard by @ghpvnist in #9231
- Modify NoneRemover to return modified list/tuple by @ghpvnist in #9232
- Fix docstring indentation error by @ghpvnist in #9244
- Delete unused tracing functions from
ops.cpp
. by @ysiraichi in #9240 - Fix unspecified behavior on custom calls. by @ysiraichi in #9247
- changed build to mostly match upstream PT, fixed some doc references. by @yaoshiang in #9248
- Revert "Switch bazel default spawn strategy to
sandboxed
" by @zhanyong-wan in #9212 - Add relevant imports to code snippets in amp.md by @ghpvnist in #9255
- [Kernel] add group_offset and transpose_rhs support in gmm kernel by @yaochengji in #9251
- Update README.md with supported Python version by @zzzwen in #9239
- Optimize build_developer.sh for the common workflow. by @zhanyong-wan in #9262
- Convert some XLA_CHECKs to fatal errors. by @zhanyong-wan in #9263
- Document the difference between tracing time and execution time by @sdasgup3 in #9133
- Improve
torch_xla.compile
documentation by @sdasgup3 in #9194 - [Kernel] support kv cache quantization in ragged attention kernel by @yaochengji in #9249
- Run
test_collective_permute.py
on TPU CI. by @ysiraichi in #9257 - [torchax] Added support for bicubic and billinear resampling by @qihqi in #9222
- In pjrt runtime client, raise a Python exception if XLA compilation fails. by @zhanyong-wan in #9138
- Update libtpu and jax to their latest nightly builds. by @zhanyong-wan in #9264
- Tune CI job time-outs to be closer to the actual durations. by @zhanyong-wan in #9267
- [torchax] Support View in 'jax_view' by @lsy323 in #9273
- Add scripts/update_deps.py to automate updating the dependencies. by @zhanyong-wan in #9270
- Refactor GetNumDevices and create GetNumGlobalDevices by @pgmoka in #9184
- Remove CI jobs for GPU by @zhanyong-wan in #9277
- Add instruction for exporting inlined constant by @qihqi in #8707
- Update OpenXLA to 2025/5/30. by @zhanyong-wan in #9271
- Migrate runtime.xla_device in favor of core.xla_model.xla_device by @ghpvnist in #9200
- Add interop with flax (Part 1) by @qihqi in #9176
- dynamo_bridge.py: convert tuple to list before none_remover is called by @wzhang313 in #9279
- configurable number of thread for dcp by @aws-zhenguo in #9188
- [torchax] Fixes test for cat #7398 by @qianminj123 in #9047
- Remove GPU tests from test/run_tests.sh by @zhanyong-wan in #9280
- [torchax] Make copy_ works on different device. by @qihqi in #9211
- Support fp8 lowering in scalar by @yaochengji in #9283
- Remove a few CUDA tests in preparation for #9202 by @ghpvnist in #9286
- Split training tests to separate test shard by @ghpvnist in #9281
- Disable AMP by default on CPU by @haifeng-jin in #9218
- [Kernel] add heuristic gmm block sizes choosing logic by @yaochengji in #9289
- Add w8a8 quantized matmul kernel by @vanbasten23 in #9278
- Add a script to update local/forked repos to match upstream by @zhanyong-wan in #9288
- Add w8a8 quantized matmul torchxla wrapper by @vanbasten23 in #9290
- Deprecate
devkind
field fromxla_model.xla_device
by @ghpvnist in #9284 - Introduce annotate_custom_sharding binding by @rpsilva-aws in #9203
- Add cache to
value_and_grad_partitioned
by @iwknow in #9163 - Split CPU tests so the shards have a more even test time by @ghpvnist in #9292
- Deprecate XLA:CUDA: add warning on initialization. by @ysiraichi in #9295
- Pin update 06052025 by @bhavya01 in #9299
- Do not clear UncachedCompile between graph executions by @haifeng-jin in #9282
- Misc torchax improvements by @qihqi in #9294
- [dynamic shapes] support sym_min, sym_max for XLA SymNode by @pianpwk in #9291
- Sync up jax date with libtpu date by @tengyifei in #9296
- [torchax] Fixes test for kthvalue #7458 by @qianminj123 in #9223
- Add a script to automatically format C++/Python files upon
git push
by @zhanyong-wan in #9293 - Merge cpp_tests1 and cpp_tests2 into single shard by @ghpvnist in #9298
- Update dependency Updater by @tengyifei in #9302
- Support trace_me and xp.Trace in assume_pure by @tengyifei in #9311
- Stop tracking data and log files by @zzzwen in #9308
- Fix torchgen imports in codegen by @XuehaiPan in #9310
- Extend build_developer.sh and git_sync_main.py with a -a flag. by @zhanyong-wan in #9319
- Test torchprime from PyTorch/XLA by @tengyifei in #9152
- enable_python_dispatcher in some XLA custom passes. by @laithsakka in #9312
- Clean up torchax readme. by @zhanyong-wan in #9321
- Revert Create copy of nightly wheel without version number by @bfolie in #9318
- Migrate uses of
import torch_xla as xla
toimport torch_xla
by @ghpvnist in #9325 - Migrate
.to(torch_xla.device())
to.to('xla')
by @ghpvnist in #9324 - Remove nightly CUDA builds by @tengyifei in #9329
- Update submodules whenever we switch branches. by @zhanyong-wan in #9320
- doc: update pytorch-on-xla-devices and troubleshoot doc for tensor synchronization issue by @aws-yyjau in #9258
- Update pins by @pgmoka in #9330
- Surface CUDA deprecation warning. by @ysiraichi in #9333
- Fix usage of dimensions_size() to check for tuple by @bfolie in #9347
- cleanup: remove defunct --distinct_host_configuration flag by @rickeylev in #9300
- Switch XLA_CHECKs to ABSL_CHECKs in lowering_context. by @zhanyong-wan in #9338
- Ignore the MODULE.bazel* files generated by bazel. by @zhanyong-wan in #9352
- Split _ops and _decomps. by @qihqi in #9323
- Update ragged paged attention kernel to prevent vmem oom by @Chenyaaang in #9346
- docs: fix contribute guide command by @giacomoni in #9355
- Add an option to not use dlpack. by @qihqi in #9304
- Improve
assume_pure
SPMD functionality by @tengyifei in #9360 - fix all_gather_into_tensor test and logic by @bfolie in #9332
- Avoid recompilation caused by
scan_layers
by @tengyifei in #9367 - update l_ref in kernel matrix calculation by @Chenyaaang in #9372
- Improve the API for binding python functions in C++ by @zhanyong-wan in #9351
- Fix race condition in runAtenTest by @benawilson in #9306
- Deprecate ShapeOfXlaOp in favor of GetShape by @zhanyong-wan in #9381
- Update the tuned block size. by @QiliangCui in #9376
- Add autocast feature as
torchax.amp.autocast
. by @qihqi in #9364 - Include both pytorch and torch_xla revisions in the compilation cache key. by @zhanyong-wan in #9383
- Implement prng_key as a mutable array by @tengyifei in #9305
- Revert "Avoid unnecessary copy in TensorSource (#8849)" by @jeffhataws in #9379
- Wrap def_static to enable warning reporting in static python functions. by @zhanyong-wan in #9380
- Add cache support for scan_layers by @iwknow in #9297
- Torchax: JittableModule isinstance to work with encapsulated model by @zmelumian972 in #9375
- Make the shape property adhere to torch.Tensor interface. by @hfan in #9342
- Update documentations for scan cache by @iwknow in #9394
- Update gru.py to use is_fn_pure by @tengyifei in #9393
- [build_developer] Fix vision installation command by @tengyifei in #9395
- Refactor call_jax to allow implementing ops in python by @qihqi in #9354
- Add jax_device context manager to control the device target by @zzzwen in #9382
- Update README.md by @shauheen in #9397
- [torchax]: JittableModule statedict handling by @zmelumian972 in #9195
- Revert "[torchax]: JittableModule statedict handling" by @qihqi in #9401
- Support torch.distributed.scatter collective by @bfolie in #9365
- Migrate to correct logger interface by @emmanuel-ferdman in #9191
- ErrorHandling: make
GetComputationClient()
returnStatusOr<T>
type. by @ysiraichi in #9384 - Prepare for pytorch tensor impl change in is_contiguous_custom by @laithsakka in #9402
- Allgather coalescee: Check tuple shape only if return shape is tuple. by @jeffhataws in #9403
- adding tol for numeric test of checkpointing by @yaoshiang in #9404
- Update CODEOWNERS by @bhavya01 in #9409
- Add nightly and dev images for python 3.12 by @bhavya01 in #9408
- Style improvements. by @zhanyong-wan in #9410
- Add support for 3.13 builds by @bhavya01 in #9417
- Pin update to 20250617. After that there is a pallas regression by @qihqi in #9415
- implement diagonal_copy by @Chenyaaang in #9416
- s/torch_xla2/torchax by @tengyifei in #9353
EmbeddingDenseBackward
: Removepadding_idx
cast todouble
by @unterumarmung in #9406- Update tests CI for r2.8 by @pgmoka in #9426
- Update release CI to use python 3.12 torch_xla development docker images by @bhavya01 in #9485
- [torchax] Remove safe_zip by @bhavya01 in #9525
- Update libtpu and jax versions to use with release by @bhavya01 in #9526
- cherry-pick: make jax as optional dependency by @qihqi in #9530
- Add flags for size optimization by @bhavya01 in #9547
New Contributors
- @iwknow made their first contribution in #8826
- @yarongmu-google made their first contribution in #8966
- @haifeng-jin made their first contribution in #8990
- @yaoshiang made their first contribution in #8998
- @zhanyong-wan made their first contribution in #9059
- @clee2000 made their first contribution in #9102
- @aws-yyjau made their first contribution in #9171
- @aws-zhenguo made their first contribution in #9120
- @wzhang313 made their first contribution in #9279
- @qianminj123 made their first contribution in #9047
- @pianpwk made their first contribution in #9291
- @laithsakka made their first contribution in #9312
- @rickeylev made their first contribution in #9300
- @Chenyaaang made their first contribution in #9346
- @giacomoni made their first contribution in #9355
- @benawilson made their first contribution in #9306
- @QiliangCui made their first contribution in #9376
- @hfan made their first contribution in #9342
- @unterumarmung made their first contribution in #9406
Full Changelog: v2.7.0...v2.8.0