Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ kernels = [
"kernels"
]
liger = [
"liger-kernel>=0.7.0"
"liger-kernel>=0.8.0"
]
peft = [
"peft>=0.8.0"
Expand Down Expand Up @@ -103,7 +103,7 @@ dev = [
# kernels
"kernels",
# liger
"liger-kernel>=0.7.0",
"liger-kernel>=0.8.0",
# peft
"peft>=0.8.0",
# quality
Expand Down
12 changes: 9 additions & 3 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,9 @@ def test_training(self, config_name):
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

@pytest.mark.parametrize("use_liger_kernel", [False, pytest.param(True, marks=require_liger_kernel)])
@pytest.mark.parametrize("loss_type", ["bnpo", "dr_grpo", "dapo", "cispo", "sapo", "luspo", "vespo"])
def test_training_loss_types(self, loss_type):
def test_training_loss_types(self, loss_type, use_liger_kernel):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

training_args = GRPOConfig(
Expand All @@ -290,6 +291,7 @@ def test_training_loss_types(self, loss_type):
max_completion_length=32, # reduce the completion length to reduce memory usage
gradient_accumulation_steps=2, # set to 2 to test than DAPO can operate with accumulated batch
loss_type=loss_type,
use_liger_kernel=use_liger_kernel,
report_to="none",
)
trainer = GRPOTrainer(
Expand Down Expand Up @@ -1188,7 +1190,8 @@ def gen_with_is_ratio(*args, **kwargs):

release_memory(trainer.model, trainer)

def test_training_with_bias_correction_kl(self):
@pytest.mark.parametrize("use_liger_kernel", [False, pytest.param(True, marks=require_liger_kernel)])
def test_training_with_bias_correction_kl(self, use_liger_kernel):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
training_args = GRPOConfig(
output_dir=self.tmp_dir,
Expand All @@ -1198,6 +1201,7 @@ def test_training_with_bias_correction_kl(self):
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=8, # reduce the completion length to reduce memory usage
use_liger_kernel=use_liger_kernel,
report_to="none",
)
trainer = GRPOTrainer(
Expand Down Expand Up @@ -1687,7 +1691,8 @@ def test_training_num_generations_larger_than_batch_size(self):
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

def test_training_delta_clipping(self):
@pytest.mark.parametrize("use_liger_kernel", [False, pytest.param(True, marks=require_liger_kernel)])
def test_training_delta_clipping(self, use_liger_kernel):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

training_args = GRPOConfig(
Expand All @@ -1697,6 +1702,7 @@ def test_training_delta_clipping(self):
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=8, # reduce the completion length to reduce memory usage
delta=2.0, # set delta to a non-None value
use_liger_kernel=use_liger_kernel,
report_to="none",
)
trainer = GRPOTrainer(
Expand Down
2 changes: 1 addition & 1 deletion trl/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from packaging.version import Version


LIGER_KERNEL_MIN_VERSION = "0.7.0"
LIGER_KERNEL_MIN_VERSION = "0.8.0"
PACKAGE_DISTRIBUTION_MAPPING = importlib.metadata.packages_distributions()


Expand Down
3 changes: 0 additions & 3 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,3 @@ def __post_init__(self):
"GRPO requires at least 2 generations per prompt to calculate the advantages. You provided "
f"{self.num_generations}, which is less than the minimum required."
)

if self.delta is not None and self.use_liger_kernel:
raise ValueError("Liger kernel does not support two-sided GRPO loss yet.")
8 changes: 8 additions & 0 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,14 @@ def cast_outputs_to_original_dtype(module, args, output):
loss_type=self.loss_type,
max_completion_length=self.max_completion_length,
importance_sampling_level=self.importance_sampling_level,
delta=args.delta,
use_bias_correction_kl=args.use_bias_correction_kl,
sapo_temperature_pos=args.sapo_temperature_pos,
sapo_temperature_neg=args.sapo_temperature_neg,
vespo_k_pos=args.vespo_k_pos,
Comment thread
kashif marked this conversation as resolved.
vespo_lambda_pos=args.vespo_lambda_pos,
vespo_k_neg=args.vespo_k_neg,
vespo_lambda_neg=args.vespo_lambda_neg,
)

# Initialize the metrics
Expand Down
Loading