Fit chinese wwm to new datasets (#9887)

* MOD: fit chinese wwm to new datasets

* MOD: move wwm to new folder

* MOD: formate code

* Styling

* MOD add param and recover trainer

Co-authored-by: Sylvain Gugger <sylvain.gugger@gmail.com>
This commit is contained in:
wlhgtc 2021-02-01 16:37:59 +08:00 коммит произвёл GitHub
Родитель 24881008a6
Коммит 1682804ebd
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
6 изменённых файлов: 249 добавлений и 67 удалений

Просмотреть файл

@ -100,72 +100,7 @@ sure all your batches have the same length.
### Whole word masking
The BERT authors released a new version of BERT using Whole Word Masking in May 2019. Instead of masking randomly
selected tokens (which may be part of words), they mask randomly selected words (masking all the tokens corresponding
to that word). This technique has been refined for Chinese in [this paper](https://arxiv.org/abs/1906.08101).
To fine-tune a model using whole word masking, use the following script:
```bash
python run_mlm_wwm.py \
--model_name_or_path roberta-base \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--do_train \
--do_eval \
--output_dir /tmp/test-mlm-wwm
```
For Chinese models, we need to generate a reference files (which requires the ltp library), because it's tokenized at
the character level.
**Q :** Why a reference file?
**A :** Suppose we have a Chinese sentence like: `我喜欢你` The original Chinese-BERT will tokenize it as
`['我','喜','欢','你']` (character level). But `喜欢` is a whole word. For whole word masking proxy, we need a result
like `['我','喜','##欢','你']`, so we need a reference file to tell the model which position of the BERT original token
should be added `##`.
**Q :** Why LTP ?
**A :** Cause the best known Chinese WWM BERT is [Chinese-BERT-wwm](https://github.com/ymcui/Chinese-BERT-wwm) by HIT.
It works well on so many Chines Task like CLUE (Chinese GLUE). They use LTP, so if we want to fine-tune their model,
we need LTP.
Now LTP only only works well on `transformers==3.2.0`. So we don't add it to requirements.txt.
You need to create a separate environment with this version of Transformers to run the `run_chinese_ref.py` script that
will create the reference files. The script is in `examples/contrib`. Once in the proper environment, run the
following:
```bash
export TRAIN_FILE=/path/to/dataset/wiki.train.raw
export LTP_RESOURCE=/path/to/ltp/tokenizer
export BERT_RESOURCE=/path/to/bert/tokenizer
export SAVE_PATH=/path/to/data/ref.txt
python examples/contrib/run_chinese_ref.py \
--file_name=path_to_train_or_eval_file \
--ltp=path_to_ltp_tokenizer \
--bert=path_to_bert_tokenizer \
--save_path=path_to_reference_file
```
Then you can run the script like this:
```bash
python run_mlm_wwm.py \
--model_name_or_path roberta-base \
--train_file path_to_train_file \
--validation_file path_to_validation_file \
--train_ref_file path_to_train_chinese_ref_file \
--validation_ref_file path_to_validation_chinese_ref_file \
--do_train \
--do_eval \
--output_dir /tmp/test-mlm-wwm
```
**Note:** On TPU, you should the flag `--pad_to_max_length` to make sure all your batches have the same length.
This part was moved to `examples/research_projects/mlm_wwm`.
### XLNet and permutation language modeling

Просмотреть файл

@ -0,0 +1,92 @@
<!---
Copyright 2020 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
## Whole Word Mask Language Model
These scripts leverage the 🤗 Datasets library and the Trainer API. You can easily customize them to your needs if you
need extra processing on your datasets.
The following examples, will run on a datasets hosted on our [hub](https://huggingface.co/datasets) or with your own
text files for training and validation. We give examples of both below.
The BERT authors released a new version of BERT using Whole Word Masking in May 2019. Instead of masking randomly
selected tokens (which may be part of words), they mask randomly selected words (masking all the tokens corresponding
to that word). This technique has been refined for Chinese in [this paper](https://arxiv.org/abs/1906.08101).
To fine-tune a model using whole word masking, use the following script:
```bash
python run_mlm_wwm.py \
--model_name_or_path roberta-base \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--do_train \
--do_eval \
--output_dir /tmp/test-mlm-wwm
```
For Chinese models, we need to generate a reference files (which requires the ltp library), because it's tokenized at
the character level.
**Q :** Why a reference file?
**A :** Suppose we have a Chinese sentence like: `我喜欢你` The original Chinese-BERT will tokenize it as
`['我','喜','欢','你']` (character level). But `喜欢` is a whole word. For whole word masking proxy, we need a result
like `['我','喜','##欢','你']`, so we need a reference file to tell the model which position of the BERT original token
should be added `##`.
**Q :** Why LTP ?
**A :** Cause the best known Chinese WWM BERT is [Chinese-BERT-wwm](https://github.com/ymcui/Chinese-BERT-wwm) by HIT.
It works well on so many Chines Task like CLUE (Chinese GLUE). They use LTP, so if we want to fine-tune their model,
we need LTP.
You could run the following:
```bash
export TRAIN_FILE=/path/to/dataset/wiki.train.raw
export LTP_RESOURCE=/path/to/ltp/tokenizer
export BERT_RESOURCE=/path/to/bert/tokenizer
export SAVE_PATH=/path/to/data/ref.txt
python run_chinese_ref.py \
--file_name=path_to_train_or_eval_file \
--ltp=path_to_ltp_tokenizer \
--bert=path_to_bert_tokenizer \
--save_path=path_to_reference_file
```
Then you can run the script like this:
```bash
python run_mlm_wwm.py \
--model_name_or_path roberta-base \
--train_file path_to_train_file \
--validation_file path_to_validation_file \
--train_ref_file path_to_train_chinese_ref_file \
--validation_ref_file path_to_validation_chinese_ref_file \
--do_train \
--do_eval \
--output_dir /tmp/test-mlm-wwm
```
**Note1:** On TPU, you should the flag `--pad_to_max_length` to make sure all your batches have the same length.
**Note2:** And if you have any questions or something goes wrong when runing this code, don't hesitate to pin @wlhgtc.

Просмотреть файл

@ -0,0 +1,4 @@
datasets >= 1.1.3
sentencepiece != 0.1.92
protobuf
ltp

Просмотреть файл

@ -0,0 +1,147 @@
import argparse
import json
from typing import List
from ltp import LTP
from transformers.models.bert.tokenization_bert import BertTokenizer
def _is_chinese_char(cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if (
(cp >= 0x4E00 and cp <= 0x9FFF)
or (cp >= 0x3400 and cp <= 0x4DBF) #
or (cp >= 0x20000 and cp <= 0x2A6DF) #
or (cp >= 0x2A700 and cp <= 0x2B73F) #
or (cp >= 0x2B740 and cp <= 0x2B81F) #
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
or (cp >= 0xF900 and cp <= 0xFAFF)
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
): #
return True
return False
def is_chinese(word: str):
# word like '180' or '身高' or '神'
for char in word:
char = ord(char)
if not _is_chinese_char(char):
return 0
return 1
def get_chinese_word(tokens: List[str]):
word_set = set()
for token in tokens:
chinese_word = len(token) > 1 and is_chinese(token)
if chinese_word:
word_set.add(token)
word_list = list(word_set)
return word_list
def add_sub_symbol(bert_tokens: List[str], chinese_word_set: set()):
if not chinese_word_set:
return bert_tokens
max_word_len = max([len(w) for w in chinese_word_set])
bert_word = bert_tokens
start, end = 0, len(bert_word)
while start < end:
single_word = True
if is_chinese(bert_word[start]):
l = min(end - start, max_word_len)
for i in range(l, 1, -1):
whole_word = "".join(bert_word[start : start + i])
if whole_word in chinese_word_set:
for j in range(start + 1, start + i):
bert_word[j] = "##" + bert_word[j]
start = start + i
single_word = False
break
if single_word:
start += 1
return bert_word
def prepare_ref(lines: List[str], ltp_tokenizer: LTP, bert_tokenizer: BertTokenizer):
ltp_res = []
for i in range(0, len(lines), 100):
res = ltp_tokenizer.seg(lines[i : i + 100])[0]
res = [get_chinese_word(r) for r in res]
ltp_res.extend(res)
assert len(ltp_res) == len(lines)
bert_res = []
for i in range(0, len(lines), 100):
res = bert_tokenizer(lines[i : i + 100], add_special_tokens=True, truncation=True, max_length=512)
bert_res.extend(res["input_ids"])
assert len(bert_res) == len(lines)
ref_ids = []
for input_ids, chinese_word in zip(bert_res, ltp_res):
input_tokens = []
for id in input_ids:
token = bert_tokenizer._convert_id_to_token(id)
input_tokens.append(token)
input_tokens = add_sub_symbol(input_tokens, chinese_word)
ref_id = []
# We only save pos of chinese subwords start with ##, which mean is part of a whole word.
for i, token in enumerate(input_tokens):
if token[:2] == "##":
clean_token = token[2:]
# save chinese tokens' pos
if len(clean_token) == 1 and _is_chinese_char(ord(clean_token)):
ref_id.append(i)
ref_ids.append(ref_id)
assert len(ref_ids) == len(bert_res)
return ref_ids
def main(args):
# For Chinese (Ro)Bert, the best result is from : RoBERTa-wwm-ext (https://github.com/ymcui/Chinese-BERT-wwm)
# If we want to fine-tune these model, we have to use same tokenizer : LTP (https://github.com/HIT-SCIR/ltp)
with open(args.file_name, "r", encoding="utf-8") as f:
data = f.readlines()
data = [line.strip() for line in data if len(line) > 0 and not line.isspace()] # avoid delimiter like '\u2029'
ltp_tokenizer = LTP(args.ltp) # faster in GPU device
bert_tokenizer = BertTokenizer.from_pretrained(args.bert)
ref_ids = prepare_ref(data, ltp_tokenizer, bert_tokenizer)
with open(args.save_path, "w", encoding="utf-8") as f:
data = [json.dumps(ref) + "\n" for ref in ref_ids]
f.writelines(data)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="prepare_chinese_ref")
parser.add_argument(
"--file_name",
type=str,
default="./resources/chinese-demo.txt",
help="file need process, same as training data in lm",
)
parser.add_argument(
"--ltp", type=str, default="./resources/ltp", help="resources for LTP tokenizer, usually a path"
)
parser.add_argument("--bert", type=str, default="./resources/robert", help="resources for Bert tokenizer")
parser.add_argument("--save_path", type=str, default="./resources/ref.txt", help="path to save res")
args = parser.parse_args()
main(args)

Просмотреть файл

@ -337,6 +337,10 @@ def main():
tokenized_datasets["validation"] = add_chinese_references(
tokenized_datasets["validation"], data_args.validation_ref_file
)
# If we have ref files, need to avoid it removed by trainer
has_ref = data_args.train_ref_file or data_args.validation_ref_file
if has_ref:
training_args.remove_unused_columns = False
# Data collator
# This one will take care of randomly masking the tokens.

Просмотреть файл

@ -402,7 +402,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling):
# For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢]
if "chinese_ref" in e:
ref_pos = tolist(e["chinese_ref"])
len_seq = e["input_ids"].size(0)
len_seq = len(e["input_ids"])
for i in range(len_seq):
if i in ref_pos:
ref_tokens[i] = "##" + ref_tokens[i]