A dependently-typed deep learning framework for Lean 4, providing compile-time tensor shape verification. Very much WIP.
Tyr uses Lean 4's dependent type system to catch tensor dimension mismatches at compile time, not runtime. This eliminates a major source of bugs while providing access to optimized tensor operations.
-- Shapes are tracked in the type system
def linear {m n b : UInt64} (x : T #[b, m]) (M : T #[n, m]) : T #[b, n] := ...
-- Mismatched dimensions fail at compile time, not runtime!
let x : T #[32, 768] := ...
let w : T #[768, 512] := ...
let y := linear x w -- Error: expected T #[n, 768], got T #[768, 512]Install elan (the Lean version manager):
curl https://raw.githubusercontent.com/leanprover/elan/master/elan-init.sh -sSf | shThe correct Lean nightly is pinned in lean-toolchain and will be installed
automatically on first lake build.
macOS:
cd external
LIBTORCH_VERSION=2.10.0
curl --fail --location --retry 5 --retry-all-errors --show-error \
-o libtorch.zip "https://download.pytorch.org/libtorch/cpu/libtorch-macos-arm64-${LIBTORCH_VERSION}.zip"
unzip -tq libtorch.zip
unzip -q libtorch.zip && rm libtorch.zip
cd ..Or run the helper script:
bash dependencies_macos.shLinux (CPU):
cd external
LIBTORCH_VERSION=2.10.0
curl --fail --location --retry 5 --retry-all-errors --show-error \
-o libtorch.zip "https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-${LIBTORCH_VERSION}%2Bcpu.zip"
unzip -tq libtorch.zip
unzip -q libtorch.zip && rm libtorch.zip
cd ..Linux (CUDA 12.6):
cd external
curl -O https://download.pytorch.org/libtorch/nightly/cu126/libtorch-cxx11-abi-shared-with-deps-latest.zip
unzip libtorch-cxx11-abi-shared-with-deps-latest.zip && rm libtorch-cxx11-abi-shared-with-deps-latest.zip
cd ..macOS (Homebrew):
brew install libompLinux: OpenMP is typically included with GCC. Install if needed:
sudo apt install libomp-dev # Debian/UbuntumacOS:
brew install apache-arrowLinux:
sudo apt install libarrow-dev libparquet-dev- macOS: Xcode command line tools (
xcode-select --install) - Linux: GCC 9+ or Clang 10+ (
sudo apt install build-essential)
# Build all targets with Lake
lake build
# Build specific executables
lake build test_runner
lake build TrainGPT
lake build TrainDiffusion
lake build TrainNanoChat
lake build FluxDemoAll executables need the native library paths set at runtime:
macOS (Apple Silicon):
export DYLD_LIBRARY_PATH=external/libtorch/lib:/opt/homebrew/opt/libomp/lib:/opt/homebrew/libmacOS (Intel):
export DYLD_LIBRARY_PATH=external/libtorch/lib:/usr/local/opt/libomp/lib:/usr/local/libLinux:
export LD_LIBRARY_PATH=external/libtorch/lib:/usr/libOr use the Lake helper scripts which set these automatically:
lake run # runs test_runner
lake run train # runs TrainGPTlake build test_runner
.lake/build/bin/test_runner
# Or use the helper script
lake run
# Experimental/in-progress suites
lake build test_runner_experimental
.lake/build/bin/test_runner_experimentalSee Examples/README.md for detailed per-example documentation.
| Example | Description | Build target |
|---|---|---|
| TrainGPT | Character-level GPT on Shakespeare | lake build TrainGPT |
| TrainDiffusion | Discrete masked diffusion on ASCII text | lake build TrainDiffusion |
| TrainNanoChat | Modded-nanogpt distributed training | lake build TrainNanoChat |
| FluxDemo | Flux Klein 4B image generation | lake build FluxDemo |
| BranchingFlows | Combinatorial branching flow sampler | Part of Examples lib |
| NanoProof | Transformer theorem prover (model only) | Part of Examples lib |
Use the helper scripts to run TrainNanoChat under torchrun without pulling in a mismatched CUDA module stack:
# default: debug smoke run on 2 GPUs
./scripts/nanochat/run_train_torchrun.sh
# explicit 4-GPU run
NPROC_PER_NODE=4 ./scripts/nanochat/run_train_torchrun.sh \
--debug --iterations 2 --data data/nanochat --val data/nanochat
# scaling check (1/2/4 GPUs by default)
./scripts/nanochat/bench_distributed.shNotes:
run_train_torchrun.shdefaultsTORCHRUN_BINto/grid/it/data/elzar/easybuild/software/Anaconda3/2023.07-2/bin/torchrun.- Override launcher path with
TORCHRUN_BIN=/path/to/torchrun. - Override process counts in the benchmark script with
SIZES="2 4"(or any space-separated list).
The ThunderKittens-style GPU coverage now has a reusable end-to-end parity path centered on seeded fixture generation plus hardware-backed validation:
# Run the current GPU parity suite
./scripts/gpu/test_parity_suite.sh
# Add randomized MHA trials on top of the deterministic suite
RANDOMIZED_MHA_TRIALS=10 ./scripts/gpu/test_parity_suite.shNotes:
- The suite currently covers
copy,rotary,layernorm,flashattn, andmha_h100entrypoints through the executables declared inlakefile.lean. - PyTorch is the default numerical oracle for fixture generation and parity.
- If you have a local vendored ThunderKittens reference runner, set
TYR_GPU_VENDORED_REF_RUNNER=/path/to/runner. It will be called asrunner <suite-name> <fixture-dir>after each suite and should exit nonzero on mismatch. - This checkout may have an empty
thirdparty/ThunderKittensdirectory; the vendored hook is optional and exists so parity can start paying off before the submodule/runtime reference path is fully wired.
The core data type is T s - a tensor type indexed by its shape:
-- T is parameterized by shape (Array UInt64)
def T (s : Shape) : Type := TSpec.type
-- Shape mismatches are compile-time errors
def matmul {a b c : UInt64} (x : T #[a, b]) (y : T #[b, c]) : T #[a, c] := ...Generic traversal over structures containing tensors:
class TensorStruct (α : Type) where
map : (∀ {s}, T s → T s) → α → α
mapM : (∀ {s}, T s → m (T s)) → α → m α
zipWith : (∀ {s}, T s → T s → T s) → α → α → α
fold : (∀ {s}, T s → β → β) → β → α → βUse Vector n α instead of Array α for type-safe zipWith operations.
Two patterns for different use cases:
-- Fixed dimensions (type-safe):
let iter := SequentialBatchIterator.new loader 8 256
let (batch, iter') := iter.next -- Returns T #[8, 256]
-- Dynamic dimensions:
let iter := BatchIterator.new shard 8 256
let (batch, iter') ← iter.next -- Returns T #[] (erased)Tyr includes a Lean command-level type provider for SafeTensors schemas:
import Tyr.SafeTensors
open torch
-- Works with a single .safetensors file or a sharded directory.
-- If `model.safetensors.index.json` exists, introspection follows `weight_map`.
safetensors_type_provider "/path/to/model_dir_or_file" as ModelWeights
def inspectWeights : IO Unit := do
IO.println s!"discovered tensors: {ModelWeights.tensorCount}"
-- Per-tensor typed loader + typed schema metadata
let tokEmbed ← ModelWeights.load_model_embed_tokens_weight
IO.println s!"embed dtype: {ModelWeights.model_embed_tokens_weightSpec.dtype}"
IO.println s!"embed shape: {tokEmbed.runtimeShape}"
-- Hierarchical aggregate generated from tensor names
let weights ← ModelWeights.loadAll
let qProj := weights.model.layers[0]!.self_attn.q_proj.weight
IO.println s!"q_proj shape: {qProj.runtimeShape}"
-- Hierarchical subtree loaders are also generated
let decoder ← ModelWeights.model.load
IO.println s!"loaded subtree"Notes:
- Tensor schema dtype metadata uses the core
torch.DTypetype (not raw strings). - For sharded checkpoints, unsafe index shard paths (absolute paths or
..traversal) are rejected.
The C++ bindings use careful reference counting. See cc/src/tyr.cpp header for details:
borrowTensor(): Shared ownership, auto-cleanupgiveTensor(): Transfer ownership to Leanlean_dec(): Required after extracting fromlean_obj_arg, not forb_lean_obj_arg
Monitor tensor leaks via get_live_tensors which tracks outstanding C++ tensors.
- Add Lean declaration in
Tyr/Torch.leanwith@[extern "lean_torch_xxx"] - Implement in
cc/src/tyr.cppfollowing reference counting conventions - Rebuild:
lake build
lakefile.lean- Lake build configurationlean-toolchain- Lean version specificationcc/- C++ FFI bindings (LibTorch wrapper)Tyr/- Core framework (tensors, modules, optimizers, distributed)Examples/- Training scripts and model implementationsTests/- Test suites
This repo uses scoped conventional commit subjects:
type(scope): summary
A commit message template is included at .gitmessage. Enable it locally:
./scripts/setup-git-hooks.shThis sets:
commit.template=.gitmessagecore.hooksPath=.githooks
Included hooks:
pre-commit: fails on staged whitespace errors and conflict markerscommit-msg: enforcestype(scope): summary(e.g.feat(qwen35): add video stream patchify)pre-push: validates pushed commit subjects withscripts/check-commit-messages.sh
CI also enforces commit subjects using scripts/check-commit-messages.sh.
Manual check examples:
./scripts/check-commit-messages.sh HEAD~20..HEAD
COMMIT_MSG_ENFORCE_FROM=<commit> ./scripts/check-commit-messages.sh HEAD~20..HEADTBD.