16 releases (5 breaking)
Uses new Rust 2024
| new 0.6.2 | Apr 17, 2026 |
|---|---|
| 0.6.0 | Apr 9, 2026 |
| 0.5.0 | Mar 31, 2026 |
| 0.4.3 | Mar 25, 2026 |
| 0.1.0 | Jan 19, 2026 |
#1072 in Programming languages
Used in axonml
1MB
23K
SLoC
axonml-jit
Overview
axonml-jit is the AxonML tracing JIT. A Tracer records tensor operations
into a typed Graph IR; an Optimizer runs a stack of passes (constant
folding, DCE, CSE, algebraic simplification, elementwise fusion, strength
reduction); a JitCompiler emits either an interpreter-backed
CompiledFunction or a native Cranelift-compiled one; and a graph-hash
FunctionCache (LRU) lets repeated compilations hit cache. Cranelift
0.111 is the code-gen backend.
Features
- Operation Tracing —
trace(|tracer| { ... })andTracerAPIs build aGraphfrom recorded operations using thread-local state - Typed IR —
Graph,Node,NodeId,Op(40+ variants),Shape(with broadcast checks and broadcast-shape computation),DataType - Optimizer —
Optimizer::default_passes()plus sixOptimizationPassvariants (ConstantFolding,DeadCodeElimination,CommonSubexpressionElimination,AlgebraicSimplification,ElementwiseFusion,StrengthReduction) - JIT Compiler —
JitCompilerwith interpreter execution and optional Cranelift native codegen (enable_native(true)) - Higher-Level Facade —
compile_fn,compile_fn_with_config,compile_graph,compile_graph_with_config,CompiledModel,LazyCompiled(deferred compilation) withCompileConfig(Mode::{Default, ReduceOverhead, MaxAutotune},Backend::{Default, Eager, AOT, ONNX},fullgraph,dynamic,disable, custom passes) - Function Caching —
FunctionCachewith LRU eviction andSelf::hash_graph-based keying;CacheStatswithutilization - Shape Inference — automatic shape propagation including broadcast semantics (
Shape::broadcast_shape) - Thread-Local Tracing — safe concurrent tracing via per-thread tracer state
Modules
| Module | Description |
|---|---|
ir |
Graph, Node, NodeId, Op, Shape, DataType, topological order, validation |
trace |
Tracer, TracedValue, trace entry point, thread-local state |
optimize |
Optimizer, OptimizationPass (6 variants), default_passes |
codegen |
JitCompiler, CompiledFunction (Interpreted + Cranelift Native kinds) |
compile |
compile_fn, compile_graph, CompiledModel, LazyCompiled, CompileConfig, CompileStats, Mode, Backend |
cache |
FunctionCache (LRU), CacheStats |
error |
JitError, JitResult |
Usage
Add this to your Cargo.toml:
[dependencies]
axonml-jit = "0.6.1"
Basic Tracing and Compilation
use axonml_jit::{trace, JitCompiler};
// Trace operations to build a computation graph
let graph = trace(|tracer| {
let a = tracer.input("a", &[2, 3]);
let b = tracer.input("b", &[2, 3]);
let c = a.add(&b);
let d = c.mul_scalar(2.0);
tracer.output("result", d)
});
// Compile the graph (interpreter-backed by default)
let compiler = JitCompiler::new();
let compiled = compiler.compile(&graph)?;
// Execute with real data — inputs are name/slice tuples
let a_data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let b_data = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5];
let result = compiled.run(&[("a", &a_data[..]), ("b", &b_data[..])])?;
Cranelift Native Codegen
use axonml_jit::JitCompiler;
let mut compiler = JitCompiler::new();
compiler.enable_native(true); // opt in to Cranelift codegen
let compiled = compiler.compile(&graph)?;
Traced Operations
use axonml_jit::trace;
let graph = trace(|tracer| {
let x = tracer.input("x", &[4, 4]);
// Elementwise + scalar ops
let y = x.relu()
.mul_scalar(2.0)
.add_scalar(1.0);
// Activation functions
let z = y.sigmoid().tanh().gelu();
// Reductions
let mean = z.mean_axis(1, true);
// Shape operations
let reshaped = mean.reshape(&[-1]);
tracer.output("output", reshaped)
});
Custom Optimization
use axonml_jit::{Optimizer, OptimizationPass, JitCompiler};
// Build a custom pass pipeline
let mut optimizer = Optimizer::new();
optimizer.add_pass(OptimizationPass::ConstantFolding);
optimizer.add_pass(OptimizationPass::AlgebraicSimplification);
optimizer.add_pass(OptimizationPass::DeadCodeElimination);
optimizer.add_pass(OptimizationPass::CommonSubexpressionElimination);
// Apply optimizations directly
let optimized_graph = optimizer.optimize(graph.clone());
// Or hand the optimizer to the compiler
let compiler = JitCompiler::with_optimizer(optimizer);
let compiled = compiler.compile(&graph)?;
Higher-Level compile_fn / CompiledModel
use axonml_jit::{compile_fn, compile_fn_with_config, CompileConfig, Mode, Backend};
use std::collections::HashMap;
// Zero-config
let model = compile_fn(|t| {
let x = t.input("x", &[8]);
t.output("y", x.relu())
})?;
// With config
let cfg = CompileConfig::new()
.mode(Mode::MaxAutotune)
.backend(Backend::Default)
.fullgraph(true);
let model = compile_fn_with_config(|t| {
let x = t.input("x", &[8]);
t.output("y", x.gelu())
}, cfg)?;
// CompiledModel runs on HashMap<String, Vec<f32>>
let mut inputs = HashMap::new();
inputs.insert("x".to_string(), vec![-1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let outputs = model.run(&inputs)?;
// Inspect compilation
println!("{} -> {} nodes ({:.1}% reduction)",
model.stats().original_nodes,
model.stats().optimized_nodes,
model.stats().optimization_ratio() * 100.0);
Lazy Compilation
use axonml_jit::LazyCompiled;
let lazy = LazyCompiled::new(|t| {
let x = t.input("x", &[3]);
t.output("y", x.exp())
});
// Compiled on first call, cached thereafter.
let outputs = lazy.run(&inputs)?;
Cache Management
use axonml_jit::JitCompiler;
let compiler = JitCompiler::new();
// Compile multiple graphs
let _ = compiler.compile(&graph1)?;
let _ = compiler.compile(&graph2)?;
// Check cache statistics
let stats = compiler.cache_stats();
println!("Cached functions: {}", stats.entries);
println!("Cache utilization: {:.1}%", stats.utilization() * 100.0);
// Clear cache if needed
compiler.clear_cache();
Supported Operations
All recorded via Op enum variants:
Binary Operations
Add,Sub,Mul,Div,Pow,Max,Min
Unary Operations
Neg,Abs,Sqrt,Exp,Log,Sin,Cos,Tanh
Activations
Relu,Sigmoid,Gelu,Silu
Scalar Operations
AddScalar,MulScalar
Reductions
Sum,SumAxis,Mean,MeanAxis,MaxAxis
Shape Operations
Reshape,Transpose,Squeeze,Unsqueeze,Broadcast
Matrix Operations
MatMul
Comparison / Conditional
Gt,Lt,Eq,Where
Special
Cast(changeDataType),Contiguous,Input,Output,Constant
Optimization Passes
| Pass | Description |
|---|---|
ConstantFolding |
Evaluate constant expressions at compile time |
DeadCodeElimination |
Remove nodes that don't feed an output |
CommonSubexpressionElimination |
Reuse identical subexpressions |
AlgebraicSimplification |
x * 1 = x, x + 0 = x, etc. |
ElementwiseFusion |
Fuse consecutive elementwise ops |
StrengthReduction |
Replace expensive ops with cheaper equivalents |
Default CompileConfig enables ConstantFolding, DeadCodeElimination, and
CommonSubexpressionElimination. Mode::MaxAutotune additionally appends
ElementwiseFusion and AlgebraicSimplification.
Tests
cargo test -p axonml-jit
License
Licensed under either of:
- MIT License
- Apache License, Version 2.0
at your option.
Dependencies
~14–26MB
~362K SLoC