Skip to content

vbkaisetsu/prima-undine

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Prima-undine ⛵ : A Neural Network Toolkit

Backends

Example

use std::fs;

use prima_undine::functions::ArithmeticFunctions;
use prima_undine::functions::BasicFunctions;
use prima_undine::{
    devices as D, initializers as I, optimizers as O, shape, Device, Model, Node, Optimizer,
    Parameter, Tensor,
};
use prima_undine_contrib::functions::ContribFunctions;

use serde::{Deserialize, Serialize};
use serde_json::json;

#[derive(Model, Serialize, Deserialize)]
struct XORModel<'dev> {
    pw1: Parameter<'dev>,
    pb1: Parameter<'dev>,
    pw2: Parameter<'dev>,
    pb2: Parameter<'dev>,
}

impl<'dev> XORModel<'dev> {
    fn new(device: &'dev Device) -> Self {
        Self {
            pw1: device.new_parameter(shape![8, 2], &I::Normal::new(0.5, 1.)),
            pb1: device.new_parameter(shape![8], &I::Normal::new(0.5, 1.)),
            pw2: device.new_parameter(shape![1, 8], &I::Normal::new(0.5, 1.)),
            pb2: device.new_parameter(shape![], &I::Normal::new(0.5, 1.)),
        }
    }
}

fn forward<'arg, 'dev, T>(x: &T, params: &'arg mut XORModel<'dev>) -> T
where
    'dev: 'arg,
    T: From<&'arg mut Parameter<'dev>> + ArithmeticFunctions<T> + BasicFunctions,
{
    let w1 = T::from(&mut params.pw1);
    let b1 = T::from(&mut params.pb1);
    let w2 = T::from(&mut params.pw2);
    let b2 = T::from(&mut params.pb2);
    let h = (w1.matmul(x) + b1).tanh();
    w2.matmul(h) + b2
}

fn main() {
    let dev = D::Naive::new();

    // Prepare training data
    let ref x_data = dev.new_tensor_by_slice(shape![2; 4], &[0., 0., 0., 1., 1., 0., 1., 1.]);
    let ref t_data = dev.new_tensor_by_slice(shape![; 4], &[0., 1., 1., 0.]);

    let train = true;

    if train {
        // Initialize parameters
        let mut model = XORModel::new(&dev);

        // Use SGD Optimizer
        let mut optimizer = O::SGD::new(0.1);
        optimizer.configure_model(&mut model);

        // Train data
        for _ in 0..100 {
            {
                let ref t = Node::from(t_data);
                let ref x = Node::from(x_data);
                let ref y = forward(x, &mut model);
                let ref diff = t - y;
                let ref loss = (diff * diff).batch_mean();
                println!("loss: {}", loss.to_float());
                println!("  y: {:?}", y.to_vec());
                loss.backward();
            }
            optimizer.update_model(&mut model);
        }

        // Save parameters using Serde
        let model_json = json!(model).to_string();
        fs::write("./model.json", &model_json).unwrap();

    } else {
        // Load parameters using Serde
        let model_json = fs::read_to_string("./model.json").unwrap();
        let mut model: XORModel = serde_json::from_str(&model_json).unwrap();

        // Move parameters to the device
        model.move_to_device(&dev);

        // Calculate and print the result
        let ref y = forward(x_data, &mut model);
        let ref diff = t_data - y;
        let loss = (diff * diff).batch_mean();
        println!("loss: {}", loss.to_float());
        println!("  y: {:?}", y.to_vec());
    }
}

About

Prima-undine ⛵ : A Neural Network Toolkit inspired by primitiv

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages