-
Notifications
You must be signed in to change notification settings - Fork 27
Open
Description
When I want to use the SnapKV on the Qwen2-VL to compress the visual token, the key, value is compressed successfully. I print the input shape of _flash_attn_forward_func, but the result is same as origin which is not compressed. Even I set the window_size to 1 and max_capacity_prompt=2, the result has no change.
here is my code:
import math
from typing import Optional, Tuple
from loguru import logger as eval_logger
from transformers import AutoProcessor, AutoTokenizer, Qwen2VLForConditionalGeneration
import transformers
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.models.qwen2_vl.modeling_qwen2_vl import(
apply_multimodal_rotary_pos_emb,
repeat_kv
)
from transformers.utils import (
is_flash_attn_2_available,
logging,
)
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.utils import logging
logger = logging.get_logger(__name__)
try:
from qwen_vl_utils import process_vision_info, extract_vision_info, fetch_image, fetch_video
except ImportError:
eval_logger.warning("Failed to import qwen_vl_utils; Please install it via `pip install qwen-vl-utils`")
if is_flash_attn_2_available():
from flash_attn import flash_attn_varlen_func
from transformers.modeling_flash_attention_utils import _flash_attention_forward
else:
raise RuntimeError("Only support flash attention 2 for now, please install the flash attention 2")
# NOTE: copy from SnapKV.snapkv_utils.py
class SnapKVCluster():
def __init__(self, window_size = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool'):
self.window_size = window_size
self.max_capacity_prompt = max_capacity_prompt
assert self.max_capacity_prompt - self.window_size > 0
self.kernel_size = kernel_size
self.pooling = pooling
def reset(self, window_size = 64, max_capacity_prompt = 256 + 64, kernel_size = 5, pooling = 'avgpool'):
self.window_size = window_size
self.max_capacity_prompt = max_capacity_prompt
assert self.max_capacity_prompt - self.window_size > 0
self.kernel_size = kernel_size
self.pooling = pooling
def update_kv(self, key_states, query_states, value_states, attention_mask, num_key_value_groups):
# check if prefix phase
assert key_states.shape[-2] == query_states.shape[-2]
bsz, num_heads, q_len, head_dim = query_states.shape
if q_len < self.max_capacity_prompt:
return key_states, value_states
else:
# 计算L_obs queries与 所有keys的attention score
attn_weights = torch.matmul(query_states[..., -self.window_size:, :], key_states.transpose(2, 3)) / math.sqrt(head_dim)
# 构建一个下三角掩码矩阵,下三角的元素为0,上三角的元素为-inf
mask = torch.full((self.window_size, self.window_size), torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
mask_cond = torch.arange(mask.size(-1), device=attn_weights.device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(attn_weights.device)
# 将掩码矩阵扩展到4维,与key_value的维度匹配
attention_mask = mask[None, None, :, :]
# 对L_obs窗口内的attention score进行mask,由于会进行softmax,所有mask为0的位置不影响,mask为-inf的位置会被softmax后变为0
attn_weights[:, :, -self.window_size:, -self.window_size:] += attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
# NOTE: 计算论文中的公式(2)的C,即L_obs窗口内的每个query对prefix的attention score之和
attn_weights_sum = attn_weights[:, :, -self.window_size:, : -self.window_size].sum(dim = -2)
# 池化,暂时没在论文中看见
if self.pooling == 'avgpool':
attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
elif self.pooling == 'maxpool':
attn_cache = F.max_pool1d(attn_weights_sum, kernel_size = self.kernel_size, padding=self.kernel_size//2, stride=1)
else:
raise ValueError('Pooling method not supported')
# 计算论文中的公式(3)的I, attn_cache.shape=(batch, num_heads, L_prompt), indices.shape=(batch, num_heads, k)
indices = attn_cache.topk(self.max_capacity_prompt - self.window_size, dim=-1).indices
indices = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)
# 根据indices取出prefix中对应的key_value
k_past_compress = key_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices)
v_past_compress = value_states[:, :, :-self.window_size, :].gather(dim = 2, index = indices)
# 将prefix中的key_value与L_obs窗口内的key_value拼接
k_cur = key_states[:, :, -self.window_size:, :]
v_cur = value_states[:, :, -self.window_size:, :]
key_states = torch.cat([k_past_compress, k_cur], dim = 2)
value_states = torch.cat([v_past_compress, v_cur], dim = 2)
return key_states, value_states
def init_snapkv(self):
if not hasattr(self, "kv_cluster"):
if not hasattr(self.config, 'window_size'):
self.config.window_size = 512
if not hasattr(self.config, 'max_capacity_prompt'):
self.config.max_capacity_prompt = 2048
if not hasattr(self.config, 'kernel_size'):
self.config.kernel_size = 5
if not hasattr(self.config, 'pooling'):
self.config.pooling = 'avgpool'
self.kv_cluster = SnapKVCluster(
window_size = self.config.window_size,
max_capacity_prompt = self.config.max_capacity_prompt,
kernel_size = self.config.kernel_size,
pooling = self.config.pooling
)
def qwen2_vl_flash_attn2_forward(
self: torch.nn.Module,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
)-> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
init_snapkv(self)
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
if hasattr(self, "kv_seq_len"): #[SnapKV] add kv_seq_len
if self.kv_seq_len != 0:
kv_seq_len += self.kv_seq_len
else:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
else:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_multimodal_rotary_pos_emb(
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
# key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# print('kv_seq_len:', kv_seq_len)
# print('key_states.shape:', key_states.shape)
if key_states.shape[-2] == kv_seq_len: # [SnapKV] add kv_cluster
self.kv_seq_len = kv_seq_len # [SnapKV] register kv_seq_len
key_states_compress, value_states_compress = self.kv_cluster.update_kv(key_states, query_states, value_states, attention_mask, self.num_key_value_groups)
past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs)
if self.layer_idx == 0:
eval_logger.info(f"SnapKV compressing...After compress, the length of kv is {key_states_compress.shape[-2]}")
else:
self.kv_seq_len += q_len
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
if self.layer_idx == 0:
eval_logger.info(f"SnapKV do nothing, the length of kv is {key_states.shape[-2]}")
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
if (
self.config.use_sliding_window
and getattr(self.config, "sliding_window", None) is not None
and self.layer_idx >= self.config.max_window_layers
):
sliding_window = self.config.sliding_window
else:
sliding_window = None
print(f"SnapKV flash attn: Q: {query_states.shape}, K: {key_states.shape}, V: {value_states.shape}")
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
sliding_window=sliding_window,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def qwen2_vl_prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
pixel_values=None,
pixel_values_videos=None,
image_grid_thw=None,
video_grid_thw=None,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
if past_key_values is None or (isinstance(past_key_values, DynamicCache) and past_key_values.get_seq_length() == 0): # [SnapKV]
for layer in self.model.layers:
layer.self_attn.kv_seq_len = 0
if past_key_values is not None:
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
if cache_position[0] != 0:
pixel_values = None
pixel_values_videos = None
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids, "inputs_embeds": None}
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = inputs_embeds.shape
device = inputs_embeds.device
else:
batch_size, sequence_length = input_ids.shape
device = input_ids.device
attention_mask = self.model._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_cache_shape(),
dtype=self.lm_head.weight.dtype,
device=device,
cache_position=cache_position,
batch_size=batch_size,
config=self.config,
past_key_values=past_key_values,
)
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"pixel_values_videos": pixel_values_videos,
"image_grid_thw": image_grid_thw,
"video_grid_thw": video_grid_thw,
"cache_position": cache_position,
}
)
return model_inputs
# patch
def patch_qwen2_vl():
transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLFlashAttention2.forward = qwen2_vl_flash_attn2_forward
transformers.models.qwen2_vl.modeling_qwen2_vl.Qwen2VLForConditionalGeneration.prepare_inputs_for_generation = qwen2_vl_prepare_inputs_for_generation
patch_qwen2_vl()
def main():
qwen2_vl_path = "path/to/your checkpoints"
model = Qwen2VLForConditionalGeneration.from_pretrained(qwen2_vl_path, device_map="auto", torch_dtype="auto", attn_implementation="flash_attention_2",).eval()
processor = AutoProcessor.from_pretrained(qwen2_vl_path)
tokenizer = AutoTokenizer.from_pretrained(qwen2_vl_path)
prompt = "Hello, who are you?"
input_ids = tokenizer([prompt], return_tensors="pt")
input_ids = input_ids.to("cuda")
output = model.generate(**input_ids)
text = tokenizer.batch_decode(output)
print(text)
if __name__ == "__main__":
main()the output text is always "['Hello, who are you? I am a language model created by Alibaba Cloud. I am called Qwen. I am a large']", whatever the size of window and max_capacity_prompt are.
Metadata
Metadata
Assignees
Labels
No labels