Skip to content

Train reasoning models on your Mac. GRPO training framework for Apple Silicon with curriculum learning.

License

Notifications You must be signed in to change notification settings

adeelahmad/mlx-guided-grpo

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

10 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Apple Silicon MLX Native GRPO

🧠 MLX Guided GRPO

Train reasoning models on your Mac. No cloud needed.

The first production-ready GRPO training framework for Apple Silicon.
Fine-tune LLMs to think step-by-step using your M1/M2/M3/M4 Mac.

Stars Forks Issues License

Quick Start β€’ Features β€’ Why Guided GRPO β€’ Installation β€’ Examples β€’ Docs


🎯 Train Your Own Reasoning Model in 5 Minutes

# Install
pip install mlx-guided-grpo

# Train (yes, it's this simple)
mlx-grpo --model mlx-community/Qwen2.5-3B-Instruct-4bit \
         --data ./your_data.jsonl \
         --train --train-type lora \
         --curriculum-enabled

That's it. Your Mac is now training a reasoning model with curriculum learning.


πŸ€” Why Guided GRPO?

The Problem

Training reasoning models (like DeepSeek-R1, o1) requires:

  • ❌ Expensive cloud GPUs ($$$)
  • ❌ Complex distributed setups
  • ❌ NVIDIA-only frameworks
  • ❌ Weeks of engineering

Most developers can't train reasoning models.

The Solution

MLX Guided GRPO gives you:

  • βœ… Train on your Mac - M1/M2/M3/M4
  • βœ… One command - No config hell
  • βœ… Curriculum learning - Progressive difficulty
  • βœ… Production ready - Crash recovery, logging

Train reasoning models on consumer hardware.


✨ Features

πŸŽ“ Curriculum Learning

Gradually reduce scaffolding so models learn to think independently. Start with 100% guidance, end with 0%.

πŸ”„ Two-Phase Generation

Automatic recovery for incomplete <think> outputs. Never lose a training sample.

🎯 Smart Token Masking

Only train on tokens the model generated. Scaffolded tokens are properly masked from loss.

⚑ Apple Silicon Native

Built on MLX for maximum Metal GPU utilization. 2-3x faster than PyTorch on Mac.

🧠 Conditional Gradient Scaling

Train different layers for thinking vs answering. Fine-grained control over what the model learns.

πŸ’Ύ Crash Recovery

Automatic checkpointing and resume. Metal GPU crashes? Training continues.

Full Feature List

  • Training: GRPO, DR-GRPO, BNPO loss variants
  • Adapters: LoRA, DoRA, Full fine-tuning
  • Type System: Extensible type-aware rewards for tool calling, MCQ, and general Q&A (docs)
  • Memory: Gradient checkpointing, cache management
  • Rewards: Type-dispatched rewards, custom reward functions
  • Logging: WandB integration, rollout logging
  • Monitoring: Threshold-based early stopping

πŸ“Š Benchmarks

Model Hardware Tokens/sec Memory
Qwen2.5-3B-4bit M3 Max 64GB ~150 12GB
Qwen2.5-7B-4bit M3 Max 64GB ~80 24GB
Llama-3.2-3B-4bit M2 Pro 32GB ~120 10GB

GRPO training with group_size=4, batch_size=2


πŸš€ Installation

From PyPI (Recommended)

pip install mlx-guided-grpo

From Source

git clone https://github.com/adeelahmad/mlx-guided-grpo.git
cd mlx-guided-grpo
pip install -e ".[all]"

Requirements

  • macOS 13.5+ with Apple Silicon (M1/M2/M3/M4)
  • Python 3.10+
  • 16GB+ RAM recommended

πŸƒ Quick Start

1. Prepare Your Data

Create a JSONL file with prompts and reasoning traces:

{"prompt": "What is 15 * 7?", "answer": "<think>\nI need to multiply 15 by 7.\n15 * 7 = 105\n</think>\n\n\\boxed{105}"}
{"prompt": "Solve: 2x + 5 = 13", "answer": "<think>\nSubtract 5 from both sides:\n2x = 8\nDivide by 2:\nx = 4\n</think>\n\n\\boxed{4}"}

2. Train Your Model

mlx-grpo \
    --model mlx-community/Qwen2.5-3B-Instruct-4bit \
    --data ./math_data.jsonl \
    --train \
    --train-type lora \
    --iters 1000 \
    --batch-size 2 \
    --group-size 4 \
    --curriculum-enabled \
    --adapter-path ./my-reasoning-model

3. Use Your Model

from mlx_lm import load, generate

model, tokenizer = load("mlx-community/Qwen2.5-3B-Instruct-4bit",
                        adapter_path="./my-reasoning-model")

prompt = "What is 23 * 17?"
response = generate(model, tokenizer, prompt=prompt, max_tokens=500)
print(response)
# <think>
# I need to multiply 23 by 17...
# </think>
# \boxed{391}

