Tags: pytorch/xla
Tags
Update PyTorch and XLA pin. (#9668) This PR updates the following pins: - PyTorch: pytorch/pytorch@928ac57 to pytorch/pytorch@21fec65 (v2.9.0-rc5) - OpenXLA: openxla/xla@92f7b59 to openxla/xla@9a9aa0e - `libtpu`: 0.0.21 to 0.0.24 - JAX (and `jaxlib`): 0.7.1 to 0.8.0 **Key Changes:** - `@python` was replaced by `@rules_python` at `BUILD` file (ref: [jax-ml/jax#31709](jax-ml/jax#31709)) - `TF_ATTRIBUTE_NORETURN` was removed in favor of abseil (ref: [openxla/xla#31699](openxla/xla#31699)) - Replaced include of `xla/pjrt/tfrt_cpu_pjrt_client.h` file by `xla/pjrt/cpu/cpu_client.h` in `pjrt_registry.cpp` ([openxla/xla#30936](openxla/xla#30936)) - Moved the old `xla/tsl/platform/default/logging.*` to `torch_xla/csrc/runtime/tsl_platform_logging.*` - They were removed in [openxla/xla#29477](openxla/xla#29477) - Copied them here, temporarily. They should be removed once we update our error throwing macros. - Commented out a few macro definitions, avoiding macro re-definitions **Update (Oct 3):** - Add an OpenXLA patch for fixing `static_assert(false)` for GCC < 13 ([ref](https://gcc.gnu.org/git/?p=gcc.git;a=commit;h=9944ca17c0766623bce260684edc614def7ea761)) - Removed the `flax` pin, since it does not overwrite `jax` anymore - Removed `TPU*` prefix of `jax.experimental.pallas.tpu` components (ref: [jax-ml/jax#29115](jax-ml/jax#29115)) --------- Co-authored-by: Bhavya Bahl <bbahl@google.com>
Update libtpu and jax versions to use with release (#9526)
PreviousNext