Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 65 additions & 83 deletions configs/experiments/code_gen_base.yaml
Original file line number Diff line number Diff line change
@@ -1,131 +1,113 @@
# Production-ready configuration for a code generation task using GRPO.
# This file is validated against the pydantic schemas in src/mlx_rl_trainer/core/config.py.
# Scenario: 4B model (full fine-tuning) on a 10k sample dataset, targeting verbosity.

trainer:
algorithm: "grpo"
output_dir: "./outputs/code_gen_run_01"
num_training_steps: 80000
learning_rate: 3e-5 # CORRECTED: Lowered to a safe value for fine-tuning.
output_dir: "./outputs/full_finetune_run_01" # New directory for a fresh start

# Adjusted for a 10k dataset and an effective batch size of 8.
# This will run for approximately two epochs.
num_training_steps: 2500

# --- Optimizer & Scheduler (Tuned for stable full fine-tuning) ---
learning_rate: 1e-5 # ⭐ CRITICAL: Very low LR is essential for stable full fine-tuning.
optimizer_beta1: 0.9
optimizer_beta2: 0.95
optimizer_weight_decay: 0.01

lr_schedule_config:
name: "cosine_decay"
arguments: [1e-5, 1000, 1e-6] # Decay from peak to end over (2500-1500) steps
warmup: 250 # A long warmup is crucial for stability.

# --- Batching & Algorithm ---
ppo_batch_size: 1
num_rollout_samples: 2
grad_accum_steps: 2
# CORRECTED: Increased beta to prevent model collapse.
grpo_beta: 0.01
num_rollout_samples: 1
grad_accum_steps: 8 # Effective batch size of 8, keeps memory low but ensures stable updates.
grpo_beta: 0.0015
seed: -1

# Dual gradients for System 1/2
# --- Dual Gradients ---
use_dual_gradients: false
thinking_layer_start: 26
thinking_layer_end: 36
answer_layer_start: 21
answer_layer_end: 30
answer_gradient_weight: 1.5
thinking_layer_start: 18
thinking_layer_end: 30
answer_layer_start: 24
answer_layer_end: 32
answer_gradient_weight: 2.5

# SFT configuration
# --- SFT (Disabled as per your config) ---
use_sft_on_answer: false
sft_mode: "all" # Recommended for System 1/2
sft_weight: 0.1
sft_thinking_weight: 0.1 # For 'weighted' mode
sft_answer_weight: 1.0 # For 'weighted' mode

model:
model_path: "/Users/adeelahmad/work/mlx_rl_trainer/outputs/model"
model_path: "/Users/adeelahmad/.cache/lm-studio/models/lmstudio-community/Qwen-4B-Thinking-2507"
ref_model_path: "/Users/adeelahmad/.cache/lm-studio/models/lmstudio-community/Qwen-4B-Thinking-2507"
base_model_path: "/Users/adeelahmad/.cache/lm-studio/models/lmstudio-community/Qwen-4B-Thinking-2507"

# As requested: LoRA is disabled.
use_lora: false
lora_rank: 16

generation:
force_close_think_after: 96
# ... your existing temperature settings ...
# Tag definitions
think_start_tag: "<think>"
think_end_tag: "</think>"

# For structural guidance
bias_close_think: 3.0
# Biases for structural guidance
bias_close_think: -2.0
bias_answer_start: 6.0
punish_extra_think_end: -12.0
min_think_tokens: 16
think_end_early_bias: -10.0

# For MCQ handling
hard_mask_mcq_first_token: true
mcq_letter_lift: 8.0
mcq_ban_first_bias: -14.0
think_end_early_bias: 12.0

data:
train_path: "/Users/adeelahmad/work/SiLLM-examples/helpsteer/mlx-grpo/judge/train.jsonl"
val_path: "/Users/adeelahmad/work/SiLLM-examples/helpsteer/mlx-grpo/judge/valid.jsonl"
train_path: "/Users/adeelahmad/work/SiLLM-examples/helpsteer/mlx-grpo/strat/train.jsonl"

max_prompt_len: 150
max_gen_len: 256
max_gen_len: 128 # ⭐ IMPORTANT: Force shorter generations.
loader_type: "jsonl"
shuffle_data: true

# The system_prompt from your ExperimentConfig will be used automatically by the trainer.
# You don't need to specify it here unless you want to overridpe the default.

rewards:
- name: "format_structure"
weight: 0.07
weight: 0.05
config:
min_think_length: 20
min_answer_length: 15
think_length_target_min: 80 # ⭐ Characters, not tokens!
think_length_target_max: 250 # ⭐ Was 80 - MUST CHANGE
length_penalty_strength: 0.6
verbosity_penalty_factor: 1.5
debug_logging: true
min_think_length: 10
min_answer_length: 2
# ⭐ Tighter targets to discourage long thinking
think_length_target_min: 40
think_length_target_max: 90

- name: "thinking_quality"
weight: 0.18
weight: 0.2
config:
target_length_min: 80 # ⭐ MUST UPDATE
target_length_max: 250 # ⭐ Was 80 - MUST CHANGE
optimal_length_min: 100
optimal_length_max: 200
excessive_length_threshold: 300 # ⭐ Was 90 - CRITICAL FIX!
excessive_length_penalty: 0.5
conciseness_bonus: 0.15
use_trainer_thinking_limits: false
debug_logging: true
# ⭐ Tighter targets to discourage long thinking
target_length_min: 40
target_length_max: 90
excessive_length_threshold: 120 # Penalize heavily after this

- name: "answer_quality"
weight: 0.20
weight: 0.1
config:
max_penalty: 1.0
case_sensitive: false
debug_logging: true
phrase_penalty: 0.25

- name: "semantic_similarity"
weight: 0.60
weight: 0.65 # ⭐ Increased weight to prioritize relevance and conciseness
config:
method: "jaccard"
min_length: 5
remove_stop_words: true
debug_logging: true
method: "tfidf"
min_length: 10
apply_length_penalty: false
apply_verbosity_penalty: false
# ⭐ AGGRESSIVE PENALTY: Punish verbosity more harshly.
verbosity_penalty_strength: 0.01

monitoring:
log_samples_every: 1 # Log samples at every update step to debug easily.
max_logged_samples: 50 # Log a few samples to see the outputs.

# Chart generation
generate_charts_every: 10
generate_at_checkpoints: true

log_samples_every: 1
max_logged_samples: 50
use_wandb: true
wandb_project: "mlx-grpo-qwen3-v10"
wandb_entity: "adeelahmad"
wandb_run_name: "experiment-001"

# Logging frequency
wandb_project: "mlx-grpo-qwen3-full-finetune"
log_prompts: true

# Checkpointing Configuration - Enhanced with Retry Logic
checkpointing:
checkpoint_dir: "checkpoints"
save_every: 10 # Save checkpoint every N steps
keep_last_n: 2 # Keep only last N checkpoints
save_every: 500
keep_last_n: 3
save_optimizer_state: false

# NEW: Retry Logic
max_retries: 3 # Number of retry attempts for failed checkpoint saves
retry_delay_seconds: 2 # Initial delay (exponential backoff)
Loading