Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,8 @@ def _extend_tokens(
return input_ids, labels, loss_scale

def forward_context(self, model, inputs):
# This function is only used to handle scenarios where the model needs
# to be patched during the forward pass.
return nullcontext()

@staticmethod
Expand Down
16 changes: 7 additions & 9 deletions swift/llm/template/template/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,15 +319,6 @@ def _get_position_ids(self, inputs: Dict[str, Any]):
attention_mask=inputs.get('attention_mask'))
return self._concat_text_position_ids(position_ids)

def forward_context(self, model, inputs):
position_ids = inputs['position_ids']
inputs['position_ids'] = position_ids[1:]
inputs['text_position_ids'] = text_position_ids = position_ids[0]
# https://github.com/huggingface/transformers/pull/40194
if text_position_ids.shape[0] == 1:
inputs.update(get_packed_seq_params(text_position_ids))
return super().forward_context(model, inputs)

def _post_encode(self, model, inputs: Dict[str, Any]) -> Dict[str, Any]:
if not self.is_training:
return inputs
Expand All @@ -341,6 +332,13 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in
res = super()._data_collator(batch, padding_to=padding_to)
if not self.padding_free and self.is_training:
res['position_ids'] = self._get_position_ids(res)
if 'position_ids' in res:
position_ids = res['position_ids']
res['position_ids'] = position_ids[1:]
res['text_position_ids'] = text_position_ids = position_ids[0]
# https://github.com/huggingface/transformers/pull/40194
if text_position_ids.shape[0] == 1:
res.update(get_packed_seq_params(text_position_ids))
Comment on lines +335 to +341
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic for processing position_ids appears to be duplicated in swift/llm/template/template/qwen.py. To improve maintainability and reduce code duplication, consider abstracting this block into a shared helper function. The function could take the res dictionary and an optional condition (for the transformers_version check in qwen.py) as arguments.

return res


Expand Down
22 changes: 13 additions & 9 deletions swift/llm/template/template/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,10 @@ class Qwen2VLTemplate(Template):
use_model = True
support_padding_free = True

def init_env_args(self):
super().init_env_args()
self.transformers_version = version.parse(transformers.__version__)

def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
inputs: StdTemplateInputs) -> List[Context]:
from qwen_vl_utils import fetch_image, fetch_video
Expand Down Expand Up @@ -313,16 +317,9 @@ def _get_new_tokens(i):
return encoded

def forward_context(self, model, inputs):
position_ids = inputs['position_ids']
inputs['position_ids'] = position_ids[1:]
inputs['text_position_ids'] = text_position_ids = position_ids[0]
transformers_version = version.parse(transformers.__version__)
if transformers_version >= version.parse('4.53.0.dev') and text_position_ids.shape[0] == 1:
# https://github.com/huggingface/transformers/pull/40194
inputs.update(get_packed_seq_params(text_position_ids))
return super().forward_context(model, inputs)
if not self.padding_free:
if not self.padding_free or self.transformers_version >= version.parse('4.53.0.dev'):
return super().forward_context(model, inputs)
text_position_ids = inputs['text_position_ids']
if self.version == 'v2':
from transformers.models.qwen2_vl import modeling_qwen2_vl as modeling_module
elif self.version == 'v2_5':
Expand Down Expand Up @@ -382,6 +379,13 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in
res = super()._data_collator(batch, padding_to=padding_to)
if not self.padding_free and self.is_training:
res['position_ids'] = self._get_position_ids(res)
if 'position_ids' in res:
position_ids = res['position_ids']
res['position_ids'] = position_ids[1:]
res['text_position_ids'] = text_position_ids = position_ids[0]
if self.transformers_version >= version.parse('4.53.0.dev') and text_position_ids.shape[0] == 1:
# https://github.com/huggingface/transformers/pull/40194
res.update(get_packed_seq_params(text_position_ids))
return res


Expand Down
3 changes: 1 addition & 2 deletions swift/megatron/utils/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,7 @@ def test_convert_precision(hf_model, mg_model, template, torch_dtype=torch.float
ignore_modules = (model_arch.vision_tower + model_arch.aligner) if is_multimodal else []

hf_modules = _find_modules(hf_model, ignore_modules=ignore_modules)
with torch.inference_mode(), _model_cpu_forward_context(
hf_modules, torch_dtype, share_embedding=share_embedding), template.forward_context(hf_model, inputs):
with torch.inference_mode(), _model_cpu_forward_context(hf_modules, torch_dtype, share_embedding=share_embedding):
inputs.pop('text_position_ids', None)
hf_logits = hf_model(**inputs).logits
hf_model.to('cpu')
Expand Down
6 changes: 3 additions & 3 deletions swift/trainers/rlhf_trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,9 @@ def __init__(self,
self.enable_server_multi_turn = False
# for multi-turn server, maybe the num of rollout outputs is not equal to the num of rollout inputs
self.dynamic_num_samples = False
self.padding_free = self.template.padding_free
self.template.padding_free = False
self.template.packing = False
if self.use_vllm:
if not is_vllm_available():
raise ImportError('vLLM is not available and `use_vllm` is set to True. '
Expand Down Expand Up @@ -327,9 +330,6 @@ def __init__(self,
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
# self.model_accepts_loss_kwargs to False to enable scaling.
self.model_accepts_loss_kwargs = False
self.padding_free = self.template.padding_free
self.template.padding_free = False
self.template.packing = False
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel):
if self.is_deepspeed_enabled:
Expand Down
Loading