This repository contains a script for training the Phi3-V model with Parameter-Efficient Fine-Tuning (PEFT) techniques using various configurations and options.
- Training on the mixture of NLP data and vision-language data
- Flexible selection of LoRA target modules
- Deepspeed Zero-2
- Deepspeed Zero-3
- PyTorch FSDP
- Gradient checkpointing (only compatible with ZeRO-3 for now)
- QLoRA
- Disable/enable Flash Attention 2
Install the required packages using either requirements.txt
or environment.yml
.
pip install -r requirements.txt
conda env create -f environment.yml
conda activate phi3v
Before training, download the Phi3-V model from HuggingFace. It is recommended to use the huggingface-cli
to do this.
- Install the HuggingFace CLI:
pip install -U "huggingface_hub[cli]"
- Download the model:
huggingface-cli download microsoft/Phi-3-vision-128k-instruct --local-dir Phi-3-vision-128k-instruct --resume-download
To run the training script, use the following command:
bash scripts/train.sh
Note: Remember to replace the paths in train.sh
with your specific paths.
--data_path
(str): Path to the LLaVA formatted training data (a JSON file). (Required)--image_folder
(str): Path to the images folder as referenced in the LLaVA formatted training data. (Required)--model_id
(str): Path to the Phi3-V model. (Required)--proxy
(str): Proxy settings (default: None).--output_dir
(str): Output directory for model checkpoints (default: "output/test_train").--num_train_epochs
(int): Number of training epochs (default: 1).--per_device_train_batch_size
(int): Training batch size per GPU per forwarding step.--gradient_accumulation_steps
(int): Gradient accumulation steps (default: 4).--deepspeed_config
(str): Path to DeepSpeed config file (default: "scripts/zero2.json").--num_lora_modules
(int): Number of target modules to add LoRA (-1 means all layers).--lora_namespan_exclude
(str): Exclude modules with namespans to add LoRA.--max_seq_length
(int): Maximum sequence length (default: 3072).--quantization
(flag): Enable quantization.--disable_flash_attn2
(flag): Disable Flash Attention 2.--report_to
(str): Reporting tool (choices: 'tensorboard', 'wandb', 'none') (default: 'tensorboard').--logging_dir
(str): Logging directory (default: "./tf-logs").--lora_rank
(int): LoRA rank (default: 128).--lora_alpha
(int): LoRA alpha (default: 256).--lora_dropout
(float): LoRA dropout (default: 0.05).--logging_steps
(int): Logging steps (default: 1).--dataloader_num_workers
(int): Number of data loader workers (default: 4).
The script requires a dataset formatted according to the LLaVA specification. The dataset should be a JSON file where each entry contains information about conversations and images. Ensure that the image paths in the dataset match the provided --image_folder
.
Example Dataset
[
{
"id": "000000033471",
"image": "000000033471.jpg",
"conversations": [
{
"from": "human",
"value": "<image>\nWhat are the colors of the bus in the image?"
},
{
"from": "gpt",
"value": "The bus in the image is white and red."
},
{
"from": "human",
"value": "What feature can be seen on the back of the bus?"
},
{
"from": "gpt",
"value": "The back of the bus features an advertisement."
},
{
"from": "human",
"value": "Is the bus driving down the street or pulled off to the side?"
},
{
"from": "gpt",
"value": "The bus is driving down the street, which is crowded with people and other vehicles."
}
]
}
...
]
- Add support for DeepSpeed ZeRO-3.
- Add support for FSDP
- Add support for simultaneously finetuning
img_projector
- Add support for full finetuning
- Add support for grounded finetuning
- Add support for multi-image finetuning
- More advanced PEFT method (e.g., DoRA)
- FSDP with ActivationCheckpointing Wrapper
- Intergration with Chuanhu Chat
This project is licensed under the Apache-2.0 License. See the LICENSE file for details.
This project borrowed code from LLaVA and Microsoft Phi-3-vision-128k-instruct. Thanks to both projects for their contributions.
If you use this codebase in your work, please cite this project:
@misc{phi3vfinetuning2023,
author = {Gai Zhenbiao & Shao Zhenwei},
title = {Phi3V-Finetuning},
year = {2023},
publisher = {GitHub},
url = {https://github.com/GaiZhenbiao/Phi3V-Finetuning},
note = {GitHub repository},
}