πŸ“– Examples

Basic GRPO Training

mlx-grpo \
    --model mlx-community/Qwen2.5-0.5B-Instruct-4bit \
    --data ./data \
    --train --train-type lora \
    --group-size 4 \
    --learning-rate 1e-5

Curriculum Learning (Recommended for Reasoning)

mlx-grpo \
    --model mlx-community/Qwen2.5-3B-Instruct-4bit \
    --data ./reasoning_data \
    --train --train-type lora \
    --curriculum-enabled \
    --curriculum-start-ratio 1.0 \
    --curriculum-end-ratio 0.0 \
    --curriculum-warmup-iters 100 \
    --curriculum-taper-iters 500 \
    --enforce-thinking

With WandB Logging

mlx-grpo \
    --model mlx-community/Qwen2.5-3B-Instruct-4bit \
    --data ./data \
    --train --train-type lora \
    --wandb my-experiment \
    --log-rollouts \
    --log-rollouts-to-wandb

Advanced: Dual-Gradient Mode (CGS)

mlx-grpo \
    --model mlx-community/Qwen2.5-7B-Instruct-4bit \
    --data ./data \
    --train --train-type lora \
    --thinking-layers "0-15" \
    --answer-layers "16-31" \
    --thinking-gradient-weight 0.5 \
    --answer-gradient-weight 1.0

πŸ”§ Key Concepts

Curriculum Learning

Progressive scaffolding teaches models to reason independently:

Iteration 0-100:   [β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ] 100% scaffolding (model learns format)
Iteration 100-400: [β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‘β–‘β–‘β–‘]  66% scaffolding (gradual reduction)
Iteration 400-700: [β–ˆβ–ˆβ–ˆβ–ˆβ–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘]  33% scaffolding (increasing independence)
Iteration 700+:    [β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘β–‘]   0% scaffolding (full independence)

Smart Token Masking

Only train on what the model actually generated:

[PROMPT] [SCAFFOLD PREFIX] [MODEL GENERATION]
   ↓           ↓                  ↓
 masked      masked         LOSS COMPUTED

This prevents the model from getting "free credit" for scaffolded tokens.

Two-Phase Generation

Automatic recovery for incomplete structured outputs:

Phase 1: Model generates β†’ "<think>Let me solve this... 2+2="
         (Incomplete! Missing </think>)

Phase 2: Inject "</think>\n\boxed{" β†’ Continue generation β†’ "4}"
         (Complete! Injected tokens masked from loss)

πŸ“š Documentation

Topic Link
Full CLI Reference docs/cli.md
Training Arguments docs/arguments.md
Custom Rewards docs/rewards.md
Type System TYPE_SYSTEM.md
Architecture docs/architecture.md
API Reference docs/api.md

πŸ†š Comparison

Feature MLX Guided GRPO TRL (HuggingFace) OpenRLHF
Apple Silicon Native βœ… ❌ ❌
Curriculum Learning βœ… ❌ ❌
Scaffold Token Masking βœ… ❌ ❌
Two-Phase Generation βœ… ❌ ❌
Single GPU Training βœ… βœ… ⚠️
Consumer Hardware βœ… ⚠️ ❌
One-Command Training βœ… ❌ ❌

πŸ› οΈ Troubleshooting

Out of Memory?
# Reduce memory usage
mlx-grpo ... \
    --grad-checkpoint \
    --batch-size 1 \
    --group-size 2 \
    --max-completion-length 256
Metal GPU Crash?

Training auto-saves checkpoints. Just resume:

mlx-grpo ... --resume
Slow Training?
# Use quantized model
--model mlx-community/Qwen2.5-3B-Instruct-4bit

# Reduce group size
--group-size 2

🀝 Contributing

Contributions are welcome! See CONTRIBUTING.md for guidelines.

# Setup development environment
git clone https://github.com/adeelahmad/mlx-guided-grpo.git
cd mlx-guided-grpo
pip install -e ".[dev]"

# Run formatting
black mlx_grpo/
isort mlx_grpo/

πŸ“œ Citation

If you use MLX Guided GRPO in your research, please cite:

@software{mlx_guided_grpo,
  author = {Ahmad, Adeel},
  title = {MLX Guided GRPO: Reasoning Model Training for Apple Silicon},
  year = {2024},
  url = {https://github.com/adeelahmad/mlx-guided-grpo}
}

πŸ“„ License

MIT License - see LICENSE for details.


πŸ™ Acknowledgments

  • MLX - Apple's ML framework
  • mlx-lm - MLX language model utilities
  • DeepSeek - GRPO algorithm
  • Qwen - Excellent base models

Built with ❀️ for the Mac ML community

LinkedIn β€’ GitHub β€’ Contact

If this project helps you, please ⭐ star the repo!

Sponsor this project

Contributors 2

  •  
  •