-
Notifications
You must be signed in to change notification settings - Fork 418
Description
Bug report
It looks like the attention flops calculation assumes you must compute the entire qk and wv products, when with causal masking you only need to compute it across half the key/value. We noticed this because in local attention calculations we more precisely count how much of the key/value we need for each query, and this reduces the number of flops a lot. This makes the top-line tflop/s/device number get significantly worse, but this is mostly a false signal because the global attention calculations are really giving too much credit for flops that don't need to be done.
We weren't sure if there's an accepted best practice for this sort of calculation, so we checked Megatron, and they do seem to divide the attention flops by 2 to account for this. I think we should do the same here.
Logs/Output
No response
Environment Information
No response
Additional Context
No response