1. 遇到问题的章节 / Affected Chapter
Chapter5.1.6
2. 具体问题描述 / Problem Description
在第五章 5.1.6 搭建完整模型中,模型初始化部分有如下逻辑:
# 对残差投影进行特殊的缩放初始化
for pn, p in self.named_parameters():
if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * args.n_layers))
这里注释写的是“对残差投影进行特殊的缩放初始化”,但在当前 MLP 定义中:
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
其中:
w1 是 gate 分支
w3 是 up/value 分支
w2 是从 hidden_dim 投影回 dim 的输出投影层
因此如果按“残差投影 / residual projection”的语义理解,MLP 中应当被缩放初始化的层更像是 w2.weight,而不是 w3.weight。Attention 中对应的残差投影是 wo.weight。
建议确认这里是否应修改为:
if pn.endswith('w2.weight') or pn.endswith('wo.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * args.n_layers))
或者如果保留 w3.weight 是有意设计,建议修改注释,避免读者误解为 MLP 的输出投影层。
### 3. 问题重现材料 / Reproduction Materials
**3. 问题重现材料 / Reproduction Materials**
```markdown
相关代码位置:
- 第五章 5.1.6 搭建完整模型
- `Transformer.__init__` 中参数初始化部分
- `MLP.forward` 中 `w1 / w2 / w3` 的使用关系
对照依据:
Meta LLaMA 的 MLP 写法为:
return self.w2(F.silu(self.w1(x)) * self.w3(x))
其中 w2 是输出投影层。
因此 w2 / down_proj 才是 MLP 中投影回 residual stream 的层。
### 确认事项 / Verification
- [x] 此问题未在过往Issue中被报告过 / This issue hasn't been reported before
1. 遇到问题的章节 / Affected Chapter
Chapter5.1.6
2. 具体问题描述 / Problem Description
在第五章 5.1.6 搭建完整模型中,模型初始化部分有如下逻辑: