A functional lensing library for JAX/Equinox, providing a way to focus on and modify nested values within PyTree structures. Optix generates the same HLO code as direct access, ensuring zero overhead.
- Type-safe lenses for any JAX PyTree structure
- Zero runtime overhead (generates identical HLO code)
- Intuitive API for accessing and modifying nested values
- Complete static typing support
from optix import focus
import jax.numpy as jnp
# Create a nested PyTree structure
data = MyStruct(
x=jnp.array([1.0, 2.0]),
nested=NestedStruct(y=jnp.array(3.0))
)
# Focus on and modify a nested value
result = focus(data).at(lambda x: x.nested.y).apply(jnp.square)
>>> MyStruct(
>>> x=Array([1., 2.], dtype=float32),
>>> nested=NestedStruct(
>>> y=Array(9., dtype=float32)
>>> )
>>> )pip install jax-optixMIT License
Special thanks to Patrick Kidger for providing helpful hints and the Equinox library, which this project builds upon.