Leveraging 2D Information for Long-term Time Series Forecasting with Vanilla Transformers
A simple yet strong Long-term Time Series prediction model.
| vanilla Transformer | Multivariate Modeling | Sequntial Modeling | |
|---|---|---|---|
| DLinear (AAAI2023) | ❌ | ❌ | ❌ |
| CrossFormer (ICLR2023) | ❌ | ✔️ | ✔️ |
| PatchTST (ICLR2023) | ✔️ | ❌ | ✔️ |
| iTransformer (ICLR 2024) | ✔️ | ✔️ | ❌ |
| GridTST | ✔️ | ✔️ | ✔️ |
| Model | GridTST | PatchTST (ICLR 2023) | iTransformer (ICLR 2024) | Dlinear (AAAI 2023) |
|---|---|---|---|---|
| Weather | 0.223 | 0.228 | 0.236 | 0.246 |
| Traffic | 0.372 | 0.396 | 0.386 | 0.433 |
| Electricity | 0.152 | 0.163 | 0.165 | 0.166 |
| Illness | 1.649 | 1.806 | 2.122 | 2.169 |
| Etth1 | 0.416 | 0.421 | 0.450 | 0.422 |
| Ettm1 | 0.345 | 0.351 | 0.365 | 0.357 |
| Solar | 0.187 | 0.215 | 0.215 | 0.244 |
We recommand to use Conda to mange a virtual environment:
conda create -n gridtst python=3.8 && conda activate gridtst
pip install -r requirements.txtlogging and multi-gpu training setup:
wandb login
accelerate configThis is the dataset we use, you could download here and put all csv files in the dataset folder.
| Datast | # Channels | # TimeSteps | Prediction Length | Information |
|---|---|---|---|---|
| Weather | 21 | 52696 | {96,192,336,720} | Weather |
| Traffic | 862 | 17544 | {96,192,336,720} | Transportation |
| Electricity | 321 | 26304 | {96,192,336,720} | Electricity |
| Illness | 7 | 966 | {12,24,48,60} | Illness |
| Etth1 | 7 | 17420 | {96,192,336,720} | Electricity |
| Ettm1 | 7 | 69680 | {96,192,336,720} | Electricity |
| Solar | 137 | 52560 | {96,192,336,720} | Energy |
We provide all the scripts on the scripts folder.
For example, training on the Weather dataset with lookback window = 336:
bash scripts/lookback_window_336/weather.shWe provide our trained model on the huggingface space
To evaluate these models, you could either specify a perticular model or evaluate them all at once.
For a certain model, for example GridTST on traffic dataset with lookback window=336 and prediction length=96:
python benchmark.py --data_file dataset/traffic.csv --seq_len 336 --label_len 96To evaluate them all:
python benchmark.py --all