Skip to content

Attention flops calculation doesn't reflect causal masking #1972

@philip-essential

Description

@philip-essential

Bug report

It looks like the attention flops calculation assumes you must compute the entire qk and wv products, when with causal masking you only need to compute it across half the key/value. We noticed this because in local attention calculations we more precisely count how much of the key/value we need for each query, and this reduces the number of flops a lot. This makes the top-line tflop/s/device number get significantly worse, but this is mostly a false signal because the global attention calculations are really giving too much credit for flops that don't need to be done.

We weren't sure if there's an accepted best practice for this sort of calculation, so we checked Megatron, and they do seem to divide the attention flops by 2 to account for this. I think we should do the same here.

Logs/Output

No response

Environment Information

No response

Additional Context

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions