Skip to content

Molikano/Sandokan

Repository files navigation

Contributors Forks Stargazers Issues License


Sandokan Logo

Sandokan

Train with reasoning, optimize with precision.
A CPU-first C++ neural network training engine built for on-device learning — no Python, no PyTorch, no GPU required.

Explore the docs »  ·  Report Bug  ·  Request Feature


Table of Contents
  1. About
  2. Core Features
  3. Layers
  4. Performance
  5. Accuracy
  6. Build
  7. Examples
  8. Contributing
  9. License

About

Most neural network training assumes a GPU and a Python runtime. That assumption breaks in the places where learning matters most: a microcontroller updating a sensor model in the field, a robot adapting its controller between tasks, an embedded vision system that must improve on the device it runs on. Sandokan is built for those environments.

Sandokan is a CPU-only, on-device training engine. There is no CUDA dependency, no Python interpreter, no ~1 GB LibTorch runtime to drag in. The target is small-scale models — fully connected networks in the tens-of-thousands to low-millions of parameter range — trained directly on the hardware where they will be used.

Drop a single header into any C++ project and get a complete training pipeline: classification, regression, custom datasets, optimizers, learning rate schedules, and model persistence. The engine is backed by a custom slab allocator (PMAD) and Apple AMX acceleration via Eigen, so CPU training is as fast as the hardware allows.

Why CPU-only and small-scale?

The trend toward giant GPU-trained models obscures a different class of problem: systems that must keep learning after deployment, with local data, on hardware that has no network connection or power budget for a GPU. Sandokan's design constraints are deliberate — tight memory control via PMAD, mmap-backed datasets with bounded RSS, and a header-only footprint make it practical to embed a full training loop into firmware, a game engine, or a latency-critical trading system.

Built for:

  • Embedded systems and edge devices that need on-device model adaptation
  • Robotics and control systems that update parameters between episodes
  • Game engines doing real-time AI personalization without a cloud round-trip
  • Trading systems and other latency-sensitive C++ codebases with on-device inference and retraining
  • Any environment where GPU access is unavailable and Python is not an option

Built With

  • C++
  • Eigen
  • CMake
  • Apple Accelerate

(back to top)


Core Features

Module System

Define networks by composing typed submodules. Submodule<T> auto-registers with the parent on construction — you cannot accidentally forget a register_module call.

struct LetterNet : Module {
    Submodule<Linear>   proj { *this, 784, 64 };
    ReLU                relu;
    Submodule<Linear>   head { *this, 64,  26 };

    LetterNet() = default;

    Eigen::MatrixXf forward(const Eigen::MatrixXf& x) override {
        return head.forward(relu.forward(proj.forward(x)));
    }
    Eigen::MatrixXf backward(const Eigen::MatrixXf& dy) override {
        return proj.backward(relu.backward(head.backward(dy)));
    }
};

Residual blocks are first-class — the skip connection is a single + x / + dy:

struct ResBlock : Module {
    Submodule<Linear> fc1 { *this, 64, 64 };
    ReLU              relu1;
    Submodule<Linear> fc2 { *this, 64, 64 };
    ReLU              relu2;

    ResBlock() = default;

    Eigen::MatrixXf forward(const Eigen::MatrixXf& x) override {
        return relu2.forward(fc2.forward(relu1.forward(fc1.forward(x)))) + x;
    }
    Eigen::MatrixXf backward(const Eigen::MatrixXf& dy) override {
        return fc1.backward(relu1.backward(fc2.backward(relu2.backward(dy)))) + dy;
    }
};

(back to top)


PMAD Slab Allocator

The standard allocator story for neural network training is painful: every gradient buffer is a separate heap allocation, the allocator is called thousands of times per epoch, and the heap fragments over long training runs as buffers are freed and reallocated at different sizes. On embedded targets — where malloc may not even be available — this is a non-starter.

PMAD (Pre-allocated Memory Arena for Derivatives) solves this at the layer level:

  • Single allocation, zero fragmentation. PMAD walks the network topology before training begins, computes the exact size class required for every gradient buffer across all layers, and satisfies all of them from one contiguous slab. During training there are zero malloc/free calls — gradient memory is carved from the slab at fixed offsets.
  • Cache-friendly layout. Because all gradient buffers for a forward/backward pass are packed contiguously, they fit in L2/L3 cache together. Pointer chasing across scattered heap blocks disappears.
  • Deterministic latency. No allocator lock contention, no OS page-fault surprises mid-epoch. The slab is faulted in once at init time; subsequent accesses are always hits.
  • Topology-aware sizing. Size classes are derived automatically from the network architecture — you do not need to tune buffer sizes manually. Add a layer, re-call init_pmad_for(), and the slab is rebuilt correctly.
LetterNet net;
init_pmad_for(net);   // walks topology → computes size classes → allocates slab → migrates gradient pointers

Combined with Apple Accelerate/AMX for batched GEMM, PMAD is the primary reason Sandokan's batched training path runs 1.5× faster than plain Eigen on EMNIST Letters and 1.19× faster on Fashion MNIST.

(back to top)


Dataset Abstractions

Sandokan provides two dataset backends that handle normalization, shuffling, and memory layout so the training loop never sees raw bytes.

ImageDataset — mmap-backed IDX loader

Image data is memory-mapped directly from disk. Pages are faulted on demand during batch assembly — the working set stays bounded regardless of dataset size, which matters on devices with limited RAM.

Why mmap instead of fread?

  • The OS page cache deduplicates reads across processes and runs. A second training run on the same data costs zero disk I/O.
  • Random-access shuffling across 100k+ images is free — there is no seek penalty and no need to load the full dataset into RAM upfront.
  • On Apple Silicon, large contiguous mmap regions are prefetched by the AMX DMA engine, giving free hardware prefetch for sequential batch access patterns.
ImageDataset train = load_emnist_letters("data/Emnist Letters", /*train=*/true);
ImageDataset test  = load_emnist_letters("data/Emnist Letters", /*train=*/false);

train.compute_normalization();          // computes per-channel mean and std from training split
test.apply_normalization_from(train);   // applies training statistics — never leaks test distribution

Normalization statistics are stored inside the dataset and can be serialized into .sand model files so inference-time inputs are normalized identically to training.

TabularDataset — in-memory column-major store

Numeric CSV and Eigen matrix data is stored in column-major order, matching Eigen's default layout. Columns are contiguous in memory, so feature-wise normalization and batch slicing are single pointer arithmetic operations with no copying.

// From CSV — last column is the target by default
TabularDataset ds = load_csv("boston.csv");
ds.compute_normalization();          // z-scores each feature column independently
ds.compute_target_normalization();   // z-scores the target; train_regression() inverts on output

// From Eigen matrices — zero-copy when the matrices are already column-major
TabularDataset ds = TabularDataset::from_matrices(X_features, y_targets);

Both dataset types expose the same shuffled-index interface consumed by the training loops, so swapping ImageDataset for TabularDataset requires no changes to training code.

(back to top)


Optimizers and Learning Rate Schedulers

Adam     optim(1e-3f);
LinearLR sched(optim, 150 /*total epochs*/, 1e-5f /*end lr*/);

train_module(net, sched, train_set, test_set, 150, 128);
Optimizer Notes
SGD Stochastic gradient descent with fixed learning rate
Adam Adaptive moments with bias correction
Scheduler Notes
LinearLR Linearly decays learning rate from start to end over N epochs

(back to top)


Loss Functions

Loss Output activation Use case
CrossEntropyLoss Softmax Multi-class classification
BCELoss Sigmoid Binary / multi-label classification
MSELoss Linear (none) Regression

CrossEntropyLoss folds the Softmax Jacobian into its backward pass — Softmax's own backward is a passthrough. This avoids computing the full Jacobian matrix while producing the correct gradient.

(back to top)


Training Loops

// Classification — reports cross-entropy loss + accuracy each epoch
train_module(net, optim, train_set, test_set, epochs, batch_size);

// Regression — normalises targets during training, reports RMSE in original units
train_regression(net, optim, train_set, test_set, epochs, batch_size);

Both loops shuffle each epoch, skip partial batches, and call optim.epoch_end() for scheduler stepping.

(back to top)


Model Persistence

Custom .sand binary format — 4-word header, optional normalisation block, then Linear weight blocks in DFS traversal order.

#include <sandokan/io.h>

save_model(net, "letter_net.sand");                        // weights only
save_model<TabularDataset>(net, "model.sand", ds);         // weights + normalization

load_model(net, "letter_net.sand");
load_model<TabularDataset>(net, "model.sand", ds);

(back to top)


Inference

#include <sandokan/inference.h>

auto pred = predict(net, x);              // {label, confidence}
auto topk = predict_topk(net, x, 5);     // top-5 predictions
show_prediction(raw_image, true_label, topk, label_names);  // ASCII art + ranked list

(back to top)


Layers

Layer Description
Linear Fully connected — He-initialised weights, PMAD-backed gradient buffers
ReLU Element-wise rectifier, stores pre-activation for backward pass
Softmax Numerically stable column-wise softmax, passthrough backward (CE loss folds in Jacobian)
Sigmoid Element-wise sigmoid

(back to top)


Performance

Benchmarks run on Apple Silicon (M-series) with EIGEN_USE_BLAS (Apple Accelerate / AMX).
Architecture 784 → 64 → 64 → 26  |  batch = 128  |  lr = 0.01

EMNIST Letters — 124 800 training samples

Backend Total (ms) ms / epoch ms / sample samples / sec
Sandokan single-sample 7 540 1 508 0.0121 82 757
Eigen single-sample 9 257 1 851 0.0148 67 408
Sandokan batched + parallel 386 77 0.0006 1 615 666
Eigen batched 614 123 0.0010 1 015 951

Sandokan's batched path is 19.5× faster than single-sample and 1.5× faster than plain Eigen batched.

EMNIST benchmark

Fashion MNIST — 60 000 training samples

Backend ms / epoch samples / sec
Sandokan batched + parallel 34.4 1 742 000
Eigen batched 40.9 1 464 000

Speedup: 1.19×

Fashion MNIST benchmark

(back to top)


Accuracy

Dataset Architecture Optimizer Result
EMNIST Letters 784 → 64 → ResBlock(64) → 26 Adam + LinearLR 88.25% test accuracy
Fashion MNIST 784 → 64 → 64 → 10 SGD converges to ~85%

(back to top)


Build

Requirements: C++17, CMake ≥ 3.15, Eigen 3.

cmake -B build .
cmake --build build -j

For Apple AMX acceleration (strongly recommended on Apple Silicon):

target_compile_definitions(sandokan INTERFACE EIGEN_USE_BLAS)
target_link_libraries(sandokan INTERFACE "-framework Accelerate")

(back to top)


Examples

Example Task Dataset
examples/emnist_letters 26-class letter recognition EMNIST Letters
examples/tabular_demo Generic CSV classification any numeric CSV
examples/benchmark Full timing sweep (single / batched / Module) EMNIST Letters
examples/emnist_bench Sandokan vs Eigen — per-epoch timing EMNIST Letters
examples/fashion_mnist_bench Sandokan vs Eigen — per-epoch timing Fashion MNIST

Run examples from the project root so relative data/ paths resolve:

./build/examples/emnist_letters/emnist_letters
./build/examples/emnist_bench/emnist_bench
./build/examples/fashion_mnist_bench/fashion_mnist_bench

(back to top)


Contributing

Contributions are welcome. If you have a suggestion that would make this better, please fork the repo and create a pull request, or open an issue with the tag enhancement.

  1. Fork the Project
  2. Create your Feature Branch (git checkout -b feature/AmazingFeature)
  3. Commit your Changes (git commit -m 'Add some AmazingFeature')
  4. Push to the Branch (git push origin feature/AmazingFeature)
  5. Open a Pull Request

(back to top)


License

Distributed under the MIT License. See LICENSE for more information.

(back to top)


About

A CPU-first C++ neural network training engine built for on-device learning — no Python, no PyTorch, no GPU required.

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors