зеркало из https://github.com/microsoft/DeBERTa.git
Add superglue fine-tuning tasks
This commit is contained in:
Родитель
a2e7630023
Коммит
4cfa08e7c8
|
@ -1,3 +1,4 @@
|
|||
from .ner import *
|
||||
from .multi_choice import *
|
||||
from .sequence_classification import *
|
||||
from .record_qa import *
|
||||
|
|
|
@ -17,18 +17,24 @@ import math
|
|||
|
||||
from ...deberta import *
|
||||
from ...utils import *
|
||||
import pdb
|
||||
|
||||
__all__ = ['MultiChoiceModel']
|
||||
class MultiChoiceModel(NNModule):
|
||||
def __init__(self, config, num_labels = 2, drop_out=None, **kwargs):
|
||||
super().__init__(config)
|
||||
self.deberta = DeBERTa(config)
|
||||
self.num_labels = num_labels
|
||||
self.classifier = torch.nn.Linear(config.hidden_size, 1)
|
||||
self._register_load_state_dict_pre_hook(self._pre_load_hook)
|
||||
self.deberta = DeBERTa(config)
|
||||
self.config = config
|
||||
pool_config = PoolConfig(self.config)
|
||||
output_dim = self.deberta.config.hidden_size
|
||||
self.pooler = ContextPooler(pool_config)
|
||||
output_dim = self.pooler.output_dim()
|
||||
drop_out = config.hidden_dropout_prob if drop_out is None else drop_out
|
||||
self.classifier = torch.nn.Linear(output_dim, 1)
|
||||
self.dropout = StableDropout(drop_out)
|
||||
self.apply(self.init_weights)
|
||||
self.deberta.apply_state()
|
||||
|
||||
def forward(self, input_ids, type_ids=None, input_mask=None, labels=None, position_ids=None, **kwargs):
|
||||
num_opts = input_ids.size(1)
|
||||
|
@ -41,16 +47,9 @@ class MultiChoiceModel(NNModule):
|
|||
input_mask = input_mask.view([-1, input_mask.size(-1)])
|
||||
encoder_layers = self.deberta(input_ids, token_type_ids=type_ids, attention_mask=input_mask,
|
||||
position_ids=position_ids, output_all_encoded_layers=True)
|
||||
seqout = encoder_layers[-1]
|
||||
cls = seqout[:,:1,:]
|
||||
cls = cls/math.sqrt(seqout.size(-1))
|
||||
att_score = torch.matmul(cls, seqout.transpose(-1,-2))
|
||||
att_mask = input_mask.unsqueeze(1).to(att_score)
|
||||
att_score = att_mask*att_score + (att_mask-1)*10000.0
|
||||
att_score = torch.nn.functional.softmax(att_score, dim=-1)
|
||||
pool = torch.matmul(att_score, seqout).squeeze(-2)
|
||||
cls = self.dropout(pool)
|
||||
logits = self.classifier(cls).float().squeeze(-1)
|
||||
hidden_states = encoder_layers[-1]
|
||||
logits = self.classifier(self.dropout(self.pooler(hidden_states)))
|
||||
logits = logits.float().squeeze(-1)
|
||||
logits = logits.view([-1, num_opts])
|
||||
loss = 0
|
||||
if labels is not None:
|
||||
|
@ -60,3 +59,14 @@ class MultiChoiceModel(NNModule):
|
|||
|
||||
return (logits, loss)
|
||||
|
||||
def _pre_load_hook(self, state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs):
|
||||
new_state = dict()
|
||||
bert_prefix = prefix + 'bert.'
|
||||
deberta_prefix = prefix + 'deberta.'
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith(bert_prefix):
|
||||
nk = deberta_prefix + k[len(bert_prefix):]
|
||||
value = state_dict[k]
|
||||
del state_dict[k]
|
||||
state_dict[nk] = value
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
# Copyright (c) Microsoft, Inc. 2020
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
# Author: penhe@microsoft.com
|
||||
# Date: 01/25/2019
|
||||
#
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import torch
|
||||
from torch.nn import CrossEntropyLoss,BCEWithLogitsLoss
|
||||
from ...deberta import *
|
||||
from ...utils import *
|
||||
|
||||
__all__ = ['ReCoRDQAModel']
|
||||
|
||||
class ReCoRDQAModel(NNModule):
|
||||
def __init__(self, config, drop_out=None, **kwargs):
|
||||
super().__init__(config)
|
||||
self.deberta = DeBERTa(config)
|
||||
self.config = config
|
||||
self.proj = torch.nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.classifier = torch.nn.Linear(config.hidden_size, 1)
|
||||
drop_out = config.hidden_dropout_prob if drop_out is None else drop_out
|
||||
self.dropout = StableDropout(drop_out)
|
||||
self.apply(self.init_weights)
|
||||
self.deberta.apply_state()
|
||||
|
||||
def forward(self, input_ids, entity_indice, type_ids=None, input_mask=None, labels=None, position_ids=None, placeholder=None, **kwargs):
|
||||
encoder_layers = self.deberta(input_ids, attention_mask=input_mask, token_type_ids=type_ids,\
|
||||
position_ids=position_ids, output_all_encoded_layers=True)
|
||||
# bxexsp
|
||||
entity_mask = entity_indice>0
|
||||
tokens = encoder_layers[-1]
|
||||
# bxexspxd
|
||||
entities = torch.gather(tokens.unsqueeze(1).expand(entity_indice.size()[:2]+tokens.size()[1:]), index=entity_indice.long().unsqueeze(-1).expand(entity_indice.size()+(tokens.size(-1),)), dim=-2)
|
||||
ctx = tokens[:,:1,:]/math.sqrt(tokens.size(-1))
|
||||
# bxsx1
|
||||
att_score = torch.matmul(tokens, ctx.transpose(-1,-2))
|
||||
# bxexspx1
|
||||
entity_score = torch.gather(att_score.unsqueeze(1).expand(entity_indice.size()[:2]+att_score.size()[1:]), index=entity_indice.long().unsqueeze(-1).expand(entity_indice.size()+(att_score.size(-1),)), dim=-2)
|
||||
entity_score = entity_score.squeeze(-1)*entity_mask.to(entity_score) - (1-entity_mask.to(entity_score))*10000.0
|
||||
att_prob = torch.nn.functional.softmax(entity_score, dim=-1).unsqueeze(-2)
|
||||
# bxexd
|
||||
entity_ebd = torch.matmul(att_prob, entities).squeeze(-2)
|
||||
|
||||
entity_ebd = self.proj(entity_ebd)
|
||||
entity_ebd = ACT2FN['gelu'](entity_ebd)
|
||||
|
||||
sequence_out = self.dropout(entity_ebd)
|
||||
logits = self.classifier(sequence_out).float().squeeze(-1)
|
||||
entity_mask = (entity_mask.sum(-1)>0).to(logits)
|
||||
logits = logits*entity_mask + (entity_mask-1)*10000.0
|
||||
loss = 0
|
||||
if labels is not None:
|
||||
entity_index = entity_mask.view(-1).nonzero().view(-1)
|
||||
sp_logits = logits.view(-1)
|
||||
labels = labels.view(-1)
|
||||
sp_logits = torch.gather(sp_logits, index=entity_index, dim=0)
|
||||
labels = torch.gather(labels, index=entity_index, dim=0)
|
||||
loss_fn = BCEWithLogitsLoss()
|
||||
loss = loss_fn(sp_logits, labels.to(sp_logits))
|
||||
|
||||
return (logits, loss)
|
|
@ -14,7 +14,6 @@ from __future__ import print_function
|
|||
import torch
|
||||
from torch.nn import CrossEntropyLoss
|
||||
import math
|
||||
import pdb
|
||||
|
||||
from ...deberta import *
|
||||
from ...utils import *
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
#from .ner_task import *
|
||||
from .glue_tasks import *
|
||||
#from .race_task import *
|
||||
from .task import *
|
||||
from .task_registry import *
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
"""
|
||||
Official evaluation script for ReCoRD v1.0.
|
||||
(Some functions are adopted from the SQuAD evaluation script.)
|
||||
"""
|
||||
|
||||
from __future__ import print_function
|
||||
from collections import Counter
|
||||
import string
|
||||
import re
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from ...utils import get_logger
|
||||
|
||||
logger=get_logger()
|
||||
|
||||
__all__=['normalize_answer', 'evaluate']
|
||||
|
||||
|
||||
def normalize_answer(s):
|
||||
"""Lower text and remove punctuation, articles and extra whitespace."""
|
||||
def remove_articles(text):
|
||||
return re.sub(r'\b(a|an|the)\b', ' ', text)
|
||||
|
||||
def white_space_fix(text):
|
||||
return ' '.join(text.split())
|
||||
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return ''.join(ch for ch in text if ch not in exclude)
|
||||
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
|
||||
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||
|
||||
|
||||
def f1_score(prediction, ground_truth):
|
||||
prediction_tokens = normalize_answer(prediction).split()
|
||||
ground_truth_tokens = normalize_answer(ground_truth).split()
|
||||
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
|
||||
num_same = sum(common.values())
|
||||
if num_same == 0:
|
||||
return 0
|
||||
precision = 1.0 * num_same / len(prediction_tokens)
|
||||
recall = 1.0 * num_same / len(ground_truth_tokens)
|
||||
f1 = (2 * precision * recall) / (precision + recall)
|
||||
return f1
|
||||
|
||||
|
||||
def exact_match_score(prediction, ground_truth):
|
||||
return normalize_answer(prediction) == normalize_answer(ground_truth)
|
||||
|
||||
|
||||
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
||||
scores_for_ground_truths = []
|
||||
for ground_truth in ground_truths:
|
||||
score = metric_fn(prediction, ground_truth)
|
||||
scores_for_ground_truths.append(score)
|
||||
return max(scores_for_ground_truths)
|
||||
|
||||
|
||||
def evaluate(answers, predictions):
|
||||
f1 = exact_match = total = 0
|
||||
correct_ids = []
|
||||
for qid in answers:
|
||||
total += 1
|
||||
if qid not in predictions:
|
||||
message = 'Unanswered question {} will receive score 0.'.format(qid)
|
||||
logger.warning(message)
|
||||
continue
|
||||
|
||||
ground_truths = list(map(lambda x: x['text'], answers[qid]))
|
||||
prediction = predictions[qid]
|
||||
|
||||
_exact_match = metric_max_over_ground_truths(exact_match_score, prediction, ground_truths)
|
||||
if int(_exact_match) == 1:
|
||||
correct_ids.append(qid)
|
||||
exact_match += _exact_match
|
||||
|
||||
f1 += metric_max_over_ground_truths(f1_score, prediction, ground_truths)
|
||||
|
||||
exact_match = 100.0 * exact_match / total
|
||||
f1 = 100.0 * f1 / total
|
||||
|
||||
return {'exact_match': exact_match, 'f1': f1}
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,7 @@
|
|||
# Dev performance on SuperGLUE
|
||||
|
||||
|Models | COPA | RTE | ReCoRD(F1/EM)|
|
||||
|----------|--------|---------|--------------|
|
||||
|XXLarge-v2|97 |93.5 |94.1/93.7 |
|
||||
|XLarge-v2 |97 |93.5 |93.8/93.8 |
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
{
|
||||
"pooling": {
|
||||
"dropout": 0,
|
||||
"hidden_act": "gelu"
|
||||
}
|
||||
}
|
|
@ -0,0 +1,81 @@
|
|||
#!/bin/bash
|
||||
SCRIPT=$(readlink -f "$0")
|
||||
SCRIPT_DIR=$(dirname "$SCRIPT")
|
||||
cd $SCRIPT_DIR
|
||||
|
||||
cache_dir=/tmp/DeBERTa/
|
||||
|
||||
function setup_data(){
|
||||
task=$1
|
||||
mkdir -p $cache_dir
|
||||
if [[ ! -e $cache_dir/superglue_tasks/${task}/train.jsonl ]]; then
|
||||
./download_data.sh $cache_dir/superglue_tasks
|
||||
fi
|
||||
}
|
||||
|
||||
|
||||
Task=COPA
|
||||
setup_data $Task
|
||||
|
||||
init=$1
|
||||
tag=$init
|
||||
case ${init,,} in
|
||||
base-mnli)
|
||||
parameters=" --num_train_epochs 3 \
|
||||
--warmup 50 \
|
||||
--learning_rate 1e-5 \
|
||||
--train_batch_size 32 \
|
||||
--cls_drop_out 0.1 "
|
||||
;;
|
||||
large-mnli)
|
||||
parameters=" --num_train_epochs 4 \
|
||||
--warmup 30 \
|
||||
--learning_rate 5e-6 \
|
||||
--train_batch_size 16 \
|
||||
--cls_drop_out 0.1 \
|
||||
--fp16 True "
|
||||
;;
|
||||
xlarge-mnli)
|
||||
parameters=" --num_train_epochs 4 \
|
||||
--warmup 30 \
|
||||
--learning_rate 3e-6 \
|
||||
--train_batch_size 24 \
|
||||
--cls_drop_out 0.1 \
|
||||
--fp16 True "
|
||||
;;
|
||||
xlarge-v2-mnli)
|
||||
parameters=" --num_train_epochs 4 \
|
||||
--warmup 30 \
|
||||
--learning_rate 2e-6 \
|
||||
--train_batch_size 24 \
|
||||
--cls_drop_out 0.1 \
|
||||
--fp16 True "
|
||||
;;
|
||||
xxlarge-v2-mnli)
|
||||
parameters=" --num_train_epochs 8 \
|
||||
--warmup 30 \
|
||||
--learning_rate 1e-6 \
|
||||
--train_batch_size 24 \
|
||||
--cls_drop_out 0.1 \
|
||||
--fp16 True "
|
||||
;;
|
||||
*)
|
||||
echo "usage $0 <Pretrained model configuration>"
|
||||
echo "Supported configurations"
|
||||
echo "base - Pretrained DeBERTa v1 model with 140M parameters (12 layers, 768 hidden size)"
|
||||
echo "large - Pretrained DeBERta v1 model with 380M parameters (24 layers, 1024 hidden size)"
|
||||
echo "xlarge - Pretrained DeBERTa v1 model with 750M parameters (48 layers, 1024 hidden size)"
|
||||
echo "xlarge-v2 - Pretrained DeBERTa v2 model with 900M parameters (24 layers, 1536 hidden size)"
|
||||
echo "xxlarge-v2 - Pretrained DeBERTa v2 model with 1.5B parameters (48 layers, 1536 hidden size)"
|
||||
exit 0
|
||||
;;
|
||||
esac
|
||||
|
||||
python -m DeBERTa.apps.run --model_config config.json \
|
||||
--tag $tag \
|
||||
--do_train \
|
||||
--max_seq_len 136 \
|
||||
--task_name $Task \
|
||||
--data_dir $cache_dir/superglue_tasks/$Task \
|
||||
--init_model $init \
|
||||
--output_dir /tmp/ttonly/$tag/$task $parameters
|
|
@ -0,0 +1,11 @@
|
|||
#!/bin/bash
|
||||
|
||||
cache_dir=$1
|
||||
if [[ -z $cache_dir ]]; then
|
||||
cache_dir=/tmp/DeBERTa/superglue
|
||||
fi
|
||||
|
||||
|
||||
mkdir -p $cache_dir
|
||||
curl -s -J -L https://dl.fbaipublicfiles.com/glue/superglue/data/v2/combined.zip -o $cache_dir/super.zip
|
||||
unzip $cache_dir/super.zip -d $cache_dir
|
|
@ -0,0 +1,82 @@
|
|||
#!/bin/bash
|
||||
SCRIPT=$(readlink -f "$0")
|
||||
SCRIPT_DIR=$(dirname "$SCRIPT")
|
||||
cd $SCRIPT_DIR
|
||||
|
||||
cache_dir=/tmp/DeBERTa/
|
||||
|
||||
function setup_data(){
|
||||
task=$1
|
||||
mkdir -p $cache_dir
|
||||
if [[ ! -e $cache_dir/superglue_tasks/${task}/train.jsonl ]]; then
|
||||
./download_data.sh $cache_dir/superglue_tasks
|
||||
fi
|
||||
}
|
||||
|
||||
|
||||
Task=ReCoRD
|
||||
setup_data $Task
|
||||
|
||||
init=$1
|
||||
tag=$init
|
||||
case ${init,,} in
|
||||
base)
|
||||
parameters=" --num_train_epochs 3 \
|
||||
--warmup 1000 \
|
||||
--learning_rate 2e-5 \
|
||||
--train_batch_size 64 \
|
||||
--cls_drop_out 0.3 "
|
||||
;;
|
||||
large)
|
||||
parameters=" --num_train_epochs 3 \
|
||||
--warmup 1000 \
|
||||
--learning_rate 1e-5 \
|
||||
--train_batch_size 64 \
|
||||
--cls_drop_out 0.3 \
|
||||
--fp16 True "
|
||||
;;
|
||||
xlarge)
|
||||
parameters=" --num_train_epochs 3 \
|
||||
--warmup 1000 \
|
||||
--learning_rate 7e-6 \
|
||||
--train_batch_size 64 \
|
||||
--cls_drop_out 0.3 \
|
||||
--fp16 True "
|
||||
;;
|
||||
xlarge-v2)
|
||||
parameters=" --num_train_epochs 3 \
|
||||
--warmup 1000 \
|
||||
--learning_rate 5e-6 \
|
||||
--train_batch_size 64 \
|
||||
--cls_drop_out 0.15 \
|
||||
--fp16 True "
|
||||
;;
|
||||
xxlarge-v2)
|
||||
parameters=" --num_train_epochs 3 \
|
||||
--warmup 1000 \
|
||||
--learning_rate 3e-6 \
|
||||
--accumulative_update 2 \
|
||||
--train_batch_size 64 \
|
||||
--cls_drop_out 0.15 \
|
||||
--fp16 True "
|
||||
;;
|
||||
*)
|
||||
echo "usage $0 <Pretrained model configuration>"
|
||||
echo "Supported configurations"
|
||||
echo "base - Pretrained DeBERTa v1 model with 140M parameters (12 layers, 768 hidden size)"
|
||||
echo "large - Pretrained DeBERta v1 model with 380M parameters (24 layers, 1024 hidden size)"
|
||||
echo "xlarge - Pretrained DeBERTa v1 model with 750M parameters (48 layers, 1024 hidden size)"
|
||||
echo "xlarge-v2 - Pretrained DeBERTa v2 model with 900M parameters (24 layers, 1536 hidden size)"
|
||||
echo "xxlarge-v2 - Pretrained DeBERTa v2 model with 1.5B parameters (48 layers, 1536 hidden size)"
|
||||
exit 0
|
||||
;;
|
||||
esac
|
||||
|
||||
python -m DeBERTa.apps.run --model_config config.json \
|
||||
--tag $tag \
|
||||
--do_train \
|
||||
--max_seq_len 512 \
|
||||
--task_name $Task \
|
||||
--data_dir $cache_dir/superglue_tasks/$Task \
|
||||
--init_model $init \
|
||||
--output_dir /tmp/ttonly/$tag/$task $parameters
|
Загрузка…
Ссылка в новой задаче