Skip to content
49 changes: 47 additions & 2 deletions praxis/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2691,6 +2691,8 @@ def _get_raw_grad_transformation(self, lr: optax.Schedule):

def sharded_static_accumulation(
num_sub_batches: int,
clip_gradient_norm_to_value: float,
clip_gradient_single_norm_to_value: float,
base_tx: ShardedGradientTransformation,
) -> ShardedGradientTransformation:
"""Gradient transformation for ShardedStaticAccumulator optimizer."""
Expand Down Expand Up @@ -2759,10 +2761,52 @@ def update_fn(updates: NestedJTensor,
lambda: new_count)

def _run_base_tx():

def _compute_grad_norm(grads: NestedMap) -> JTensor:
"""Computes total grad norm."""
grad_norms_squared = jax.tree_map(lambda x: jnp.sum(x * x), grads)
grad_norms_squared, _ = jax.tree_util.tree_flatten(grad_norms_squared)
return jnp.sqrt(jnp.sum(jnp.stack(grad_norms_squared)))


def scale_gradients(
raw_grads: NestedMap,
clip_grad_norm_to_value: float = 0.0,
clip_grad_single_norm_to_value: float = 0.0):

def clip_grads(grads):
assert not (clip_grad_norm_to_value and clip_grad_single_norm_to_value)
if clip_grad_norm_to_value:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe assert only one of them is true?

grad_norm = _compute_grad_norm(raw_grads)

grad_scale = jnp.minimum(
jnp.array(1, grad_norm.dtype),
jnp.array(clip_grad_norm_to_value, grad_norm.dtype)
/ grad_norm)
grads = jax.tree_map(lambda g: g * grad_scale, grads)
elif clip_grad_single_norm_to_value:
grad_single_norm = jax.tree_map(lambda x: jnp.sqrt(jnp.sum(x * x)),
grads)

def scale_gradient(grad, norm):
return grad * jnp.minimum(
jnp.array(1, norm.dtype),
jnp.array(clip_grad_single_norm_to_value,
norm.dtype) / norm)
grads = jax.tree_map(scale_gradient, grads, grad_single_norm)

return grads

grads = clip_grads(raw_grads)
return grads

averaged_updated = jax.tree_map(lambda acc: acc / num_sub_batches,
new_accumulated_update)
scaled_updated = scale_gradients(averaged_updated,
clip_gradient_norm_to_value,
clip_gradient_single_norm_to_value)
emission_updates, emission_base_state = base_tx.update(
averaged_updated, state.base_state, params) # pytype: disable=attribute-error # jax-ndarray
scaled_updated, state.base_state, params) # pytype: disable=attribute-error # jax-ndarray
return (emission_updates,
jax.tree_map(lambda u: jnp.zeros_like(u, dtype=jnp.float32),
updates), emission_base_state)
Expand Down Expand Up @@ -2830,4 +2874,5 @@ def _get_raw_grad_transformation(
self, lr: optax.Schedule) -> GeneralGradientTransformation:
p = self._hparams
base_tx = self.base_optimizer._get_raw_grad_transformation(lr) # pylint: disable=protected-access
return sharded_static_accumulation(p.num_sub_batches, base_tx)
return sharded_static_accumulation(p.num_sub_batches, p.clip_gradient_norm_to_value,
p.clip_gradient_single_norm_to_value, base_tx)