Skip to content

Force fp32 in attention.MultiHeadDotProductAttention for softmax operator #4008

@willisma

Description

@willisma

Hi,

As I was trying to implement mixed precision training under Flax for my project, I noticed that the force_fp32_for_softmax flag defined in attention.MultiHeadDotProductAttention does not get to pass into dot_product_attention (the default attention function).
Screenshot 2024-06-18 at 11 18 38 PM
Screenshot 2024-06-18 at 11 17 55 PM
Screenshot 2024-06-18 at 11 19 21 PM

I think this might lead to loose control over the softmax operator and result in some stability issues under bf16 or fp16 precision, so I wonder if there's an alternate? Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions