Skip to content

[问题/Issue] 章节5.1.6:残差投影缩放初始化处疑似将 w2 写成 w3 / Chapter5.1.6: Brief description #200

@Tw1stzz259

Description

@Tw1stzz259

1. 遇到问题的章节 / Affected Chapter

Chapter5.1.6

2. 具体问题描述 / Problem Description

在第五章 5.1.6 搭建完整模型中,模型初始化部分有如下逻辑:

# 对残差投影进行特殊的缩放初始化
for pn, p in self.named_parameters():
    if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
        torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * args.n_layers))

这里注释写的是对残差投影进行特殊的缩放初始化”,但在当前 MLP 定义中return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
其中w1  gate 分支
w3  up/value 分支
w2 是从 hidden_dim 投影回 dim 的输出投影层
因此如果按残差投影 / residual projection的语义理解MLP 中应当被缩放初始化的层更像是 w2.weight而不是 w3.weightAttention 中对应的残差投影是 wo.weight建议确认这里是否应修改为if pn.endswith('w2.weight') or pn.endswith('wo.weight'):
    torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * args.n_layers))
或者如果保留 w3.weight 是有意设计建议修改注释避免读者误解为 MLP 的输出投影层### 3. 问题重现材料 / Reproduction Materials


**3. 问题重现材料 / Reproduction Materials**

```markdown
相关代码位置- 第五章 5.1.6 搭建完整模型
- `Transformer.__init__` 中参数初始化部分
- `MLP.forward`  `w1 / w2 / w3` 的使用关系

对照依据Meta LLaMA  MLP 写法为return self.w2(F.silu(self.w1(x)) * self.w3(x))

其中 w2 是输出投影层因此 w2 / down_proj 才是 MLP 中投影回 residual stream 的层### 确认事项 / Verification

- [x] 此问题未在过往Issue中被报告过 / This issue hasn't been reported before

Metadata

Metadata

Assignees

No one assigned

    Labels

    documentationImprovements or additions to documentation

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions