Black auto formatting.
This commit is contained in:
Родитель
3d1c1862d9
Коммит
2473e1a75c
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -1,8 +1,11 @@
|
|||
"""This script reuses some code from
|
||||
https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/examples/run_classifier.py"""
|
||||
https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/examples
|
||||
/run_classifier.py"""
|
||||
|
||||
import pandas as pd
|
||||
import csv, sys, random
|
||||
import csv
|
||||
import sys
|
||||
import random
|
||||
from collections import namedtuple
|
||||
|
||||
|
||||
|
@ -29,26 +32,27 @@ from collections import namedtuple
|
|||
# self.label = label
|
||||
|
||||
## New version of BERTInputData using namedtuple
|
||||
"""
|
||||
"""
|
||||
A single BERT input data containing three fields:
|
||||
1. text_a: text of the first sentence,
|
||||
2. text_b: text of the second sentence, optional, required for
|
||||
two-sentence tasks.
|
||||
2. text_b: text of the second sentence, optional, required for
|
||||
two-sentence tasks.
|
||||
3. label: label, optional, required for training and validation data
|
||||
"""
|
||||
BertInputData = namedtuple('BertInputData', ['text_a', 'text_b', 'label'],
|
||||
defaults=[None, None])
|
||||
BertInputData = namedtuple(
|
||||
"BertInputData", ["text_a", "text_b", "label"], defaults=[None, None]
|
||||
)
|
||||
|
||||
|
||||
class DataProcessor(object):
|
||||
"""Base class for data converters for classification data sets."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""Gets a collection of `InputExample`s for the train set."""
|
||||
"""Gets a collection of `BertInputData`s for the train set."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""Gets a collection of `InputExample`s for the dev set."""
|
||||
"""Gets a collection of `BertInputData`s for the dev set."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_labels(self):
|
||||
|
@ -63,7 +67,7 @@ class DataProcessor(object):
|
|||
lines = []
|
||||
for line in reader:
|
||||
if sys.version_info[0] == 2:
|
||||
line = list(unicode(cell, 'utf-8') for cell in line)
|
||||
line = list(unicode(cell, "utf-8") for cell in line)
|
||||
lines.append(line)
|
||||
return lines
|
||||
|
||||
|
@ -73,6 +77,7 @@ class KaggleNERProcessor(DataProcessor):
|
|||
Data processor for the Kaggle NER dataset:
|
||||
https://www.kaggle.com/abhinavwalia95/entity-annotated-corpus
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir, dev_percentage):
|
||||
"""
|
||||
Initializes the data processor.
|
||||
|
@ -100,21 +105,23 @@ class KaggleNERProcessor(DataProcessor):
|
|||
|
||||
def get_train_examples(self):
|
||||
"""
|
||||
Gets the training examples.
|
||||
Gets the training data.
|
||||
"""
|
||||
data = self._read_data(self.data_dir)
|
||||
train_data = data.loc[
|
||||
data["Sentence #"].isin(self.train_sentence_nums)].copy()
|
||||
data["Sentence #"].isin(self.train_sentence_nums)
|
||||
].copy()
|
||||
|
||||
return self._create_examples(train_data)
|
||||
|
||||
def get_dev_examples(self):
|
||||
"""
|
||||
Gets the dev/validation examples.
|
||||
Gets the dev/validation data.
|
||||
"""
|
||||
data = self._read_data(self.data_dir)
|
||||
dev_data = data.loc[
|
||||
data["Sentence #"].isin(self.dev_sentence_nums)].copy()
|
||||
data["Sentence #"].isin(self.dev_sentence_nums)
|
||||
].copy()
|
||||
|
||||
return self._create_examples(dev_data)
|
||||
|
||||
|
@ -133,10 +140,14 @@ class KaggleNERProcessor(DataProcessor):
|
|||
"""
|
||||
Converts input data into BertInputData type.
|
||||
"""
|
||||
agg_func = lambda s: [(w, p, t) for w, p, t in
|
||||
zip(s["Word"].values.tolist(),
|
||||
s["POS"].values.tolist(),
|
||||
s["Tag"].values.tolist())]
|
||||
agg_func = lambda s: [
|
||||
(w, p, t)
|
||||
for w, p, t in zip(
|
||||
s["Word"].values.tolist(),
|
||||
s["POS"].values.tolist(),
|
||||
s["Tag"].values.tolist(),
|
||||
)
|
||||
]
|
||||
data_grouped = data.groupby("Sentence #").apply(agg_func)
|
||||
sentences = [s for s in data_grouped]
|
||||
examples = []
|
||||
|
@ -144,6 +155,7 @@ class KaggleNERProcessor(DataProcessor):
|
|||
text_a = " ".join([s[0] for s in sent])
|
||||
label = [s[2] for s in sent]
|
||||
examples.append(
|
||||
BertInputData(text_a=text_a, text_b=None, label=label))
|
||||
BertInputData(text_a=text_a, text_b=None, label=label)
|
||||
)
|
||||
|
||||
return examples
|
||||
return examples
|
||||
|
|
|
@ -1,13 +1,18 @@
|
|||
"""This script reuses some code from
|
||||
https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/examples/run_classifier.py"""
|
||||
https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/examples
|
||||
/run_classifier.py"""
|
||||
from enum import Enum, auto
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
import torch
|
||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
TensorDataset)
|
||||
from torch.utils.data import (
|
||||
DataLoader,
|
||||
RandomSampler,
|
||||
SequentialSampler,
|
||||
TensorDataset,
|
||||
)
|
||||
from torch.optim import Adam
|
||||
|
||||
from pytorch_pretrained_bert.optimization import BertAdam
|
||||
|
@ -16,10 +21,10 @@ from pytorch_pretrained_bert.modeling import BertForTokenClassification
|
|||
|
||||
def get_device():
|
||||
"""
|
||||
Helper function for detecting devices and number of gpus.
|
||||
Helper function for detecting devices and number of GPUs.
|
||||
|
||||
Returns:
|
||||
(str, int): tuple of device and number of gpus
|
||||
(str, int): tuple of device and number of GPUs
|
||||
"""
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
n_gpu = torch.cuda.device_count()
|
||||
|
@ -27,12 +32,14 @@ def get_device():
|
|||
return device, n_gpu
|
||||
|
||||
|
||||
def create_token_feature_dataset(data,
|
||||
tokenizer,
|
||||
label_map,
|
||||
max_seq_length,
|
||||
true_label_available,
|
||||
trailing_piece_tag="X"):
|
||||
def create_token_feature_dataset(
|
||||
data,
|
||||
tokenizer,
|
||||
label_map,
|
||||
max_seq_length,
|
||||
true_label_available,
|
||||
trailing_piece_tag="X",
|
||||
):
|
||||
"""
|
||||
Converts data from text to TensorDataset containing numerical features.
|
||||
|
||||
|
@ -43,11 +50,11 @@ def create_token_feature_dataset(data,
|
|||
text_b: text of the second sentence(optional)
|
||||
label: label (optional)
|
||||
tokenizer (BertTokenizer): Tokenizer for splitting sentence into
|
||||
word pieces.
|
||||
word piece tokens.
|
||||
label_map (dict): Dictionary for mapping token labels to integers.
|
||||
max_seq_length (int): Maximum length of the list of tokens. Lists
|
||||
longer than this are truncated and shorter ones are padded with
|
||||
zeros.
|
||||
"O"s.
|
||||
true_label_available (bool): Whether data labels are available.
|
||||
trailing_piece_tag (str): Tags used to label trailing word pieces.
|
||||
For example, "playing" is broken down into "play" and "##ing",
|
||||
|
@ -56,10 +63,13 @@ def create_token_feature_dataset(data,
|
|||
Returns:
|
||||
TensorDataset: A TensorDataset consisted of the following numerical
|
||||
feature tensors:
|
||||
1. token ids
|
||||
2. attention mask
|
||||
3. segment ids
|
||||
4. label ids, if true_label_available is True
|
||||
1. token ids: numerical values each corresponds to a token.
|
||||
2. attention mask: 1 for input tokens and 0 for padded tokens,
|
||||
so that padded tokens are not attended to.
|
||||
3. segment ids: 0 for the first sentence and 1 for the second
|
||||
sentence, only used in two sentence tasks.
|
||||
4. label ids: numerical values each corresponds to a token
|
||||
label, if true_label_available is True
|
||||
"""
|
||||
|
||||
features = []
|
||||
|
@ -69,7 +79,6 @@ def create_token_feature_dataset(data,
|
|||
new_labels = []
|
||||
tokens = []
|
||||
for word, tag in zip(text_lower.split(), example.label):
|
||||
# print('splitting: ', word)
|
||||
sub_words = tokenizer.wordpiece_tokenizer.tokenize(word)
|
||||
for count, sub_word in enumerate(sub_words):
|
||||
if tag is not None and count > 0:
|
||||
|
@ -98,30 +107,25 @@ def create_token_feature_dataset(data,
|
|||
segment_ids += padding
|
||||
new_labels += label_padding
|
||||
|
||||
assert len(input_ids) == max_seq_length
|
||||
assert len(input_mask) == max_seq_length
|
||||
assert len(segment_ids) == max_seq_length
|
||||
assert len(new_labels) == max_seq_length
|
||||
|
||||
label_ids = [label_map[label] for label in new_labels]
|
||||
|
||||
features.append((input_ids, input_mask, segment_ids, label_ids))
|
||||
|
||||
all_input_ids = torch.tensor([f[0] for f in features],
|
||||
dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f[1] for f in features],
|
||||
dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f[2] for f in features],
|
||||
dtype=torch.long)
|
||||
all_input_ids = torch.tensor([f[0] for f in features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f[1] for f in features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f[2] for f in features], dtype=torch.long)
|
||||
|
||||
if true_label_available:
|
||||
all_label_ids = torch.tensor([f[3] for f in features],
|
||||
dtype=torch.long)
|
||||
tensor_data = TensorDataset(all_input_ids, all_input_mask,
|
||||
all_segment_ids, all_label_ids)
|
||||
all_label_ids = torch.tensor(
|
||||
[f[3] for f in features], dtype=torch.long
|
||||
)
|
||||
tensor_data = TensorDataset(
|
||||
all_input_ids, all_input_mask, all_segment_ids, all_label_ids
|
||||
)
|
||||
else:
|
||||
tensor_data = TensorDataset(all_input_ids, all_input_mask,
|
||||
all_segment_ids)
|
||||
tensor_data = TensorDataset(
|
||||
all_input_ids, all_input_mask, all_segment_ids
|
||||
)
|
||||
|
||||
return tensor_data
|
||||
|
||||
|
@ -139,23 +143,32 @@ class Language(Enum):
|
|||
class BertTokenClassifier:
|
||||
"""BERT-based token classifier."""
|
||||
|
||||
def __init__(self, config, label_map, device, n_gpu,
|
||||
language=Language.ENGLISH, cache_dir="."):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
label_map,
|
||||
device,
|
||||
n_gpu,
|
||||
language=Language.ENGLISH,
|
||||
cache_dir=".",
|
||||
):
|
||||
|
||||
"""
|
||||
Initializes the classifier and the underlying pretrained model and
|
||||
Initializes the classifier and the underlying pre-trained model and
|
||||
optimizer.
|
||||
|
||||
Args:
|
||||
config (BERTFineTuneConfig): A configuration object contains
|
||||
settings of model, training, and optimizer.
|
||||
label_map (dict): Dictionary used to map token labels to
|
||||
integers during data preprocessing.
|
||||
integers during data pre-processing. This is used to
|
||||
convert token ids back to original token labels during
|
||||
prediction.
|
||||
device (str): "cuda" or "cpu". Can be obtained by calling
|
||||
get_device.
|
||||
n_gpu (int): Number of GPUs available.Can be obtained by calling
|
||||
n_gpu (int): Number of GPUs available.m,Can be obtained by calling
|
||||
get_device.
|
||||
language (Language, optinal): The pretrained model's language.
|
||||
language (Language, optional): The pre-trained model's language.
|
||||
Defaults to Language.ENGLISH.
|
||||
cache_dir (str, optional): Location of BERT's cache directory.
|
||||
Defaults to ".".
|
||||
|
@ -190,10 +203,12 @@ class BertTokenClassifier:
|
|||
self._is_trained = False
|
||||
|
||||
def _load_model(self):
|
||||
"""Loads the pretrained BERT model."""
|
||||
"""Loads the pre-trained BERT model."""
|
||||
model = BertForTokenClassification.from_pretrained(
|
||||
self.bert_model, cache_dir=self.cache_dir,
|
||||
num_labels=self.num_labels)
|
||||
self.bert_model,
|
||||
cache_dir=self.cache_dir,
|
||||
num_labels=self.num_labels,
|
||||
)
|
||||
|
||||
model.to(self.device)
|
||||
|
||||
|
@ -209,20 +224,32 @@ class BertTokenClassifier:
|
|||
"""
|
||||
param_optimizer = list(self.model.named_parameters())
|
||||
optimizer_grouped_parameters = [
|
||||
{'params': [p for n, p in param_optimizer if
|
||||
not any(nd in n for nd in self.no_decay_params)],
|
||||
'weight_decay': self.params_weight_decay},
|
||||
{'params': [p for n, p in param_optimizer if
|
||||
any(nd in n for nd in self.no_decay_params)],
|
||||
'weight_decay': 0.0}
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in param_optimizer
|
||||
if not any(nd in n for nd in self.no_decay_params)
|
||||
],
|
||||
"weight_decay": self.params_weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in param_optimizer
|
||||
if any(nd in n for nd in self.no_decay_params)
|
||||
],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
|
||||
if self.optimizer_name == 'BertAdam':
|
||||
optimizer = BertAdam(optimizer_grouped_parameters,
|
||||
lr=self.learning_rate)
|
||||
elif self.optimizer_name == 'Adam':
|
||||
optimizer = Adam(optimizer_grouped_parameters,
|
||||
lr=self.learning_rate)
|
||||
if self.optimizer_name == "BertAdam":
|
||||
optimizer = BertAdam(
|
||||
optimizer_grouped_parameters, lr=self.learning_rate
|
||||
)
|
||||
elif self.optimizer_name == "Adam":
|
||||
optimizer = Adam(
|
||||
optimizer_grouped_parameters, lr=self.learning_rate
|
||||
)
|
||||
|
||||
return optimizer
|
||||
|
||||
|
@ -233,30 +260,36 @@ class BertTokenClassifier:
|
|||
Args:
|
||||
train_dataset (TensorDataset): TensorDataset consisted of the
|
||||
following numerical feature tensors.
|
||||
1. token ids
|
||||
2. attention mask
|
||||
3. segment ids
|
||||
4. label ids
|
||||
1. token ids: numerical values each corresponds to a token.
|
||||
2. attention mask: 1 for input tokens and 0 for padded tokens,
|
||||
so that padded tokens are not attended to.
|
||||
3. segment ids: 0 for the first sentence and 1 for the second
|
||||
sentence, only used in two sentence tasks.
|
||||
4. label ids: numerical values each corresponds to a token
|
||||
label
|
||||
"""
|
||||
train_sampler = RandomSampler(train_dataset)
|
||||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler,
|
||||
batch_size=self.batch_size)
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset, sampler=train_sampler, batch_size=self.batch_size
|
||||
)
|
||||
|
||||
global_step = 0
|
||||
self.model.train()
|
||||
for _ in trange(int(self.num_train_epochs), desc="Epoch"):
|
||||
tr_loss = 0
|
||||
nb_tr_steps = 0
|
||||
for step, batch in enumerate(tqdm(train_dataloader,
|
||||
desc="Iteration",
|
||||
mininterval=30)):
|
||||
for step, batch in enumerate(
|
||||
tqdm(train_dataloader, desc="Iteration", mininterval=30)
|
||||
):
|
||||
batch = tuple(t.to(self.device) for t in batch)
|
||||
input_ids, input_mask, segment_ids, label_ids = batch
|
||||
|
||||
loss = self.model(input_ids=input_ids,
|
||||
token_type_ids=segment_ids,
|
||||
attention_mask=input_mask,
|
||||
labels=label_ids)
|
||||
loss = self.model(
|
||||
input_ids=input_ids,
|
||||
token_type_ids=segment_ids,
|
||||
attention_mask=input_mask,
|
||||
labels=label_ids,
|
||||
)
|
||||
|
||||
if self.n_gpu > 1:
|
||||
# mean() to average on multi-gpu.
|
||||
|
@ -270,13 +303,14 @@ class BertTokenClassifier:
|
|||
if self.clip_gradient:
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
parameters=self.model.parameters(),
|
||||
max_norm=self.max_gradient_norm)
|
||||
max_norm=self.max_gradient_norm,
|
||||
)
|
||||
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
global_step += 1
|
||||
|
||||
train_loss = tr_loss/nb_tr_steps
|
||||
train_loss = tr_loss / nb_tr_steps
|
||||
print("Train loss: {}".format(train_loss))
|
||||
|
||||
self._is_trained = True
|
||||
|
@ -288,32 +322,38 @@ class BertTokenClassifier:
|
|||
Args:
|
||||
test_dataset (TensorDataset): TensorDataset consisted of the
|
||||
following numerical feature tensors.
|
||||
1. token ids
|
||||
2. attention mask
|
||||
3. segment ids
|
||||
4. label ids, optional
|
||||
1. token ids: numerical values each corresponds to a token.
|
||||
2. attention mask: 1 for input tokens and 0 for padded tokens,
|
||||
so that padded tokens are not attended to.
|
||||
3. segment ids: 0 for the first sentence and 1 for the second
|
||||
sentence, only used in two sentence tasks.
|
||||
4. label ids: numerical values each corresponds to a token
|
||||
label, optional
|
||||
|
||||
Returns:
|
||||
tuple: The first element of the tuple is the predicted token
|
||||
labels. If the testing dataset contain label ids, the second
|
||||
element of the tuple is the true token labels.
|
||||
labels. If the testing dataset contains label ids, the second
|
||||
element of the tuple is the true token labels, otherwise
|
||||
it's None.
|
||||
"""
|
||||
test_sampler = SequentialSampler(test_dataset)
|
||||
test_dataloader = DataLoader(test_dataset, sampler=test_sampler,
|
||||
batch_size=self.batch_size)
|
||||
test_dataloader = DataLoader(
|
||||
test_dataset, sampler=test_sampler, batch_size=self.batch_size
|
||||
)
|
||||
|
||||
if not self._is_trained:
|
||||
raise Exception("Model is not trained. Please train model before "
|
||||
"predict.")
|
||||
raise Exception(
|
||||
"Model is not trained. Please train model before " "predict."
|
||||
)
|
||||
|
||||
self.model.eval()
|
||||
predictions = []
|
||||
true_labels = []
|
||||
eval_loss, eval_accuracy = 0, 0
|
||||
eval_loss = 0
|
||||
nb_eval_steps = 0
|
||||
for step, batch in enumerate(tqdm(test_dataloader,
|
||||
desc="Iteration",
|
||||
mininterval=10)):
|
||||
for step, batch in enumerate(
|
||||
tqdm(test_dataloader, desc="Iteration", mininterval=10)
|
||||
):
|
||||
batch = tuple(t.to(self.device) for t in batch)
|
||||
true_label_available = False
|
||||
if len(batch) == 3:
|
||||
|
@ -324,19 +364,23 @@ class BertTokenClassifier:
|
|||
|
||||
with torch.no_grad():
|
||||
if true_label_available:
|
||||
tmp_eval_loss = self.model(b_input_ids,
|
||||
token_type_ids=None,
|
||||
attention_mask=b_input_mask,
|
||||
labels=b_labels)
|
||||
logits = self.model(b_input_ids,
|
||||
token_type_ids=None,
|
||||
attention_mask=b_input_mask)
|
||||
tmp_eval_loss = self.model(
|
||||
b_input_ids,
|
||||
token_type_ids=b_segment_ids,
|
||||
attention_mask=b_input_mask,
|
||||
labels=b_labels,
|
||||
)
|
||||
logits = self.model(
|
||||
b_input_ids,
|
||||
token_type_ids=b_segment_ids,
|
||||
attention_mask=b_input_mask,
|
||||
)
|
||||
|
||||
logits = logits.detach().cpu().numpy()
|
||||
predictions.extend([list(p) for p in np.argmax(logits, axis=2)])
|
||||
|
||||
if true_label_available:
|
||||
label_ids = b_labels.to('cpu').numpy()
|
||||
label_ids = b_labels.to("cpu").numpy()
|
||||
true_labels.append(label_ids)
|
||||
eval_loss += tmp_eval_loss.mean().item()
|
||||
|
||||
|
@ -346,16 +390,17 @@ class BertTokenClassifier:
|
|||
print("Validation loss: {}".format(validation_loss))
|
||||
|
||||
reversed_label_map = {v: k for k, v in self.label_map.items()}
|
||||
pred_tags = [[reversed_label_map[p_i] for p_i in p] for p in
|
||||
predictions]
|
||||
pred_tags = [
|
||||
[reversed_label_map[p_i] for p_i in p] for p in predictions
|
||||
]
|
||||
|
||||
if true_label_available:
|
||||
valid_tags = [[reversed_label_map[l_ii] for l_ii in l_i] for
|
||||
l in true_labels for l_i in l]
|
||||
valid_tags = [
|
||||
[reversed_label_map[l_ii] for l_ii in l_i]
|
||||
for l in true_labels
|
||||
for l_i in l
|
||||
]
|
||||
|
||||
return pred_tags, valid_tags
|
||||
else:
|
||||
return pred_tags,
|
||||
|
||||
|
||||
|
||||
return (pred_tags,)
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
class BERTFineTuneConfig:
|
||||
"""
|
||||
Configurations for fine tuning pre-trained bert model.
|
||||
Configurations for fine tuning pre-trained BERT models.
|
||||
"""
|
||||
|
||||
def __init__(self, config_dict):
|
||||
"""
|
||||
Initializes a BERTFineTuneConfig object.
|
||||
|
@ -9,23 +10,30 @@ class BERTFineTuneConfig:
|
|||
Args:
|
||||
config_dict (dict): A nested dictionary containing three key,
|
||||
value pairs.
|
||||
"ModelConfig": model configuration dictionary
|
||||
"bert_model": str, name of the bert pretrained model.
|
||||
Accepted values are
|
||||
"ModelConfig": model configuration dictionary:
|
||||
"bert_model": str, name of the bert pre-trained model.
|
||||
Accepted values are:
|
||||
"bert-base-uncased"
|
||||
"bert-large-uncased"
|
||||
"bert-base-cased"
|
||||
"bert-large-cased"
|
||||
"bert-base-multilingual-uncased"
|
||||
"bert-base-multilingual-cased"
|
||||
"bert-base-chinese"
|
||||
"max_seq_length": int, optional. Maximum length of token
|
||||
sequence. Default value is 512.
|
||||
"do_lower_case": bool, optional. Whether to convert
|
||||
capital letters to lower case during tokenization.
|
||||
Default value is True.
|
||||
"TrainConfig" (optional): training configuration dictionary
|
||||
"TrainConfig" (optional): training configuration dictionary:
|
||||
"batch_size": int, optional. Default value is 32.
|
||||
"num_train_epochs": int, optional. Default value is 3.
|
||||
"OptimizerConfig" (optional): optimizer configuration
|
||||
dictionary
|
||||
dictionary:
|
||||
"optimizer_name": str, optional. Name of the optimizer
|
||||
to use. Accepted values are "BertAdam", "Adam".
|
||||
Default value is "BertAdam"
|
||||
"learning_rate": float, optional, default value is 35e-05,
|
||||
"learning_rate": float, optional, default value is 5e-05,
|
||||
"no_decay_params": list of strings, optional. Names of
|
||||
parameters to apply weight decay on. Default value
|
||||
is [].
|
||||
|
@ -61,15 +69,21 @@ class BERTFineTuneConfig:
|
|||
self.batch_size = config_dict["batch_size"]
|
||||
else:
|
||||
self.batch_size = 32
|
||||
print("batch_size is set to default value: {}.".format(
|
||||
self.batch_size))
|
||||
print(
|
||||
"batch_size is set to default value: {}.".format(
|
||||
self.batch_size
|
||||
)
|
||||
)
|
||||
|
||||
if "num_train_epochs" in config_dict:
|
||||
self.num_train_epochs = config_dict["num_train_epochs"]
|
||||
else:
|
||||
self.num_train_epochs = 3
|
||||
print("num_train_epochs is set to default value: {}.".format(
|
||||
self.num_train_epochs))
|
||||
print(
|
||||
"num_train_epochs is set to default value: {}.".format(
|
||||
self.num_train_epochs
|
||||
)
|
||||
)
|
||||
|
||||
def _configure_model_settings(self, config_dict):
|
||||
self.bert_model = config_dict["bert_model"]
|
||||
|
@ -78,63 +92,88 @@ class BERTFineTuneConfig:
|
|||
self.max_seq_length = config_dict["max_seq_length"]
|
||||
else:
|
||||
self.max_seq_length = 512
|
||||
print("max_seq_length is set to default value: {}.".format(
|
||||
self.max_seq_length))
|
||||
print(
|
||||
"max_seq_length is set to default value: {}.".format(
|
||||
self.max_seq_length
|
||||
)
|
||||
)
|
||||
|
||||
if 'do_lower_case' in config_dict:
|
||||
self.do_lower_case = config_dict['do_lower_case']
|
||||
if "do_lower_case" in config_dict:
|
||||
self.do_lower_case = config_dict["do_lower_case"]
|
||||
else:
|
||||
self.do_lower_case = True
|
||||
print("do_lower_case is set to default value: {}.".format(
|
||||
self.do_lower_case))
|
||||
print(
|
||||
"do_lower_case is set to default value: {}.".format(
|
||||
self.do_lower_case
|
||||
)
|
||||
)
|
||||
|
||||
def _configure_optimizer_settings(self, config_dict):
|
||||
if "optimizer_name" in config_dict:
|
||||
self.optimizer_name = config_dict["optimizer_name"]
|
||||
else:
|
||||
self.optimizer_name = "BertAdam"
|
||||
print("optimizer_name is set to default value: {}.".format(
|
||||
self.optimizer_name))
|
||||
print(
|
||||
"optimizer_name is set to default value: {}.".format(
|
||||
self.optimizer_name
|
||||
)
|
||||
)
|
||||
|
||||
if 'learning_rate' in config_dict:
|
||||
self.learning_rate = config_dict['learning_rate']
|
||||
if "learning_rate" in config_dict:
|
||||
self.learning_rate = config_dict["learning_rate"]
|
||||
else:
|
||||
self.learning_rate = 5e-5
|
||||
print("learning_rate is set to default value: {}.".format(
|
||||
self.learning_rate))
|
||||
print(
|
||||
"learning_rate is set to default value: {}.".format(
|
||||
self.learning_rate
|
||||
)
|
||||
)
|
||||
|
||||
if "no_decay_params" in config_dict:
|
||||
self.no_decay_params = config_dict["no_decay_params"]
|
||||
else:
|
||||
self.no_decay_params = []
|
||||
print("no_decay_params is set to default value: {}.".format(
|
||||
self.no_decay_params))
|
||||
print(
|
||||
"no_decay_params is set to default value: {}.".format(
|
||||
self.no_decay_params
|
||||
)
|
||||
)
|
||||
|
||||
if "params_weight_decay" in config_dict:
|
||||
self.params_weight_decay = config_dict["params_weight_decay"]
|
||||
else:
|
||||
self.params_weight_decay = 0.01
|
||||
print("Default params_weight_decay, {}, is used".format(
|
||||
self.params_weight_decay))
|
||||
print(
|
||||
"Default params_weight_decay, {}, is used".format(
|
||||
self.params_weight_decay
|
||||
)
|
||||
)
|
||||
|
||||
if "clip_gradient" in config_dict:
|
||||
self.clip_gradient = config_dict["clip_gradient"]
|
||||
else:
|
||||
self.clip_gradient = False
|
||||
print("clip_gradient is set to default value: {}.".format(
|
||||
self.clip_gradient))
|
||||
print(
|
||||
"clip_gradient is set to default value: {}.".format(
|
||||
self.clip_gradient
|
||||
)
|
||||
)
|
||||
|
||||
if "max_gradient_norm" in config_dict:
|
||||
self.max_gradient_norm = config_dict["max_gradient_norm"]
|
||||
else:
|
||||
self.max_gradient_norm = 1.0
|
||||
print("max_gradient_norm is set to default value: {}.".format(
|
||||
self.max_gradient_norm))
|
||||
print(
|
||||
"max_gradient_norm is set to default value: {}.".format(
|
||||
self.max_gradient_norm
|
||||
)
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
sb = []
|
||||
for key in self.__dict__:
|
||||
sb.append(
|
||||
"{key}={value}".format(key=key, value=self.__dict__[key]))
|
||||
"{key}={value}".format(key=key, value=self.__dict__[key])
|
||||
)
|
||||
|
||||
return '\n'.join(sb)
|
||||
return "\n".join(sb)
|
||||
|
|
Загрузка…
Ссылка в новой задаче