This repository contains a JAX software-based implementation of some stochastically rounded operations.
When encoding the weights of a neural network in low precision (such as bfloat16
), one runs into stagnation problems: updates end up being too small relative to the numbers the precision of the encoding.
This leads to weights becoming stuck and the model's accuracy being significantly reduced.
Stochastic arithmetic lets you perform the operations in such a way that the weights have a non-zero probability of being modified anyway. This avoids the stagnation problem (see figure 4 of "Revisiting BFloat16 Training") without increasing the memory usage (as might happen if one were using a compensated summation to solve the problem).
The downside is that software-based stochastic arithmetic is significantly slower than normal floating-point arithmetic. It is thus viable for things like the weight update (when using the output of an Optax optimizer for example) but would not be appropriate in a hot loop.
Do not hesitate to submit an issue or a pull request if you need added functionalities for your needs!
This repository introduces the add
and tree_add
operations.
They take a PRNGkey and two tensors (or pytree respectively) to be added but round the result up or down randomly:
import jax
import jax.numpy as jnp
import jochastic
# problem definition
size = 10
dtype = jnp.bfloat16
key = jax.random.PRNGKey(1993)
# deterministic addition
key, keyx, keyy = jax.random.split(key, num=3)
x = jax.random.normal(keyx, shape=(size,), dtype=dtype)
y = jax.random.normal(keyy, shape=(size,), dtype=dtype)
result = x + y
print(f"deterministic addition: {result}")
# stochastic addition
result_sto = jochastic.add(key, x, y)
print(f"stochastic addition: {result_sto} ({result_sto.dtype})")
difference = result - result_sto
print(f"difference: {difference}")
Both functions take an optional is_biased
boolean parameter.
If is_biased
is True
(the default value), the random number generator is biased according to the relative error of the operation
else, it will round up half of the time on average.
Jitting the functions is left to the user's discretion (you will need to indicate that is_biased
is static).
NOTE:
Very low precision (16 bits floating-point arithmetic or less) is extremely brittle.
We recommend using higher precision locally (such as using 32 bits floating point arithmetic to compute the optimizer's update) then casting down to 16 bits at summing / storage time (something that Pytorch does transparently when using their addcdiv
in low precision).
Both functions will accept mixed-precision inputs (adding a high precision number to a low precision), use that information for the rounding then return an output in the lowest precision of their inputs (contrary to most casting conventions).
We use TwoSum
to measure the numerical error done by the addition, our tests show that it behaves as needed on bfloat16
(some edge cases might be invalid, leading to an inexact computation of the numerical error but, it is reliable enough for our purpose).
This and the nextafter
function let us emulate various rounding modes in software (this is inspired by Verrou's backend).
You can use this BibTeX reference if you use Jochastic within a published work:
@misc{Jochastic,
author = {Nestor, Demeure},
title = {Jochastic: stochastically rounded operations between JAX tensors.},
year = {2022},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/nestordemeure/jochastic}}
}
You will find a Pytorch implementation called StochasTorch here.