visu-hlo is a Python package that displays the HLO (High Level Operations) representation of JAX functions as SVG visualizations. This tool helps developers understand the computational graphs generated by XLA compilations.
- 🎯 Easy Visualization: Display HLO graphs with a single function call
- ⚡ JIT Support: Works with both regular and jitted JAX functions
- 🖼️ SVG Output: High-quality vector graphics that scale perfectly
- 🖥️ Cross-Platform: Supports Linux, macOS, and Windows
- 📦 Lightweight: Minimal dependencies - just JAX and Graphviz
import jax.numpy as jnp
from visu_hlo import show
# Display optimized HLO (default)
show(lambda x: 3 * x * 2, jnp.ones(10))To display the non-optimized HLO:
show(lambda x: 3 * x * 2, jnp.ones(10), jit=False)To save as an SVG file:
from visu_hlo import write_svg
write_svg('graph.svg', func, jnp.ones(10))pip install visu-hloSystem dependency: Install Graphviz
Full documentation: https://visu-hlo.readthedocs.io/
MIT