This is a clean implementation of CNN-LSX, a framework for training neural networks with high-quality saliency explanations.
CNN-LSX implements a novel approach for training neural networks that not only make accurate predictions but also produce meaningful explanations for those predictions. It uses a learner-critic framework:
- Learner: A neural network that makes predictions and generates explanations
- Critic: A separate network that evaluates the quality of those explanations
- Multiple explanation methods (input × gradient, integrated gradients, GradCAM)
- Flexible training pipeline with multiple stages (pretraining, joint training, finetuning)
- Support for different datasets (MNIST, DecoyMNIST, ColorMNIST)
- Comprehensive evaluation metrics for explanation quality
git clone https://github.com/yourusername/cnn-lsx.git
cd cnn-lsx
pip install -r requirements.txt# Train on MNIST with default parameters
python main.py --dataset mnist --training_mode pretrain_and_joint_and_finetuning
# Train on a small portion of the data
python main.py --dataset mnist --few_shot_train_percent 0.02 --training_mode pretrain_and_joint_and_finetuning- Classification only:
--training_mode only_classification - Joint training only:
--training_mode joint - Pretrain then joint train:
--training_mode pretrain_and_joint - Complete pipeline:
--training_mode pretrain_and_joint_and_finetuning - Finetuning only:
--training_mode finetuning --model_pt path/to/model.pt - Testing a model:
--training_mode test --model_pt path/to/model.pt
Control the balance between classification accuracy and explanation quality:
python main.py --classification_loss_weight 1 --explanation_loss_weight 100 --explanation_loss_weight_finetune 100- Pretraining Phase: The learner is trained for classification only.
- Joint Training Phase: The learner is trained to optimize both classification performance and explanation quality, as judged by the critic.
- Finetuning Phase: The model is refined to further improve explanation quality.
Training progress and model evaluations are logged to TensorBoard:
tensorboard --logdir runsFor more information about CNN-LSX, including the theoretical background and implementation details, please refer to the documentation in the docs/ directory.