-
Notifications
You must be signed in to change notification settings - Fork 13
Open
Description
Nice work. The MM-DiT black has a concat operation between image modal and text modal before the Q K V Attention, emmmm I could not find it......
Look forward to your reply.
Lines 119 to 150 in cca1e6c
| for x, mask, to_qkv, q_rmsnorm, k_rmsnorm in zip(inputs, masks, self.to_qkv, self.q_rmsnorms, self.k_rmsnorms): | |
| qkv = to_qkv(x) | |
| qkv = self.split_heads(qkv) | |
| # optional qk rmsnorm per modality | |
| if self.qk_rmsnorm: | |
| q, k, v = qkv | |
| q = q_rmsnorm(q) | |
| k = k_rmsnorm(k) | |
| qkv = torch.stack((q, k, v)) | |
| all_qkvs.append(qkv) | |
| # handle mask per modality | |
| if not exists(mask): | |
| mask = torch.ones(x.shape[:2], device = device, dtype = torch.bool) | |
| all_masks.append(mask) | |
| # combine all qkv and masks | |
| all_qkvs, packed_shape = pack(all_qkvs, 'qkv b h * d') | |
| all_masks, _ = pack(all_masks, 'b *') | |
| # attention | |
| q, k, v = all_qkvs | |
| outs, *_ = self.attend(q, k, v, mask = all_masks) |
Metadata
Metadata
Assignees
Labels
No labels