TabNet 是 Google 提出的一种专门用于表格数据(类似 Excel 表格的数据)的深度学习模型。它的特点是:像决策树一样可以告诉你「哪些特征对预测最重要」,同时又拥有深度学习的强大能力。
📄 原论文:TabNet: Attentive Interpretable Tabular Learning (Arik & Pfister, 2019)
| 方法 | 优点 | 缺点 |
|---|---|---|
| XGBoost / LightGBM | 表格数据效果好、可解释 | 不支持端到端学习、无法利用未标注数据 |
| 传统神经网络 (MLP) | 端到端学习、可迁移 | 表格数据效果一般、不可解释 |
| TabNet ✨ | 兼具两者优点!可解释 + 效果好 + 支持预训练 | 训练稍慢 |
- 可解释性:能输出每个特征的重要性(类似决策树)
- 注意力机制:自动选择重要特征,每一步只关注少数特征
- 半监督学习:可以利用没有标签的数据进行预训练
- 端到端训练:不需要手动特征工程
只需几行代码就能训练一个 TabNet 模型:
from pytorch_tabnet.tab_model import TabNetClassifier, TabNetRegressor
# 1. 创建模型(分类任务)
clf = TabNetClassifier()
# 2. 训练模型
clf.fit(
X_train, y_train, # 训练数据
eval_set=[(X_valid, y_valid)] # 验证数据(用于早停)
)
# 3. 预测
predictions = clf.predict(X_test)pip install pytorch-tabnetconda install -c conda-forge pytorch-tabnetTabNet 提供了三种模型,对应不同的任务:
| 模型 | 适用场景 | 示例 |
|---|---|---|
| TabNetClassifier | 二分类、多分类 | 预测用户是否流失、图片分类 |
| TabNetRegressor | 回归、多目标回归 | 预测房价、股票价格 |
| TabNetMultiTaskClassifier | 多任务多分类 | 同时预测用户年龄段和购买意向 |
💡 这部分帮助你理解 TabNet 的工作原理,不需要记住所有细节
TabNet 最核心的思想是「逐步决策」:
第1步: 关注「年龄」和「收入」→ 做出部分决策
第2步: 关注「学历」和「职业」→ 完善决策
第3步: 关注「地区」→ 最终决策
每一步只看少数几个特征,这样既高效又可解释。
- Softmax:给所有特征都分配权重(虽然有大有小)
- Sparsemax:直接把不重要的特征权重设为 0(更稀疏、更可解释)
TabNet 默认使用 Sparsemax,所以它的特征选择更「干净」。
一种特殊的批归一化方法,能让模型在大 batch size 下也能稳定训练。
创建模型时可以调整的参数:
| 参数 | 默认值 | 说明 | 调参建议 |
|---|---|---|---|
n_d |
8 | 决策层宽度 | 8-64,越大容量越大,但容易过拟合 |
n_a |
8 | 注意力层宽度 | 通常设成和 n_d 一样 |
n_steps |
3 | 决策步数 | 3-10,越大模型越复杂 |
gamma |
1.3 | 特征重用系数 | 1.0-2.0,越接近 1 每步选的特征越不同 |
| 参数 | 默认值 | 说明 |
|---|---|---|
cat_idxs |
[] | 类别特征的列索引(比如 [0, 3, 5]) |
cat_dims |
[] | 每个类别特征的类别数 |
cat_emb_dim |
1 | 类别特征的嵌入维度 |
lambda_sparse |
1e-3 | 稀疏性正则化系数,越大越稀疏 |
mask_type |
"sparsemax" | 可选 "sparsemax" 或 "entmax" |
| 参数 | 默认值 | 说明 |
|---|---|---|
device_name |
"auto" | 设备:"auto"、"cpu" 或 "cuda" |
seed |
0 | 随机种子,设置相同值可复现结果 |
verbose |
1 | 是否打印训练过程 |
调用 fit() 方法时可以调整的参数:
| 参数 | 默认值 | 说明 |
|---|---|---|
max_epochs |
200 | 最大训练轮数 |
patience |
10 | 早停耐心值(连续多少轮不提升就停止) |
batch_size |
1024 | 批大小,TabNet 推荐用大 batch |
eval_set |
None | 验证集,格式为 [(X_val, y_val)] |
eval_metric |
自动 | 评估指标列表 |
分类任务:
'auc':ROC-AUC(二分类默认)'accuracy':准确率(多分类默认)'balanced_accuracy':平衡准确率'logloss':对数损失
回归任务:
'mse':均方误差(默认)'mae':平均绝对误差'rmse':均方根误差'rmsle':对数均方根误差
| 参数 | 默认值 | 说明 |
|---|---|---|
virtual_batch_size |
128 | Ghost Batch Normalization 的虚拟 batch 大小 |
num_workers |
0 | 数据加载的并行进程数 |
weights |
0 | 类别权重(0=不加权,1=自动平衡) |
warm_start |
False | 是否从上次训练继续 |
如果你有大量没有标签的数据,可以先进行预训练:
from pytorch_tabnet.pretraining import TabNetPretrainer
# 1. 无监督预训练
unsupervised_model = TabNetPretrainer()
unsupervised_model.fit(
X_train=X_unlabeled, # 无标签数据
pretraining_ratio=0.8, # 随机遮盖 80% 的特征进行重建
)
# 2. 有监督微调
clf = TabNetClassifier()
clf.fit(
X_train, y_train,
from_unsupervised=unsupervised_model # 使用预训练权重
)# 保存模型
clf.save_model("./my_tabnet_model")
# 加载模型
loaded_clf = TabNetClassifier()
loaded_clf.load_model("./my_tabnet_model.zip")from pytorch_tabnet.metrics import Metric
from sklearn.metrics import roc_auc_score
class Gini(Metric):
def __init__(self):
self._name = "gini"
self._maximize = True # 越大越好
def __call__(self, y_true, y_score):
auc = roc_auc_score(y_true, y_score[:, 1])
return max(2 * auc - 1, 0.)
# 使用自定义指标
clf.fit(X_train, y_train, eval_metric=[Gini])# 训练后可以直接获取
feature_importance = clf.feature_importances_
# 或者获取更详细的解释
explain_matrix, masks = clf.explain(X_test)tabnet/
├── 📂 pytorch_tabnet/ # 核心代码目录
│ ├── abstract_model.py # 抽象基类
│ ├── augmentations.py # 数据增强
│ ├── callbacks.py # 回调函数
│ ├── metrics.py # 评估指标
│ ├── multiclass_utils.py # 多分类工具
│ ├── multitask.py # 多任务学习
│ ├── pretraining.py # 预训练模块
│ ├── pretraining_utils.py # 预训练工具
│ ├── sparsemax.py # 稀疏激活函数
│ ├── tab_model.py # 主要模型类
│ ├── tab_network.py # 神经网络架构
│ └── utils.py # 通用工具
│
├── 📓 示例 Notebooks
│ ├── census_example.ipynb # 二分类示例(人口普查数据)
│ ├── forest_example.ipynb # 多分类示例(森林覆盖)
│ ├── regression_example.ipynb # 回归示例
│ ├── multi_regression_example.ipynb # 多目标回归示例
│ ├── multi_task_example.ipynb # 多任务示例
│ ├── pretraining_example.ipynb # 预训练示例
│ └── customizing_example.ipynb # 自定义示例
│
├── 📂 tests/ # 测试代码
├── 📂 docs/ # 文档
└── pyproject.toml # 项目配置
下面对每个文件进行详细解释,帮助你理解整个代码库的架构。
这是你使用 TabNet 时最直接接触的文件。
包含两个核心类:
| 类名 | 用途 | 默认损失函数 | 默认评估指标 |
|---|---|---|---|
TabNetClassifier |
分类任务(二分类/多分类) | CrossEntropy | AUC(二分类)/ Accuracy(多分类) |
TabNetRegressor |
回归任务(单目标/多目标) | MSE | MSE |
主要方法:
clf = TabNetClassifier()
# 训练
clf.fit(X_train, y_train, eval_set=[(X_val, y_val)])
# 预测类别
predictions = clf.predict(X_test)
# 预测概率(仅分类器)
probabilities = clf.predict_proba(X_test)内部工作流程:
fit() 被调用
↓
检查输入数据格式 → 推断输出维度 → 创建数据加载器
↓
初始化网络结构(调用 tab_network.py)
↓
开始训练循环:
每个 epoch:
- 训练阶段:前向传播 → 计算损失 → 反向传播
- 验证阶段:计算评估指标
- 检查早停条件
↓
训练结束,加载最佳权重
当你需要同时预测多个分类目标时使用。
例如:同时预测用户的「年龄段」和「购买意向」两个分类任务。
from pytorch_tabnet.multitask import TabNetMultiTaskClassifier
clf = TabNetMultiTaskClassifier()
# y_train 形状: (n_samples, n_tasks)
# 例如: [[25岁以下, 高意向], [25-35岁, 低意向], ...]
clf.fit(X_train, y_train)
# 返回一个列表,每个元素是一个任务的预测结果
predictions = clf.predict(X_test) # [task1_preds, task2_preds, ...]与普通分类器的区别:
- 输出层是多个并行的分类头,每个任务一个
- 损失函数是所有任务损失的平均
- 可以为每个任务指定不同的损失函数
利用无标签数据进行预训练,提升模型效果。
这是 TabNet 的一大亮点!当你有大量无标签数据时,可以先进行预训练。
预训练原理:
原始数据: [特征1, 特征2, 特征3, 特征4, 特征5]
↓ 随机遮盖部分特征
遮盖后: [特征1, ? , 特征3, ? , 特征5]
↓ 模型尝试重建
重建目标: [特征1, 特征2, 特征3, 特征4, 特征5]
模型通过「猜测被遮盖的特征」来学习数据的内在结构。
使用方法:
from pytorch_tabnet.pretraining import TabNetPretrainer
# 第一步:无监督预训练
pretrainer = TabNetPretrainer(
n_d=8, n_a=8,
mask_type='sparsemax'
)
pretrainer.fit(
X_train=X_unlabeled, # 大量无标签数据
pretraining_ratio=0.8, # 遮盖 80% 的特征
)
# 第二步:用预训练权重初始化分类器
clf = TabNetClassifier()
clf.fit(
X_train, y_train,
from_unsupervised=pretrainer # 关键!使用预训练权重
)何时使用预训练:
- 有大量无标签数据(比标签数据多 10 倍以上)
- 标签数据较少(< 10000 条)
- 特征之间有复杂的内在关系
所有模型(Classifier、Regressor、Pretrainer)的共同父类。
你不需要直接使用这个文件,但了解它有助于理解代码结构。
定义的核心功能:
| 方法 | 功能 |
|---|---|
fit() |
完整的训练循环 |
predict() |
模型预测 |
explain() |
获取特征重要性和注意力掩码 |
save_model() |
保存模型到 .zip 文件 |
load_model() |
从 .zip 文件加载模型 |
继承关系图:
TabModel (abstract_model.py)
│
┌───────────────┼───────────────┐
↓ ↓ ↓
TabNetClassifier TabNetRegressor TabNetPretrainer
(tab_model.py) (tab_model.py) (pretraining.py)
│
↓
TabNetMultiTaskClassifier
(multitask.py)
定义了 TabNet 的所有神经网络组件。
这是理解 TabNet 原理的关键文件!
整体架构图:
┌─────────────────────────────────────────────────────────────────────┐
│ TabNet 完整架构 │
├─────────────────────────────────────────────────────────────────────┤
│ │
│ 输入: [年龄, 收入, 学历, 城市, ...] │
│ ↓ │
│ ┌─────────────────────────┐ │
│ │ EmbeddingGenerator │ ← 把类别特征转成向量 │
│ │ (类别特征嵌入) │ 例: 城市"北京" → [0.2, 0.8, 0.1] │
│ └───────────┬─────────────┘ │
│ ↓ │
│ ┌─────────────────────────┐ │
│ │ Initial BN │ ← 初始批归一化 │
│ └───────────┬─────────────┘ │
│ ↓ │
│ ╔═══════════════════════════════════════════════════════════════╗ │
│ ║ TabNetEncoder ║ │
│ ║ ┌─────────────────────────────────────────────────────────┐ ║ │
│ ║ │ Step 1 │ ║ │
│ ║ │ ┌──────────────────┐ ┌──────────────────┐ │ ║ │
│ ║ │ │ AttentiveTransf. │ → │ FeatTransformer │ → d₁ │ ║ │
│ ║ │ │ (选择重要特征) │ │ (特征变换) │ │ ║ │
│ ║ │ └──────────────────┘ └──────────────────┘ │ ║ │
│ ║ └─────────────────────────────────────────────────────────┘ ║ │
│ ║ ┌─────────────────────────────────────────────────────────┐ ║ │
│ ║ │ Step 2 │ ║ │
│ ║ │ ┌──────────────────┐ ┌──────────────────┐ │ ║ │
│ ║ │ │ AttentiveTransf. │ → │ FeatTransformer │ → d₂ │ ║ │
│ ║ │ └──────────────────┘ └──────────────────┘ │ ║ │
│ ║ └─────────────────────────────────────────────────────────┘ ║ │
│ ║ ┌─────────────────────────────────────────────────────────┐ ║ │
│ ║ │ Step 3 │ ║ │
│ ║ │ ┌──────────────────┐ ┌──────────────────┐ │ ║ │
│ ║ │ │ AttentiveTransf. │ → │ FeatTransformer │ → d₃ │ ║ │
│ ║ │ └──────────────────┘ └──────────────────┘ │ ║ │
│ ║ └─────────────────────────────────────────────────────────┘ ║ │
│ ╚═══════════════════════════════════════════════════════════════╝ │
│ ↓ │
│ d₁ + d₂ + d₃ (聚合所有步骤的输出) │
│ ↓ │
│ ┌─────────────────────────┐ │
│ │ Final Mapping │ ← 输出层 │
│ └───────────┬─────────────┘ │
│ ↓ │
│ 输出: 预测结果 │
│ │
└─────────────────────────────────────────────────────────────────────┘
核心组件详解:
作用:将类别特征转换为稠密向量。
原始输入: [25, 50000, "本科", "北京"]
↓
数值特征直接保留: [25, 50000]
类别特征嵌入: "本科" → [0.3, 0.7]
"北京" → [0.2, 0.8, 0.1]
↓
合并后: [25, 50000, 0.3, 0.7, 0.2, 0.8, 0.1]
作用:决定每一步关注哪些特征。
# 简化版工作原理
class AttentiveTransformer:
def forward(self, priors, processed_feat):
# priors: 之前步骤已经使用过的特征(用于避免重复)
# processed_feat: 当前处理的特征
x = Linear(processed_feat) # 全连接层
x = BatchNorm(x) # 批归一化
x = x * priors # 结合先验信息
x = Sparsemax(x) # 稀疏化!大部分特征权重变成 0
return x # 返回注意力掩码关键点:使用 Sparsemax 让注意力「稀疏」,即每步只关注少数特征。
作用:对选中的特征进行非线性变换。
由多个 GLU_Block 组成:
- Shared GLU blocks:所有步骤共享的变换层(促进特征复用)
- Independent GLU blocks:每步独立的变换层(学习步骤特定的模式)
作用:TabNet 的基本计算单元。
输入 x
↓
Linear(x) → [h, g] # 分成两半
↓
h * sigmoid(g) # 门控机制
↓
输出
门控机制让网络能「选择性」地传递信息,类似 LSTM 中的门。
作用:让大 batch 训练更稳定。
普通 BN: 整个 batch 一起归一化
batch_size = 1024 → 统计量来自 1024 个样本
Ghost BN: 拆分成多个虚拟小 batch 分别归一化
batch_size = 1024, virtual_batch_size = 128
→ 拆成 8 个小 batch,各自归一化后拼接
为什么需要 GBN?
- TabNet 推荐使用大 batch(1024+)
- 但大 batch 的 BN 统计量可能不稳定
- GBN 通过虚拟小 batch 解决这个问题
作用:仅用于预训练,将编码器输出重建为原始特征。
编码器输出 [d₁, d₂, d₃]
↓
FeatTransformer
↓
Reconstruction Layer
↓
重建的原始特征(与被遮盖前对比计算损失)
作用:预训练时随机遮盖特征。
# pretraining_ratio = 0.8 表示遮盖 80%
原始: [特征1, 特征2, 特征3, 特征4, 特征5]
遮盖: [特征1, 0 , 0 , 0 , 特征5] # 随机遮盖了 2,3,4实现 Sparsemax 和 Entmax15 激活函数。
这是让 TabNet 具有「特征选择」能力的关键!
Softmax vs Sparsemax 对比:
输入: [2.0, 1.0, 0.1, 0.1, 0.1]
Softmax 输出: [0.48, 0.18, 0.11, 0.11, 0.11]
↑ 所有特征都有权重(虽然有大小之分)
Sparsemax 输出: [0.85, 0.15, 0.00, 0.00, 0.00]
↑ 不重要的特征直接为 0!更稀疏、更可解释
包含的类:
| 类名 | 描述 | 稀疏程度 |
|---|---|---|
Sparsemax |
原始 sparsemax | 高 |
Entmax15 |
α=1.5 的 entmax | 中等(介于 softmax 和 sparsemax 之间) |
使用哪个?
- 默认用
sparsemax(更稀疏,可解释性更好) - 如果效果不好,可以试试
entmax(更平滑)
在训练过程中的特定时机执行自定义逻辑。
内置回调:
| 回调类 | 功能 | 触发时机 |
|---|---|---|
History |
记录训练历史(loss、metrics) | 每个 batch/epoch 结束 |
EarlyStopping |
早停(验证集不再提升就停止) | 每个 epoch 结束 |
LRSchedulerCallback |
学习率调度 | 每个 batch/epoch 结束 |
早停工作原理:
epoch 1: val_loss = 0.5 ← 最佳
epoch 2: val_loss = 0.4 ← 新的最佳!重置计数器
epoch 3: val_loss = 0.45 ← 没提升,计数 = 1
epoch 4: val_loss = 0.43 ← 没提升,计数 = 2
...
epoch 12: val_loss = 0.42 ← 没提升,计数 = 10
→ patience=10 达到,停止训练,恢复 epoch 2 的权重
自定义回调示例:
from pytorch_tabnet.callbacks import Callback
class MyCallback(Callback):
def on_epoch_end(self, epoch, logs=None):
print(f"Epoch {epoch} 结束, loss = {logs['loss']}")定义各种评估指标,用于监控训练和验证。
内置指标:
| 指标类 | 名称 | 任务类型 | 优化方向 |
|---|---|---|---|
AUC |
'auc' | 二分类 | 越大越好 ↑ |
Accuracy |
'accuracy' | 分类 | 越大越好 ↑ |
BalancedAccuracy |
'balanced_accuracy' | 分类 | 越大越好 ↑ |
LogLoss |
'logloss' | 分类 | 越小越好 ↓ |
MSE |
'mse' | 回归 | 越小越好 ↓ |
MAE |
'mae' | 回归 | 越小越好 ↓ |
RMSE |
'rmse' | 回归 | 越小越好 ↓ |
RMSLE |
'rmsle' | 回归 | 越小越好 ↓ |
UnsupervisedMetric |
'unsup_loss' | 预训练 | 越小越好 ↓ |
自定义指标模板:
from pytorch_tabnet.metrics import Metric
class MyMetric(Metric):
def __init__(self):
self._name = "my_metric" # 指标名称
self._maximize = True # True=越大越好, False=越小越好
def __call__(self, y_true, y_score):
# y_true: 真实标签
# y_score: 模型预测(分类是概率矩阵,回归是预测值)
return your_calculation(y_true, y_score)训练时对数据进行随机变换,提升模型泛化能力。
内置增强方法:
| 类名 | 适用任务 | 原理 |
|---|---|---|
RegressionSMOTE |
回归 | 将两个样本按比例混合(特征和标签都混合) |
ClassificationSMOTE |
分类 | 只混合特征,标签保持主样本的类别 |
SMOTE 原理图解:
样本 A: 特征=[1,2,3], 标签=0.5
样本 B: 特征=[4,5,6], 标签=0.8
混合比例 λ=0.7
回归 SMOTE:
新特征 = 0.7*[1,2,3] + 0.3*[4,5,6] = [1.9, 2.9, 3.9]
新标签 = 0.7*0.5 + 0.3*0.8 = 0.59
分类 SMOTE:
新特征 = 0.7*[1,2,3] + 0.3*[4,5,6] = [1.9, 2.9, 3.9]
新标签 = A 的标签 = 0(保持不变,因为 λ>0.5)
使用方法:
from pytorch_tabnet.augmentations import ClassificationSMOTE
aug = ClassificationSMOTE(p=0.8) # 80% 的样本会被增强
clf.fit(X_train, y_train, augmentations=aug)各种辅助功能的集合。
主要功能分类:
# 用于将 numpy 数组转换为 PyTorch Dataset
TorchDataset(x, y) # 普通数组
SparseTorchDataset(x, y) # 稀疏矩阵
PredictDataset(x) # 预测用(无标签)
SparsePredictDataset(x) # 稀疏矩阵预测用create_dataloaders(X_train, y_train, eval_set, ...)
# 创建训练和验证的 DataLoader
# 支持类别权重采样、稀疏矩阵等define_device("auto") # 自动检测 CUDA
# 返回 "cuda" 或 "cpu"create_group_matrix(list_groups, input_dim)
# 创建特征分组矩阵
# 例:[[0,1,2], [3,4]] 表示前三个特征为一组,后两个为一组check_input(X) # 检查输入格式
check_embedding_parameters() # 检查嵌入参数
validate_eval_set() # 验证验证集格式处理分类任务中的标签相关逻辑。
主要功能:
| 函数 | 功能 |
|---|---|
unique_labels(y) |
提取所有唯一的类别标签 |
type_of_target(y) |
自动判断任务类型(二分类/多分类/多标签等) |
infer_output_dim(y_train) |
推断输出维度(类别数) |
check_output_dim(labels, y) |
检查验证集标签是否在训练集中出现过 |
任务类型判断示例:
type_of_target([0, 1, 0, 1]) # → 'binary'(二分类)
type_of_target([0, 1, 2, 3]) # → 'multiclass'(多分类)
type_of_target([[0,1], [1,0]]) # → 'multilabel-indicator'(多标签)专门为无监督预训练服务的工具函数。
与 utils.py 的区别:
utils.py:需要 (X, y) 对pretraining_utils.py:只需要 X(无标签)
主要函数:
create_dataloaders(X_train, eval_set, ...)
# 创建预训练用的 DataLoader(没有 y_train)
validate_eval_set(eval_set, eval_name, X_train)
# 验证预训练验证集(只检查 X 的形状)import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_cancer
from pytorch_tabnet.tab_model import TabNetClassifier
# 1. 准备数据
data = load_breast_cancer()
X = data.data
y = data.target
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
X_train, X_valid, y_train, y_valid = train_test_split(
X_train, y_train, test_size=0.2, random_state=42
)
# 2. 创建并训练模型
clf = TabNetClassifier(
n_d=8,
n_a=8,
n_steps=3,
gamma=1.3,
lambda_sparse=1e-3,
verbose=1,
seed=42
)
clf.fit(
X_train, y_train,
eval_set=[(X_valid, y_valid)],
eval_name=['valid'],
eval_metric=['auc', 'accuracy'],
max_epochs=100,
patience=10,
batch_size=256,
)
# 3. 预测和评估
predictions = clf.predict(X_test)
proba = clf.predict_proba(X_test)
# 4. 查看特征重要性
print("特征重要性:", clf.feature_importances_)
# 5. 保存模型
clf.save_model("./breast_cancer_tabnet")from pytorch_tabnet.tab_model import TabNetRegressor
import numpy as np
# 注意:回归任务的 y 必须是 2D 的 (n_samples, n_targets)
y_train = y_train.reshape(-1, 1)
y_valid = y_valid.reshape(-1, 1)
reg = TabNetRegressor()
reg.fit(
X_train, y_train,
eval_set=[(X_valid, y_valid)],
eval_metric=['mae', 'rmse'],
max_epochs=100,
patience=10,
)
predictions = reg.predict(X_test)- 数据量太小(< 1000 条)时,传统方法可能更好
- 特征非常少时(< 10 个),简单模型可能就够了
- 需要极快推理速度时,树模型更快
TabNet 不直接处理缺失值,建议:
- 用均值/中位数填充
- 用 -1 或特殊值标记缺失
推荐顺序:
- 先用默认参数跑一个 baseline
- 调整
n_d和n_a(一起调,8→16→32) - 调整
n_steps(3→5→7) - 调整
batch_size(1024→2048→4096) - 调整正则化
lambda_sparse
💡 提示:如果你是第一次使用,建议先跑通
census_example.ipynb示例,再根据自己的数据修改。
有问题欢迎在 Issues 中讨论!🎉