Skip to content
forked from CarperAI/trlx

A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)

License

Notifications You must be signed in to change notification settings

ethankim00/trlx

 
 

Repository files navigation

Transformer Reinforcement Learning X

trlx allows you to fine-tune 🤗 Hugging Face supported language models (gpt2, gpt-j, gpt-neo and gpt-neox based) up to 20B parameters using reinforcement learning via either a provided reward function or reward-labeled dataset. Proximal Policy Optimization (PPO) and Implicit Language Q-Learning (ILQL) are implemented.

You can read more about trlX in our documentation.

Installation

From Source

git clone https://github.com/CarperAI/trlx.git
cd trlx
pip install torch --extra-index-url https://download.pytorch.org/whl/cu113 # for cuda
pip install -e .

How to Train

You can train your model using a reward function or a reward-labeled dataset.

Using a reward function

import trlx

# optimize some reward function
model = trlx.train('gpt2', reward_fn=lambda samples: [sample.count('cats') for sample in samples])

# model is a wrapper with some logit preprocessing
model.generate(**tokenizer('Q: Who rules the world? A:', return_tensors='pt'), do_sample=True)

Using a reward-labeled dataset

import trlx

# Steer a model with a collection of rated samples
model = trlx.train('EleutherAI/gpt-j-6B', dataset=[('dolphins', 'geese'), (1.0, 100.0)])

# model is a wrapper with some logit preprocessing
model.generate(**tokenizer('Q: Who rules the world? A:', return_tensors='pt'), do_sample=True)

Using 🤗 Accelerate to speed up the training

Launch distributed training with 🤗 Accelerate (only DeepSpeed integration is tested)

accelerate config
accelerate launch examples/simulacra.py

For more usage see examples

Contributing

For development check out these guidelines and also read our docs

Acknowledgements

Thanks Leandro for starting the original trl

About

A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)

Resources

License

Code of conduct

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.6%
  • Makefile 0.4%