Skip to content

Zeng-WH/User-Intent

Repository files navigation

Multi-label Intent Classification

此代码包括multi-label意图识别分类的代码。BackBone是GPT-2.

预训练模型来自:https://huggingface.co/uer/gpt2-chinese-cluecorpussmall

训练模型:

bash train_intent_gpt_fgm.sh

需要修改数据路径等,在OOM的情况下需要灵活修改batch size以及learning rate,极端情况下可能需要修改train epoch.

目前主要调节的超参为对抗的扰动程度,即为adver_eplsilon

测试模型:

bash test_intent_gpt.sh

注:

  1. 训练模型的过程中使用训练数据训练,验证集数据搜索最佳概率阈值。这个阈值得去log里看。

然后使用得到的阈值在测试集上测结果(切记!)

  1. 在训练模型过程中,在run_intent_gpt_search_fgm.py中使用的是compute_metrics_search函数,而在测试的过程中使用的应是compute_metrics函数,同时修改该函数中temp_gate为验证集上的最佳概率阈值。(将在后续版本中优化该设置)

函数切换代码在:

    trainer = Intent_Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=eval_dataset if training_args.do_eval else None,
        compute_metrics=compute_metrics_search,
        tokenizer=tokenizer,
        data_collator=data_collator,
    )

About

Multi-Label User Intent Classification

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published