Adam optimizer that offloads gradients and optimizer states to CPU memory, enabling full-parameter training of larger models with limited GPU memory.
from offload_adam import Adam
# Create a model
model = create_model().bfloat16().cuda()
# Initialize the optimizer
optimizer = Adam(
model.parameters(),
lr=0.001,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0.01,
mode="stochastic_rounding",
decoupled_weight_decay=True, # AdamW
)
# Training loop
for input_data, target in dataloader:
# Forward pass
output = model(input_data)
loss = loss_function(output, target)
# Backward pass
if gradient_accumulation:
loss.backward()
else:
loss.backward()
optimizer.step()
optimizer.zero_grad()from offload_adam import OffloadAdam
# Create a model
model = create_model().bfloat16().cuda()
# Initialize the optimizer
optimizer = OffloadAdam(
model, # pass model instead of model.parameters()
lr=0.001,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0.01,
mode="stochastic_rounding",
decoupled_weight_decay=True, # AdamW
)
# Training loop
for input_data, target in dataloader:
# Forward pass
output = model(input_data)
loss = loss_function(output, target)
# Backward pass
if gradient_accumulation:
optimizer.ready_for_optimizer_step = False
loss.backward()
else:
optimizer.ready_for_optimizer_step = True
loss.backward()
optimizer.step()-
Module
register_full_backward_pre_hook:- asynchronously copy states from CPU to GPU
-
Parameter
register_post_accumulate_grad_hook:- gradient accumulation
- norm calculation for gradient clipping (in OffloadAdam)
- optimizer step (in OffloadAdamV2)
- asynchronously copy states back to CPU
Optimizer step is done on GPU.
With offloading, it's possible do full-parameter training of:
- 7B models using single 24GB GPU and 42GB+ host memory
- 14B models using single 48GB GPU and 84GB+ host memory
- 32B models using single 80GB GPU and 192GB+ host memory
The overhead of offloading depends on the input size (total number of tokens in a batch) and GPU compute speed.
gradients, momentum and variance in BF16
| Stage | H2D (bytes per param) | D2H (bytes per param) |
|---|---|---|
| gradient reset | 0 | 2 |
| gradient accumulation | 2 | 2 |
| optimizer step (stochastic rounding) | 6 | 4 |
| optimizer step (fp32 master weights) | 10 | 8 |
Per-token backward time:
Weight transfer time:
Number of tokens to overlap weight transfer:
With theoretical values (165 TFLOPS, 32 GB/s PCIe 4.0):
- ~2578 tokens to overlap gradient transfer
With measured values (175 TFLOPS, 25 GB/s):
- ~3500 tokens to overlap gradient transfer
In actual training, bandwidth will be lower.
Large memory consumption but small computational cost.
Actual involved tokens are usually smaller than the full table:
- For gradient accumulation, only used tokens in the current batch are involved.
- For optimizer step, all ever used tokens are involved because of momentum.
There are optimization chances but not implemented yet.
- DTensor support