Skip to content
/ mt2 Public

Repository for paper "Multi-class-token transformer for multitask self-supervised music information retrieval"

License

Notifications You must be signed in to change notification settings

deezer/mt2

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MT2: Multi-class-token transformer for multitask self-supervised music information retrieval

This is the repository for paper Multi-class-token transformer for multitask self-supervised music information retrieval. The model is also named as MT^2. It is a music representation model which extracts both global and sequential representations from a piece of music, with an original architecture called multi-class-token transformer, trained with multiple self-supervised losses. This paper is accepted at WASPAA 2025.

Features

  • 🧠 Full training code with a toy dataloader (you need to plug in your own dataloader).
  • ⚙️ Code to extract representations using a pre-trained model (same model used in the paper), for further downstream training and evaluation. Downstream training and evaluation code are not provided.

Installation

git clone https://github.com/deezer/mt2.git
cd mt2
docker build . -t <image-name>
docker run -ti --gpus all -v ./:/workspace/ --entrypoint bash <image-name>
# poetry run python <script> # or
# poetry run ipython

Usage

🐍 Extraction of Features

To extract features for an audio input:

import torch
from src.inference import extract_feature, load_model

input = torch.randn(16, 64000) # input shape. Audio data with sr=16000Hz, duration=4s. Batch size in the first dimension. Audio also needs to be normalized between 0 and 1.
model = load_model(ckpt_path="model_state_dict.pt", device="cpu")
features = extract_feature(input, model)

Parameters:

  • ckpt_path (str or None, optional): Path to the model checkpoint file. If None, the default model is used.
  • device (str or None, optional): Device to run on (default: cpu)

Input:

  • input (Tensor): The input tensor should have shape (batch_size, sr*duration, 1). sr = 16000 and duration =4 for default checkpoint model.

Return:

  • features (dict): a dictionary with keys cls_token_equiv (equivariant class token), cls_token_contrastive (contrastive class token) and seq_tokens (sequence tokens).

You can use this snippet of code to further train MT2 for downstream tasks by attaching a head. We recommend using:

  • the average of both class tokens for global tasks; sequence tokens for local tasks, such as beat tracking and chord estimation.
  • or, equivariant class token for tonality-related tasks, e.g., key estimation, pitch estimation; contrastive class token for other global tasks, e.g., instrument recognition, music tagging.

🔧 Training Code

For training, you need to install tensorflow (the package is only called for the Progbar, I will remove the need of this dependency in near future), using the following code:

poetry add tensorflow

Then, for training with a toy dataloader:

poetry run python -m src.main config/mt2.gin .

This will train a MT2 model from scratch using a toy dataloader (randomly generated torch.Tensor), default arguments and default gin file (same configuration used in the paper). The checkpoints will be saved at current directory.

To specify additional options, use the following arguments:

poetry run python -m src.main config/mt2.gin /save/dir -d cuda:0 -e 999 -ts 512 -vs 6e -lr 3e-4 -name basic
  • /save/dir: directory to save checkpoints and logs.
  • -d: device used for training.
  • -e: number of epochs for training.
  • -ts: number of training steps in each epoch.
  • -vs: number of validation steps in each epoch.
  • -lr: start value of learning rate.
  • -name: name of the experiment. Used in creating the path for saving.

To retrain from scratch using your own data, you will need to plug in your own dataloader which loads data in the same shape as provided in mt2/src/dataloader.py. The shape of each data batch should be (batch_size, sr*duration, 3) where 3 corresponds to the number of segments extracted from the same audio piece. The dataloader should be able to load data in an infinite loop. Also, you can uncomment and modify the existing wandb code to plug in your own tool for logs.

🗂️ Code organization

mt2
├── config
│   └── mt2.gin
├── Dockerfile
├── LICENSE
├── poetry.lock
├── pyproject.toml
├── README.md
└── src
    ├── dataloader.py
    ├── inference.py
    ├── loss.py
    ├── main.py
    ├── model.py
    ├── training.py
    └── utils
        ├── callbacks.py
        ├── scheduler.py
        ├── training.py
        └── vit.py

Acknowledgement

Special thanks to Kamil Akesbi for great help on refactoring the code.

📚 Reference

If you use this work in your research, please cite:

@inproceedings{kong2025multi,
  title={Multi-Class-Token Transformer for Multitask Self-supervised Music Information Retrieval},
  author={Kong, Yuexuan and Lostanlen, Vincent and Hennequin, Romain and Lagrange, Mathieu and Meseguer-Brocal, Gabriel},
  booktitle={IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (WASPAA 2025)},
  year={2025}
}

📄 License

The code of MT2 is MIT-licensed.

About

Repository for paper "Multi-class-token transformer for multitask self-supervised music information retrieval"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published