Skip to content

A little question about mmdit attention #8

@AlphaNext

Description

@AlphaNext

Nice work. The MM-DiT black has a concat operation between image modal and text modal before the Q K V Attention, emmmm I could not find it......

Look forward to your reply.

for x, mask, to_qkv, q_rmsnorm, k_rmsnorm in zip(inputs, masks, self.to_qkv, self.q_rmsnorms, self.k_rmsnorms):
qkv = to_qkv(x)
qkv = self.split_heads(qkv)
# optional qk rmsnorm per modality
if self.qk_rmsnorm:
q, k, v = qkv
q = q_rmsnorm(q)
k = k_rmsnorm(k)
qkv = torch.stack((q, k, v))
all_qkvs.append(qkv)
# handle mask per modality
if not exists(mask):
mask = torch.ones(x.shape[:2], device = device, dtype = torch.bool)
all_masks.append(mask)
# combine all qkv and masks
all_qkvs, packed_shape = pack(all_qkvs, 'qkv b h * d')
all_masks, _ = pack(all_masks, 'b *')
# attention
q, k, v = all_qkvs
outs, *_ = self.attend(q, k, v, mask = all_masks)

Image

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions