Skip to content

Implementation of generation-augmented retrieval plus reinforcement learning for open-domain question answering

Notifications You must be signed in to change notification settings

jerhadf/gar_plus

Repository files navigation

gar_plus

This is an implementation of generation-augmented retrieval (GAR), an older ML method (not RAG) to improve the performance of retrieval-based models. It works by generating additional relevant context for a given input, which is then used to retrieve more accurate and relevant results. This is particularly useful in tasks such as question answering, where the quality of the retrieved documents can significantly impact the final answer. Here, I implement a version of GAR that uses reinforcement learning on top to perform question answering (QA) across different domains. I also include fine-tuning and evaluations scripts.

To fine tune a model, run

python fine_tune.py
  --model_name_or_path <model_name> \
  --do_train \
  --do_eval \
  --do_predict \
  --train_file <train_filepath> \
  --validation_file <validation_filepath> \
  --test_file <test_filepath> \
  --output_dir <path_to_output_dir> \
  --overwrite_output_dir \
  --text_column "question" \
  --summary_column "augment" \
  --predict_with_generate \
  --learning_rate 0.00005 \
  --num_train_epochs 15 \
  --sc_scaling 0.98 \

About

Implementation of generation-augmented retrieval plus reinforcement learning for open-domain question answering

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages