This is a single-file implementation of a minGRU-based language model that is trained only using integer datatypes. This is possible thanks to EGGROLL, our novel method for evolution that is highly efficient on GPUs.
Install the version of jax available on your system
conda create -n nanoegg python=3.13 conda activate nanoegg pip install datasets tyro tqdm wandb "jax[cuda12]"
This script will automatically download and convert the minipile dataset before training the model.
python run.py --parallel_generations_per_gpu 32768 --validate_every 10
Our logged fitnesses are token-level cross-entropy in bits per byte (i.e. log base 2 instead of natural log). For context, a unigram model would get scores of 5.00 and 4.91 on the train and validation set respectively, while a bigram model would get scores of 3.97 and 3.90.
Explanation of atypical hyperparameters:
- parallel_generations_per_gpu: effectively the “batch size” per gpu. Multi-gpu and multi-node training is not integrated yet.
- alpha: analogous to learning rate in typical training
- when alpha < 1.0, this sets a threshold based on the “Z-test” in statistics, which effectively ensures that only alpha fraction of the parameters update each iteration
- When alpha > 1.0, it acts as a decaying alpha value over the epochs, where the final value is 2-alpha
- sigma_shift: sets the sigma for evolution to 2-sigma_shift
- use_clt: when true, uses the standard ES gradient approximation; otherwise it uses the sign of noises in the gradient approximation
- fast_fitness: when true, uses the sign of the difference of performance between antithetical samples as the normalized fitness; otherwise it normalizes the difference in performance
- noise_reuse: number of epochs to reuse the same perturbation for the same thread_id. By default, it is set to 1, meaning that each epoch uses a new set of perturbations. If set to 0 it keeps the same perturbations across all epochs.
- tokens_per_update: effectively the “sequence length” for optimization
- group_size: Number of noise perturbations that should reuse the same set of input tokens. It must be a multiple of 2 (due to antithetical sampling) and a factor of parallel_generations_per_gpu.
- Note that amount of unique data used per epoch is parallel_generations_per_gpu * tokens_per_update // group_size