Skip to content

timlautk/polargrad

Repository files navigation

PolarGrad: A Class of Matrix-Gradient Optimizers from a Unifying Preconditioning Perspective

PolarGrad (Polar Gradient methods; Lau et al., 2025) is a class of matrix-gradient optimizers based on the concept of gradient-anisotropy preconditioning in optimization. It has close relation to Muon (Jordan et al., 2024) and stochastic spectral descent (SSD; Carlson et al., 2015a, 2015b). In addition to being an optimizer for matrix parameters in neural networks, PolarGrad can also be used as a preconditioned matrix optimization algorithm for matrix optimization problems such as matrix regression and low-rank matrix factorization/completion.

The main differences between PolarGrad and Muon/SSD are:

  • PolarGrad uses the QDWH (Nakatsukasa et al., 2010) or ZOLO-PD (Nakatsukasa and Freund, 2016) algorithm to compute the polar decomposition of the gradient matrix, while Muon uses the Newton-Schulz iteration to compute the polar decomposition (see the section below for further details). The NS iteration is a matrix iterative polynomial method that computes the polar decomposition of a matrix by iteratively applying a polynomial to the matrix. However, it requires tuning of the coefficients of the polynomial, which can be challenging in practice. PolarGrad also include the nuclear norm (the dual norm of the spectral norm) scaling of the update matrix, which is not present in Muon. The inclusion of such term is necessary for the convergence of optimizers based on polar decomposition for strongly convex and Lipschitz smooth problems with deterministic gradients, as shown in the convergence analysis and the matrix quadratic regression example of PolarGrad (Lau et al., 2025).
  • Following the concurrent work of Amsel et al. (2025), PolarGrad also includes the Polar Express method to compute the polar decomposition of a matrix, which uses polynomial approximations of the sign function to compute the polar decomposition, rather than rational approximations of the sign function used in the QDWH and ZOLO-PD algorithms, hence avoiding the use of QR decompositions and involving only matrix-matrix products (in half-precision arithmetic). Its implementation is directly taken from the paper. Yet, it has not been heavily tested in experiments yet.
  • While SSD also includes the nuclear norm scaling, PolarGrad uses more advanced numerical linear algebra algorithms for polar decomposition than the randomized SVD algorithm used in SSD, namely the QDWH and ZOLO-PD algorithms.

Overview

