Skip to content

langjihao/tabnet

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

TabNet:可解释的表格数据深度学习模型

🎯 一句话介绍

TabNet 是 Google 提出的一种专门用于表格数据(类似 Excel 表格的数据)的深度学习模型。它的特点是:像决策树一样可以告诉你「哪些特征对预测最重要」,同时又拥有深度学习的强大能力。

📄 原论文:TabNet: Attentive Interpretable Tabular Learning (Arik & Pfister, 2019)


📖 目录


为什么选择 TabNet?

🆚 与传统方法对比

方法 优点 缺点
XGBoost / LightGBM 表格数据效果好、可解释 不支持端到端学习、无法利用未标注数据
传统神经网络 (MLP) 端到端学习、可迁移 表格数据效果一般、不可解释
TabNet 兼具两者优点!可解释 + 效果好 + 支持预训练 训练稍慢

🌟 TabNet 的独特优势

  1. 可解释性:能输出每个特征的重要性(类似决策树)
  2. 注意力机制:自动选择重要特征,每一步只关注少数特征
  3. 半监督学习:可以利用没有标签的数据进行预训练
  4. 端到端训练:不需要手动特征工程

快速开始

只需几行代码就能训练一个 TabNet 模型:

from pytorch_tabnet.tab_model import TabNetClassifier, TabNetRegressor

# 1. 创建模型(分类任务)
clf = TabNetClassifier()

# 2. 训练模型
clf.fit(
    X_train, y_train,           # 训练数据
    eval_set=[(X_valid, y_valid)]  # 验证数据(用于早停)
)

# 3. 预测
predictions = clf.predict(X_test)

安装方法

使用 pip 安装

pip install pytorch-tabnet

使用 conda 安装

conda install -c conda-forge pytorch-tabnet

能解决什么问题?

TabNet 提供了三种模型,对应不同的任务:

模型 适用场景 示例
TabNetClassifier 二分类、多分类 预测用户是否流失、图片分类
TabNetRegressor 回归、多目标回归 预测房价、股票价格
TabNetMultiTaskClassifier 多任务多分类 同时预测用户年龄段和购买意向

核心概念解释

💡 这部分帮助你理解 TabNet 的工作原理,不需要记住所有细节

1. 注意力机制(Attention)

TabNet 最核心的思想是「逐步决策」:

第1步: 关注「年龄」和「收入」→ 做出部分决策
第2步: 关注「学历」和「职业」→ 完善决策
第3步: 关注「地区」→ 最终决策

每一步只看少数几个特征,这样既高效又可解释。

2. Sparsemax vs Softmax

  • Softmax:给所有特征都分配权重(虽然有大有小)
  • Sparsemax:直接把不重要的特征权重设为 0(更稀疏、更可解释)

TabNet 默认使用 Sparsemax,所以它的特征选择更「干净」。

3. Ghost Batch Normalization

一种特殊的批归一化方法,能让模型在大 batch size 下也能稳定训练。


模型参数详解

创建模型时可以调整的参数:

🔑 最重要的参数

参数 默认值 说明 调参建议
n_d 8 决策层宽度 8-64,越大容量越大,但容易过拟合
n_a 8 注意力层宽度 通常设成和 n_d 一样
n_steps 3 决策步数 3-10,越大模型越复杂
gamma 1.3 特征重用系数 1.0-2.0,越接近 1 每步选的特征越不同

📝 其他常用参数

参数 默认值 说明
cat_idxs [] 类别特征的列索引(比如 [0, 3, 5])
cat_dims [] 每个类别特征的类别数
cat_emb_dim 1 类别特征的嵌入维度
lambda_sparse 1e-3 稀疏性正则化系数,越大越稀疏
mask_type "sparsemax" 可选 "sparsemax" 或 "entmax"

💻 运行相关参数

参数 默认值 说明
device_name "auto" 设备:"auto"、"cpu" 或 "cuda"
seed 0 随机种子,设置相同值可复现结果
verbose 1 是否打印训练过程

训练参数详解

调用 fit() 方法时可以调整的参数:

🔑 最重要的参数

参数 默认值 说明
max_epochs 200 最大训练轮数
patience 10 早停耐心值(连续多少轮不提升就停止)
batch_size 1024 批大小,TabNet 推荐用大 batch
eval_set None 验证集,格式为 [(X_val, y_val)]
eval_metric 自动 评估指标列表

📊 可用的评估指标

分类任务

  • 'auc':ROC-AUC(二分类默认)
  • 'accuracy':准确率(多分类默认)
  • 'balanced_accuracy':平衡准确率
  • 'logloss':对数损失

回归任务

  • 'mse':均方误差(默认)
  • 'mae':平均绝对误差
  • 'rmse':均方根误差
  • 'rmsle':对数均方根误差

⚙️ 其他参数

参数 默认值 说明
virtual_batch_size 128 Ghost Batch Normalization 的虚拟 batch 大小
num_workers 0 数据加载的并行进程数
weights 0 类别权重(0=不加权,1=自动平衡)
warm_start False 是否从上次训练继续

进阶功能

🔄 半监督预训练

如果你有大量没有标签的数据,可以先进行预训练:

from pytorch_tabnet.pretraining import TabNetPretrainer

# 1. 无监督预训练
unsupervised_model = TabNetPretrainer()
unsupervised_model.fit(
    X_train=X_unlabeled,  # 无标签数据
    pretraining_ratio=0.8,  # 随机遮盖 80% 的特征进行重建
)

# 2. 有监督微调
clf = TabNetClassifier()
clf.fit(
    X_train, y_train,
    from_unsupervised=unsupervised_model  # 使用预训练权重
)

💾 保存和加载模型

# 保存模型
clf.save_model("./my_tabnet_model")

# 加载模型
loaded_clf = TabNetClassifier()
loaded_clf.load_model("./my_tabnet_model.zip")

📈 自定义评估指标

from pytorch_tabnet.metrics import Metric
from sklearn.metrics import roc_auc_score

class Gini(Metric):
    def __init__(self):
        self._name = "gini"
        self._maximize = True  # 越大越好

    def __call__(self, y_true, y_score):
        auc = roc_auc_score(y_true, y_score[:, 1])
        return max(2 * auc - 1, 0.)

# 使用自定义指标
clf.fit(X_train, y_train, eval_metric=[Gini])

🔍 获取特征重要性

# 训练后可以直接获取
feature_importance = clf.feature_importances_

# 或者获取更详细的解释
explain_matrix, masks = clf.explain(X_test)

代码目录结构

tabnet/
├── 📂 pytorch_tabnet/          # 核心代码目录
│   ├── abstract_model.py       # 抽象基类
│   ├── augmentations.py        # 数据增强
│   ├── callbacks.py            # 回调函数
│   ├── metrics.py              # 评估指标
│   ├── multiclass_utils.py     # 多分类工具
│   ├── multitask.py            # 多任务学习
│   ├── pretraining.py          # 预训练模块
│   ├── pretraining_utils.py    # 预训练工具
│   ├── sparsemax.py            # 稀疏激活函数
│   ├── tab_model.py            # 主要模型类
│   ├── tab_network.py          # 神经网络架构
│   └── utils.py                # 通用工具
│
├── 📓 示例 Notebooks
│   ├── census_example.ipynb        # 二分类示例(人口普查数据)
│   ├── forest_example.ipynb        # 多分类示例(森林覆盖)
│   ├── regression_example.ipynb    # 回归示例
│   ├── multi_regression_example.ipynb  # 多目标回归示例
│   ├── multi_task_example.ipynb    # 多任务示例
│   ├── pretraining_example.ipynb   # 预训练示例
│   └── customizing_example.ipynb   # 自定义示例
│
├── 📂 tests/                   # 测试代码
├── 📂 docs/                    # 文档
└── pyproject.toml              # 项目配置

