Graph transformer-based traffic prediction using Waze data.
This project implements a graph transformer architecture for traffic forecasting using Waze data. The model represents road topology and traffic patterns as dynamic graphs, where nodes represent junctions and endpoints, and edges capture traffic flow characteristics.
- Graph transformer architecture for traffic prediction
- Support for both subgraph and full graph training
- Sparse tensor implementation for memory efficiency
- Temporal snapshot creation for time-series analysis
- Multi-GPU training via HuggingFace Accelerate
- Experiment tracking with Weights & Biases
# Create and activate conda environment
conda env create -f environment.yaml
conda activate waze-traffic# Install in development mode
pip install -e .The model configuration is controlled through a YAML file (config.yaml). Important settings include:
data:
# Training on full graph vs subgraph
full_graph: false # Set to true for full graph training
subgraph_nodes: 5000 # Max nodes when using subgraph
batch_size: 1000 # Mini-batch size for full graph trainingTo train the model using the default sparse subgraph approach:
python scripts/train_model.py --config config.yamlTo train on the full graph using mini-batch training:
# Edit config.yaml to set data.full_graph: true
python scripts/train_model.py --config config.yaml# Basic command with optimization
python scripts/train_model.py --config config.yaml --optimize# Maximum performance setup
python scripts/train_model.py --config config.yaml --optimize --mixed_precision fp16 --batch_multiplier 8 --cache_data --num_workers 16For multi-GPU training:
accelerate launch scripts/train_model.py --config config.yaml# Track experiments with Weights & Biases
python scripts/train_model.py --config config.yaml --wandb_project "waze-traffic"
# Resume from checkpoint
python scripts/train_model.py --config config.yaml --resume_from checkpoints/best_model.pt
# Use specific GPUs
CUDA_VISIBLE_DEVICES=0,1 accelerate launch scripts/train_model.py --config config.yamlThe implementation is based on the STGformer architecture, which combines:
- Graph Propagation Layer: Models spatial dependencies through message passing
- Spatiotemporal Attention: Captures both local and global dependencies
- Temporal Positional Encoding: Preserves temporal ordering
The model can work with two training approaches:
-
Sparse Subgraph (Default):
- Samples a connected subgraph of important nodes
- Uses sparse tensors for memory efficiency
- Suitable for limited hardware resources
-
Full Graph with Mini-Batches:
- Processes the entire graph through node neighborhood sampling
- Trains on all nodes and edges
- Requires more computational resources
waze-traffic-forecast/
├── waze_traffic_forecast/ # Main package
│ ├── data/ # Data processing modules
│ │ ├── graph_builder.py # Graph construction
│ │ ├── preprocessor.py # Data preprocessing
│ │ └── inspector.py # Schema inspection
│ ├── models/ # Model implementations
│ │ ├── layers.py # Model layers
│ │ └── stgformer.py # STGformer implementation
│ ├── dataset.py # Dataset implementation
│ └── _config.py # Configuration handling
├── scripts/ # Executable scripts
│ ├── train_model.py # Training script
│ ├── build_waze_graph.py # Graph building script
│ └── inspect_waze_schema.py # Schema inspection script
├── config.yaml # Configuration file
├── environment.yaml # Conda environment file
├── setup.py # Package installation
└── README.md # This file
config.yaml: Default YAML configuration file for data, model, and training settings.environment.yaml: Conda environment specification for dependencies and setup.run_slurm.slurm: SLURM batch submission script for running jobs on HPC clusters.output/: Directory where model outputs, logs, and copied configurations are saved.scripts/: Collection of utility scripts for data processing, schema inspection, and training.tests/: Unit tests for validating data processing and model components.waze_traffic_forecast/: Core Python package containing modules for configuration, data handling, and model implementation.waze_traffic_forecast.egg-info/: Package metadata generated by setup.py (not for direct modification).
To validate the functionality of data processing and model components, run the test suite:
pytest@inproceedings{waze-traffic-forecast,
title={Graph Transformers for Traffic Forecasting},
author={Potluri, Sravanth and Jerge, Michael M. and Sahay, Shreejeet},
year={2025},
organization={University of Virginia}
}
This project is licensed under the terms of the MIT license.