This repository provides implementations of PolarGrad in PyTorch utilizing two more advanced numerical linear algebra algorithms for polar decomposition than the Newton-Schulz (NS) iteration:

  1. The QWDH algorithm (Nakatsukasa et al., 2010; see here and here for implementation in JAX)
  2. The ZOLO-PD algorithm (Nakatsukasa and Freund, 2016; see here for the authors' MATLAB implementation)

These two algorithms, unlike the NS iteration, do not require tuning of the coefficients of the matrix iterative polynomial, and they are more numerically stable (Nakatsukasa and Higham, 2012; Nakatsukasa and Freund, 2016). Hence, they are more suitable for matrix parameters of different sizes and potentially ill-conditioned initializations, making them a better candidate and optimizers based on polar decomposition like PolarGrad and Muon (Jordan et al., 2024) a drop-in replacement of other adaptive gradient optimizers such as Adam(W). Currently, the QWDH algorithm is particularly more efficient for large matrices, while ZOLO-PD is designed for small to medium-sized matrices. Note that both of these algorithms involve QR decompositions, which might not be efficient for GPUs and half-precision arithmetic. To addresss such issue, we also include the Polar Express method in Amsel et al. (2025) to compute the polar decomposition of a matrix, which uses polynomial approximations of the sign function to compute the polar decomposition, rather than rational approximations of the sign function used in the QDWH and ZOLO-PD algorithms, hence avoiding the use of QR decompositions and involving only matrix-matrix products (in half-precision arithmetic).

In particular, with the assist of ChatGPT, we translated these implementations in JAX and MATLAB to PyTorch. Currently, limited by the QR decomposition implementation in PyTorch, mixed precisions such as bfloat16 are not yet supported. Notice that the current implementation is not optimized for speed and parallelization, although we have also provided a DDP implementation polar_grad_ddp.py, following the implementation of Muon. The three main files are:

  1. polar.py: includes the function polar which mimics the JAX jax.scipy.linalg.polar function, which computes the polar decomposition of a matrix using four possible numerical algorithms.

    i. method=qdwh: uses the QDWH algorithm (Nakatsukasa et al., 2010) to compute the polar decomposition of a matrix. This is suitable for large matrices and is more numerically stable than the Newton-Schulz iteration.

    ii. method=zolo-pd: uses the ZOLO-PD algorithm (Nakatsukasa and Freund, 2016) to compute the polar decomposition of a matrix. This is suitable for small to medium-sized matrices and is also more numerically stable than the Newton-Schulz iteration.

    iii. method=ns: uses the Newton-Schulz (NS) iteration to compute the polar decomposition of a matrix. This might require tuning of the coefficients of the matrix iterative polynomial for different model and layer sizes, which can be challenging in practice. This is the same method used in the Muon optimizer (Jordan et al., 2024), and is adopted from its GitHub repository.

    iv. method=precond_ns: uses the preconditioned Newton-Schulz iteration in Lewis et al. (2022) to compute the polar decomposition of a matrix. This is potentially an improved variant of the NS iteration with the need of coefficient tuning, but might still suffer from the stability issue of the NS iteration. We include this method for completeness, but is not heavily tested and not used in the experiments in the paper.

    v. method=polar_express: uses the Polar Express method in Amsel et al. (2025) to compute the polar decomposition of a matrix.

  2. polar_grad.py: includes the torch.optim.Optimizer class PolarGrad which implements the PolarGrad optimizer based on the above four numerical polar decomposition algorithms of the gradient matrix.

    • The argument polar_first specifies whether polar-first momentum is used; default is False which is similar to the implementation of Muon (Jordan et al., 2024).
    • The argument method specifies which polar decomposition algorithm to use, and can be one of the following: qdwh (cf. qdwh.py adopted from its JAX implementation jax.lax.linalg.qdwh), zolo-pd (cf. zolopd.py adopted from its MATLAB implementation), ns or precond_ns (cf. newton_schulz.py adopted from Muon's GitHub repository). The default is qdwh, which is suitable for large matrices.
    • The argument inner_steps specifies the number of (inner) steps for either the QDWH algorithm or the NS iteration. The other two algorithms (ZOLO-PD and preconditioned NS) do not require this argument. The default is 2.
    • The arguments a, b and c specify the coefficients of the matrix iterative polynomial for the NS iteration, which are used only when method='ns'. The default values are the same as those in Muon, which are suitable for most cases for hidden layers. However, they can be tuned for different model and layer sizes if necessary.

    The optimizer can be used as follows:

    optimizer = PolarGrad(model.parameters(), lr=1e-3, weight_decay=0., momentum=0.9, polar_first=False, method='qdwh', inner_steps=2)
  3. polar_grad_ddp.py: includes the torch.optim.Optimizer class PolarGrad which implements the PolarGrad optimizer based on the above four numerical polar decomposition algorithms of the gradient matrix with torch.distributed, following the implementation in Muon's GitHub repository.

Installation of Required Libraries

Install PyTorch (nightly) accodring to the instructions at https://pytorch.org/get-started/locally/, e.g., for Linux and CUDA 12.6:

pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126

Then, install some auxiliary libraries:

pip install -U numpy matplotlib tqdm fire SciencePlots

For correct LaTeX rendering in matplotlib, you might also need to have a LaTeX distribution installed, such as TeX Live (MacTeX) or MikTeX, or disable the LaTeX rendering in matplotlib by setting rcParams['text.usetex'] = False and changing some of the plot labels in the code.

Usage

For small-scale experiments which can be run with CPU, you can run the following commands to test the PolarGrad optimizer on different matrix optimization problems. The --seed argument is used to set the random seed for reproducibility.

  • Matrix quadratic regression (a strongly convex problem with deterministic gradient):

    # PolarGrad
    python mat_quad_reg.py --steps=4000 --seed=42
    
    # PolarGradM
    python mat_quad_reg_mom.py --steps=4000 --seed=42
    
  • Matrix logistic regression (a strongly convex problem with stochastic gradient):

    # PolarSGD
    python mat_log_reg.py --steps=1500 --seed=42
    
    # PolarSGDM
    python mat_log_reg_mom.py --steps=1500 --seed=42
    
  • Low-rank matrix completion (a non-convex problem with deterministic gradient):

    # PolarGrad
    python low_rank_mat_comp.py --steps=1000 --seed=42
    
    # PolarGradM
    python low_rank_mat_comp_mom.py --steps=200 --seed=42
    

We will update the repository with examples and experiments for language model pre-training soon.

Citation

If you find this repository useful for your research, please consider citing our paper using the BibTeX entry below:

@article{lau2025polargrad,
  title={\textsc{PolarGrad}: A Class of Matrix-Gradient Optimizers from a Unifying Preconditioning Perspective},
  author={Lau, Tim Tsz-Kit and Qi Long and Weijie Su},
  year={2025},
  journal={arXiv preprint arXiv:2505.21799}
}

References

About

PolarGrad: A Class of Matrix-Gradient Optimizers from a Unifying Preconditioning Perspective

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages