Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dfdx-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ num-traits = { workspace = true }
safetensors = { workspace = true, optional = true }
memmap2 = { workspace = true, optional = true }
half = { version = "2.3.1", optional = true, features = ["num-traits", "rand_distr"] }
gemm = { version = "0.16.14", default-features = false, optional = true, features = ["rayon"] }
gemm = { version = "0.17.1", default-features = false, optional = true, features = ["rayon"] }
rayon = { version = "1.7.0", optional = true }
libm = { workspace = true }
wgpu = { version = "0.18.0", features = ["glsl", "spirv"], optional = true }
Expand Down
1 change: 1 addition & 0 deletions dfdx-core/src/data/collate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ impl<A, B> Collate for Vec<(A, B)> {
impl<'a, A, B> Collate for Vec<&'a (A, B)> {
type Collated = (Vec<&'a A>, Vec<&'a B>);
fn collated(self) -> Self::Collated {
#[allow(clippy::map_identity)]
self.into_iter().map(|(a, b)| (a, b)).unzip()
}
}
Expand Down
38 changes: 0 additions & 38 deletions dfdx-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,44 +128,6 @@ pub mod prelude {
pub use crate::tensor_ops::*;
}

/// Sets a CPU `sse` flag to flush denormal floating point numbers to zero. The opposite of this is [keep_denormals()].
///
/// Some resources:
/// 1. [Effects of Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/the-effects-of-using-flush-to-zero-mode?lang=en)
/// 2. [When to use Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/when-to-use-flush-to-zero-mode?lang=en)
pub fn flush_denormals_to_zero() {
#[cfg(all(target_arch = "x86", target_feature = "sse"))]
{
use std::arch::x86::{_MM_FLUSH_ZERO_ON, _MM_SET_FLUSH_ZERO_MODE};
unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON) }
}

#[cfg(all(target_arch = "x86_64", target_feature = "sse"))]
{
use std::arch::x86_64::{_MM_FLUSH_ZERO_ON, _MM_SET_FLUSH_ZERO_MODE};
unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON) }
}
}

/// Sets a CPU flag to keep denormal floating point numbers. The opposite of this is [flush_denormals_to_zero()].
///
/// Some resources:
/// 1. [Effects of Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/the-effects-of-using-flush-to-zero-mode?lang=en)
/// 2. [When to use Flush-To-Zero mode](https://developer.arm.com/documentation/dui0473/c/neon-and-vfp-programming/when-to-use-flush-to-zero-mode?lang=en)
pub fn keep_denormals() {
#[cfg(all(target_arch = "x86", target_feature = "sse"))]
{
use std::arch::x86::{_MM_FLUSH_ZERO_OFF, _MM_SET_FLUSH_ZERO_MODE};
unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_OFF) }
}

#[cfg(all(target_arch = "x86_64", target_feature = "sse"))]
{
use std::arch::x86_64::{_MM_FLUSH_ZERO_OFF, _MM_SET_FLUSH_ZERO_MODE};
unsafe { _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_OFF) }
}
}

#[cfg(test)]
pub(crate) mod tests {
pub use num_traits::{Float, NumCast, Zero};
Expand Down
15 changes: 15 additions & 0 deletions dfdx-core/src/shapes/shape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,18 +121,33 @@ where
pub trait Array<T>: IntoIterator<Item = T> {
type Dim: Dim;
fn dim(&self) -> Self::Dim;
fn from_fn<F>(cb: F, len: Self::Dim) -> Self
where
F: FnMut(usize) -> T;
}
impl<T, const N: usize> Array<T> for [T; N] {
type Dim = Const<N>;
fn dim(&self) -> Self::Dim {
Const
}
fn from_fn<F>(cb: F, _len: Self::Dim) -> Self
where
F: FnMut(usize) -> T,
{
std::array::from_fn(cb)
}
}
impl<T> Array<T> for std::vec::Vec<T> {
type Dim = usize;
fn dim(&self) -> Self::Dim {
self.len()
}
fn from_fn<F>(cb: F, len: Self::Dim) -> Self
where
F: FnMut(usize) -> T,
{
(0..len).map(cb).collect()
}
}

/// A collection of dimensions ([Dim]) that change how a multi-dimensional
Expand Down
2 changes: 1 addition & 1 deletion dfdx-core/src/tensor/gradients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ impl<E, D: Storage<E>> Gradients<E, D> {
#[inline]
pub(crate) fn many_and_ref<L: Shape, R: Shape>(
&mut self,
ls: &Vec<impl Tensorlike<L, E, D>>,
ls: &[impl Tensorlike<L, E, D>],
r: &impl Tensorlike<R, E, D>,
) -> (Vec<&mut D::Vec>, &D::Vec) {
for i in 0..ls.len() {
Expand Down
2 changes: 2 additions & 0 deletions dfdx-core/src/tensor_ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ mod sum_to;
mod tanh;
mod to_dtype;
mod tri;
mod unstack;
mod upscale2d;
mod var_to;

Expand Down Expand Up @@ -276,6 +277,7 @@ pub use sum_to::SumTo;
pub use tanh::tanh;
pub use to_dtype::{to_dtype, ToDtypeKernel};
pub use tri::{lower_tri, upper_tri};
pub use unstack::{SubDim, TryUnstack};
pub use upscale2d::{
Bilinear, GenericUpscale2D, NearestNeighbor, TryUpscale2D, Upscale2DKernel, UpscaleMethod,
};
Expand Down
63 changes: 63 additions & 0 deletions dfdx-core/src/tensor_ops/unstack/cpu_kernel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use crate::{
prelude::NoneTape,
shapes::*,
tensor::{unique_id, Cpu, Error, Tensor},
};

// note: in order to return NoneTape items and not require a tape type information T,
// each element must be optional.
impl<E: Dtype> super::UnstackKernel<E> for Cpu {
fn forward<S: Shape, OptionalItems>(
&self,
stack: Tensor<S, E, Self, NoneTape>,
) -> Result<OptionalItems, Error>
where
S: super::SubDim,
OptionalItems: Array<Option<Tensor<S::Tail, E, Self, NoneTape>>, Dim = S::Head>,
{
let (head, tail) = stack.shape().sub_dim();
let stack_data = stack.data.as_slice();
let unstack_num_elements = tail.num_elements();
Ok(OptionalItems::from_fn(
|i| {
let mut data = self
.try_alloc_elem(unstack_num_elements, E::default())
// TODO: remove unwrap (needs try_from_fn)
// https://github.com/rust-lang/rust/issues/89379
.unwrap();

data.copy_from_slice(
&stack_data[i * unstack_num_elements..(i + 1) * unstack_num_elements],
);

Some(Tensor {
id: unique_id(),
data: std::sync::Arc::new(data),
shape: *tail.shape(),
strides: tail.strides(),
device: self.clone(),
tape: NoneTape,
})
},
head,
))
}
fn backward(
&self,
grad_stack: &mut Self::Vec,
grad_unstack: &Self::Vec,
unstack_idx: usize,
) -> Result<(), Error> {
let unstack_num_elements = grad_unstack.len();
for (i, stacked) in grad_stack
.iter_mut()
.skip(unstack_idx * unstack_num_elements)
.take(unstack_num_elements)
.enumerate()
{
*stacked += grad_unstack[i];
}

Ok(())
}
}
27 changes: 27 additions & 0 deletions dfdx-core/src/tensor_ops/unstack/cuda_kernel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
use crate::{
prelude::NoneTape,
shapes::*,
tensor::{Cuda, Error, Tensor},
};
use cudarc::types::CudaTypeName;

impl<E: Dtype + CudaTypeName> super::UnstackKernel<E> for Cuda {
fn forward<S: Shape, OptionalItems>(
&self,
_stack: Tensor<S, E, Self, NoneTape>,
) -> Result<OptionalItems, Error>
where
S: super::SubDim,
OptionalItems: Array<Option<Tensor<S::Tail, E, Self, NoneTape>>, Dim = S::Head>,
{
todo!()
}
fn backward(
&self,
_grad_stack: &mut Self::Vec,
_grad_unstack: &Self::Vec,
_unstack_idx: usize,
) -> Result<(), Error> {
todo!()
}
}
Loading