📦 pytorch_tabnet 目录详解

下面对每个文件进行详细解释,帮助你理解整个代码库的架构。


📄 tab_model.py - 主要模型入口(最常用!)

这是你使用 TabNet 时最直接接触的文件。

包含两个核心类:

类名 用途 默认损失函数 默认评估指标
TabNetClassifier 分类任务(二分类/多分类) CrossEntropy AUC(二分类)/ Accuracy(多分类)
TabNetRegressor 回归任务(单目标/多目标) MSE MSE

主要方法:

clf = TabNetClassifier()

# 训练
clf.fit(X_train, y_train, eval_set=[(X_val, y_val)])

# 预测类别
predictions = clf.predict(X_test)

# 预测概率(仅分类器)
probabilities = clf.predict_proba(X_test)

内部工作流程:

fit() 被调用
    ↓
检查输入数据格式 → 推断输出维度 → 创建数据加载器
    ↓
初始化网络结构(调用 tab_network.py)
    ↓
开始训练循环:
    每个 epoch:
        - 训练阶段:前向传播 → 计算损失 → 反向传播
        - 验证阶段:计算评估指标
        - 检查早停条件
    ↓
训练结束,加载最佳权重

📄 multitask.py - 多任务分类

当你需要同时预测多个分类目标时使用。

例如:同时预测用户的「年龄段」和「购买意向」两个分类任务。

from pytorch_tabnet.multitask import TabNetMultiTaskClassifier

clf = TabNetMultiTaskClassifier()

# y_train 形状: (n_samples, n_tasks)
# 例如: [[25岁以下, 高意向], [25-35岁, 低意向], ...]
clf.fit(X_train, y_train)

# 返回一个列表,每个元素是一个任务的预测结果
predictions = clf.predict(X_test)  # [task1_preds, task2_preds, ...]

与普通分类器的区别:

  • 输出层是多个并行的分类头,每个任务一个
  • 损失函数是所有任务损失的平均
  • 可以为每个任务指定不同的损失函数

📄 pretraining.py - 自监督预训练

利用无标签数据进行预训练,提升模型效果。

这是 TabNet 的一大亮点!当你有大量无标签数据时,可以先进行预训练。

预训练原理:

原始数据: [特征1, 特征2, 特征3, 特征4, 特征5]
                         ↓ 随机遮盖部分特征
遮盖后:   [特征1,   ?  , 特征3,   ?  , 特征5]
                         ↓ 模型尝试重建
重建目标: [特征1, 特征2, 特征3, 特征4, 特征5]

模型通过「猜测被遮盖的特征」来学习数据的内在结构。

使用方法:

from pytorch_tabnet.pretraining import TabNetPretrainer

# 第一步:无监督预训练
pretrainer = TabNetPretrainer(
    n_d=8, n_a=8,
    mask_type='sparsemax'
)
pretrainer.fit(
    X_train=X_unlabeled,      # 大量无标签数据
    pretraining_ratio=0.8,    # 遮盖 80% 的特征
)

# 第二步:用预训练权重初始化分类器
clf = TabNetClassifier()
clf.fit(
    X_train, y_train,
    from_unsupervised=pretrainer  # 关键!使用预训练权重
)

何时使用预训练:

  • 有大量无标签数据(比标签数据多 10 倍以上)
  • 标签数据较少(< 10000 条)
  • 特征之间有复杂的内在关系

📄 abstract_model.py - 抽象基类

所有模型(Classifier、Regressor、Pretrainer)的共同父类。

你不需要直接使用这个文件,但了解它有助于理解代码结构。

定义的核心功能:

方法 功能
fit() 完整的训练循环
predict() 模型预测
explain() 获取特征重要性和注意力掩码
save_model() 保存模型到 .zip 文件
load_model() 从 .zip 文件加载模型

继承关系图:

                    TabModel (abstract_model.py)
                         │
         ┌───────────────┼───────────────┐
         ↓               ↓               ↓
  TabNetClassifier  TabNetRegressor  TabNetPretrainer
   (tab_model.py)   (tab_model.py)   (pretraining.py)
         │
         ↓
TabNetMultiTaskClassifier
    (multitask.py)

📄 tab_network.py - 神经网络架构(核心!)

定义了 TabNet 的所有神经网络组件。

这是理解 TabNet 原理的关键文件!

整体架构图:

┌─────────────────────────────────────────────────────────────────────┐
│                           TabNet 完整架构                            │
├─────────────────────────────────────────────────────────────────────┤
│                                                                     │
│  输入: [年龄, 收入, 学历, 城市, ...]                                  │
│              ↓                                                      │
│  ┌─────────────────────────┐                                        │
│  │   EmbeddingGenerator    │  ← 把类别特征转成向量                    │
│  │   (类别特征嵌入)          │    例: 城市"北京" → [0.2, 0.8, 0.1]     │
│  └───────────┬─────────────┘                                        │
│              ↓                                                      │
│  ┌─────────────────────────┐                                        │
│  │      Initial BN         │  ← 初始批归一化                         │
│  └───────────┬─────────────┘                                        │
│              ↓                                                      │
│  ╔═══════════════════════════════════════════════════════════════╗  │
│  ║                    TabNetEncoder                              ║  │
│  ║  ┌─────────────────────────────────────────────────────────┐  ║  │
│  ║  │                      Step 1                             │  ║  │
│  ║  │  ┌──────────────────┐    ┌──────────────────┐          │  ║  │
│  ║  │  │ AttentiveTransf. │ →  │  FeatTransformer │ → d₁     │  ║  │
│  ║  │  │ (选择重要特征)    │    │  (特征变换)       │          │  ║  │
│  ║  │  └──────────────────┘    └──────────────────┘          │  ║  │
│  ║  └─────────────────────────────────────────────────────────┘  ║  │
│  ║  ┌─────────────────────────────────────────────────────────┐  ║  │
│  ║  │                      Step 2                             │  ║  │
│  ║  │  ┌──────────────────┐    ┌──────────────────┐          │  ║  │
│  ║  │  │ AttentiveTransf. │ →  │  FeatTransformer │ → d₂     │  ║  │
│  ║  │  └──────────────────┘    └──────────────────┘          │  ║  │
│  ║  └─────────────────────────────────────────────────────────┘  ║  │
│  ║  ┌─────────────────────────────────────────────────────────┐  ║  │
│  ║  │                      Step 3                             │  ║  │
│  ║  │  ┌──────────────────┐    ┌──────────────────┐          │  ║  │
│  ║  │  │ AttentiveTransf. │ →  │  FeatTransformer │ → d₃     │  ║  │
│  ║  │  └──────────────────┘    └──────────────────┘          │  ║  │
│  ║  └─────────────────────────────────────────────────────────┘  ║  │
│  ╚═══════════════════════════════════════════════════════════════╝  │
│              ↓                                                      │
│         d₁ + d₂ + d₃ (聚合所有步骤的输出)                            │
│              ↓                                                      │
│  ┌─────────────────────────┐                                        │
│  │     Final Mapping       │  ← 输出层                               │
│  └───────────┬─────────────┘                                        │
│              ↓                                                      │
│  输出: 预测结果                                                       │
│                                                                     │
└─────────────────────────────────────────────────────────────────────┘

核心组件详解:

1️⃣ EmbeddingGenerator(嵌入生成器)

作用:将类别特征转换为稠密向量。

原始输入: [25, 50000, "本科", "北京"]
              ↓
数值特征直接保留: [25, 50000]
类别特征嵌入:     "本科" → [0.3, 0.7]
                 "北京" → [0.2, 0.8, 0.1]
              ↓
合并后: [25, 50000, 0.3, 0.7, 0.2, 0.8, 0.1]
2️⃣ AttentiveTransformer(注意力变换器)

