shenwzh3 99348f5a98 upload jgr 2023-05-23 15:21:21 +08:00
data upload jgr 2023-05-23 15:21:21 +08:00
data_utils upload jgr 2023-05-23 15:21:21 +08:00
image upload jgr 2023-05-23 15:21:21 +08:00
model_utils upload jgr 2023-05-23 15:21:21 +08:00
trainer_utils upload jgr 2023-05-23 15:21:21 +08:00
warmup-generator upload jgr 2023-05-23 15:21:21 +08:00
warmup-ranker upload jgr 2023-05-23 15:21:21 +08:00
.gitignore upload jgr 2023-05-23 15:21:21 +08:00 upload jgr 2023-05-23 15:21:21 +08:00 upload jgr 2023-05-23 15:21:21 +08:00 upload jgr 2023-05-23 15:21:21 +08:00

Joint Generator-Ranker Learning for Natural Language Generation

This repo contains the code, data and models for our paper Joint Generator-Ranker Learning for Natural Language Generation.

🚀 Overview

In this paper, we propose a novel Joint training paradigm of both Generator and Ranker (JGR) for NLG tasks. Unlike previous works, which train the generator and ranker models separately, we explore a joint and iterative training algorithm that updates both models in turn.

The JGR framework consists of a generator and a ranker. During training, the generator and ranker alternate to update their parameters, and each of them involves the other's outputs in its own input signals. Specifically, at the ranker training phase, the ranker model is trained to rank the outputs generated by the generator model for a given input text by assigning a ranking score. At the generator training phase, the generator model uses a combination of the ranker score and the matching score (e.g., BLEU) as the reward for each sample, and trains with policy gradients, which encourages the generator to produce candidates with higher rewards and mitigates the exposure bias issue in the teacher-forcing learning.

You can find more details in the paper.

⚙️ Experiment Preparation


  • torch>=1.7
  • transformers==4.8.1
  • datasets==1.12.1
  • nltk==3.7
  • rouge-score

Description of Codes:

  • data -> directories to store datasets and the data preprocessing codes.
  • data_utils -> codes for dataloader and evaluation metrics
  • model_utils -> generator model and ranker model
  • trainer_utils -> trainer and trainer configuration
  • warmup-generato -> warming-up generator
  • warmup-ranker -> first training iteration of ranker
  • -> the main function to run JGR

Data Preprocessing:

Now JGR can be used in CNN/DailyMail, SAMSum, SquadQG and Personachat. Users should download the raw data of these datasets and run the codes in ./data to preprocess them. The codes will preprocess and re-orgnize the data samples into json file for later steps. For more details, please check the in ./data, it will give you the details of preprocessing each dataset.


We also provide our trained checkpoints for you:

Generator Ranker
CNNDM (initialized with BRIO)

💡 Warm-up generator

To achieve a better performance of JGR, it's neccessary to pre-finetune the generator with MLE loss on the target training set. Follow the steps described in ./warmup-generator/ and you will get a warm-uped generator for target dataset stored in ./warmup-generator/saves.

Notes: Before the first ranker training iteration. Your should use the warm-uped generator to generator the candidates for the first ranker training iteration. Don't forget to execute 2. Generate candidates for warming-up ranker in ./warmup-generator/

💡 Warm-up ranker

As mentioned in the paper, in order to initialize the ranker with a more general and reasonable ranking function, we increase the number of training steps and add a certain number of warm-up steps at the first ranker training iteration. Here we use the fine-tuned generator to generator the candidates for the first ranker training iteration. Follow the steps described in ./warmup-ranker/ and you will get a warm-uped generator for target dataset stored in ./warmup-ranker/saves.

JGR training

After obtaining the warm-uped generator and ranker, you can now turn to JGR-training. Taking cnndm as example, to train JGR, run:

EXPORT generator=warmup-generator/saves/bart-large-cnndm # the warm-uped generator
EXPORT ranker=warmup-ranker/saves/roberta-large-cnndm # the ranker after first iteration
EXPORT save_name=JGR-large-cnndm # the save name of model

python -m torch.distributed.launch --nproc_per_node 8 --overwrite_output_dir \
    --task_name sum --dataset_name cnndm \
    --train_data_path data/cnndm \
    --dev_data_path data/cnndm \
    --test_data_path data/cnndm \
    --load_tokenized_data False \
    --evaluate_generator True \
    --generator_num_cand_generated 8 --generator_num_cand_picked 8 \
    --num_cand_generated 16 --num_cand_picked 3 --candidate_pick_strategy bottom \
    --do_train True --do_eval True --do_predict True --prediction_loss_only False \
    --per_device_train_batch_size 2 --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 4 \
    --generator_learning_rate 5e-5 --reranker_learning_rate 1e-5 \
    --num_train_epochs 3 \
    --evaluation_strategy steps --eval_steps 1000 \
    --logging_strategy steps --logging_steps 500 \
    --save_strategy steps --save_steps 1000 --save_total_limit 20 \
    --iteration_steps 1000 --iteration_reranker_steps 500 \
    --load_best_model_at_end True \
    --metric_for_best_model generator_eval_rouge1 --greater_is_better True \
    --reranker_model_name_or_path $ranker \
    --generator_model_name_or_path $generator \
    --output_dir saves/$save_name \
    --generator_max_source_length 1020 --reranker_max_source_length 400 --generator_max_target_length 109 --reranker_max_target_length 109 \
    --cache_data \
    --disable_tqdm False 

The above instructions will store the trained generator and ranker in saves/JGR-large/cnndm/generator and saves/JGR-large/cnndm/reranker, respectively. For the JGR taining on other datasets, check

💬 Evaluate

To evaluate the trained generator and ranker, run:

EXPORT generator=saves/JGR-large-cnndm/generator # the trained  generator
EXPORT ranker=saves/JGR-large-cnndm/ranker # the  trained iteration
EXPORT save_name=JGR-large-cnndm # the save name of model

python -m torch.distributed.launch --nproc_per_node 8 --overwrite_output_dir \
    --task_name sum --dataset_name cnndm \
    --train_data_path data/cnndm \
    --dev_data_path data/cnndm \
    --test_data_path data/cnndm \
    --load_tokenized_data False \
    --generator_num_cand_generated 8 --generator_num_cand_picked 8 \
    --num_cand_generated 16 --num_cand_picked 3 --candidate_pick_strategy bottom \
    --do_predict True --prediction_loss_only False \
    --per_device_eval_batch_size 4 \
    --evaluation_strategy steps --eval_steps 1000 \
    --logging_strategy steps --logging_steps 500 \
    --save_strategy steps --save_steps 1000 --save_total_limit 20 \
    --iteration_steps 1000 --iteration_reranker_steps 500 \
    --load_best_model_at_end True \
    --metric_for_best_model generator_eval_rouge1 --greater_is_better True \
    --reranker_model_name_or_path $ranker \
    --generator_model_name_or_path $generator \
    --output_dir saves/$save_name \
    --generator_max_source_length 1020 --reranker_max_source_length 400 --generator_max_target_length 109 --reranker_max_target_length 109 \
    --cache_data \
    --disable_tqdm False