Add superglue fine-tuning tasks

This commit is contained in:
Pengcheng He 2021-02-04 03:12:39 -05:00 коммит произвёл Pengcheng He
Родитель a2e7630023
Коммит 4cfa08e7c8
12 изменённых файлов: 1495 добавлений и 16 удалений

Просмотреть файл

@ -1,3 +1,4 @@
from .ner import *
from .multi_choice import *
from .sequence_classification import *
from .record_qa import *

Просмотреть файл

@ -17,18 +17,24 @@ import math
from ...deberta import *
from ...utils import *
import pdb
__all__ = ['MultiChoiceModel']
class MultiChoiceModel(NNModule):
def __init__(self, config, num_labels = 2, drop_out=None, **kwargs):
super().__init__(config)
self.deberta = DeBERTa(config)
self.num_labels = num_labels
self.classifier = torch.nn.Linear(config.hidden_size, 1)
self._register_load_state_dict_pre_hook(self._pre_load_hook)
self.deberta = DeBERTa(config)
self.config = config
pool_config = PoolConfig(self.config)
output_dim = self.deberta.config.hidden_size
self.pooler = ContextPooler(pool_config)
output_dim = self.pooler.output_dim()
drop_out = config.hidden_dropout_prob if drop_out is None else drop_out
self.classifier = torch.nn.Linear(output_dim, 1)
self.dropout = StableDropout(drop_out)
self.apply(self.init_weights)
self.deberta.apply_state()
def forward(self, input_ids, type_ids=None, input_mask=None, labels=None, position_ids=None, **kwargs):
num_opts = input_ids.size(1)
@ -41,16 +47,9 @@ class MultiChoiceModel(NNModule):
input_mask = input_mask.view([-1, input_mask.size(-1)])
encoder_layers = self.deberta(input_ids, token_type_ids=type_ids, attention_mask=input_mask,
position_ids=position_ids, output_all_encoded_layers=True)
seqout = encoder_layers[-1]
cls = seqout[:,:1,:]
cls = cls/math.sqrt(seqout.size(-1))
att_score = torch.matmul(cls, seqout.transpose(-1,-2))
att_mask = input_mask.unsqueeze(1).to(att_score)
att_score = att_mask*att_score + (att_mask-1)*10000.0
att_score = torch.nn.functional.softmax(att_score, dim=-1)
pool = torch.matmul(att_score, seqout).squeeze(-2)
cls = self.dropout(pool)
logits = self.classifier(cls).float().squeeze(-1)
hidden_states = encoder_layers[-1]
logits = self.classifier(self.dropout(self.pooler(hidden_states)))
logits = logits.float().squeeze(-1)
logits = logits.view([-1, num_opts])
loss = 0
if labels is not None:
@ -60,3 +59,14 @@ class MultiChoiceModel(NNModule):
return (logits, loss)
def _pre_load_hook(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
new_state = dict()
bert_prefix = prefix + 'bert.'
deberta_prefix = prefix + 'deberta.'
for k in list(state_dict.keys()):
if k.startswith(bert_prefix):
nk = deberta_prefix + k[len(bert_prefix):]
value = state_dict[k]
del state_dict[k]
state_dict[nk] = value

Просмотреть файл

@ -0,0 +1,69 @@
# Copyright (c) Microsoft, Inc. 2020
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# Author: penhe@microsoft.com
# Date: 01/25/2019
#
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import torch
from torch.nn import CrossEntropyLoss,BCEWithLogitsLoss
from ...deberta import *
from ...utils import *
__all__ = ['ReCoRDQAModel']
class ReCoRDQAModel(NNModule):
def __init__(self, config, drop_out=None, **kwargs):
super().__init__(config)
self.deberta = DeBERTa(config)
self.config = config
self.proj = torch.nn.Linear(config.hidden_size, config.hidden_size)
self.classifier = torch.nn.Linear(config.hidden_size, 1)
drop_out = config.hidden_dropout_prob if drop_out is None else drop_out
self.dropout = StableDropout(drop_out)
self.apply(self.init_weights)
self.deberta.apply_state()
def forward(self, input_ids, entity_indice, type_ids=None, input_mask=None, labels=None, position_ids=None, placeholder=None, **kwargs):
encoder_layers = self.deberta(input_ids, attention_mask=input_mask, token_type_ids=type_ids,\
position_ids=position_ids, output_all_encoded_layers=True)
# bxexsp
entity_mask = entity_indice>0
tokens = encoder_layers[-1]
# bxexspxd
entities = torch.gather(tokens.unsqueeze(1).expand(entity_indice.size()[:2]+tokens.size()[1:]), index=entity_indice.long().unsqueeze(-1).expand(entity_indice.size()+(tokens.size(-1),)), dim=-2)
ctx = tokens[:,:1,:]/math.sqrt(tokens.size(-1))
# bxsx1
att_score = torch.matmul(tokens, ctx.transpose(-1,-2))
# bxexspx1
entity_score = torch.gather(att_score.unsqueeze(1).expand(entity_indice.size()[:2]+att_score.size()[1:]), index=entity_indice.long().unsqueeze(-1).expand(entity_indice.size()+(att_score.size(-1),)), dim=-2)
entity_score = entity_score.squeeze(-1)*entity_mask.to(entity_score) - (1-entity_mask.to(entity_score))*10000.0
att_prob = torch.nn.functional.softmax(entity_score, dim=-1).unsqueeze(-2)
# bxexd
entity_ebd = torch.matmul(att_prob, entities).squeeze(-2)
entity_ebd = self.proj(entity_ebd)
entity_ebd = ACT2FN['gelu'](entity_ebd)
sequence_out = self.dropout(entity_ebd)
logits = self.classifier(sequence_out).float().squeeze(-1)
entity_mask = (entity_mask.sum(-1)>0).to(logits)
logits = logits*entity_mask + (entity_mask-1)*10000.0
loss = 0
if labels is not None:
entity_index = entity_mask.view(-1).nonzero().view(-1)
sp_logits = logits.view(-1)
labels = labels.view(-1)
sp_logits = torch.gather(sp_logits, index=entity_index, dim=0)
labels = torch.gather(labels, index=entity_index, dim=0)
loss_fn = BCEWithLogitsLoss()
loss = loss_fn(sp_logits, labels.to(sp_logits))
return (logits, loss)

Просмотреть файл

@ -14,7 +14,6 @@ from __future__ import print_function
import torch
from torch.nn import CrossEntropyLoss
import math
import pdb
from ...deberta import *
from ...utils import *

Просмотреть файл

@ -1,5 +1,3 @@
#from .ner_task import *
from .glue_tasks import *
#from .race_task import *
from .task import *
from .task_registry import *

Просмотреть файл

@ -0,0 +1,86 @@
"""
Official evaluation script for ReCoRD v1.0.
(Some functions are adopted from the SQuAD evaluation script.)
"""
from __future__ import print_function
from collections import Counter
import string
import re
import argparse
import json
import sys
from ...utils import get_logger
logger=get_logger()
__all__=['normalize_answer', 'evaluate']
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def f1_score(prediction, ground_truth):
prediction_tokens = normalize_answer(prediction).split()
ground_truth_tokens = normalize_answer(ground_truth).split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def exact_match_score(prediction, ground_truth):
return normalize_answer(prediction) == normalize_answer(ground_truth)
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
scores_for_ground_truths = []
for ground_truth in ground_truths:
score = metric_fn(prediction, ground_truth)
scores_for_ground_truths.append(score)
return max(scores_for_ground_truths)
def evaluate(answers, predictions):
f1 = exact_match = total = 0
correct_ids = []
for qid in answers:
total += 1
if qid not in predictions:
message = 'Unanswered question {} will receive score 0.'.format(qid)
logger.warning(message)
continue
ground_truths = list(map(lambda x: x['text'], answers[qid]))
prediction = predictions[qid]
_exact_match = metric_max_over_ground_truths(exact_match_score, prediction, ground_truths)
if int(_exact_match) == 1:
correct_ids.append(qid)
exact_match += _exact_match
f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths)
exact_match = 100.0 * exact_match / total
f1 = 100.0 * f1 / total
return {'exact_match': exact_match, 'f1': f1}

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,7 @@
# Dev performance on SuperGLUE
|Models | COPA | RTE | ReCoRD(F1/EM)|
|----------|--------|---------|--------------|
|XXLarge-v2|97 |93.5 |94.1/93.7 |
|XLarge-v2 |97 |93.5 |93.8/93.8 |

Просмотреть файл

@ -0,0 +1,6 @@
{
"pooling": {
"dropout": 0,
"hidden_act": "gelu"
}
}

81
experiments/superglue/copa.sh Executable file
Просмотреть файл

@ -0,0 +1,81 @@
#!/bin/bash
SCRIPT=$(readlink -f "$0")
SCRIPT_DIR=$(dirname "$SCRIPT")
cd $SCRIPT_DIR
cache_dir=/tmp/DeBERTa/
function setup_data(){
task=$1
mkdir -p $cache_dir
if [[ ! -e $cache_dir/superglue_tasks/${task}/train.jsonl ]]; then
./download_data.sh $cache_dir/superglue_tasks
fi
}
Task=COPA
setup_data $Task
init=$1
tag=$init
case ${init,,} in
base-mnli)
parameters=" --num_train_epochs 3 \
--warmup 50 \
--learning_rate 1e-5 \
--train_batch_size 32 \
--cls_drop_out 0.1 "
;;
large-mnli)
parameters=" --num_train_epochs 4 \
--warmup 30 \
--learning_rate 5e-6 \
--train_batch_size 16 \
--cls_drop_out 0.1 \
--fp16 True "
;;
xlarge-mnli)
parameters=" --num_train_epochs 4 \
--warmup 30 \
--learning_rate 3e-6 \
--train_batch_size 24 \
--cls_drop_out 0.1 \
--fp16 True "
;;
xlarge-v2-mnli)
parameters=" --num_train_epochs 4 \
--warmup 30 \
--learning_rate 2e-6 \
--train_batch_size 24 \
--cls_drop_out 0.1 \
--fp16 True "
;;
xxlarge-v2-mnli)
parameters=" --num_train_epochs 8 \
--warmup 30 \
--learning_rate 1e-6 \
--train_batch_size 24 \
--cls_drop_out 0.1 \
--fp16 True "
;;
*)
echo "usage $0 <Pretrained model configuration>"
echo "Supported configurations"
echo "base - Pretrained DeBERTa v1 model with 140M parameters (12 layers, 768 hidden size)"
echo "large - Pretrained DeBERta v1 model with 380M parameters (24 layers, 1024 hidden size)"
echo "xlarge - Pretrained DeBERTa v1 model with 750M parameters (48 layers, 1024 hidden size)"
echo "xlarge-v2 - Pretrained DeBERTa v2 model with 900M parameters (24 layers, 1536 hidden size)"
echo "xxlarge-v2 - Pretrained DeBERTa v2 model with 1.5B parameters (48 layers, 1536 hidden size)"
exit 0
;;
esac
python -m DeBERTa.apps.run --model_config config.json \
--tag $tag \
--do_train \
--max_seq_len 136 \
--task_name $Task \
--data_dir $cache_dir/superglue_tasks/$Task \
--init_model $init \
--output_dir /tmp/ttonly/$tag/$task $parameters

Просмотреть файл

@ -0,0 +1,11 @@
#!/bin/bash
cache_dir=$1
if [[ -z $cache_dir ]]; then
cache_dir=/tmp/DeBERTa/superglue
fi
mkdir -p $cache_dir
curl -s -J -L https://dl.fbaipublicfiles.com/glue/superglue/data/v2/combined.zip -o $cache_dir/super.zip
unzip $cache_dir/super.zip -d $cache_dir

82
experiments/superglue/record.sh Executable file
Просмотреть файл

@ -0,0 +1,82 @@
#!/bin/bash
SCRIPT=$(readlink -f "$0")
SCRIPT_DIR=$(dirname "$SCRIPT")
cd $SCRIPT_DIR
cache_dir=/tmp/DeBERTa/
function setup_data(){
task=$1
mkdir -p $cache_dir
if [[ ! -e $cache_dir/superglue_tasks/${task}/train.jsonl ]]; then
./download_data.sh $cache_dir/superglue_tasks
fi
}
Task=ReCoRD
setup_data $Task
init=$1
tag=$init
case ${init,,} in
base)
parameters=" --num_train_epochs 3 \
--warmup 1000 \
--learning_rate 2e-5 \
--train_batch_size 64 \
--cls_drop_out 0.3 "
;;
large)
parameters=" --num_train_epochs 3 \
--warmup 1000 \
--learning_rate 1e-5 \
--train_batch_size 64 \
--cls_drop_out 0.3 \
--fp16 True "
;;
xlarge)
parameters=" --num_train_epochs 3 \
--warmup 1000 \
--learning_rate 7e-6 \
--train_batch_size 64 \
--cls_drop_out 0.3 \
--fp16 True "
;;
xlarge-v2)
parameters=" --num_train_epochs 3 \
--warmup 1000 \
--learning_rate 5e-6 \
--train_batch_size 64 \
--cls_drop_out 0.15 \
--fp16 True "
;;
xxlarge-v2)
parameters=" --num_train_epochs 3 \
--warmup 1000 \
--learning_rate 3e-6 \
--accumulative_update 2 \
--train_batch_size 64 \
--cls_drop_out 0.15 \
--fp16 True "
;;
*)
echo "usage $0 <Pretrained model configuration>"
echo "Supported configurations"
echo "base - Pretrained DeBERTa v1 model with 140M parameters (12 layers, 768 hidden size)"
echo "large - Pretrained DeBERta v1 model with 380M parameters (24 layers, 1024 hidden size)"
echo "xlarge - Pretrained DeBERTa v1 model with 750M parameters (48 layers, 1024 hidden size)"
echo "xlarge-v2 - Pretrained DeBERTa v2 model with 900M parameters (24 layers, 1536 hidden size)"
echo "xxlarge-v2 - Pretrained DeBERTa v2 model with 1.5B parameters (48 layers, 1536 hidden size)"
exit 0
;;
esac
python -m DeBERTa.apps.run --model_config config.json \
--tag $tag \
--do_train \
--max_seq_len 512 \
--task_name $Task \
--data_dir $cache_dir/superglue_tasks/$Task \
--init_model $init \
--output_dir /tmp/ttonly/$tag/$task $parameters