Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 674413113
  • Loading branch information
tensorflower-gardener committed Sep 13, 2024
1 parent b6b0b41 commit 2a88649
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 5 deletions.
75 changes: 73 additions & 2 deletions official/nlp/modeling/layers/block_sparse_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,42 @@

"""Block sparse attention converts query/key/value into blocks and performs diagonal block sparse attention."""
import collections
import logging

import tensorflow as tf, tf_keras


def _large_compatible_negative(tensor_type):
"""Large negative number as Tensor.
This function is necessary because the standard value for epsilon
in this module (-1e9) cannot be represented using tf.float16
Args:
tensor_type: a dtype to determine the type.
Returns:
a large negative number.
"""
# In case of dtype=float16 (e.g., for mixed-precision), the largest
# negative number (dtypes.float16.min) is divided by 2, in order to
# avoid overflows when summing negative inputs.
if tensor_type == tf.float16:
return tf.float16.min / 2.0
return -1e9


class MultiHeadAttention(tf_keras.layers.MultiHeadAttention):
"""Multi-head block sparse attention layer."""

def __init__(self, src_block_size=None, tgt_block_size=None, **kwargs):
def __init__(
self,
src_block_size=None,
tgt_block_size=None,
use_sigmoid_attn=False,
sigmoid_attn_bias=None,
**kwargs
):
"""Initializes the block sparse attention layer.
Args:
Expand All @@ -30,18 +58,34 @@ def __init__(self, src_block_size=None, tgt_block_size=None, **kwargs):
tgt_block_size: The block size of the key/value. An integer that divides
the sequence length into blocks. The number of blocks in the source and
target must be the same.
use_sigmoid_attn: If enabled, uses sigmoid instead of softmax to compute
attn probs. https://arxiv.org/pdf/2409.04431
sigmoid_attn_bias: Bias for sigmoid attn. Suggested value -ln(seq_len).
**kwargs: Args passed to the base class.
"""
super().__init__(**kwargs)
if src_block_size is None or src_block_size <= 0:
raise ValueError("src_block_size must be specified.")
self._src_block_size = src_block_size
self._tgt_block_size = tgt_block_size or self._src_block_size
self._use_sigmoid_attn = use_sigmoid_attn
self._sigmoid_attn_bias = sigmoid_attn_bias
if self._use_sigmoid_attn:
if self._sigmoid_attn_bias is None:
raise ValueError(
"sigmoid_attn_bias must be specified for sigmoid attn."
)

def _build_from_signature(self, query, value, key=None):
# pytype: disable=attribute-error
super()._build_from_signature(query, value, key)
# pytype: enable=attribute-error
# If block sizes are same as sequence lengths, we defer to default attn.
if (
self._query_shape[-2] == self._src_block_size
and self._key_shape[-2] == self._tgt_block_size
):
return
# The following capital letters are used to denote the tensor dimension
# parameters:
# B = batch size
Expand Down Expand Up @@ -127,11 +171,38 @@ def _masked_softmax(self, attention_scores, attention_mask=None):
if attention_mask is not None:
# `attention_mask` = [B, 1, L, T, S]
attention_mask = tf.expand_dims(attention_mask, axis=1)
return self._softmax(attention_scores, attention_mask)
if self._use_sigmoid_attn:
if attention_mask is not None:
adder = (1.0 - tf.cast(attention_mask, attention_scores.dtype)) * (
_large_compatible_negative(attention_scores.dtype)
)
attention_scores += adder
attention_scores += self._sigmoid_attn_bias
return tf_keras.activations.sigmoid(attention_scores)
else:
return self._softmax(attention_scores, attention_mask)

def _compute_attention(
self, query, key, value, attention_mask=None, training=None
):
# If block sizes are same as sequence lengths, we defer to default attn.
if (
self._query_shape[-2] == self._src_block_size
and self._key_shape[-2] == self._tgt_block_size
):
logging.info(
"Computing default attention as block sizes are equal to sequence"
" lengths."
)
# pytype: disable=attribute-error
return super()._compute_attention(
query,
key,
value,
attention_mask=attention_mask,
training=training,
)
# pytype: enable=attribute-error
# src_num_blocks and tgt_num_blocks are the number of blocks in the source
# and target. Care should be taken to ensure that the number of blocks in
# the source and target are the same.
Expand Down
96 changes: 93 additions & 3 deletions official/nlp/modeling/layers/block_sparse_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

"""Tests for block sparse attention layer."""

import math

from absl.testing import parameterized
import numpy as np
import tensorflow as tf, tf_keras
Expand Down Expand Up @@ -53,12 +55,29 @@ def test_non_masked_self_attention(self):
output = test_layer(query, query)
self.assertEqual(output.shape.as_list(), [None, 40, 80])

@parameterized.named_parameters(("with_bias", True), ("no_bias", False))
def test_masked_attention(self, use_bias):
@parameterized.named_parameters(
("with_bias", True),
("no_bias", False),
("with_sigmoid_attn", True, True),
)
def test_masked_attention(
self,
use_bias,
use_sigmoid_attn=False,
):
"""Test with a mask tensor."""
if use_sigmoid_attn:
sigmoid_attn_bias = -math.log(2)
else:
sigmoid_attn_bias = None
test_layer = block_sparse_attention.MultiHeadAttention(
num_heads=4, key_dim=2, use_bias=use_bias, src_block_size=2,
num_heads=4,
key_dim=2,
use_bias=use_bias,
src_block_size=2,
tgt_block_size=1,
use_sigmoid_attn=use_sigmoid_attn,
sigmoid_attn_bias=sigmoid_attn_bias,
)
# Create a 3-dimensional input (the first dimension is implicit).
batch_size = 3
Expand Down Expand Up @@ -112,6 +131,77 @@ def test_masked_attention(self, use_bias):
self.assertLen(test_layer._query_dense.trainable_variables, 1)
self.assertLen(test_layer._output_dense.trainable_variables, 1)

@parameterized.named_parameters(
("default_with_softmax", False),
("default_with_sigmoid", True),
)
def test_default_masked_attention(
self,
use_sigmoid_attn=False,
):
"""Test with a mask tensor."""
seq_len = 8
if use_sigmoid_attn:
sigmoid_attn_bias = -math.log(seq_len)
else:
sigmoid_attn_bias = None
test_layer = block_sparse_attention.MultiHeadAttention(
num_heads=4,
key_dim=2,
use_bias=True,
src_block_size=seq_len,
tgt_block_size=seq_len,
use_sigmoid_attn=use_sigmoid_attn,
sigmoid_attn_bias=sigmoid_attn_bias,
)
# Create a 3-dimensional input (the first dimension is implicit).
batch_size = 3
query = tf_keras.Input(shape=(seq_len, 8))
value = tf_keras.Input(shape=(seq_len, 8))
mask_tensor = tf_keras.Input(shape=(seq_len, seq_len))
output = test_layer(query=query, value=value, attention_mask=mask_tensor)

# Create a model containing the test layer.
model = tf_keras.Model([query, value, mask_tensor], output)

# Generate data for the input (non-mask) tensors.
from_data = 10 * np.random.random_sample((batch_size, seq_len, 8))
to_data = 10 * np.random.random_sample((batch_size, seq_len, 8))

# Invoke the data with a random set of mask data. This should mask at
# least one element.
mask_data = np.random.randint(2, size=(batch_size, seq_len, seq_len))
masked_output_data = model.predict([from_data, to_data, mask_data])

# Invoke the same data, but with a null mask (where no elements are
# masked).
null_mask_data = np.ones((batch_size, seq_len, seq_len))
unmasked_output_data = model.predict([from_data, to_data, null_mask_data])

# Because one data is masked and one is not, the outputs should not be
# the same.
self.assertNotAllClose(masked_output_data, unmasked_output_data)

# Tests the layer with three inputs: Q, K, V.
key = tf_keras.Input(shape=(seq_len, 8))
output = test_layer(
query, value=value, key=key, attention_mask=mask_tensor
)
model = tf_keras.Model([query, value, key, mask_tensor], output)

masked_output_data = model.predict(
[from_data, to_data, to_data, mask_data]
)
unmasked_output_data = model.predict(
[from_data, to_data, to_data, null_mask_data]
)
# Because one data is masked and one is not, the outputs should not be
# the same.
self.assertNotAllClose(masked_output_data, unmasked_output_data)

self.assertLen(test_layer._query_dense.trainable_variables, 2)
self.assertLen(test_layer._output_dense.trainable_variables, 2)

def test_masked_attention_with_scores(self):
"""Test with a mask tensor."""
test_layer = block_sparse_attention.MultiHeadAttention(
Expand Down

0 comments on commit 2a88649

Please sign in to comment.