Hi,
As I was trying to implement mixed precision training under Flax for my project, I noticed that the force_fp32_for_softmax flag defined in attention.MultiHeadDotProductAttention does not get to pass into dot_product_attention (the default attention function).



I think this might lead to loose control over the softmax operator and result in some stability issues under bf16 or fp16 precision, so I wonder if there's an alternate? Thanks!
Hi,
As I was trying to implement mixed precision training under Flax for my project, I noticed that the



force_fp32_for_softmaxflag defined inattention.MultiHeadDotProductAttentiondoes not get to pass intodot_product_attention(the default attention function).I think this might lead to loose control over the softmax operator and result in some stability issues under
bf16orfp16precision, so I wonder if there's an alternate? Thanks!