Official PyTorch Implementation for "Let's Think with Images Efficiently! An Interleaved-Modal Chain-of-Thought Reasoning Framework with Dynamic and Precise Visual Thoughts"
This repository contains the official implementation for our paper, DaP-ICoT.
Recently, Interleaved-modal Chain-of-Thought (ICoT) reasoning has shown promising performance by leveraging both multimodal inputs and outputs. However, existing ICoT methods suffer from two fundamental limitations:
- Static Visual Thought Positioning: Visual information is statically inserted at fixed steps, leading to inefficient and inflexible reasoning.
- Broken Visual Thought Representation: Fragmented visual cues hinder semantic coherence and precision, undermining the quality of the reasoning process.
To address these critical issues, we introduce DaP-ICoT, an Interleaved-modal Chain-of-Thought reasoning framework with Dynamic and Precise Visual Thoughts.
DaP-ICoT incorporates two key components to revolutionize ICoT reasoning:
-
π§ Dynamic Visual Thought Integration: Adaptively introduces visual inputs based on the model's real-time reasoning needs. This reduces redundancy by focusing only on key visual cues, making the process more efficient and human-like.
-
π― Precise Visual Thought Guidance: Ensures that the generated visual representations are semantically coherent and contextually aligned with the reasoning chain. This enhances the accuracy and reliability of the model's outputs.
Our experiments across multiple benchmarks and models demonstrate that DaP-ICoT not only achieves state-of-the-art performance but also significantly improves efficiency. It leads to a 72.6% decrease in token consumption by reducing the number of inserted images, paving the way for more practical and scalable ICoT reasoning.
- Introduction
- Key Features
- Prerequisites
- Installation and Setup
- Data Preparation
- Running the Code
- Project Structure
- Python 3.10
- Conda
- Git
Follow these steps carefully to set up the project environment and all necessary components.
First, clone this repository to your local machine.
git clone https://github.com/67L1/DaP-ICoT.git
cd dap_icotWe recommend using Conda to manage dependencies. Create and activate a new environment with Python 3.10.
conda create -n dapicot python=3.10
conda activate dapicotInstall all the required Python packages using the requirements.txt file.
pip install -r requirements.txtπ‘ Note on PyTorch Installation
The
requirements.txtfile may not automatically install the correct version of PyTorch for your specific hardware (especially CUDA). If you encounter errors related totorchor CUDA during the installation, we strongly recommend installing PyTorch manually first.For CUDA 12.1, you can use the following command:
pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu121For other CUDA versions or CPU-only installations, please visit the official PyTorch website to find the correct command for your system. This will ensure full compatibility.
β οΈ IMPORTANT: This project requires a manual modification to thetransformerslibrary to support custom visual token handling for the Qwen model. Without this patch, the model will not function correctly.
You need to find the utils.py file within your installed transformers library and modify the _sample method of the GenerationMixin class.
a. Find the file location:
You can find the path to utils.py by running this Python command in your activated dapicot environment:
python -c "import transformers; import os; print(os.path.join(os.path.dirname(transformers.__file__), 'generation', 'utils.py'))"This will print the full path to the file you need to edit.
b. Apply the patch:
Open the utils.py file and locate the following line (around line 3257):
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)Replace this line with the code block below:
- # update generated ids, model inputs, and length for next step
- input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
+ # update generated ids, model inputs, and length for next step
+
+ # qwen
+ if 'selected_vokens' in outputs and outputs['selected_vokens'] is not None:
+ # if outputs['selected_vokens'].shape[0] != 1 :
+ num_vokens = outputs['selected_vokens'].shape[0]
+ voken_ids = torch.full(
+ (1, num_vokens),
+ fill_value=151655,
+ dtype=input_ids.dtype,
+ device=input_ids.device
+ )
+ start_token = torch.full((1, 1), 151652, dtype=input_ids.dtype, device=input_ids.device)
+ end_token = torch.full((1, 1), 151653, dtype=input_ids.dtype, device=input_ids.device)
+ input_ids = torch.cat([input_ids, start_token, voken_ids, end_token, next_tokens[:, None]], dim=-1)
+ else:
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)We use Segment Anything Model 2 (SAM2) for object detection.
π‘ Note on SAM2 Dependencies
SAM2 has its own set of dependencies. Although our
requirements.txtcovers all of them, if you encounter any installation or dependency errors specifically when running SAM2 scripts, please refer to the official SAM2 GitHub repository for detailed installation instructions and troubleshooting.
a. Clone the SAM2 repository:
git clone https://github.com/facebookresearch/sam2.gitb. Download SAM2 checkpoints:
Navigate into the sam2 directory and run the official script to download the model weights.
cd sam2/checkpoints
# On some systems you might need to make the script executable first: chmod +x download_ckpts.sh
./download_ckpts.sh
cd ../../This will place the checkpoints in the src/sam2/checkpoints/ directory.
π‘ We use sam2.1_hiera_large.pt as our tool.
Download the test set for the M3CoT dataset from Hugging Face:
- Dataset Link: M3CoT
Place the downloaded files into a directory of your choice. You will need to specify this path later in the config file.
Run the pq_jsonl.py script to filter out entries with empty images and convert all images to the .png format. This script will generate a test.jsonl file.
cd data_all
python pq_jsonl.pyBy default, the output test.jsonl and processed images will be stored in the data_all/m3cot/ directory.
a. Move custom scripts into the sam2 directory:
Our custom scripts for SAM2 pre-processing must be located inside the sam2 folder.
# Ensure you are in the root directory 'DaP_ICoT/src'
mv preprocess_pool.py process_res.py sam2_detect.py sam2/b. Generate the image pool: This step uses SAM2 to detect objects in the dataset images and creates a pre-processed "image pool".
cd sam2Next, modify the config.yaml file located in the src/config/ directory. You will need to set the sam2_checkpoint path and the correct path for your dataset.
π‘ We recommend using absolute paths directly for these settings.
After configuring, run the script:
python preprocess_pool.pyThe resulting image pool will be stored in data_all/m3cot/ (or your configured path).
Now you are ready to run the main experiment.
-
Navigate back to the project root directory
DaP-ICoT/src:cd ../ # If you are still in the sam2 directory
-
Configure the main run: Before running, open the main
config.yamlfile in the project's root directory (dap_icot/config.yaml). Adjust the paths and other parameters as needed for your setup. -
Execute the main script:
python run.py
Here is a simplified overview of the project directory structure:
dap_icot/src/
βββ sam2/ # Cloned SAM2 repository
β βββ checkpoints/
β βββ config.yaml # Config for SAM2 pre-processing
β βββ preprocess_pool.py # (Moved here)
β βββ ...
βββ data_all/
β βββ pq_jsonl.py # Dataset filtering script
β βββ m3cot/ # Processed data and image pools
| βββ images/ # M3CoT's images
β βββ test.jsonl
β βββ image_pool_qwen.pkl # Image pool for Qwen
βββ config/ # Main configuration file for run.py
β βββ config.yaml
βββ requirements.txt # Python dependencies
βββ run.py # Main script to run the experiment
βββ README.md