作用:决定每一步关注哪些特征。

# 简化版工作原理
class AttentiveTransformer:
    def forward(self, priors, processed_feat):
        # priors: 之前步骤已经使用过的特征(用于避免重复)
        # processed_feat: 当前处理的特征
        
        x = Linear(processed_feat)  # 全连接层
        x = BatchNorm(x)            # 批归一化
        x = x * priors              # 结合先验信息
        x = Sparsemax(x)            # 稀疏化!大部分特征权重变成 0
        return x  # 返回注意力掩码

关键点:使用 Sparsemax 让注意力「稀疏」,即每步只关注少数特征。

3️⃣ FeatTransformer(特征变换器)

作用:对选中的特征进行非线性变换。

由多个 GLU_Block 组成:

  • Shared GLU blocks:所有步骤共享的变换层(促进特征复用)
  • Independent GLU blocks:每步独立的变换层(学习步骤特定的模式)
4️⃣ GLU_Block & GLU_Layer(门控线性单元)

作用:TabNet 的基本计算单元。

输入 x
    ↓
Linear(x) → [h, g]  # 分成两半
    ↓
h * sigmoid(g)  # 门控机制
    ↓
输出

门控机制让网络能「选择性」地传递信息,类似 LSTM 中的门。

5️⃣ GBN - Ghost Batch Normalization

作用:让大 batch 训练更稳定。

普通 BN: 整个 batch 一起归一化
         batch_size = 1024 → 统计量来自 1024 个样本

Ghost BN: 拆分成多个虚拟小 batch 分别归一化
         batch_size = 1024, virtual_batch_size = 128
         → 拆成 8 个小 batch,各自归一化后拼接

为什么需要 GBN?

  • TabNet 推荐使用大 batch(1024+)
  • 但大 batch 的 BN 统计量可能不稳定
  • GBN 通过虚拟小 batch 解决这个问题
6️⃣ TabNetDecoder(解码器)

作用:仅用于预训练,将编码器输出重建为原始特征。

编码器输出 [d₁, d₂, d₃]
           ↓
     FeatTransformer
           ↓
     Reconstruction Layer
           ↓
重建的原始特征(与被遮盖前对比计算损失)
7️⃣ RandomObfuscator(随机遮盖器)

作用:预训练时随机遮盖特征。

# pretraining_ratio = 0.8 表示遮盖 80%
原始: [特征1, 特征2, 特征3, 特征4, 特征5]
遮盖: [特征1,   0  ,   0  ,   0  , 特征5]  # 随机遮盖了 2,3,4

📄 sparsemax.py - 稀疏激活函数

实现 Sparsemax 和 Entmax15 激活函数。

这是让 TabNet 具有「特征选择」能力的关键!

Softmax vs Sparsemax 对比:

输入: [2.0, 1.0, 0.1, 0.1, 0.1]

Softmax 输出:  [0.48, 0.18, 0.11, 0.11, 0.11]  
               ↑ 所有特征都有权重(虽然有大小之分)

Sparsemax 输出: [0.85, 0.15, 0.00, 0.00, 0.00]
                ↑ 不重要的特征直接为 0!更稀疏、更可解释

包含的类:

类名 描述 稀疏程度
Sparsemax 原始 sparsemax
Entmax15 α=1.5 的 entmax 中等(介于 softmax 和 sparsemax 之间)

使用哪个?

  • 默认用 sparsemax(更稀疏,可解释性更好)
  • 如果效果不好,可以试试 entmax(更平滑)

📄 callbacks.py - 回调函数

在训练过程中的特定时机执行自定义逻辑。

内置回调:

回调类 功能 触发时机
History 记录训练历史(loss、metrics) 每个 batch/epoch 结束
EarlyStopping 早停(验证集不再提升就停止) 每个 epoch 结束
LRSchedulerCallback 学习率调度 每个 batch/epoch 结束

早停工作原理:

epoch 1: val_loss = 0.5    ← 最佳
epoch 2: val_loss = 0.4    ← 新的最佳!重置计数器
epoch 3: val_loss = 0.45   ← 没提升,计数 = 1
epoch 4: val_loss = 0.43   ← 没提升,计数 = 2
...
epoch 12: val_loss = 0.42  ← 没提升,计数 = 10
→ patience=10 达到,停止训练,恢复 epoch 2 的权重

自定义回调示例:

from pytorch_tabnet.callbacks import Callback

class MyCallback(Callback):
    def on_epoch_end(self, epoch, logs=None):
        print(f"Epoch {epoch} 结束, loss = {logs['loss']}")

📄 metrics.py - 评估指标

定义各种评估指标,用于监控训练和验证。

内置指标:

指标类 名称 任务类型 优化方向
AUC 'auc' 二分类 越大越好 ↑
Accuracy 'accuracy' 分类 越大越好 ↑
BalancedAccuracy 'balanced_accuracy' 分类 越大越好 ↑
LogLoss 'logloss' 分类 越小越好 ↓
MSE 'mse' 回归 越小越好 ↓
MAE 'mae' 回归 越小越好 ↓
RMSE 'rmse' 回归 越小越好 ↓
RMSLE 'rmsle' 回归 越小越好 ↓
UnsupervisedMetric 'unsup_loss' 预训练 越小越好 ↓

自定义指标模板:

from pytorch_tabnet.metrics import Metric

class MyMetric(Metric):
    def __init__(self):
        self._name = "my_metric"  # 指标名称
        self._maximize = True     # True=越大越好, False=越小越好

    def __call__(self, y_true, y_score):
        # y_true: 真实标签
        # y_score: 模型预测(分类是概率矩阵,回归是预测值)
        return your_calculation(y_true, y_score)

📄 augmentations.py - 数据增强

训练时对数据进行随机变换,提升模型泛化能力。

内置增强方法:

类名 适用任务 原理
RegressionSMOTE 回归 将两个样本按比例混合(特征和标签都混合)
ClassificationSMOTE 分类 只混合特征,标签保持主样本的类别

SMOTE 原理图解:

样本 A: 特征=[1,2,3], 标签=0.5
样本 B: 特征=[4,5,6], 标签=0.8
混合比例 λ=0.7

回归 SMOTE:
  新特征 = 0.7*[1,2,3] + 0.3*[4,5,6] = [1.9, 2.9, 3.9]
  新标签 = 0.7*0.5 + 0.3*0.8 = 0.59

分类 SMOTE:
  新特征 = 0.7*[1,2,3] + 0.3*[4,5,6] = [1.9, 2.9, 3.9]
  新标签 = A 的标签 = 0(保持不变,因为 λ>0.5)

使用方法:

from pytorch_tabnet.augmentations import ClassificationSMOTE

aug = ClassificationSMOTE(p=0.8)  # 80% 的样本会被增强

clf.fit(X_train, y_train, augmentations=aug)

📄 utils.py - 工具函数

各种辅助功能的集合。

主要功能分类:

数据集类
# 用于将 numpy 数组转换为 PyTorch Dataset
TorchDataset(x, y)         # 普通数组
SparseTorchDataset(x, y)   # 稀疏矩阵
PredictDataset(x)          # 预测用(无标签)
SparsePredictDataset(x)    # 稀疏矩阵预测用
数据加载器创建
create_dataloaders(X_train, y_train, eval_set, ...)
# 创建训练和验证的 DataLoader
# 支持类别权重采样、稀疏矩阵等
设备检测
define_device("auto")  # 自动检测 CUDA
# 返回 "cuda" 或 "cpu"
特征分组
create_group_matrix(list_groups, input_dim)
# 创建特征分组矩阵
# 例:[[0,1,2], [3,4]] 表示前三个特征为一组,后两个为一组
参数验证
check_input(X)               # 检查输入格式
check_embedding_parameters() # 检查嵌入参数
validate_eval_set()          # 验证验证集格式

📄 multiclass_utils.py - 多分类工具

处理分类任务中的标签相关逻辑。

主要功能:

函数 功能
unique_labels(y) 提取所有唯一的类别标签
type_of_target(y) 自动判断任务类型(二分类/多分类/多标签等)
infer_output_dim(y_train) 推断输出维度(类别数)
check_output_dim(labels, y) 检查验证集标签是否在训练集中出现过

任务类型判断示例:

type_of_target([0, 1, 0, 1])       # → 'binary'(二分类)
type_of_target([0, 1, 2, 3])       # → 'multiclass'(多分类)
type_of_target([[0,1], [1,0]])     # → 'multilabel-indicator'(多标签)

📄 pretraining_utils.py - 预训练工具

专门为无监督预训练服务的工具函数。

utils.py 的区别:

  • utils.py:需要 (X, y) 对
  • pretraining_utils.py:只需要 X(无标签)

主要函数:

create_dataloaders(X_train, eval_set, ...)
# 创建预训练用的 DataLoader(没有 y_train)

validate_eval_set(eval_set, eval_name, X_train)
# 验证预训练验证集(只检查 X 的形状)

示例代码

完整的分类示例

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_cancer
from pytorch_tabnet.tab_model import TabNetClassifier

# 1. 准备数据
data = load_breast_cancer()
X = data.data
y = data.target

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)
X_train, X_valid, y_train, y_valid = train_test_split(
    X_train, y_train, test_size=0.2, random_state=42
)

# 2. 创建并训练模型
clf = TabNetClassifier(
    n_d=8,
    n_a=8,
    n_steps=3,
    gamma=1.3,
    lambda_sparse=1e-3,
    verbose=1,
    seed=42
)

clf.fit(
    X_train, y_train,
    eval_set=[(X_valid, y_valid)],
    eval_name=['valid'],
    eval_metric=['auc', 'accuracy'],
    max_epochs=100,
    patience=10,
    batch_size=256,
)

# 3. 预测和评估
predictions = clf.predict(X_test)
proba = clf.predict_proba(X_test)

# 4. 查看特征重要性
print("特征重要性:", clf.feature_importances_)

# 5. 保存模型
clf.save_model("./breast_cancer_tabnet")

回归示例

from pytorch_tabnet.tab_model import TabNetRegressor
import numpy as np

# 注意:回归任务的 y 必须是 2D 的 (n_samples, n_targets)
y_train = y_train.reshape(-1, 1)
y_valid = y_valid.reshape(-1, 1)

reg = TabNetRegressor()
reg.fit(
    X_train, y_train,
    eval_set=[(X_valid, y_valid)],
    eval_metric=['mae', 'rmse'],
    max_epochs=100,
    patience=10,
)

predictions = reg.predict(X_test)

📚 学习资源


🤔 常见问题

Q: TabNet 什么时候不适用?

  • 数据量太小(< 1000 条)时,传统方法可能更好
  • 特征非常少时(< 10 个),简单模型可能就够了
  • 需要极快推理速度时,树模型更快

Q: 如何处理缺失值?

TabNet 不直接处理缺失值,建议:

  • 用均值/中位数填充
  • 用 -1 或特殊值标记缺失

Q: 如何调参?

推荐顺序:

  1. 先用默认参数跑一个 baseline
  2. 调整 n_dn_a(一起调,8→16→32)
  3. 调整 n_steps(3→5→7)
  4. 调整 batch_size(1024→2048→4096)
  5. 调整正则化 lambda_sparse

💡 提示:如果你是第一次使用,建议先跑通 census_example.ipynb 示例,再根据自己的数据修改。

有问题欢迎在 Issues 中讨论!🎉

About

For my girl

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 60.0%
  • Jupyter Notebook 35.4%
  • Makefile 1.9%
  • Shell 1.7%
  • Batchfile 0.3%
  • CSS 0.3%
  • Other 0.4%