15 releases (4 breaking)

Uses new Rust 2024

0.6.2 Apr 17, 2026
0.6.0 Apr 9, 2026
0.5.0 Mar 31, 2026
0.4.3 Mar 25, 2026
0.2.8 Jan 26, 2026

#237 in Machine learning


Used in 6 crates

MIT/Apache

1.5MB
33K SLoC

axonml-serialize

AxonML Logo

License: Apache-2.0 License: MIT Rust 1.75+ Version 0.6.1 Part of AxonML

Overview

axonml-serialize handles model state I/O for AxonML: named-parameter StateDicts, Checkpoints with full training state, and format conversion for PyTorch / ONNX interop. The native .axonml format is bincode-encoded binary; JSON and SafeTensors (behind the safetensors feature) are also supported. Format is detected from the file extension with magic-byte fallback.

Features

  • Multiple Formats — native bincode .axonml, .json, .safetensors (feature-gated)
  • State DictionariesStateDict with from_module, insert, get, entries, keys, merge, filter_prefix, strip_prefix, add_prefix, remove, set_metadata / get_metadata, total_params, size_bytes, summary
  • Training CheckpointsCheckpoint + CheckpointBuilder + TrainingState with loss / val-loss / lr / custom metric history, best-metric tracking, epoch / step counters, ISO-8601 timestamp, config map
  • Format Detectiondetect_format(path) by extension, detect_format_from_bytes(bytes) by magic bytes; Format::is_binary, Format::supports_streaming, Format::extension, Format::name, Format::all
  • PyTorch Conversionfrom_pytorch_key, to_pytorch_key, pytorch_layer_mapping, convert_from_pytorch, transpose_linear_weights
  • ONNX Utilitiesto_onnx_shape / from_onnx_shape (dynamic batch dim handling), OnnxOpType with parse_op / as_str
  • High-Level APIsave_model(&model, path) / load_model(&model, path) (name-matched param load with positional fallback), save_state_dict / load_state_dict, save_checkpoint / load_checkpoint

Feature Flags

Flag Effect
safetensors Enables .safetensors save/load (f32 / f16 / bf16 / f64 input), pulls safetensors = "0.3" and half

Modules

Module Description
state_dict TensorData, StateDictEntry, StateDict
checkpoint Checkpoint, CheckpointBuilder, TrainingState
format Format enum and detection helpers
convert PyTorch / ONNX conversion utilities, OnnxOpType

Usage

Add the dependency to your Cargo.toml:

[dependencies]
axonml-serialize = "0.6.1"

# Or with SafeTensors:
axonml-serialize = { version = "0.6.1", features = ["safetensors"] }

Saving and Loading Models

use axonml_serialize::{save_model, load_model, load_state_dict};
use axonml_nn::Linear;

// Save a model (format detected from extension)
let model = Linear::new(10, 5);
save_model(&model, "model.axonml")?;        // Binary format
save_model(&model, "model.json")?;          // JSON format
// save_model(&model, "model.safetensors")?; // Requires `safetensors` feature

// Inspect the state dict directly
let sd = load_state_dict("model.axonml")?;
println!("Parameters: {}", sd.total_params());
println!("Size: {} bytes", sd.size_bytes());

// Or load weights back into a model (name-matched, positional fallback)
let target = Linear::new(10, 5);
let loaded = load_model(&target, "model.axonml")?;
println!("Loaded {loaded} parameters");

Working with State Dictionaries

use axonml_serialize::{StateDict, TensorData};

// Create a state dictionary
let mut state_dict = StateDict::new();

let weights = TensorData {
    shape: vec![10, 5],
    values: vec![0.0; 50],
};
state_dict.insert("linear.weight".to_string(), weights);

let bias = TensorData {
    shape: vec![5],
    values: vec![0.0; 5],
};
state_dict.insert("linear.bias".to_string(), bias);

// Query the state dictionary
assert!(state_dict.contains("linear.weight"));
println!("{}", state_dict.summary());

// Filter / rename
let linear_params = state_dict.filter_prefix("linear.");
let stripped       = state_dict.strip_prefix("linear.");
let prefixed       = state_dict.add_prefix("module.");

Training Checkpoints

use axonml_serialize::{Checkpoint, TrainingState, save_checkpoint, load_checkpoint};

// Track training state
let mut training_state = TrainingState::new();
training_state.record_loss(0.5);
training_state.record_loss(0.3);
training_state.record_val_loss(0.35);
training_state.record_lr(1e-3);
training_state.record_metric("accuracy", 0.92);
training_state.update_best("loss", 0.3, false);  // lower is better

training_state.next_epoch();
training_state.next_step();

// Average last N losses
let smoothed = training_state.avg_loss(10);

// Build checkpoint
let checkpoint = Checkpoint::builder()
    .model_state(model_state_dict)
    .optimizer_state(optimizer_state_dict)
    .training_state(training_state)
    .rng_state(rng_bytes)
    .epoch(10)
    .global_step(5000)
    .config("learning_rate", "0.001")
    .config("batch_size", "32")
    .build();

// Save and load checkpoints (bincode)
save_checkpoint(&checkpoint, "checkpoint.ckpt")?;
let loaded = load_checkpoint("checkpoint.ckpt")?;

println!("Resuming from epoch {}", loaded.epoch());
println!("Best metric: {:?}", loaded.best_metric());

Format Detection

use axonml_serialize::{detect_format, detect_format_from_bytes, Format};

// Detect from file extension
assert_eq!(detect_format("model.json"),        Format::Json);
assert_eq!(detect_format("model.safetensors"), Format::SafeTensors);
assert_eq!(detect_format("model.bin"),         Format::Axonml); // default

// Detect from file contents
let bytes = b"{\"key\": \"value\"}";
let format = detect_format_from_bytes(bytes);
assert_eq!(format, Some(Format::Json));

// Format properties
assert!(Format::Axonml.is_binary());
assert!(!Format::Json.is_binary());

PyTorch Conversion

use axonml_serialize::{
    from_pytorch_key, to_pytorch_key, pytorch_layer_mapping,
    convert_from_pytorch, transpose_linear_weights,
};

// Convert PyTorch key naming to AxonML
let key = from_pytorch_key("module.layer1.weight");

// Convert entire state dictionary
let axonml_dict = convert_from_pytorch(&pytorch_dict);

// Transpose linear weights if needed (PyTorch uses [out, in])
let transposed = transpose_linear_weights(&weight_data);

ONNX Shape Utilities

use axonml_serialize::{to_onnx_shape, from_onnx_shape, OnnxOpType};

// Convert to ONNX shape (with dynamic batch)
let onnx_shape = to_onnx_shape(&[3, 224, 224], true);
assert_eq!(onnx_shape, vec![-1, 3, 224, 224]);

// Convert from ONNX shape (replace -1 with default)
let shape = from_onnx_shape(&[-1, 3, 224, 224], 1);
assert_eq!(shape, vec![1, 3, 224, 224]);

// ONNX operator name mapping
let op = OnnxOpType::parse_op("Relu");
assert_eq!(op.as_str(), "Relu");

State Dictionary Metadata

use axonml_serialize::StateDict;

let mut state_dict = StateDict::new();
state_dict.set_metadata("framework_version", "0.6.1");
state_dict.set_metadata("model_architecture", "ResNet50");

if let Some(version) = state_dict.get_metadata("framework_version") {
    println!("Saved with version: {}", version);
}

Tests

cargo test -p axonml-serialize

License

Licensed under either of:

at your option.

Dependencies

~10MB
~193K SLoC