Skip to content

gingasan/adversarialSA

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Adversarial Self-Attention

This repo is for the AAAI 2023 paper Adversarial Self-Attention for Language Understanding.

"it has all the ==excitement== of ==eating oatmeal== ." (label 0)

sa
Self-Attention (predict 1)
asa
Adversarial Self-Attention (predict 0)

The vanilla SA model is misled by too much attention on "excitement".

Dependency

  • torch 1.9

  • transformer 4.17

Quick start

Sentence classification

python run_sent_clas.py \
    --do_train \
    --do_eval \
    --task_name SST-2 \
    --learning_rate 2e-5 \
    --train_batch_size 32 \
    --do_lower_case \
    --model_type bert \
    --load_model_path bert-base-uncased \
    --output_dir sst_bert \
    --fp16
Epoch 0: global step = 2105 | train loss = 0.250 | eval score = 92.55 | eval loss = 0.211
Epoch 1: global step = 4210 | train loss = 0.114 | eval score = 93.00 | eval loss = 0.202
Epoch 2: global step = 6315 | train loss = 0.073 | eval score = 93.46 | eval loss = 0.223

Testing:

python run_sent_clas.py \
    --do_test \
    --task_name SST-2 \
    --learning_rate 2e-5 \
    --eval_batch_size 128 \
    --do_lower_case \
    --model_type bert \
    --load_model_path bert-base-uncased \
    --output_dir sst_bert \
    --test_model_file sst_bert/2_pytorch_model.bin

Multiple choices

python run_multi_cho.py \
    --do_train \
    --do_eval \
    --task_name DREAM \
    --eval_on test \
    --num_train_epochs 6 \
    --learning_rate 2e-5 \
    --train_batch_size 16 \
    --model_type roberta \
    --load_model_path roberta-base \
    --output_dir dream_roberta \
    --fp16
Epoch 0: global step = 242 | loss = 1.066 | eval score = 56.41 | eval loss = 0.908
Epoch 1: global step = 484 | loss = 0.825 | eval score = 67.13 | eval loss = 0.749
Epoch 2: global step = 726 | loss = 0.540 | eval score = 68.76 | eval loss = 0.731
Epoch 3: global step = 968 | loss = 0.329 | eval score = 69.54 | eval loss = 0.867
Epoch 4: global step = 1210 | loss = 0.221 | eval score = 69.70 | eval loss = 0.966
Epoch 5: global step = 1452 | loss = 0.167 | eval score = 69.23 | eval loss = 1.037

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages