-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Insights: jax-ml/jax
Overview
Could not load contribution data
Please try again later
113 Pull requests merged by 12 people
-
Enforce the shape of MMA operands and result in their types.
#30390 merged
Jul 22, 2025 -
[pallas:mosaic_gpu] Do not use
lax.zeros_like_array
#30385 merged
Jul 22, 2025 -
[pallas]
pl.loop
now acceptsstep=
#30274 merged
Jul 22, 2025 -
Update deprecation schedule for
mlir.custom_call
, attributes ofxla_client
, andxla_bridge.get_backend
.#30386 merged
Jul 22, 2025 -
[pallas:mosaic_gpu]
async_load_tmem
no longer accepts anidx
#30382 merged
Jul 22, 2025 -
Rollback XLA archival change, current automation does not calculate the shasum correctly
#30379 merged
Jul 22, 2025 -
[Mosaic GPU][NFC] Implement
inline_mgpu
andcustom_primitive
for WG semantics.#30147 merged
Jul 22, 2025 -
Add visibility to
//jax:pretty_printer
#30373 merged
Jul 22, 2025 -
write docs about controlling array layouts
#30368 merged
Jul 22, 2025 -
Remove
SafeNumDevices
now that all shardings are implemented in C++#30375 merged
Jul 22, 2025 -
Loosen test tolerance for
tests/linalg_sharding_test.py::LinalgShardingTest::test_batch_axis_sharding_jvp
.#30370 merged
Jul 22, 2025 -
Fix handling of empty arrays in ufunc.reduce/accumulate
#29248 merged
Jul 21, 2025 -
Fix mesh_cast + new rng keys usage
#30367 merged
Jul 21, 2025 -
[Pallas] Only run TPU interpret mode tests on CPU.
#30366 merged
Jul 21, 2025 -
Fix
py_import
dependencies.#30365 merged
Jul 21, 2025 -
[Pallas][Mosaic GPU] Fix racy test by having each warpgroup work on non-overlapping data.
#30354 merged
Jul 21, 2025 -
[JAX] Purge more caches when clearing backends
#30324 merged
Jul 21, 2025 -
Migrate third party uses of
jax.lib.xla_bridge.get_backend
tojax.extend.backend.get_backend
.#30359 merged
Jul 21, 2025 -
[Pallas] Make semantics of BlockSpec.memory_space in emit_pipeline consistent with pallas_call.
#30284 merged
Jul 21, 2025 -
jnp.fft.fftfreq: fix regression in complex dtype support
#30356 merged
Jul 21, 2025 -
[jex] Add
proto
<->xla_client.HloSharding
andproto ->
xla_client.OpSharding` APIs to jex.#30357 merged
Jul 21, 2025 -
Fetch XLA archive using the GitHub API endpoint instead of the web link
#30358 merged
Jul 21, 2025 -
Parametrize build system on CUDA major version
#28968 merged
Jul 21, 2025 -
Allow setting transfer_size for cross host transfers.
#30355 merged
Jul 21, 2025 -
gumbel distribution implementation
#29343 merged
Jul 21, 2025 -
Automated g4 rollback of changelist 784947368.
#30353 merged
Jul 21, 2025 -
Consider this example:
#30350 merged
Jul 21, 2025 -
[Mosaic] Add store canonicalization for an expand reshape->store fusion
#30333 merged
Jul 21, 2025 -
Add note about direct linearize to changelog
#30351 merged
Jul 21, 2025 -
Update editorconfig to the Python line length of 88
#30348 merged
Jul 21, 2025 -
Repair initializers not matching Initializer protocol
#29561 merged
Jul 21, 2025 -
test_multiprocess: only initialize GPUs that exist
#30248 merged
Jul 21, 2025 -
#sdy Add section on Shardy migration issues about all meshes being the same and custom_partitioning.
#30347 merged
Jul 21, 2025 -
update
use_shardy_partitioner
description.#30345 merged
Jul 21, 2025 -
improve AOT input sharding/layout mismatch error message
#30344 merged
Jul 21, 2025 -
Fix creation of sharding rules for fused_attention_stablehlo.
#30343 merged
Jul 20, 2025 -
[JAX] Enable Shardy by default in JAX.
#30302 merged
Jul 20, 2025 -
[mutable-arrays] remat discharge rule
#29370 merged
Jul 19, 2025 -
[mutable-arrays] basic shard_map + mutable arrays support
#30340 merged
Jul 19, 2025 -
Add JAX tests for deadlock verifier
#30264 merged
Jul 19, 2025 -
Fix CI config for non-rbe CUDA tests with py_import.
#30329 merged
Jul 19, 2025 -
[mosaic]
tpu_custom_call.CostEstimate
is now a typed dict#30303 merged
Jul 19, 2025 -
Reverts 7c394f41574b1f669cbb6a47fd03bf1925a19444
#30331 merged
Jul 19, 2025 -
[Pallas/TPU] Interpret mode: Don't use source_info context manager if it's None
#30326 merged
Jul 19, 2025 -
Use
operand.aval.sharding
in convert_element_type's transpose rule#30330 merged
Jul 19, 2025 -
Fixed a bug previously hidden by jax_use_direct_linearize=True
#30325 merged
Jul 19, 2025 -
Move ShapeDtypeStruct to core.py to break circular deps
#30323 merged
Jul 18, 2025 -
Create workflow for testing bazel cuda non rbe + py import
#28860 merged
Jul 18, 2025 -
Replace
jax.stages.OutInfo
withjax.ShapeDtypeStruct
#30318 merged
Jul 18, 2025 -
#sdy condition jax export tests on compatibility version when using shardy
#30312 merged
Jul 18, 2025 -
[Mosaic GPU] Fix barrier indexing in tests.
#30307 merged
Jul 18, 2025 -
Exit the CUDA test job early if not all wheels can be downloaded.
#30316 merged
Jul 18, 2025 -
Add
eval_shape
as a method ofTraced
.#30314 merged
Jul 18, 2025 -
#sdy fix-forward jaxlib compatibility issue
#30311 merged
Jul 18, 2025 -
Use direct linearization by default
#30262 merged
Jul 18, 2025 -
[Scaled Matmul] add a sharding rule and fix custom partition method
#30250 merged
Jul 18, 2025 -
[JAX] add a missing sharding rule in ffi.md and ffi.ipynb
#30305 merged
Jul 18, 2025 -
[Mosaic GPU] Forbid the assignment of splat layouts to non-splat constants.
#30280 merged
Jul 18, 2025 -
[Mosaic GPU] Prevent
vector.load
s from being assignedsplat
layouts.#30278 merged
Jul 18, 2025 -
[Mosaic GPU] Introduce constraints in the equational layout inference.
#30275 merged
Jul 18, 2025 -
#sdy add missing include in py_client.cc for Shardy.
#30300 merged
Jul 18, 2025 -
[Mosaic GPU] Add lowering for
tmem_alloc
andtmem_dealloc
.#30223 merged
Jul 18, 2025 -
[Pallas/Mosaic GPU] Add a deviceless export test for Pallas/Mosaic GPU.
#30298 merged
Jul 18, 2025 -
[Mosaic] Add store canonicalization for an expand reshape->store fusion
#30254 merged
Jul 18, 2025 -
[Mosaic TPU] Allow producing predicate phi node in control follow.
#30292 merged
Jul 18, 2025 -
Fix
JAX_ENABLE_X64
env var name for CI scripts.#30295 merged
Jul 17, 2025 -
[XLA] Add a primitive for tagging individual XLA ops with frontend_attributes.
#30282 merged
Jul 17, 2025 -
Add type interface for
jax.nn
#30196 merged
Jul 17, 2025 -
layout: Implement out-of-line template method as a non-template.
#30285 merged
Jul 17, 2025 -
[rollforward]: Add make_transfer_server_interface_factory to the py_socket_transfer jaxlib
#30283 merged
Jul 17, 2025 -
fused_attention_stablehlo_test_gpu: test skip for 12.0
#30272 merged
Jul 17, 2025 -
[Pallas] Add option to allow out-of-bound reads in TPU interpret mode.
#30259 merged
Jul 17, 2025 -
[Mosaic GPU] Add the
Distinct
constraint to the equation system.#30273 merged
Jul 17, 2025 -
[rollforward]: Add transfer_server_interface to xla.cc as an extra optional factory argument
#30269 merged
Jul 17, 2025 -
[Pallas/TPU] Add empty_ref_like function + SMEM support for Refs
#30239 merged
Jul 17, 2025 -
Rollback #30206 due to downstream test failures
#30266 merged
Jul 17, 2025 -
Use registers for state by default in emit_pipeline
#30265 merged
Jul 16, 2025 -
[XLA] Refactor
xla_metadata.py
to resolve dependency cycle#30240 merged
Jul 16, 2025 -
Add nn.logmeanexp.
#30206 merged
Jul 16, 2025 -
[Mosaic] Allow sublane rotation for non-sublane dim aligned shape.
#29419 merged
Jul 16, 2025 -
Reverts aee89744efbd3aa93ffea7a3bab803483add9289
#30260 merged
Jul 16, 2025 -
Disable
too_slow
in data.draw() for test_cast_from_32bit#30261 merged
Jul 16, 2025 -
[jax:benchmark] Add tracing benchmarks for some common operations.
#29413 merged
Jul 16, 2025 -
Set
check_leaks=False
indef direct_linearize
to fix some tests#30256 merged
Jul 16, 2025 -
[Mosaic] Support multiple non-contracting dims if they are collapsable.
#30076 merged
Jul 16, 2025 -
fix jax2tf scatter impl
#30255 merged
Jul 16, 2025 -
[pallas]
pl.core_map
now supportfunctools.partial
ed functions#30220 merged
Jul 16, 2025 -
[Mosaic TPU] Fix the assumption that beforeBody's inputs size is same as afterBody's in scf.while.
#30244 merged
Jul 16, 2025 -
support cudnn sdpa on gb300 (compute_cap=10.3) with cudnn > 9.11
#30242 merged
Jul 16, 2025 -
[Mosaic GPU] Add
tmem_alloc
andtmem_dealloc
to the Mosaic GPU dialect.#30183 merged
Jul 16, 2025 -
Includes
sharding_constraint_p
primitive inroofline
.#30232 merged
Jul 16, 2025 -
[Mosaic GPU] Add
tcgen05.mma
to the Mosaic GPU dialect.#30189 merged
Jul 16, 2025 -
[Mosaic GPU] Add equational layout inference rule for
scf.WhileOp
.#30230 merged
Jul 16, 2025 -
[Mosaic GPU] Add equational layout inference rule for
scf.ForOp
.#30229 merged
Jul 16, 2025 -
[Pallas] Physicalize fusion-dtype input avals before evaluating expressions.
#30241 merged
Jul 16, 2025 -
- Add more items in the ragged paged attention auto tuning table
#30151 merged
Jul 15, 2025 -
Remove version checks for ml_dtypes >= 0.5.0
#30235 merged
Jul 15, 2025 -
[Pallas] Support all integer casting cases.
#30127 merged
Jul 15, 2025 -
Add a test mentioned in https://github.com/jax-ml/jax/issues/29883 to JAX test suite
#30234 merged
Jul 15, 2025 -
[Pallas][Mosaic GPU] Thread the python kernel function name to the MLIR kernel name.
#30200 merged
Jul 15, 2025 -
[doc] add docs for jax.lax.reduce_window
#30225 merged
Jul 15, 2025 -
Mark the jax.util submodule as deprecated
#30227 merged
Jul 15, 2025 -
[direct-linearize] add
instantiate
to linearize_jaxpr, fixpoint test#30231 merged
Jul 15, 2025 -
[doc] replace tensorflow.org/xla with openxla.org/xla
#30224 merged
Jul 15, 2025
48 Pull requests opened by 4 people
-
Testing orbax changes with shardy enabled by default
#30226 opened
Jul 15, 2025 -
Refactor: Add carry support to nd_loop
#30237 opened
Jul 15, 2025 -
FIX: Handle WGSplatFragLayout in cond lowering.
#30238 opened
Jul 15, 2025 -
Support core axis index in the `device_id` dict for async copy and semaphore.
#30243 opened
Jul 16, 2025 -
[XLA:GPU] Add JAX-based precision tests for Triton and cuBLAS
#30246 opened
Jul 16, 2025 -
[Mosaic GPU] Query amount of shared memory programmatically
#30257 opened
Jul 16, 2025 -
Skip Pallas and Mosaic GPU tests that don't fit on RTX 6000 PRO
#30258 opened
Jul 16, 2025 -
[Mosaic:TPU] Explicitly instantiate VectorLayout::print
#30263 opened
Jul 16, 2025 -
Remove `local_config_nvshmem` repository and corresponding macros.
#30267 opened
Jul 17, 2025 -
Implement performance optimized w8a8 pallas kernel
#30268 opened
Jul 17, 2025 -
Add Windows Bazel CPU tests with py_import dependency to continuous tests.
#30286 opened
Jul 17, 2025 -
[Pallas/TPU] Add option to allow skipping the device barrier
#30289 opened
Jul 17, 2025 -
[jax:custom_partitioning] Allow factors for non-batching dimensions to
#30291 opened
Jul 17, 2025 -
[Mosaic][SC] Add custom assembly format to tpu.enqueue_indirect_dma
#30293 opened
Jul 17, 2025 -
[JAX] Disable Shardy in `JaxExportTest` and `CompatTest` if jaxlib version is before 0.7.0.
#30304 opened
Jul 18, 2025 -
[JAX] Disable Shardy in JAX export if jaxlib version is before 0.7.0.
#30306 opened
Jul 18, 2025 -
Test PR for runner
#30315 opened
Jul 18, 2025 -
Update ragged_dot kernels to use new GroupInfo for persistence
#30317 opened
Jul 18, 2025 -
Reverts dd59b47c07caa777f57637107379374e7906de12
#30319 opened
Jul 18, 2025 -
Reverts c9700e637550b6404e85aeae1ff4eb207e1f2d76
#30320 opened
Jul 18, 2025 -
flip flag to check OSS tests
#30321 opened
Jul 18, 2025 -
[Pallas:MGPU] Expose TCGEN05_TMEM_NATIVE_COL
#30322 opened
Jul 18, 2025 -
Add aval out to pull_block_spec signature in fusible dtype
#30332 opened
Jul 19, 2025 -
Automated Code Change
#30334 opened
Jul 19, 2025 -
fix core.Tracer inheritance (just to see what breaks)
#30341 opened
Jul 19, 2025 -
[mutable-arrays] flip JAX_MUTABLE_ARRAY_CHECKS=True by default
#30342 opened
Jul 20, 2025 -
Fix numerical bugs in the gradients of sigmoid/logistic, tanh, expm1.
#30346 opened
Jul 21, 2025 -
Accelerate deprecation for jax.lib.xla_bridge.get_backend.
#30349 opened
Jul 21, 2025 -
Add more Bazel tests to Nightly/Release job.
#30361 opened
Jul 21, 2025 -
Bump fonttools from 4.51.0 to 4.59.0
#30362 opened
Jul 21, 2025 -
Bump fsspec from 2024.5.0 to 2025.7.0
#30363 opened
Jul 21, 2025 -
Bump hypothesis from 6.102.4 to 6.136.1
#30364 opened
Jul 21, 2025 -
[Mosaic] Allow matrix-vector dot.
#30369 opened
Jul 21, 2025 -
[Mosaic] Allow vector::Extract for non-32 bits vector result.
#30371 opened
Jul 21, 2025 -
Implement explicit mode sharding rule for scatter
#30377 opened
Jul 22, 2025 -
Bump libtpu version before release.
#30380 opened
Jul 22, 2025 -
#sdy Delete JAX test configs enabling Shardy.
#30381 opened
Jul 22, 2025 -
Test B200 against CUDA 12.8 only
#30383 opened
Jul 22, 2025 -
[pallas] Forked `load` and `store` into `triton` and `tpu`
#30384 opened
Jul 22, 2025 -
[Mosaic GPU] Add support for warp shuffles with elements wider than 32-bit
#30387 opened
Jul 22, 2025 -
Finalize deprecation of xla_bridge.get_compile_options.
#30388 opened
Jul 22, 2025 -
Finalize deprecation of xla_extension.
#30389 opened
Jul 22, 2025 -
[JAX] Optimize `util.lru_cache(..., trace_context_in_key=False)`
#30392 opened
Jul 22, 2025 -
Refactor Bazel CPU RBE and Bazel GPU Non-RBE and add more Bazel tests to Nightly/Release job.
#30393 opened
Jul 22, 2025 -
[jax/BUILD] avoid `_src` dependencies on top-level `jax` packages
#30394 opened
Jul 22, 2025 -
#sdy Delete JAX test configs enabling Shardy.
#30395 opened
Jul 22, 2025
12 Issues closed by 7 people
-
Failed build: CI - with Numpy/Scipy nightly wheels (nightly)
#30055 closed
Jul 22, 2025 -
Add Gumbel distribution to scipy.stats
#29319 closed
Jul 21, 2025 -
jax.numpy.fft.fftfreq no longer supports complex dtype
#30287 closed
Jul 21, 2025 -
`jax.tree_util.tree_map` fails when a registered pydantic object which has been copied using `deep=True`
#30299 closed
Jul 19, 2025 -
Could you add the support of the new optimizer: Muon
#30309 closed
Jul 19, 2025 -
[sharding-in-types] setting global mesh+complex tensors+linear solve = problems
#30327 closed
Jul 19, 2025 -
Add nn.logmeanexp
#30178 closed
Jul 16, 2025 -
`fori_loop` gets slower after `jit`
#30245 closed
Jul 16, 2025 -
[Pallas, jax 0.6.0] Interpret mode seems to invoke GPU compilation process instead of being CPU only
#30214 closed
Jul 16, 2025 -
No error message but failing zero-copy?
#30228 closed
Jul 16, 2025 -
[sharding-in-types] `jnp.linalg.{slogdet/solve}` do not work with explicit sharding
#29883 closed
Jul 15, 2025
12 Issues opened by 11 people
-
Pallas fails to write to output ref when using bfloat16 on TPU v3-8
#30391 opened
Jul 22, 2025 -
Profiling tool shows no GPU processes despite GPU Util being as expected.
#30378 opened
Jul 22, 2025 -
jet and equinox.nn.MLP
#30352 opened
Jul 21, 2025 -
Fusing the optimizer step into the backward pass
#30338 opened
Jul 19, 2025 -
jax.grad precision: float32 gradients of bfloat16 weights
#30337 opened
Jul 19, 2025 -
Segfault while Performing All-to-All Collective Operation on 8xH100 SXM5 (Shard + Swapaxes + Shard)
#30335 opened
Jul 19, 2025 -
Inconsistent Reduction precision in backwards computation
#30310 opened
Jul 18, 2025 -
jax.numpy.sort performance regression
#30296 opened
Jul 18, 2025 -
Json parse exceptions (and others) in perfetto traces
#30290 opened
Jul 17, 2025 -
[Feature Request] Add Sparse Attention kernel for GPUs in Pallas
#30281 opened
Jul 17, 2025 -
"ComputeCallSignature failed" in verbose logs investigating performance loss.
#30270 opened
Jul 17, 2025
25 Unresolved conversations
Sometimes conversations happen on old items that aren’t yet closed. Here is a list of all the Issues and Pull Requests with unresolved conversations.
-
Extend numpy.size to accept multiple axes.
#30132 commented on
Jul 22, 2025 • 14 new comments -
added solve_sylvester and accompanying tests
#28810 commented on
Jul 21, 2025 • 5 new comments -
Implementing LİSHT
#30218 commented on
Jul 18, 2025 • 0 new comments -
gather/scatter: push negative index handling into primitives
#30205 commented on
Jul 15, 2025 • 0 new comments -
#sdy Fix forward of making JAX changes so we can fall back to GSPMD in JAX export if the loaded module was lowered for GSPMD.
#30190 commented on
Jul 22, 2025 • 0 new comments -
[jaxprs] Hoist large constants as arguments during lowering
#30180 commented on
Jul 18, 2025 • 0 new comments -
[Mosaic:TPU] tileArrayShape is 1 for replicated dims
#30138 commented on
Jul 18, 2025 • 0 new comments -
#sdy Remove MHLO shardings from round-trip export
#30091 commented on
Jul 17, 2025 • 0 new comments -
#sdy Fix forward of making XLA C++ changes so we can fall back to GSPMD in JAX export if the loaded module was lowered for GSPMD.
#29951 commented on
Jul 17, 2025 • 0 new comments -
Add metadata for CUDA and libtpu versions
#29715 commented on
Jul 15, 2025 • 0 new comments -
Add Hermetic C++ Toolchains for Linux x86_64 builds.
#29672 commented on
Jul 17, 2025 • 0 new comments -
[CI] Add bazel TPU presubmit testing
#29660 commented on
Jul 18, 2025 • 0 new comments -
WIP: Refactor "How to think in JAX" and surrounding pages
#29541 commented on
Jul 21, 2025 • 0 new comments -
Bump fsspec from 2024.5.0 to 2025.5.1
#29011 commented on
Jul 21, 2025 • 0 new comments -
Initial commit for attaching XLA metadata to individual HLO operations via 'jax.attach_metadata(...)'
#28953 commented on
Jul 15, 2025 • 0 new comments -
[Doc] Rename gpu_performance_tips.md to performance_tips.md with new CPU performance tips session
#24961 commented on
Jul 16, 2025 • 0 new comments -
Add custom derivative for scipy.special.hyp2f1
#30195 commented on
Jul 22, 2025 • 0 new comments -
Error with `jax.custom_batching.sequential_vmap` in the `jax.ensure_compile_time_eval` context
#29996 commented on
Jul 22, 2025 • 0 new comments -
Support autodiff of Eigendecomposition with repeated eigenvalues
#669 commented on
Jul 21, 2025 • 0 new comments -
custom jvp + transpose raises UnexpectedTracerError
#29948 commented on
Jul 21, 2025 • 0 new comments -
feature request: sparse jacobian and sparse hessians
#1032 commented on
Jul 19, 2025 • 0 new comments -
Segmentation fault when calling exported choleksy on CPU
#29610 commented on
Jul 18, 2025 • 0 new comments -
Saturating arithmetic
#26566 commented on
Jul 17, 2025 • 0 new comments -
Gemma 3 + `jax-metal`: 'mhlo.convolution' op Not supported
#27288 commented on
Jul 17, 2025 • 0 new comments -
jax_explain_cache_misses is not thread safe
#30163 commented on
Jul 16, 2025 • 0 new comments