From 1682804ebd504d3381523116773583a52f35afd1 Mon Sep 17 00:00:00 2001 From: wlhgtc Date: Mon, 1 Feb 2021 16:37:59 +0800 Subject: [PATCH] Fit chinese wwm to new datasets (#9887) * MOD: fit chinese wwm to new datasets * MOD: move wwm to new folder * MOD: formate code * Styling * MOD add param and recover trainer Co-authored-by: Sylvain Gugger --- examples/language-modeling/README.md | 67 +------- examples/research_projects/mlm_wwm/README.md | 92 +++++++++++ .../mlm_wwm/requirements.txt | 4 + .../mlm_wwm/run_chinese_ref.py | 147 ++++++++++++++++++ .../mlm_wwm}/run_mlm_wwm.py | 4 + src/transformers/data/data_collator.py | 2 +- 6 files changed, 249 insertions(+), 67 deletions(-) create mode 100644 examples/research_projects/mlm_wwm/README.md create mode 100644 examples/research_projects/mlm_wwm/requirements.txt create mode 100644 examples/research_projects/mlm_wwm/run_chinese_ref.py rename examples/{language-modeling => research_projects/mlm_wwm}/run_mlm_wwm.py (98%) diff --git a/examples/language-modeling/README.md b/examples/language-modeling/README.md index d47430e2e..6d913bbfa 100644 --- a/examples/language-modeling/README.md +++ b/examples/language-modeling/README.md @@ -100,72 +100,7 @@ sure all your batches have the same length. ### Whole word masking -The BERT authors released a new version of BERT using Whole Word Masking in May 2019. Instead of masking randomly -selected tokens (which may be part of words), they mask randomly selected words (masking all the tokens corresponding -to that word). This technique has been refined for Chinese in [this paper](https://arxiv.org/abs/1906.08101). - -To fine-tune a model using whole word masking, use the following script: -```bash -python run_mlm_wwm.py \ - --model_name_or_path roberta-base \ - --dataset_name wikitext \ - --dataset_config_name wikitext-2-raw-v1 \ - --do_train \ - --do_eval \ - --output_dir /tmp/test-mlm-wwm -``` - -For Chinese models, we need to generate a reference files (which requires the ltp library), because it's tokenized at -the character level. - -**Q :** Why a reference file? - -**A :** Suppose we have a Chinese sentence like: `我喜欢你` The original Chinese-BERT will tokenize it as -`['我','喜','欢','你']` (character level). But `喜欢` is a whole word. For whole word masking proxy, we need a result -like `['我','喜','##欢','你']`, so we need a reference file to tell the model which position of the BERT original token -should be added `##`. - -**Q :** Why LTP ? - -**A :** Cause the best known Chinese WWM BERT is [Chinese-BERT-wwm](https://github.com/ymcui/Chinese-BERT-wwm) by HIT. -It works well on so many Chines Task like CLUE (Chinese GLUE). They use LTP, so if we want to fine-tune their model, -we need LTP. - -Now LTP only only works well on `transformers==3.2.0`. So we don't add it to requirements.txt. -You need to create a separate environment with this version of Transformers to run the `run_chinese_ref.py` script that -will create the reference files. The script is in `examples/contrib`. Once in the proper environment, run the -following: - - -```bash -export TRAIN_FILE=/path/to/dataset/wiki.train.raw -export LTP_RESOURCE=/path/to/ltp/tokenizer -export BERT_RESOURCE=/path/to/bert/tokenizer -export SAVE_PATH=/path/to/data/ref.txt - -python examples/contrib/run_chinese_ref.py \ - --file_name=path_to_train_or_eval_file \ - --ltp=path_to_ltp_tokenizer \ - --bert=path_to_bert_tokenizer \ - --save_path=path_to_reference_file -``` - -Then you can run the script like this: - - -```bash -python run_mlm_wwm.py \ - --model_name_or_path roberta-base \ - --train_file path_to_train_file \ - --validation_file path_to_validation_file \ - --train_ref_file path_to_train_chinese_ref_file \ - --validation_ref_file path_to_validation_chinese_ref_file \ - --do_train \ - --do_eval \ - --output_dir /tmp/test-mlm-wwm -``` - -**Note:** On TPU, you should the flag `--pad_to_max_length` to make sure all your batches have the same length. +This part was moved to `examples/research_projects/mlm_wwm`. ### XLNet and permutation language modeling diff --git a/examples/research_projects/mlm_wwm/README.md b/examples/research_projects/mlm_wwm/README.md new file mode 100644 index 000000000..33ff7ab6d --- /dev/null +++ b/examples/research_projects/mlm_wwm/README.md @@ -0,0 +1,92 @@ + + +## Whole Word Mask Language Model + + +These scripts leverage the 🤗 Datasets library and the Trainer API. You can easily customize them to your needs if you +need extra processing on your datasets. + +The following examples, will run on a datasets hosted on our [hub](https://huggingface.co/datasets) or with your own +text files for training and validation. We give examples of both below. + + + +The BERT authors released a new version of BERT using Whole Word Masking in May 2019. Instead of masking randomly +selected tokens (which may be part of words), they mask randomly selected words (masking all the tokens corresponding +to that word). This technique has been refined for Chinese in [this paper](https://arxiv.org/abs/1906.08101). + +To fine-tune a model using whole word masking, use the following script: +```bash +python run_mlm_wwm.py \ + --model_name_or_path roberta-base \ + --dataset_name wikitext \ + --dataset_config_name wikitext-2-raw-v1 \ + --do_train \ + --do_eval \ + --output_dir /tmp/test-mlm-wwm +``` + +For Chinese models, we need to generate a reference files (which requires the ltp library), because it's tokenized at +the character level. + +**Q :** Why a reference file? + +**A :** Suppose we have a Chinese sentence like: `我喜欢你` The original Chinese-BERT will tokenize it as +`['我','喜','欢','你']` (character level). But `喜欢` is a whole word. For whole word masking proxy, we need a result +like `['我','喜','##欢','你']`, so we need a reference file to tell the model which position of the BERT original token +should be added `##`. + +**Q :** Why LTP ? + +**A :** Cause the best known Chinese WWM BERT is [Chinese-BERT-wwm](https://github.com/ymcui/Chinese-BERT-wwm) by HIT. +It works well on so many Chines Task like CLUE (Chinese GLUE). They use LTP, so if we want to fine-tune their model, +we need LTP. + +You could run the following: + + +```bash +export TRAIN_FILE=/path/to/dataset/wiki.train.raw +export LTP_RESOURCE=/path/to/ltp/tokenizer +export BERT_RESOURCE=/path/to/bert/tokenizer +export SAVE_PATH=/path/to/data/ref.txt + +python run_chinese_ref.py \ + --file_name=path_to_train_or_eval_file \ + --ltp=path_to_ltp_tokenizer \ + --bert=path_to_bert_tokenizer \ + --save_path=path_to_reference_file +``` + +Then you can run the script like this: + + +```bash +python run_mlm_wwm.py \ + --model_name_or_path roberta-base \ + --train_file path_to_train_file \ + --validation_file path_to_validation_file \ + --train_ref_file path_to_train_chinese_ref_file \ + --validation_ref_file path_to_validation_chinese_ref_file \ + --do_train \ + --do_eval \ + --output_dir /tmp/test-mlm-wwm +``` + +**Note1:** On TPU, you should the flag `--pad_to_max_length` to make sure all your batches have the same length. + +**Note2:** And if you have any questions or something goes wrong when runing this code, don't hesitate to pin @wlhgtc. \ No newline at end of file diff --git a/examples/research_projects/mlm_wwm/requirements.txt b/examples/research_projects/mlm_wwm/requirements.txt new file mode 100644 index 000000000..2d0f26bd4 --- /dev/null +++ b/examples/research_projects/mlm_wwm/requirements.txt @@ -0,0 +1,4 @@ +datasets >= 1.1.3 +sentencepiece != 0.1.92 +protobuf +ltp diff --git a/examples/research_projects/mlm_wwm/run_chinese_ref.py b/examples/research_projects/mlm_wwm/run_chinese_ref.py new file mode 100644 index 000000000..8c4250a36 --- /dev/null +++ b/examples/research_projects/mlm_wwm/run_chinese_ref.py @@ -0,0 +1,147 @@ +import argparse +import json +from typing import List + +from ltp import LTP +from transformers.models.bert.tokenization_bert import BertTokenizer + + +def _is_chinese_char(cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) # + or (cp >= 0x20000 and cp <= 0x2A6DF) # + or (cp >= 0x2A700 and cp <= 0x2B73F) # + or (cp >= 0x2B740 and cp <= 0x2B81F) # + or (cp >= 0x2B820 and cp <= 0x2CEAF) # + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) # + ): # + return True + + return False + + +def is_chinese(word: str): + # word like '180' or '身高' or '神' + for char in word: + char = ord(char) + if not _is_chinese_char(char): + return 0 + return 1 + + +def get_chinese_word(tokens: List[str]): + word_set = set() + + for token in tokens: + chinese_word = len(token) > 1 and is_chinese(token) + if chinese_word: + word_set.add(token) + word_list = list(word_set) + return word_list + + +def add_sub_symbol(bert_tokens: List[str], chinese_word_set: set()): + if not chinese_word_set: + return bert_tokens + max_word_len = max([len(w) for w in chinese_word_set]) + + bert_word = bert_tokens + start, end = 0, len(bert_word) + while start < end: + single_word = True + if is_chinese(bert_word[start]): + l = min(end - start, max_word_len) + for i in range(l, 1, -1): + whole_word = "".join(bert_word[start : start + i]) + if whole_word in chinese_word_set: + for j in range(start + 1, start + i): + bert_word[j] = "##" + bert_word[j] + start = start + i + single_word = False + break + if single_word: + start += 1 + return bert_word + + +def prepare_ref(lines: List[str], ltp_tokenizer: LTP, bert_tokenizer: BertTokenizer): + ltp_res = [] + + for i in range(0, len(lines), 100): + res = ltp_tokenizer.seg(lines[i : i + 100])[0] + res = [get_chinese_word(r) for r in res] + ltp_res.extend(res) + assert len(ltp_res) == len(lines) + + bert_res = [] + for i in range(0, len(lines), 100): + res = bert_tokenizer(lines[i : i + 100], add_special_tokens=True, truncation=True, max_length=512) + bert_res.extend(res["input_ids"]) + assert len(bert_res) == len(lines) + + ref_ids = [] + for input_ids, chinese_word in zip(bert_res, ltp_res): + + input_tokens = [] + for id in input_ids: + token = bert_tokenizer._convert_id_to_token(id) + input_tokens.append(token) + input_tokens = add_sub_symbol(input_tokens, chinese_word) + ref_id = [] + # We only save pos of chinese subwords start with ##, which mean is part of a whole word. + for i, token in enumerate(input_tokens): + if token[:2] == "##": + clean_token = token[2:] + # save chinese tokens' pos + if len(clean_token) == 1 and _is_chinese_char(ord(clean_token)): + ref_id.append(i) + ref_ids.append(ref_id) + + assert len(ref_ids) == len(bert_res) + + return ref_ids + + +def main(args): + # For Chinese (Ro)Bert, the best result is from : RoBERTa-wwm-ext (https://github.com/ymcui/Chinese-BERT-wwm) + # If we want to fine-tune these model, we have to use same tokenizer : LTP (https://github.com/HIT-SCIR/ltp) + with open(args.file_name, "r", encoding="utf-8") as f: + data = f.readlines() + data = [line.strip() for line in data if len(line) > 0 and not line.isspace()] # avoid delimiter like '\u2029' + ltp_tokenizer = LTP(args.ltp) # faster in GPU device + bert_tokenizer = BertTokenizer.from_pretrained(args.bert) + + ref_ids = prepare_ref(data, ltp_tokenizer, bert_tokenizer) + + with open(args.save_path, "w", encoding="utf-8") as f: + data = [json.dumps(ref) + "\n" for ref in ref_ids] + f.writelines(data) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="prepare_chinese_ref") + parser.add_argument( + "--file_name", + type=str, + default="./resources/chinese-demo.txt", + help="file need process, same as training data in lm", + ) + parser.add_argument( + "--ltp", type=str, default="./resources/ltp", help="resources for LTP tokenizer, usually a path" + ) + parser.add_argument("--bert", type=str, default="./resources/robert", help="resources for Bert tokenizer") + parser.add_argument("--save_path", type=str, default="./resources/ref.txt", help="path to save res") + + args = parser.parse_args() + main(args) diff --git a/examples/language-modeling/run_mlm_wwm.py b/examples/research_projects/mlm_wwm/run_mlm_wwm.py similarity index 98% rename from examples/language-modeling/run_mlm_wwm.py rename to examples/research_projects/mlm_wwm/run_mlm_wwm.py index e08c5e979..5f1926c1b 100644 --- a/examples/language-modeling/run_mlm_wwm.py +++ b/examples/research_projects/mlm_wwm/run_mlm_wwm.py @@ -337,6 +337,10 @@ def main(): tokenized_datasets["validation"] = add_chinese_references( tokenized_datasets["validation"], data_args.validation_ref_file ) + # If we have ref files, need to avoid it removed by trainer + has_ref = data_args.train_ref_file or data_args.validation_ref_file + if has_ref: + training_args.remove_unused_columns = False # Data collator # This one will take care of randomly masking the tokens. diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index 79e30b97b..d585b419e 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -402,7 +402,7 @@ class DataCollatorForWholeWordMask(DataCollatorForLanguageModeling): # For Chinese tokens, we need extra inf to mark sub-word, e.g [喜,欢]-> [喜,##欢] if "chinese_ref" in e: ref_pos = tolist(e["chinese_ref"]) - len_seq = e["input_ids"].size(0) + len_seq = len(e["input_ids"]) for i in range(len_seq): if i in ref_pos: ref_tokens[i] = "##" + ref_tokens[i]