Giacomo Cignoni1, Andrea Cossu1, Alex Gómez Villa2, Joost van de Weijer2, Antonio Carta1
1University of Pisa, 2Computer Vision Center (CVC)
This repository contains the official implementation of Continual Latent Alignment (CLA), a novel method for online continual self-supervised learning (OCSSL).
Self-supervised learning (SSL) is able to build latent representations that generalize well to unseen data. However, only a few SSL techniques exist for the online CL setting, where data arrives in small minibatches, the model must comply with a fixed computational budget, and task boundaries are absent. We introduce Continual Latent Alignment (CLA), a novel SSL strategy for Online CL that aligns the representations learned by the current model with past representations to mitigate forgetting. We found that our CLA is able to speed up the convergence of the training process in the online scenario, outperforming state-of-the-art approaches under the same computational budget. Surprisingly, we also discovered that using CLA as a pretraining protocol in the early stages of pretraining leads to a better final performance when compared to a full i.i.d. pretraining.
Figure 1: CLA-E and CLA-R architectures.
CLA includes two variants: CLA-E and CLA-R. Both methods include replay from a FIFO buffer and introduce a regularization term to the SSL loss.
CLA-E aligns current representations of replayed samples, which are passed through an aligner network (
Figure 1: Comparing Average Accuracy and Final Accuracy of CLA-E and CLA-R.
The main experiments were conducted with an Online Continual Self-Supervised Learning (OCSSL) stream, using class-incremental Split CIFAR-100 and Split ImageNet100 datasets, with performance measured through linear probing. CLA-E and CLA-R consistently achieved state-of-the-art accuracy in both Final Accuracy (probing the end of the stream) and Average Accuracy (average accuracy of probing at the end of each task), notably outperforming existing methods and i.i.d. training baselines across various computational settings
The codebase is organized in a modular way, allowing to easily extend it with new models, strategies, encoders and buffers. The main components are:
The Trainer module runs the training loop for each experience with the method train_experience(). It requires a model, strategy and the experience dataset to train. You can train for multiple epochs per experience or multiple training passes for each minibatch in the experience. It returns the SSL model trained on that experience.
Represent the Self Supervised Model that learns representations. Currently, SimSiam, Byol, Barlow Twins, SimCLR, and Masked Autoencoders are implemented.
They are all subclasses of AbstractSSLModel. It can be specified with the command --model.
Some models may use different backbone encoders (e.g. different ResNets). Those can be specified with the command --encoder.
Currently implemented:
- ResNet-18
- ResNet-9
- Wide ResNet-18 (2x the block features)
- Wide ResNet-9 (2x the block features)
- Slim ResNet-18 (~1/3 the block features)
- Slim ResNet-9 (~1/3 the block features)
- ViT
The strategy handles how to regularize the model to counter forgetting across experiences. All strategies are implemented as subclasses of AbstractStrategy.
Currently implemented:
- No Strategy, i.e. simple finetuning.
- ER, Experience Replay from buffer.
- LUMP, interpolate buffer with stream samples (https://arxiv.org/abs/2110.06976).
- MinRed training only on buffer samples, eliminates most correlated samples from buffer (to be paired with MinRed buffer) (https://arxiv.org/abs/2203.12710).
- SCALE, contrastive self-supervised method built for online CL (https://arxiv.org/abs/2208.11266). Osiris-R, employs a contrastive cross-task loss (https://arxiv.org/abs/2404.19132).
- CaSSLe, distillation of representations with frozen past network (https://arxiv.org/abs/2112.04215).
- CaSSLe-r, variant of CaSSLe that uses a replay buffer to distill representations.
- CLA-b, alignment of representations from EMA updated network.
- CLA-R, alignment of representations from old buffer stored representations updated network.
- CLA-E, alignment of representations of replayed samples with EMA updated network.
Note: Osiris-R and SCALE strategies are "stand-alone", in the sense that they do not require a SSL model (--model option).
Buffers store past experience samples to be replayed by strategies. Currently implemented:
- FIFO buffer: Stores a fixed number of samples in a FIFO queue.
- Reservoir buffer: Equal sampling probability for all samples in the continual stream.
- MinRed buffer: Removes samples with most correlated features (https://arxiv.org/abs/2203.12710).
Evaluation on the extracted representation is done via probing. Three strategies are implemented at the moment:
- Ridge Regression (Scikit-learn)
- KNN (Scikit-learn)
- linear probe (Pytorch)
A single experiment can be run with a simple command: python main.py --option option_arg. The full list of allowable commands is in the file ./src/utils.py.
For running multiple experiments or running an hyperparameter search use python run_from_config.py. It needs a config.json file that specifies the configuration of each experiment, similar as files in ./configs folder included in this repo.
It runs a list of experiments each with its own set of arguments; it is possible to specify common arguments for all experiments (that can be eventually overridden by each experiment).
For each experiment desired to be run as an hyperparameter search, you need to specify inside the experiment the additional parameter hyperparams_search, which is supposed to be a dict of lists of the hyperparameters to try in the experiment.