зеркало из https://github.com/microsoft/MT-DNN.git
Feature/v.0.0.0 (#1)
* MT-DNN Feature release v1.0.0 ignore: IDE meta files add: initial checkin feat: make package pip installable doc: add contribution steps and update readme with additional information fix: wrong file references add: package dependencies like PyTorch transformer support fix import statements feat: add download utils feat: add download shell script doc: update readme with testing instructions doc: add data downloading and processing step feat: move fit and predict into MTDNN Model feat: make logger create a new log file each run remove: fit and predict functions from the pipeline class update: remove stale references doc: fit and predict now on the model object formatting code snippet doc: add pip install steps ignore: checkpoints and log files add: conda file for env generation feat: docker file support doc: README for example and data add: sample data in json lines format feat: jupyter example remove: batch_size is controlled by configuration feat: jupyter notebook add: license and copyright add: git lfs track sample data files data: add sample data file with git lfs doc: batch_size is now set in config cleanup
This commit is contained in:
Родитель
6b4215cd68
Коммит
9dec8253f1
|
@ -0,0 +1 @@
|
|||
*.json filter=lfs diff=lfs merge=lfs -text
|
|
@ -8,6 +8,18 @@
|
|||
*.user
|
||||
*.userosscache
|
||||
*.sln.docstates
|
||||
.vscode
|
||||
.idea
|
||||
|
||||
|
||||
# MT-DNN downloaded data and egg files
|
||||
data/
|
||||
mt_dnn_models/
|
||||
checkpoint/
|
||||
tensorboard_logdir/
|
||||
*.egg-info
|
||||
|
||||
|
||||
|
||||
# User-specific files (MonoDevelop/Xamarin Studio)
|
||||
*.userprefs
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
# Contribution Guidelines
|
||||
|
||||
Contribution are welcome! Here's a few things to know:
|
||||
|
||||
- [Contribution Guidelines](#contribution-guidelines)
|
||||
- [Microsoft Contributor License Agreement](#microsoft-contributor-license-agreement)
|
||||
- [Steps to Contributing](#steps-to-contributing)
|
||||
- [Coding Guidelines](#coding-guidelines)
|
||||
- [Code of Conduct](#code-of-conduct)
|
||||
- [Do not point fingers](#do-not-point-fingers)
|
||||
- [Provide code feedback based on evidence](#provide-code-feedback-based-on-evidence)
|
||||
- [Ask questions do not give answers](#ask-questions-do-not-give-answers)
|
||||
|
||||
## Microsoft Contributor License Agreement
|
||||
|
||||
Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.microsoft.com.
|
||||
|
||||
When you submit a pull request, a CLA-bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA.
|
||||
|
||||
## Steps to Contributing
|
||||
|
||||
Here are the basic steps to get started with your first contribution. Please reach out with any questions.
|
||||
1. Use [open issues](https://github.com/Microsoft/Recommenders/issues) to discuss the proposed changes. Create an issue describing changes if necessary to collect feedback. Also, please use provided labels to tag issues so everyone can easily sort issues of interest.
|
||||
2. [Fork the repo](https://help.github.com/articles/fork-a-repo/) so you can make and test local changes.
|
||||
3. Create a new branch for the issue. We suggest prefixing the branch with your username and then a descriptive title: (e.g. gramhagen/update_contributing_docs)
|
||||
4. Create a test that replicates the issue.
|
||||
5. Make code changes.
|
||||
6. Ensure unit tests pass and code style / formatting is consistent (see [wiki](https://github.com/Microsoft/Recommenders/wiki/Coding-Guidelines#python-and-docstrings-style) for more details).
|
||||
7. We use [pre-commit](https://pre-commit.com/) package to run our pre-commit hooks. We use black formatter and flake8 linting on each commit. In order to set up pre-commit on your machine, follow the steps here, please note that you only need to run these steps the first time you use pre-commit for this project.
|
||||
|
||||
* Update your conda environment, pre-commit is part of the yaml file or just do
|
||||
```
|
||||
$ pip install pre-commit
|
||||
```
|
||||
* Set up pre-commit by running following command, this will put pre-commit under your .git/hooks directory.
|
||||
```
|
||||
$ pre-commit install
|
||||
```
|
||||
```
|
||||
$ git commit -m "message"
|
||||
```
|
||||
* Each time you commit, git will run the pre-commit hooks (black and flake8 for now) on any python files that are getting committed and are part of the git index. If black modifies/formats the file, or if flake8 finds any linting errors, the commit will not succeed. You will need to stage the file again if black changed the file, or fix the issues identified by flake8 and and stage it again.
|
||||
|
||||
* To run pre-commit on all files just run
|
||||
```
|
||||
$ pre-commit run --all-files
|
||||
8. Create a pull request against <b>staging</b> branch.
|
||||
|
||||
Note: We use the staging branch to land all new features, so please remember to create the Pull Request against staging.
|
||||
|
||||
Once the features included in a milestone are complete we will merge staging into master and make a release. See the wiki for more detail about our [merge strategy](https://github.com/Microsoft/Recommenders/wiki/Strategy-to-merge-the-code-to-master-branch).
|
||||
|
||||
## Coding Guidelines
|
||||
|
||||
We strive to maintain high quality code to make the utilities in the repository easy to understand, use, and extend. We also work hard to maintain a friendly and constructive environment. We've found that having clear expectations on the development process and consistent style helps to ensure everyone can contribute and collaborate effectively.
|
||||
|
||||
Please review the [coding guidelines](https://github.com/Microsoft/Recommenders/wiki/Coding-Guidelines) wiki page to see more details about the expectations for development approach and style.
|
||||
|
||||
## Code of Conduct
|
||||
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||
|
||||
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
||||
|
||||
Apart from the official Code of Conduct developed by Microsoft, in the Recommenders team we adopt the following behaviors, to ensure a great working environment:
|
||||
|
||||
#### Do not point fingers
|
||||
Let’s be constructive.
|
||||
|
||||
<details>
|
||||
<summary><em>Click here to see some examples</em></summary>
|
||||
|
||||
"This method is missing docstrings" instead of "YOU forgot to put docstrings".
|
||||
|
||||
</details>
|
||||
|
||||
#### Provide code feedback based on evidence
|
||||
|
||||
When making code reviews, try to support your ideas based on evidence (papers, library documentation, stackoverflow, etc) rather than your personal preferences.
|
||||
|
||||
<details>
|
||||
<summary><em>Click here to see some examples</em></summary>
|
||||
|
||||
"When reviewing this code, I saw that the Python implementation the metrics are based on classes, however, [scikit-learn](https://scikit-learn.org/stable/modules/classes.html#sklearn-metrics-metrics) and [tensorflow](https://www.tensorflow.org/api_docs/python/tf/metrics) use functions. We should follow the standard in the industry."
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
#### Ask questions do not give answers
|
||||
Try to be empathic.
|
||||
|
||||
<details>
|
||||
<summary><em>Click here to see some examples</em></summary>
|
||||
|
||||
* Would it make more sense if ...?
|
||||
* Have you considered this ... ?
|
||||
|
||||
</details>
|
255
README.md
255
README.md
|
@ -1,14 +1,247 @@
|
|||
# Multi-Task Deep Neural Networks for Natural Language Understanding
|
||||
|
||||
This PyTorch package implements the Multi-Task Deep Neural Networks (MT-DNN) for Natural Language Understanding, as described in:
|
||||
|
||||
Xiaodong Liu\*, Pengcheng He\*, Weizhu Chen and Jianfeng Gao<br/>
|
||||
Multi-Task Deep Neural Networks for Natural Language Understanding<br/>
|
||||
[ACL 2019](https://aclweb.org/anthology/papers/P/P19/P19-1441/) <br/>
|
||||
\*: Equal contribution <br/>
|
||||
|
||||
Xiaodong Liu, Pengcheng He, Weizhu Chen and Jianfeng Gao<br/>
|
||||
Improving Multi-Task Deep Neural Networks via Knowledge Distillation for Natural Language Understanding <br/>
|
||||
[arXiv version](https://arxiv.org/abs/1904.09482) <br/>
|
||||
|
||||
|
||||
Pengcheng He, Xiaodong Liu, Weizhu Chen and Jianfeng Gao<br/>
|
||||
Hybrid Neural Network Model for Commonsense Reasoning <br/>
|
||||
[arXiv version](https://arxiv.org/abs/1907.11983) <br/>
|
||||
|
||||
|
||||
Liyuan Liu, Haoming Jiang, Pengcheng He, Weizhu Chen, Xiaodong Liu, Jianfeng Gao and Jiawei Han <br/>
|
||||
On the Variance of the Adaptive Learning Rate and Beyond <br/>
|
||||
[arXiv version](https://arxiv.org/abs/1908.03265) <br/>
|
||||
|
||||
Haoming Jiang, Pengcheng He, Weizhu Chen, Xiaodong Liu, Jianfeng Gao and Tuo Zhao <br/>
|
||||
SMART: Robust and Efficient Fine-Tuning for Pre-trained Natural Language Models through Principled Regularized Optimization <br/>
|
||||
[arXiv version](https://arxiv.org/abs/1911.03437) <br/>
|
||||
|
||||
|
||||
## Pip install package
|
||||
A [setup.py](./setup.py) file is provided in order to simplify the installation of this package.
|
||||
|
||||
1. To install the package, please run the command below (from directory root)
|
||||
|
||||
```Python
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
1. Running the command tells pip to install the `mt-dnn` package from source in development mode. This just means that any updates to `mt-dnn` source directory will immediately be reflected in the installed package without needing to reinstall; a very useful practice for a package with constant updates.
|
||||
|
||||
1. It is also possible to install directly from Github, which is the best way to utilize the package in external projects (while still reflecting updates to the source as it's installed as an editable '-e' package).
|
||||
|
||||
```Python
|
||||
pip install -e git+git@github.com:microsoft/mt-dnn.git@master#egg=mtdnn
|
||||
```
|
||||
|
||||
1. Either command, from above, makes `mt-dnn` available in your conda virtual environment. You can verify it was properly installed by running:
|
||||
|
||||
```Python
|
||||
pip list | grep mtdnn
|
||||
```
|
||||
|
||||
## How To Use
|
||||
1. Create a model configuration object, `MTDNNConfig`, with the necessary parameters to initialize the MT-DNN model. Initialization without any parameters will default to a similar configuration that initializes a BERT model. This configuration object can be initialized wit training and learning parameters like `batch_size` and `learning_rate`. Please consult the class implementation for all parameters.
|
||||
|
||||
```Python
|
||||
BATCH_SIZE = 16
|
||||
config = MTDNNConfig(batch_size=BATCH_SIZE)
|
||||
```
|
||||
|
||||
1. Define the task parameters to train for and initialize an `MTDNNTaskDefs` object.
|
||||
|
||||
```Python
|
||||
tasks_params = {
|
||||
"mnli": {
|
||||
"data_format": "PremiseAndOneHypothesis",
|
||||
"encoder_type": "BERT",
|
||||
"dropout_p": 0.3,
|
||||
"enable_san": True,
|
||||
"labels": ["contradiction", "neutral", "entailment"],
|
||||
"metric_meta": ["ACC"],
|
||||
"loss": "CeCriterion",
|
||||
"kd_loss": "MseCriterion",
|
||||
"n_class": 3,
|
||||
"split_names": [
|
||||
"train",
|
||||
"matched_dev",
|
||||
"mismatched_dev",
|
||||
"matched_test",
|
||||
"mismatched_test",
|
||||
],
|
||||
"task_type": "Classification",
|
||||
},
|
||||
}
|
||||
task_defs = MTDNNTaskDefs(tasks_params)
|
||||
```
|
||||
|
||||
1. Create a data preprocessing object, `MTDNNDataProcess`. This creates the training, test and development PyTorch dataloaders needed for training and testing. We also need to retrieve the necessary training options required to initialize the model correctly, for all tasks.
|
||||
|
||||
```Python
|
||||
data_processor = MTDNNDataProcess(
|
||||
config=config,
|
||||
task_defs=task_defs,
|
||||
data_dir="/home/useradmin/sources/mt-dnn/data/canonical_data/bert_uncased_lower",
|
||||
train_datasets_list=["mnli"],
|
||||
test_datasets_list=["mnli_mismatched", "mnli_matched"],
|
||||
)
|
||||
|
||||
# Retrieve the multi task train, dev and test dataloaders
|
||||
multitask_train_dataloader = data_processor.get_train_dataloader()
|
||||
dev_dataloaders_list = data_processor.get_dev_dataloaders()
|
||||
test_dataloaders_list = data_processor.get_test_dataloaders()
|
||||
|
||||
# Get training options to initialize model
|
||||
decoder_opts = data_processor.get_decoder_options_list()
|
||||
task_types = data_processor.get_task_types_list()
|
||||
dropout_list = data_processor.get_tasks_dropout_prob_list()
|
||||
loss_types = data_processor.get_loss_types_list()
|
||||
kd_loss_types = data_processor.get_kd_loss_types_list()
|
||||
tasks_nclass_list = data_processor.get_task_nclass_list()
|
||||
num_all_batches = data_processor.get_num_all_batches()
|
||||
```
|
||||
|
||||
1. Now we can create an `MTDNNModel`.
|
||||
```Python
|
||||
model = MTDNNModel(
|
||||
config,
|
||||
task_defs,
|
||||
pretrained_model_name="bert-base-uncased",
|
||||
num_train_step=num_all_batches,
|
||||
decoder_opts=decoder_opts,
|
||||
task_types=task_types,
|
||||
dropout_list=dropout_list,
|
||||
loss_types=loss_types,
|
||||
kd_loss_types=kd_loss_types,
|
||||
tasks_nclass_list=tasks_nclass_list,
|
||||
multitask_train_dataloader=multitask_train_dataloader,
|
||||
dev_dataloaders_list=dev_dataloaders_list,
|
||||
test_dataloaders_list=test_dataloaders_list,
|
||||
)
|
||||
```
|
||||
1. At this point the MT-DNN model allows us to fit to the model and create predictions. The fit takes an optional `epochs` parameter that overwrites the epochs set in the `MTDNNConfig` object.
|
||||
|
||||
```Python
|
||||
model.fit()
|
||||
model.predict()
|
||||
```
|
||||
|
||||
|
||||
1. The predict function can take an optional checkpoint, `trained_model_chckpt`. This can be used for inference and running evaluations on an already trained PyTorch MT-DNN model.
|
||||
Optionally using a previously trained model as checkpoint.
|
||||
|
||||
```Python
|
||||
# Predict using a PyTorch model checkpoint
|
||||
checkpt = "./model_0.pt"
|
||||
model.predict(trained_model_chckpt=checkpt)
|
||||
|
||||
```
|
||||
|
||||
## Pre-process your data in the correct format
|
||||
Depending on what `data_format` you have set in the configuration object `MTDNNConfig`, please follow the detailed data format below to prepare your data:
|
||||
|
||||
- `PremiseOnly` : single text, i.e. premise. Data format is "id" \t "label" \t "premise" .
|
||||
|
||||
- `PremiseAndOneHypothesis` : two texts, i.e. one premise and one hypothesis. Data format is "id" \t "label" \t "premise" \t "hypothesis".
|
||||
|
||||
- `PremiseAndMultiHypothesis` : one text as premise and multiple candidates of texts as hypothesis. Data format is "id" \t "label" \t "premise" \t "hypothesis_1" \t "hypothesis_2" \t ... \t "hypothesis_n".
|
||||
|
||||
- `Sequence` : sequence tagging. Data format is "id" \t "label" \t "premise".
|
||||
|
||||
|
||||
## FAQ
|
||||
|
||||
### Did you share the pretrained mt-dnn models?
|
||||
Yes, we released the pretrained shared embedings via MTL which are aligned to BERT base/large models: ```mt_dnn_base.pt``` and ```mt_dnn_large.pt```. </br>
|
||||
|
||||
### How can we obtain the data and pre-trained models to test to try out?
|
||||
Yes, we have provided a [download script](./scripts/download.sh) to assist with this.
|
||||
|
||||
### Why SciTail/SNLI do not enable SAN?
|
||||
For SciTail/SNLI tasks, the purpose is to test generalization of the learned embedding and how easy it is adapted to a new domain instead of complicated model structures for a direct comparison with BERT. Thus, we use a linear projection on the all **domain adaptation** settings.
|
||||
|
||||
### What is the difference between V1 and V2
|
||||
The difference is in the QNLI dataset. Please refere to the GLUE official homepage for more details. If you want to formulate QNLI as pair-wise ranking task as our paper, make sure that you use the old QNLI data. </br>
|
||||
Then run the prepro script with flags: ```> sh experiments/glue/prepro.sh --old_glue``` </br>
|
||||
If you have issues to access the old version of the data, please contact the GLUE team.
|
||||
|
||||
### Did you fine-tune single task for your GLUE leaderboard submission?
|
||||
We can use the multi-task refinement model to run the prediction and produce a reasonable result. But to achieve a better result, it requires a fine-tuneing on each task. It is worthing noting the paper in arxiv is a littled out-dated and on the old GLUE dataset. We will update the paper as we mentioned below.
|
||||
|
||||
|
||||
## Notes and Acknowledgments
|
||||
BERT pytorch is from: https://github.com/huggingface/pytorch-pretrained-BERT <br/>
|
||||
BERT: https://github.com/google-research/bert <br/>
|
||||
We also used some code from: https://github.com/kevinduh/san_mrc <br/>
|
||||
|
||||
## Related Projects/Codebase
|
||||
1. Pretrained UniLM: https://github.com/microsoft/unilm <br/>
|
||||
2. Pretrained Response Generation Model: https://github.com/microsoft/DialoGPT <br/>
|
||||
3. Internal MT-DNN repo: https://github.com/microsoft/mt-dnn <br/>
|
||||
|
||||
### How do I cite MT-DNN?
|
||||
|
||||
```
|
||||
@inproceedings{liu2019mt-dnn,
|
||||
title = "Multi-Task Deep Neural Networks for Natural Language Understanding",
|
||||
author = "Liu, Xiaodong and He, Pengcheng and Chen, Weizhu and Gao, Jianfeng",
|
||||
booktitle = "Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics",
|
||||
month = jul,
|
||||
year = "2019",
|
||||
address = "Florence, Italy",
|
||||
publisher = "Association for Computational Linguistics",
|
||||
url = "https://www.aclweb.org/anthology/P19-1441",
|
||||
pages = "4487--4496"
|
||||
}
|
||||
|
||||
|
||||
@article{liu2019mt-dnn-kd,
|
||||
title={Improving Multi-Task Deep Neural Networks via Knowledge Distillation for Natural Language Understanding},
|
||||
author={Liu, Xiaodong and He, Pengcheng and Chen, Weizhu and Gao, Jianfeng},
|
||||
journal={arXiv preprint arXiv:1904.09482},
|
||||
year={2019}
|
||||
}
|
||||
|
||||
|
||||
@article{he2019hnn,
|
||||
title={A Hybrid Neural Network Model for Commonsense Reasoning},
|
||||
author={He, Pengcheng and Liu, Xiaodong and Chen, Weizhu and Gao, Jianfeng},
|
||||
journal={arXiv preprint arXiv:1907.11983},
|
||||
year={2019}
|
||||
}
|
||||
|
||||
|
||||
@article{liu2019radam,
|
||||
title={On the Variance of the Adaptive Learning Rate and Beyond},
|
||||
author={Liu, Liyuan and Jiang, Haoming and He, Pengcheng and Chen, Weizhu and Liu, Xiaodong and Gao, Jianfeng and Han, Jiawei},
|
||||
journal={arXiv preprint arXiv:1908.03265},
|
||||
year={2019}
|
||||
}
|
||||
|
||||
|
||||
@article{jiang2019smart,
|
||||
title={SMART: Robust and Efficient Fine-Tuning for Pre-trained Natural Language Models through Principled Regularized Optimization},
|
||||
author={Jiang, Haoming and He, Pengcheng and Chen, Weizhu and Liu, Xiaodong and Gao, Jianfeng and Zhao, Tuo},
|
||||
journal={arXiv preprint arXiv:1911.03437},
|
||||
year={2019}
|
||||
}
|
||||
```
|
||||
### Contact Information
|
||||
|
||||
For help or issues using MT-DNN, please submit a GitHub issue.
|
||||
|
||||
For personal communication related to this package, please contact Xiaodong Liu (`xiaodl@microsoft.com`), Yu Wang (`yuwan@microsoft.com`), Pengcheng He (`penhe@microsoft.com`), Weizhu Chen (`wzchen@microsoft.com`), Jianshu Ji (`jianshuj@microsoft.com`), Emmanuel Awa (`Emmanuel.Awa@microsoft.com`) or Jianfeng Gao (`jfgao@microsoft.com`).
|
||||
|
||||
|
||||
# Contributing
|
||||
|
||||
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
||||
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
||||
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
|
||||
|
||||
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
|
||||
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
|
||||
provided by the bot. You will only need to do this once across all repos using our CLA.
|
||||
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
||||
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
||||
This project welcomes contributions and suggestions. For more details please check the complete steps to contributing to this repo [here](./CONTRIBUTION.md).
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# Adapted from https://github.com/microsoft/nlp-recipes/blob/master/docker/Dockerfile
|
||||
|
||||
FROM nvidia/cuda
|
||||
|
||||
# Install Anaconda
|
||||
# Non interactive installation instructions can be found
|
||||
# https://hub.docker.com/r/continuumio/anaconda/dockerfile
|
||||
# https://hub.docker.com/r/continuumio/miniconda/dockerfile
|
||||
ENV PATH /opt/conda/bin:$PATH
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
RUN apt-get update --fix-missing && apt-get install -y wget bzip2 ca-certificates \
|
||||
libglib2.0-0 libxext6 libsm6 libxrender1 \
|
||||
git mercurial subversion
|
||||
|
||||
|
||||
RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && \
|
||||
/bin/bash ~/miniconda.sh -b -p /opt/conda && \
|
||||
rm ~/miniconda.sh && \
|
||||
ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \
|
||||
echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \
|
||||
echo "conda activate base" >> ~/.bashrc
|
||||
|
||||
# Get the latest staging version repository
|
||||
WORKDIR /root
|
||||
RUN apt-get install -y zip && \
|
||||
wget --quiet https://github.com/microsoft/mt-dnn/archive/staging.zip -O staging.zip && \
|
||||
unzip staging.zip && rm staging.zip
|
||||
|
||||
# Install the packages
|
||||
WORKDIR /root/mt-dnn-staging
|
||||
RUN python /root/mt-dnn-staging/scripts/generate_conda_file.py --gpu && \
|
||||
conda env create -n mtdnn_gpu -f mtdnn_gpu.yaml
|
||||
RUN source activate mtdnn_gpu && \
|
||||
pip install -e . && \
|
||||
python -m ipykernel install --user --name mtdnn_gpu --display-name "Python (mtdnn_gpu)"
|
||||
|
||||
# Run notebook
|
||||
EXPOSE 8888/tcp
|
||||
WORKDIR /root/mt-dnn-staging
|
||||
CMD source activate mtdnn_gpu && \
|
||||
jupyter notebook --allow-root --ip 0.0.0.0 --port 8888 --no-browser --notebook-dir .
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
# Examples
|
||||
|
||||
This folder contains examples and best practices, written in Jupyter notebooks, for building Multi-Task Deep Neural Networks for Natural Language Understanding.
|
||||
|
||||
|
||||
|Category|Applications|Method(s)|Languages|
|
||||
|:---:| :------------------------: | :-------------------: | :---: |
|
||||
|[Classification](classification)|Topic Classification|MT-DNN|en|
|
|
@ -0,0 +1,386 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Copyright (c) Microsoft Corporation. All rights reserved.\n",
|
||||
"### Licensed under the MIT License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Multi-Task Deep Neural Networks for Natural Language Understanding \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"This PyTorch package implements the Multi-Task Deep Neural Networks (MT-DNN) for Natural Language Understanding. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### The data \n",
|
||||
"\n",
|
||||
"This notebook assumes you have data already pre-processed in the MT-DNN format and accessible in a local directory. \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"For the purposes of this example we have added sample data that is already processed in MT-DNN format which can be found in the __sample_data__ folder. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The autoreload extension is already loaded. To reload it, use:\n",
|
||||
" %reload_ext autoreload\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"%load_ext autoreload"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%autoreload 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"from mtdnn.common.types import EncoderModelType\n",
|
||||
"from mtdnn.configuration_mtdnn import MTDNNConfig\n",
|
||||
"from mtdnn.modeling_mtdnn import MTDNNModel\n",
|
||||
"from mtdnn.process_mtdnn import MTDNNDataProcess\n",
|
||||
"from mtdnn.tasks.config import MTDNNTaskDefs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Define Configuration, Tasks and Model Objects"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"DATA_DIR = \"../../sample_data/bert_uncased_lower/mnli/\"\n",
|
||||
"BATCH_SIZE = 16"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Define a Configuration Object \n",
|
||||
"\n",
|
||||
"Create a model configuration object, `MTDNNConfig`, with the necessary parameters to initialize the MT-DNN model. Initialization without any parameters will default to a similar configuration that initializes a BERT model. \n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"config = MTDNNConfig(batch_size=BATCH_SIZE)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n",
|
||||
"### Create Task Definition Object \n",
|
||||
"\n",
|
||||
"Define the task parameters to train for and initialize an `MTDNNTaskDefs` object. Create a task parameter dictionary. Definition can be a single or multiple tasks to train. `MTDNNTaskDefs` can take a python dict, yaml or json file with task(s) defintion."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO - Mapping Task attributes\n",
|
||||
"INFO - Configured task definitions - ['mnli']\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"tasks_params = {\n",
|
||||
" \"mnli\": {\n",
|
||||
" \"data_format\": \"PremiseAndOneHypothesis\",\n",
|
||||
" \"encoder_type\": \"BERT\",\n",
|
||||
" \"dropout_p\": 0.3,\n",
|
||||
" \"enable_san\": True,\n",
|
||||
" \"labels\": [\"contradiction\", \"neutral\", \"entailment\"],\n",
|
||||
" \"metric_meta\": [\"ACC\"],\n",
|
||||
" \"loss\": \"CeCriterion\",\n",
|
||||
" \"kd_loss\": \"MseCriterion\",\n",
|
||||
" \"n_class\": 3,\n",
|
||||
" \"split_names\": [\n",
|
||||
" \"train\",\n",
|
||||
" \"matched_dev\",\n",
|
||||
" \"mismatched_dev\",\n",
|
||||
" \"matched_test\",\n",
|
||||
" \"mismatched_test\",\n",
|
||||
" ],\n",
|
||||
" \"task_type\": \"Classification\",\n",
|
||||
" },\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"# Define the tasks\n",
|
||||
"task_defs = MTDNNTaskDefs(tasks_params)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n",
|
||||
"### Create the Data Processing Object \n",
|
||||
"\n",
|
||||
"Create a data preprocessing object, `MTDNNDataProcess`. This creates the training, test and development PyTorch dataloaders needed for training and testing. We also need to retrieve the necessary training options required to initialize the model correctly, for all tasks. \n",
|
||||
"\n",
|
||||
"Define a data process that handles creating the training, test and development PyTorch dataloaders"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO - Starting to process the training data sets\n",
|
||||
"INFO - Loading ../../sample_data/bert_uncased_lower/mnli/mnli_train.json as task 0\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Loaded 392702 samples out of 392702\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"INFO - Starting to process the testing data sets\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Loaded 9832 samples out of 9832\n",
|
||||
"Loaded 9847 samples out of 9847\n",
|
||||
"Loaded 9815 samples out of 9815\n",
|
||||
"Loaded 9796 samples out of 9796\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Make the Data Preprocess step and update the config with training data updates\n",
|
||||
"data_processor = MTDNNDataProcess(\n",
|
||||
" config=config,\n",
|
||||
" task_defs=task_defs,\n",
|
||||
" data_dir=DATA_DIR,\n",
|
||||
" train_datasets_list=[\"mnli\"],\n",
|
||||
" test_datasets_list=[\"mnli_mismatched\", \"mnli_matched\"],\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Retrieve the processed batch multitask batch data loaders for training, development and test"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"multitask_train_dataloader = data_processor.get_train_dataloader()\n",
|
||||
"dev_dataloaders_list = data_processor.get_dev_dataloaders()\n",
|
||||
"test_dataloaders_list = data_processor.get_test_dataloaders()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Get training options to initialize model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"decoder_opts = data_processor.get_decoder_options_list()\n",
|
||||
"task_types = data_processor.get_task_types_list()\n",
|
||||
"dropout_list = data_processor.get_tasks_dropout_prob_list()\n",
|
||||
"loss_types = data_processor.get_loss_types_list()\n",
|
||||
"kd_loss_types = data_processor.get_kd_loss_types_list()\n",
|
||||
"tasks_nclass_list = data_processor.get_task_nclass_list()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let us update the batch steps"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"num_all_batches = data_processor.get_num_all_batches()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Instantiate the MTDNN Model\n",
|
||||
"\n",
|
||||
"Now we can go ahead and create an `MTDNNModel` model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"idx: 0, number of task labels: 3\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model = MTDNNModel(\n",
|
||||
" config,\n",
|
||||
" task_defs,\n",
|
||||
" pretrained_model_name=\"bert-base-uncased\",\n",
|
||||
" num_train_step=num_all_batches,\n",
|
||||
" decoder_opts=decoder_opts,\n",
|
||||
" task_types=task_types,\n",
|
||||
" dropout_list=dropout_list,\n",
|
||||
" loss_types=loss_types,\n",
|
||||
" kd_loss_types=kd_loss_types,\n",
|
||||
" tasks_nclass_list=tasks_nclass_list,\n",
|
||||
" multitask_train_dataloader=multitask_train_dataloader,\n",
|
||||
" dev_dataloaders_list=dev_dataloaders_list,\n",
|
||||
" test_dataloaders_list=test_dataloaders_list,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Fit on one epoch and predict using the training and test \n",
|
||||
"\n",
|
||||
"At this point the MT-DNN model allows us to fit to the model and create predictions. The fit takes an optional `epochs` parameter that overwrites the epochs set in the `MTDNNConfig` object. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.fit(epoch=1)\n",
|
||||
"model.predict()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Obtain predictions with a previously trained model checkpoint"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The predict function can take an optional checkpoint, `trained_model_chckpt`. This can be used for inference and running evaluations on an already trained PyTorch MT-DNN model. \n",
|
||||
"Optionally using a previously trained model as checkpoint. \n",
|
||||
"\n",
|
||||
"```Python\n",
|
||||
"# Predict using a MT-DNN model checkpoint\n",
|
||||
"checkpt = \"<path_to_existing_model_checkpoint>\"\n",
|
||||
"model.predict(trained_model_chckpt=checkpt)\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python (nlp_gpu)",
|
||||
"language": "python",
|
||||
"name": "nlp_gpu"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.8"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
__title__ = "Microsoft MT-DNN"
|
||||
__author__ = "Microsoft Research AI"
|
||||
__license__ = "MIT"
|
||||
__copyright__ = "Copyright 2018-present Microsoft Corporation"
|
||||
__version__ = "0.0.0"
|
||||
|
||||
# Synonyms
|
||||
TITLE = __title__
|
||||
AUTHOR = __author__
|
||||
LICENSE = __license__
|
||||
COPYRIGHT = __copyright__
|
||||
VERSION = __version__
|
|
@ -0,0 +1,50 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch.nn.functional import elu, leaky_relu, prelu, relu, selu, sigmoid, tanh
|
||||
from torch.nn.init import (
|
||||
eye,
|
||||
kaiming_normal,
|
||||
kaiming_uniform,
|
||||
normal,
|
||||
orthogonal,
|
||||
uniform,
|
||||
xavier_normal,
|
||||
xavier_uniform,
|
||||
)
|
||||
|
||||
|
||||
def linear(x):
|
||||
return x
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * sigmoid(x)
|
||||
|
||||
|
||||
def bertgelu(x):
|
||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
|
||||
def gptgelu(x):
|
||||
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
|
||||
|
||||
# default gelue
|
||||
gelu = bertgelu
|
||||
|
||||
|
||||
def activation(func_a):
|
||||
"""Activation function wrapper
|
||||
"""
|
||||
try:
|
||||
f = eval(func_a)
|
||||
except:
|
||||
f = linear
|
||||
return f
|
||||
|
||||
|
||||
def init_wrapper(init="xavier_uniform"):
|
||||
return eval(init)
|
|
@ -0,0 +1,16 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"mtdnn-base-uncased": "https://mrc.blob.core.windows.net/mt-dnn-model/mt_dnn_base.pt",
|
||||
"mtdnn-large-uncased": "https://mrc.blob.core.windows.net/mt-dnn-model/mt_dnn_large.pt",
|
||||
"mtdnn-kd-large-cased": "https://mrc.blob.core.windows.net/mt-dnn-model/mt_dnn_kd_large_cased.pt",
|
||||
}
|
||||
|
||||
# TODO - Create these files and upload to blob next to model
|
||||
PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"mtdnn-base-uncased": "https://mrc.blob.core.windows.net/mt-dnn-model/mt_dnn_base.json",
|
||||
"mtdnn-large-uncased": "https://mrc.blob.core.windows.net/mt-dnn-model/mt_dnn_large.json",
|
||||
"mtdnn-kd-large-cased": "https://mrc.blob.core.windows.net/mt-dnn-model/mt_dnn_kd_large_cased.json",
|
||||
}
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value."""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
|
@ -0,0 +1,335 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
import math
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from pytorch_pretrained_bert.optimization import warmup_constant, warmup_cosine, warmup_linear
|
||||
|
||||
|
||||
def warmup_linear_xdl(x, warmup=0.002):
|
||||
if x < warmup:
|
||||
return x / warmup
|
||||
return (1.0 - x) / (1.0 - warmup)
|
||||
|
||||
|
||||
def schedule_func(sch):
|
||||
try:
|
||||
f = eval(sch)
|
||||
except:
|
||||
f = warmup_linear
|
||||
return f
|
||||
|
||||
|
||||
class Adamax(Optimizer):
|
||||
"""Implements BERT version of Adam algorithm with weight decay fix (and no ).
|
||||
Params:
|
||||
lr: learning rate
|
||||
warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
|
||||
t_total: total number of training steps for the learning
|
||||
rate schedule, -1 means constant learning rate. Default: -1
|
||||
schedule: schedule to use for the warmup (see above). Default: 'warmup_linear'
|
||||
b1: Adams b1. Default: 0.9
|
||||
b2: Adams b2. Default: 0.999
|
||||
e: Adams epsilon. Default: 1e-6
|
||||
weight_decay: Weight decay. Default: 0.01
|
||||
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
|
||||
by xiaodl
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr,
|
||||
warmup=-1,
|
||||
t_total=-1,
|
||||
schedule="warmup_linear",
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-6,
|
||||
weight_decay=0.01,
|
||||
max_grad_norm=1.0,
|
||||
):
|
||||
if not lr >= 0.0:
|
||||
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
||||
if not 0.0 <= warmup < 1.0 and not warmup == -1:
|
||||
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
schedule=schedule,
|
||||
warmup=warmup,
|
||||
t_total=t_total,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
max_grad_norm=max_grad_norm,
|
||||
)
|
||||
super(Adamax, self).__init__(params, defaults)
|
||||
|
||||
def get_lr(self):
|
||||
lr = []
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
return [0]
|
||||
if group["t_total"] != -1:
|
||||
schedule_fct = schedule_func(group["schedule"])
|
||||
lr_scheduled = group["lr"] * schedule_fct(
|
||||
state["step"] / group["t_total"], group["warmup"]
|
||||
)
|
||||
else:
|
||||
lr_scheduled = group["lr"]
|
||||
lr.append(lr_scheduled)
|
||||
return lr
|
||||
|
||||
def to(self, device):
|
||||
""" Move the optimizer state to a specified device"""
|
||||
for state in self.state.values():
|
||||
state["exp_avg"].to(device)
|
||||
state["exp_inf"].to(device)
|
||||
|
||||
def initialize_step(self, initial_step):
|
||||
"""Initialize state with a defined step (but we don't have stored averaged).
|
||||
Arguments:
|
||||
initial_step (int): Initial step number.
|
||||
"""
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
state = self.state[p]
|
||||
# State initialization
|
||||
state["step"] = initial_step
|
||||
# Exponential moving average of gradient values
|
||||
state["exp_avg"] = torch.zeros_like(p.data)
|
||||
# Exponential moving average of squared gradient values
|
||||
state["exp_inf"] = torch.zeros_like(p.data)
|
||||
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
"Adam does not support sparse gradients, please consider SparseAdam instead"
|
||||
)
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state["step"] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state["exp_avg"] = torch.zeros_like(p.data)
|
||||
state["exp_inf"] = torch.zeros_like(p.data)
|
||||
|
||||
exp_avg, exp_inf = state["exp_avg"], state["exp_inf"]
|
||||
beta1, beta2 = group["betas"]
|
||||
eps = group["eps"]
|
||||
# Add grad clipping
|
||||
if group["max_grad_norm"] > 0:
|
||||
clip_grad_norm_(p, group["max_grad_norm"])
|
||||
|
||||
# Update biased first moment estimate.
|
||||
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
||||
# Update the exponentially weighted infinity norm.
|
||||
norm_buf = torch.cat(
|
||||
[exp_inf.mul_(beta2).unsqueeze(0), grad.abs().add_(eps).unsqueeze_(0)], 0
|
||||
)
|
||||
torch.max(norm_buf, 0, keepdim=False, out=(exp_inf, exp_inf.new().long()))
|
||||
update = exp_avg / (exp_inf + eps)
|
||||
|
||||
if group["weight_decay"] > 0.0:
|
||||
update += group["weight_decay"] * p.data
|
||||
|
||||
if group["t_total"] != -1:
|
||||
schedule_fct = schedule_func(group["schedule"])
|
||||
lr_scheduled = group["lr"] * schedule_fct(
|
||||
state["step"] / group["t_total"], group["warmup"]
|
||||
)
|
||||
else:
|
||||
lr_scheduled = group["lr"]
|
||||
|
||||
update_with_lr = lr_scheduled * update
|
||||
p.data.add_(-update_with_lr)
|
||||
state["step"] += 1
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class RAdam(Optimizer):
|
||||
"""Modified from: https://github.com/LiyuanLucasLiu/RAdam/blob/master/radam.py
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr,
|
||||
warmup=-1,
|
||||
t_total=-1,
|
||||
schedule="warmup_linear",
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-6,
|
||||
weight_decay=0.001,
|
||||
max_grad_norm=1.0,
|
||||
):
|
||||
if not lr >= 0.0:
|
||||
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
||||
if not 0.0 <= warmup < 1.0 and not warmup == -1:
|
||||
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
schedule=schedule,
|
||||
warmup=warmup,
|
||||
t_total=t_total,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
max_grad_norm=max_grad_norm,
|
||||
)
|
||||
self.buffer = [[None, None, None] for ind in range(10)]
|
||||
super(RAdam, self).__init__(params, defaults)
|
||||
|
||||
def get_lr(self):
|
||||
lr = []
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
state = self.state[p]
|
||||
if len(state) == 0:
|
||||
return [0]
|
||||
if group["t_total"] != -1:
|
||||
schedule_fct = schedule_func(group["schedule"])
|
||||
lr_scheduled = group["lr"] * schedule_fct(
|
||||
state["step"] / group["t_total"], group["warmup"]
|
||||
)
|
||||
else:
|
||||
lr_scheduled = group["lr"]
|
||||
lr.append(lr_scheduled)
|
||||
return lr
|
||||
|
||||
def to(self, device):
|
||||
""" Move the optimizer state to a specified device"""
|
||||
for state in self.state.values():
|
||||
state["exp_avg"].to(device)
|
||||
state["exp_avg_sq"].to(device)
|
||||
|
||||
def initialize_step(self, initial_step):
|
||||
"""Initialize state with a defined step (but we don't have stored averaged).
|
||||
Arguments:
|
||||
initial_step (int): Initial step number.
|
||||
"""
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
state = self.state[p]
|
||||
# State initialization
|
||||
state["step"] = initial_step
|
||||
# Exponential moving average of gradient values
|
||||
state["exp_avg"] = torch.zeros_like(p.data)
|
||||
# Exponential moving average of squared gradient values
|
||||
state["exp_avg_sq"] = torch.zeros_like(p.data)
|
||||
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
# set_trace()
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data.float()
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError("RAdam does not support sparse gradients")
|
||||
|
||||
p_data_fp32 = p.data.float()
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state["step"] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state["exp_avg"] = torch.zeros_like(p_data_fp32)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
|
||||
else:
|
||||
state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
|
||||
state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
|
||||
|
||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||
beta1, beta2 = group["betas"]
|
||||
eps = group["eps"]
|
||||
# Add grad clipping
|
||||
if group["max_grad_norm"] > 0:
|
||||
clip_grad_norm_(p, group["max_grad_norm"])
|
||||
|
||||
# Update biased first moment estimate.
|
||||
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
state["step"] += 1
|
||||
|
||||
if group["t_total"] != -1:
|
||||
schedule_fct = schedule_func(group["schedule"])
|
||||
lr_scheduled = group["lr"] * schedule_fct(
|
||||
state["step"] / group["t_total"], group["warmup"]
|
||||
)
|
||||
else:
|
||||
lr_scheduled = group["lr"]
|
||||
|
||||
buffered = self.buffer[int(state["step"] % 10)]
|
||||
if state["step"] == buffered[0]:
|
||||
N_sma, step_size = buffered[1], buffered[2]
|
||||
else:
|
||||
buffered[0] = state["step"]
|
||||
beta2_t = beta2 ** state["step"]
|
||||
N_sma_max = 2 / (1 - beta2) - 1
|
||||
N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
|
||||
buffered[1] = N_sma
|
||||
|
||||
# more conservative since it's an approximated value
|
||||
if N_sma >= 5:
|
||||
step_size = (
|
||||
lr_scheduled
|
||||
* math.sqrt(
|
||||
(1 - beta2_t)
|
||||
* (N_sma - 4)
|
||||
/ (N_sma_max - 4)
|
||||
* (N_sma - 2)
|
||||
/ N_sma
|
||||
* N_sma_max
|
||||
/ (N_sma_max - 2)
|
||||
)
|
||||
/ (1 - beta1 ** state["step"])
|
||||
)
|
||||
else:
|
||||
step_size = lr_scheduled / (1 - beta1 ** state["step"])
|
||||
buffered[2] = step_size
|
||||
|
||||
if N_sma >= 5:
|
||||
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
||||
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
|
||||
else:
|
||||
p_data_fp32.add_(-step_size, exp_avg)
|
||||
|
||||
if group["weight_decay"] != 0:
|
||||
p_data_fp32.add_(-group["weight_decay"] * lr_scheduled, p_data_fp32)
|
||||
|
||||
p.data.copy_(p_data_fp32)
|
||||
|
||||
return loss
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright (c) Microsoft. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class DropoutWrapper(nn.Module):
|
||||
"""
|
||||
This is a dropout wrapper which supports the fix mask dropout
|
||||
"""
|
||||
def __init__(self, dropout_p=0, enable_vbp=True):
|
||||
super(DropoutWrapper, self).__init__()
|
||||
"""variational dropout means fix dropout mask
|
||||
ref: https://discuss.pytorch.org/t/dropout-for-rnns/633/11
|
||||
"""
|
||||
self.enable_variational_dropout = enable_vbp
|
||||
self.dropout_p = dropout_p
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
:param x: batch * len * input_size
|
||||
"""
|
||||
if self.training == False or self.dropout_p == 0:
|
||||
return x
|
||||
|
||||
if len(x.size()) == 3:
|
||||
mask = 1.0 / (1-self.dropout_p) * torch.bernoulli((1-self.dropout_p) * (x.data.new(x.size(0), x.size(2)).zero_() + 1))
|
||||
mask.requires_grad = False
|
||||
return mask.unsqueeze(1).expand_as(x) * x
|
||||
else:
|
||||
return F.dropout(x, p=self.dropout_p, training=self.training)
|
|
@ -0,0 +1,63 @@
|
|||
from experiments.glue.glue_label_map import TaskType, DATA_TYPE, GLOBAL_MAP, TASK_TYPE, DATA_META, METRIC_META, SAN_META
|
||||
from data_utils.task_def import DataFormat
|
||||
from data_utils.metrics import Metric
|
||||
|
||||
task_def_dic = {}
|
||||
dropout_p_map = {
|
||||
"mnli": 0.3,
|
||||
"cola": 0.05
|
||||
}
|
||||
for task in TASK_TYPE.keys():
|
||||
task_type = TASK_TYPE[task]
|
||||
if task == "qnnli":
|
||||
task_type = TaskType.Ranking
|
||||
elif task_type == 0:
|
||||
task_type = TaskType.Classification
|
||||
elif task_type == 1:
|
||||
task_type = TaskType.Regression
|
||||
else:
|
||||
raise ValueError(task_type)
|
||||
|
||||
data_format = DATA_TYPE[task]
|
||||
if task == "qnnli":
|
||||
data_format = DataFormat.PremiseAndMultiHypothesis
|
||||
elif data_format == 0:
|
||||
data_format = DataFormat.PremiseAndOneHypothesis
|
||||
elif data_format == 1:
|
||||
data_format = DataFormat.PremiseOnly
|
||||
else:
|
||||
raise ValueError(data_format)
|
||||
|
||||
labels = None
|
||||
if task in GLOBAL_MAP:
|
||||
labels = GLOBAL_MAP[task].get_vocab_list()
|
||||
|
||||
split_names = None
|
||||
if task == "mnli":
|
||||
split_names = ["train", "matched_dev", "mismatched_dev", "matched_test", "mismatched_test"]
|
||||
|
||||
dropout_p = dropout_p_map.get(task, None)
|
||||
|
||||
n_class = DATA_META[task]
|
||||
metric_meta = tuple(Metric(metric_no).name for metric_no in METRIC_META[task])
|
||||
enable_san = bool(SAN_META[task])
|
||||
|
||||
task_def = {"task_type": task_type.name,
|
||||
"data_format": data_format.name,
|
||||
"n_class": n_class,
|
||||
"metric_meta": metric_meta,
|
||||
"enable_san": enable_san
|
||||
}
|
||||
if labels is not None:
|
||||
task_def["labels"] = labels
|
||||
if split_names is not None:
|
||||
task_def["split_names"] = split_names
|
||||
if dropout_p is not None:
|
||||
task_def["dropout_p"] = dropout_p
|
||||
|
||||
if task not in ["diag", "qnnli"]:
|
||||
task_def_dic[task] = task_def
|
||||
|
||||
import yaml
|
||||
|
||||
yaml.safe_dump(task_def_dic, open("experiments/glue/glue_task_def.yml", "w"))
|
|
@ -0,0 +1,161 @@
|
|||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from data_utils.vocab import Vocabulary
|
||||
from data_utils.metrics import compute_acc, compute_f1, compute_mcc, compute_pearson, compute_spearman
|
||||
|
||||
# scitail
|
||||
ScitailLabelMapper = Vocabulary(True)
|
||||
ScitailLabelMapper.add('neutral')
|
||||
ScitailLabelMapper.add('entails')
|
||||
|
||||
# label map
|
||||
SNLI_LabelMapper = Vocabulary(True)
|
||||
SNLI_LabelMapper.add('contradiction')
|
||||
SNLI_LabelMapper.add('neutral')
|
||||
SNLI_LabelMapper.add('entailment')
|
||||
|
||||
# qnli
|
||||
QNLILabelMapper = Vocabulary(True)
|
||||
QNLILabelMapper.add('not_entailment')
|
||||
QNLILabelMapper.add('entailment')
|
||||
|
||||
GLOBAL_MAP = {
|
||||
'scitail': ScitailLabelMapper,
|
||||
'mnli': SNLI_LabelMapper,
|
||||
'snli': SNLI_LabelMapper,
|
||||
'qnli': QNLILabelMapper,
|
||||
'qnnli': QNLILabelMapper,
|
||||
'rte': QNLILabelMapper,
|
||||
'diag': SNLI_LabelMapper,
|
||||
}
|
||||
|
||||
# number of class
|
||||
DATA_META = {
|
||||
'mnli': 3,
|
||||
'snli': 3,
|
||||
'scitail': 2,
|
||||
'qqp': 2,
|
||||
'qnli': 2,
|
||||
'qnnli': 1,
|
||||
'wnli': 2,
|
||||
'rte': 2,
|
||||
'mrpc': 2,
|
||||
'diag': 3,
|
||||
'sst': 2,
|
||||
'stsb': 1,
|
||||
'cola': 2,
|
||||
}
|
||||
|
||||
DATA_TYPE = {
|
||||
'mnli': 0,
|
||||
'snli': 0,
|
||||
'scitail': 0,
|
||||
'qqp': 0,
|
||||
'qnli': 0,
|
||||
'qnnli': 0,
|
||||
'wnli': 0,
|
||||
'rte': 0,
|
||||
'mrpc': 0,
|
||||
'diag': 0,
|
||||
'sst': 1,
|
||||
'stsb': 0,
|
||||
'cola': 1,
|
||||
}
|
||||
|
||||
DATA_SWAP = {
|
||||
'mnli': 0,
|
||||
'snli': 0,
|
||||
'scitail': 0,
|
||||
'qqp': 1,
|
||||
'qnli': 0,
|
||||
'qnnli': 0,
|
||||
'wnli': 0,
|
||||
'rte': 0,
|
||||
'mrpc': 0,
|
||||
'diag': 0,
|
||||
'sst': 0,
|
||||
'stsb': 0,
|
||||
'cola': 0,
|
||||
}
|
||||
|
||||
# classification/regression
|
||||
TASK_TYPE = {
|
||||
'mnli': 0,
|
||||
'snli': 0,
|
||||
'scitail': 0,
|
||||
'qqp': 0,
|
||||
'qnli': 0,
|
||||
'qnnli': 0,
|
||||
'wnli': 0,
|
||||
'rte': 0,
|
||||
'mrpc': 0,
|
||||
'diag': 0,
|
||||
'sst': 0,
|
||||
'stsb': 1,
|
||||
'cola': 0,
|
||||
}
|
||||
|
||||
METRIC_META = {
|
||||
'mnli': [0],
|
||||
'snli': [0],
|
||||
'scitail': [0],
|
||||
'qqp': [0, 1],
|
||||
'qnli': [0],
|
||||
'qnnli': [0],
|
||||
'wnli': [0],
|
||||
'rte': [0],
|
||||
'mrpc': [0, 1],
|
||||
'diag': [0],
|
||||
'sst': [0],
|
||||
'stsb': [3, 4],
|
||||
'cola': [0, 2],
|
||||
}
|
||||
|
||||
METRIC_NAME = {
|
||||
0: 'ACC',
|
||||
1: 'F1',
|
||||
2: 'MCC',
|
||||
3: 'Pearson',
|
||||
4: 'Spearman',
|
||||
}
|
||||
|
||||
METRIC_FUNC = {
|
||||
0: compute_acc,
|
||||
1: compute_f1,
|
||||
2: compute_mcc,
|
||||
3: compute_pearson,
|
||||
4: compute_spearman,
|
||||
}
|
||||
|
||||
SAN_META = {
|
||||
'mnli': 1,
|
||||
'snli': 1,
|
||||
'scitail': 1,
|
||||
'qqp': 1,
|
||||
'qnli': 1,
|
||||
'qnnli': 1,
|
||||
'wnli': 1,
|
||||
'rte': 1,
|
||||
'mrpc': 1,
|
||||
'diag': 0,
|
||||
'sst': 0,
|
||||
'stsb': 0,
|
||||
'cola': 0,
|
||||
}
|
||||
|
||||
|
||||
def generate_decoder_opt(task, max_opt):
|
||||
assert task in SAN_META
|
||||
opt_v = 0
|
||||
if SAN_META[task] and max_opt < 3:
|
||||
opt_v = max_opt
|
||||
return opt_v
|
||||
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
Classification = 0
|
||||
Regression = 1
|
||||
Ranking = 2
|
|
@ -0,0 +1,286 @@
|
|||
import os
|
||||
import argparse
|
||||
import random
|
||||
from sys import path
|
||||
|
||||
path.append(os.getcwd())
|
||||
from experiments.common_utils import dump_rows
|
||||
from data_utils.task_def import DataFormat
|
||||
from data_utils.log_wrapper import create_logger
|
||||
from experiments.glue.glue_utils import *
|
||||
|
||||
logger = create_logger(__name__, to_disk=True, log_file='glue_prepro.log')
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Preprocessing GLUE/SNLI/SciTail dataset.')
|
||||
parser.add_argument('--seed', type=int, default=13)
|
||||
parser.add_argument('--root_dir', type=str, default='data')
|
||||
parser.add_argument('--old_glue', action='store_true', help='whether it is old GLUE, refer official GLUE webpage for details')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main(args):
|
||||
is_old_glue = args.old_glue
|
||||
root = args.root_dir
|
||||
assert os.path.exists(root)
|
||||
|
||||
######################################
|
||||
# SNLI/SciTail Tasks
|
||||
######################################
|
||||
scitail_train_path = os.path.join(root, 'SciTail/tsv_format/scitail_1.0_train.tsv')
|
||||
scitail_dev_path = os.path.join(root, 'SciTail/tsv_format/scitail_1.0_dev.tsv')
|
||||
scitail_test_path = os.path.join(root, 'SciTail/tsv_format/scitail_1.0_test.tsv')
|
||||
|
||||
snli_train_path = os.path.join(root, 'SNLI/train.tsv')
|
||||
snli_dev_path = os.path.join(root, 'SNLI/dev.tsv')
|
||||
snli_test_path = os.path.join(root, 'SNLI/test.tsv')
|
||||
|
||||
######################################
|
||||
# GLUE tasks
|
||||
######################################
|
||||
multi_train_path = os.path.join(root, 'MNLI/train.tsv')
|
||||
multi_dev_matched_path = os.path.join(root, 'MNLI/dev_matched.tsv')
|
||||
multi_dev_mismatched_path = os.path.join(root, 'MNLI/dev_mismatched.tsv')
|
||||
multi_test_matched_path = os.path.join(root, 'MNLI/test_matched.tsv')
|
||||
multi_test_mismatched_path = os.path.join(root, 'MNLI/test_mismatched.tsv')
|
||||
|
||||
mrpc_train_path = os.path.join(root, 'MRPC/train.tsv')
|
||||
mrpc_dev_path = os.path.join(root, 'MRPC/dev.tsv')
|
||||
mrpc_test_path = os.path.join(root, 'MRPC/test.tsv')
|
||||
|
||||
qnli_train_path = os.path.join(root, 'QNLI/train.tsv')
|
||||
qnli_dev_path = os.path.join(root, 'QNLI/dev.tsv')
|
||||
qnli_test_path = os.path.join(root, 'QNLI/test.tsv')
|
||||
|
||||
qqp_train_path = os.path.join(root, 'QQP/train.tsv')
|
||||
qqp_dev_path = os.path.join(root, 'QQP/dev.tsv')
|
||||
qqp_test_path = os.path.join(root, 'QQP/test.tsv')
|
||||
|
||||
rte_train_path = os.path.join(root, 'RTE/train.tsv')
|
||||
rte_dev_path = os.path.join(root, 'RTE/dev.tsv')
|
||||
rte_test_path = os.path.join(root, 'RTE/test.tsv')
|
||||
|
||||
wnli_train_path = os.path.join(root, 'WNLI/train.tsv')
|
||||
wnli_dev_path = os.path.join(root, 'WNLI/dev.tsv')
|
||||
wnli_test_path = os.path.join(root, 'WNLI/test.tsv')
|
||||
|
||||
stsb_train_path = os.path.join(root, 'STS-B/train.tsv')
|
||||
stsb_dev_path = os.path.join(root, 'STS-B/dev.tsv')
|
||||
stsb_test_path = os.path.join(root, 'STS-B/test.tsv')
|
||||
|
||||
sst_train_path = os.path.join(root, 'SST-2/train.tsv')
|
||||
sst_dev_path = os.path.join(root, 'SST-2/dev.tsv')
|
||||
sst_test_path = os.path.join(root, 'SST-2/test.tsv')
|
||||
|
||||
cola_train_path = os.path.join(root, 'CoLA/train.tsv')
|
||||
cola_dev_path = os.path.join(root, 'CoLA/dev.tsv')
|
||||
cola_test_path = os.path.join(root, 'CoLA/test.tsv')
|
||||
|
||||
######################################
|
||||
# Loading DATA
|
||||
######################################
|
||||
scitail_train_data = load_scitail(scitail_train_path)
|
||||
scitail_dev_data = load_scitail(scitail_dev_path)
|
||||
scitail_test_data = load_scitail(scitail_test_path)
|
||||
logger.info('Loaded {} SciTail train samples'.format(len(scitail_train_data)))
|
||||
logger.info('Loaded {} SciTail dev samples'.format(len(scitail_dev_data)))
|
||||
logger.info('Loaded {} SciTail test samples'.format(len(scitail_test_data)))
|
||||
|
||||
snli_train_data = load_snli(snli_train_path)
|
||||
snli_dev_data = load_snli(snli_dev_path)
|
||||
snli_test_data = load_snli(snli_test_path)
|
||||
logger.info('Loaded {} SNLI train samples'.format(len(snli_train_data)))
|
||||
logger.info('Loaded {} SNLI dev samples'.format(len(snli_dev_data)))
|
||||
logger.info('Loaded {} SNLI test samples'.format(len(snli_test_data)))
|
||||
|
||||
multinli_train_data = load_mnli(multi_train_path)
|
||||
multinli_matched_dev_data = load_mnli(multi_dev_matched_path)
|
||||
multinli_mismatched_dev_data = load_mnli(multi_dev_mismatched_path)
|
||||
multinli_matched_test_data = load_mnli(multi_test_matched_path, is_train=False)
|
||||
multinli_mismatched_test_data = load_mnli(multi_test_mismatched_path, is_train=False)
|
||||
|
||||
logger.info('Loaded {} MNLI train samples'.format(len(multinli_train_data)))
|
||||
logger.info('Loaded {} MNLI matched dev samples'.format(len(multinli_matched_dev_data)))
|
||||
logger.info('Loaded {} MNLI mismatched dev samples'.format(len(multinli_mismatched_dev_data)))
|
||||
logger.info('Loaded {} MNLI matched test samples'.format(len(multinli_matched_test_data)))
|
||||
logger.info('Loaded {} MNLI mismatched test samples'.format(len(multinli_mismatched_test_data)))
|
||||
|
||||
mrpc_train_data = load_mrpc(mrpc_train_path)
|
||||
mrpc_dev_data = load_mrpc(mrpc_dev_path)
|
||||
mrpc_test_data = load_mrpc(mrpc_test_path, is_train=False)
|
||||
logger.info('Loaded {} MRPC train samples'.format(len(mrpc_train_data)))
|
||||
logger.info('Loaded {} MRPC dev samples'.format(len(mrpc_dev_data)))
|
||||
logger.info('Loaded {} MRPC test samples'.format(len(mrpc_test_data)))
|
||||
|
||||
qnli_train_data = load_qnli(qnli_train_path)
|
||||
qnli_dev_data = load_qnli(qnli_dev_path)
|
||||
qnli_test_data = load_qnli(qnli_test_path, is_train=False)
|
||||
logger.info('Loaded {} QNLI train samples'.format(len(qnli_train_data)))
|
||||
logger.info('Loaded {} QNLI dev samples'.format(len(qnli_dev_data)))
|
||||
logger.info('Loaded {} QNLI test samples'.format(len(qnli_test_data)))
|
||||
|
||||
if is_old_glue:
|
||||
random.seed(args.seed)
|
||||
qnnli_train_data = load_qnnli(qnli_train_path)
|
||||
qnnli_dev_data = load_qnnli(qnli_dev_path)
|
||||
qnnli_test_data = load_qnnli(qnli_test_path, is_train=False)
|
||||
logger.info('Loaded {} QNLI train samples'.format(len(qnnli_train_data)))
|
||||
logger.info('Loaded {} QNLI dev samples'.format(len(qnnli_dev_data)))
|
||||
logger.info('Loaded {} QNLI test samples'.format(len(qnnli_test_data)))
|
||||
|
||||
qqp_train_data = load_qqp(qqp_train_path)
|
||||
qqp_dev_data = load_qqp(qqp_dev_path)
|
||||
qqp_test_data = load_qqp(qqp_test_path, is_train=False)
|
||||
logger.info('Loaded {} QQP train samples'.format(len(qqp_train_data)))
|
||||
logger.info('Loaded {} QQP dev samples'.format(len(qqp_dev_data)))
|
||||
logger.info('Loaded {} QQP test samples'.format(len(qqp_test_data)))
|
||||
|
||||
rte_train_data = load_rte(rte_train_path)
|
||||
rte_dev_data = load_rte(rte_dev_path)
|
||||
rte_test_data = load_rte(rte_test_path, is_train=False)
|
||||
logger.info('Loaded {} RTE train samples'.format(len(rte_train_data)))
|
||||
logger.info('Loaded {} RTE dev samples'.format(len(rte_dev_data)))
|
||||
logger.info('Loaded {} RTE test samples'.format(len(rte_test_data)))
|
||||
|
||||
wnli_train_data = load_wnli(wnli_train_path)
|
||||
wnli_dev_data = load_wnli(wnli_dev_path)
|
||||
wnli_test_data = load_wnli(wnli_test_path, is_train=False)
|
||||
logger.info('Loaded {} WNLI train samples'.format(len(wnli_train_data)))
|
||||
logger.info('Loaded {} WNLI dev samples'.format(len(wnli_dev_data)))
|
||||
logger.info('Loaded {} WNLI test samples'.format(len(wnli_test_data)))
|
||||
|
||||
sst_train_data = load_sst(sst_train_path)
|
||||
sst_dev_data = load_sst(sst_dev_path)
|
||||
sst_test_data = load_sst(sst_test_path, is_train=False)
|
||||
logger.info('Loaded {} SST train samples'.format(len(sst_train_data)))
|
||||
logger.info('Loaded {} SST dev samples'.format(len(sst_dev_data)))
|
||||
logger.info('Loaded {} SST test samples'.format(len(sst_test_data)))
|
||||
|
||||
cola_train_data = load_cola(cola_train_path, header=False)
|
||||
cola_dev_data = load_cola(cola_dev_path, header=False)
|
||||
cola_test_data = load_cola(cola_test_path, is_train=False)
|
||||
logger.info('Loaded {} COLA train samples'.format(len(cola_train_data)))
|
||||
logger.info('Loaded {} COLA dev samples'.format(len(cola_dev_data)))
|
||||
logger.info('Loaded {} COLA test samples'.format(len(cola_test_data)))
|
||||
|
||||
stsb_train_data = load_sts(stsb_train_path)
|
||||
stsb_dev_data = load_sts(stsb_dev_path)
|
||||
stsb_test_data = load_sts(stsb_test_path, is_train=False)
|
||||
logger.info('Loaded {} STS-B train samples'.format(len(stsb_train_data)))
|
||||
logger.info('Loaded {} STS-B dev samples'.format(len(stsb_dev_data)))
|
||||
logger.info('Loaded {} STS-B test samples'.format(len(stsb_test_data)))
|
||||
|
||||
canonical_data_suffix = "canonical_data"
|
||||
canonical_data_root = os.path.join(root, canonical_data_suffix)
|
||||
if not os.path.isdir(canonical_data_root):
|
||||
os.mkdir(canonical_data_root)
|
||||
|
||||
# BUILD SciTail
|
||||
scitail_train_fout = os.path.join(canonical_data_root, 'scitail_train.tsv')
|
||||
scitail_dev_fout = os.path.join(canonical_data_root, 'scitail_dev.tsv')
|
||||
scitail_test_fout = os.path.join(canonical_data_root, 'scitail_test.tsv')
|
||||
dump_rows(scitail_train_data, scitail_train_fout, DataFormat.PremiseAndOneHypothesis)
|
||||
dump_rows(scitail_dev_data, scitail_dev_fout, DataFormat.PremiseAndOneHypothesis)
|
||||
dump_rows(scitail_test_data, scitail_test_fout, DataFormat.PremiseAndOneHypothesis)
|
||||
logger.info('done with scitail')
|
||||
|
||||
# BUILD SNLI
|
||||
snli_train_fout = os.path.join(canonical_data_root, 'snli_train.tsv')
|
||||
snli_dev_fout = os.path.join(canonical_data_root, 'snli_dev.tsv')
|
||||
snli_test_fout = os.path.join(canonical_data_root, 'snli_test.tsv')
|
||||
dump_rows(snli_train_data, snli_train_fout, DataFormat.PremiseAndOneHypothesis)
|
||||
dump_rows(snli_dev_data, snli_dev_fout, DataFormat.PremiseAndOneHypothesis)
|
||||
dump_rows(snli_test_data, snli_test_fout, DataFormat.PremiseAndOneHypothesis)
|
||||
logger.info('done with snli')
|
||||
|
||||
# BUILD MNLI
|
||||
multinli_train_fout = os.path.join(canonical_data_root, 'mnli_train.tsv')
|
||||
multinli_matched_dev_fout = os.path.join(canonical_data_root, 'mnli_matched_dev.tsv')
|
||||
multinli_mismatched_dev_fout = os.path.join(canonical_data_root, 'mnli_mismatched_dev.tsv')
|
||||
multinli_matched_test_fout = os.path.join(canonical_data_root, 'mnli_matched_test.tsv')
|
||||
multinli_mismatched_test_fout = os.path.join(canonical_data_root, 'mnli_mismatched_test.tsv')
|
||||
dump_rows(multinli_train_data, multinli_train_fout, DataFormat.PremiseAndOneHypothesis)
|
||||
dump_rows(multinli_matched_dev_data, multinli_matched_dev_fout, DataFormat.PremiseAndOneHypothesis)
|
||||
dump_rows(multinli_mismatched_dev_data, multinli_mismatched_dev_fout, DataFormat.PremiseAndOneHypothesis)
|
||||
dump_rows(multinli_matched_test_data, multinli_matched_test_fout, DataFormat.PremiseAndOneHypothesis)
|
||||
dump_rows(multinli_mismatched_test_data, multinli_mismatched_test_fout, DataFormat.PremiseAndOneHypothesis)
|
||||
logger.info('done with mnli')
|
||||
|
||||
mrpc_train_fout = os.path.join(canonical_data_root, 'mrpc_train.tsv')
|
||||
mrpc_dev_fout = os.path.join(canonical_data_root, 'mrpc_dev.tsv')
|
||||
mrpc_test_fout = os.path.join(canonical_data_root, 'mrpc_test.tsv')
|
||||
dump_rows(mrpc_train_data, mrpc_train_fout, DataFormat.PremiseAndOneHypothesis)
|
||||
dump_rows(mrpc_dev_data, mrpc_dev_fout, DataFormat.PremiseAndOneHypothesis)
|
||||
dump_rows(mrpc_test_data, mrpc_test_fout, DataFormat.PremiseAndOneHypothesis)
|
||||
logger.info('done with mrpc')
|
||||
|
||||
qnli_train_fout = os.path.join(canonical_data_root, 'qnli_train.tsv')
|
||||
qnli_dev_fout = os.path.join(canonical_data_root, 'qnli_dev.tsv')
|
||||
qnli_test_fout = os.path.join(canonical_data_root, 'qnli_test.tsv')
|
||||
dump_rows(qnli_train_data, qnli_train_fout, DataFormat.PremiseAndOneHypothesis)
|
||||
dump_rows(qnli_dev_data, qnli_dev_fout, DataFormat.PremiseAndOneHypothesis)
|
||||
dump_rows(qnli_test_data, qnli_test_fout, DataFormat.PremiseAndOneHypothesis)
|
||||
logger.info('done with qnli')
|
||||
|
||||
if is_old_glue:
|
||||
qnli_train_fout = os.path.join(canonical_data_root, 'qnnli_train.tsv')
|
||||
qnli_dev_fout = os.path.join(canonical_data_root, 'qnnli_dev.tsv')
|
||||
qnli_test_fout = os.path.join(canonical_data_root, 'qnnli_test.tsv')
|
||||
dump_rows(qnnli_train_data, qnli_train_fout, DataFormat.PremiseAndMultiHypothesis)
|
||||
dump_rows(qnnli_dev_data, qnli_dev_fout, DataFormat.PremiseAndMultiHypothesis)
|
||||
dump_rows(qnnli_train_data, qnli_test_fout, DataFormat.PremiseAndMultiHypothesis)
|
||||
logger.info('done with qnli')
|
||||
|
||||
qqp_train_fout = os.path.join(canonical_data_root, 'qqp_train.tsv')
|
||||
qqp_dev_fout = os.path.join(canonical_data_root, 'qqp_dev.tsv')
|
||||
qqp_test_fout = os.path.join(canonical_data_root, 'qqp_test.tsv')
|
||||
dump_rows(qqp_train_data, qqp_train_fout, DataFormat.PremiseAndOneHypothesis)
|
||||
dump_rows(qqp_dev_data, qqp_dev_fout, DataFormat.PremiseAndOneHypothesis)
|
||||
dump_rows(qqp_test_data, qqp_test_fout, DataFormat.PremiseAndOneHypothesis)
|
||||
logger.info('done with qqp')
|
||||
|
||||
rte_train_fout = os.path.join(canonical_data_root, 'rte_train.tsv')
|
||||
rte_dev_fout = os.path.join(canonical_data_root, 'rte_dev.tsv')
|
||||
rte_test_fout = os.path.join(canonical_data_root, 'rte_test.tsv')
|
||||
dump_rows(rte_train_data, rte_train_fout, DataFormat.PremiseAndOneHypothesis)
|
||||
dump_rows(rte_dev_data, rte_dev_fout, DataFormat.PremiseAndOneHypothesis)
|
||||
dump_rows(rte_test_data, rte_test_fout, DataFormat.PremiseAndOneHypothesis)
|
||||
logger.info('done with rte')
|
||||
|
||||
wnli_train_fout = os.path.join(canonical_data_root, 'wnli_train.tsv')
|
||||
wnli_dev_fout = os.path.join(canonical_data_root, 'wnli_dev.tsv')
|
||||
wnli_test_fout = os.path.join(canonical_data_root, 'wnli_test.tsv')
|
||||
dump_rows(wnli_train_data, wnli_train_fout, DataFormat.PremiseAndOneHypothesis)
|
||||
dump_rows(wnli_dev_data, wnli_dev_fout, DataFormat.PremiseAndOneHypothesis)
|
||||
dump_rows(wnli_test_data, wnli_test_fout, DataFormat.PremiseAndOneHypothesis)
|
||||
logger.info('done with wnli')
|
||||
|
||||
sst_train_fout = os.path.join(canonical_data_root, 'sst_train.tsv')
|
||||
sst_dev_fout = os.path.join(canonical_data_root, 'sst_dev.tsv')
|
||||
sst_test_fout = os.path.join(canonical_data_root, 'sst_test.tsv')
|
||||
dump_rows(sst_train_data, sst_train_fout, DataFormat.PremiseOnly)
|
||||
dump_rows(sst_dev_data, sst_dev_fout, DataFormat.PremiseOnly)
|
||||
dump_rows(sst_test_data, sst_test_fout, DataFormat.PremiseOnly)
|
||||
logger.info('done with sst')
|
||||
|
||||
cola_train_fout = os.path.join(canonical_data_root, 'cola_train.tsv')
|
||||
cola_dev_fout = os.path.join(canonical_data_root, 'cola_dev.tsv')
|
||||
cola_test_fout = os.path.join(canonical_data_root, 'cola_test.tsv')
|
||||
dump_rows(cola_train_data, cola_train_fout, DataFormat.PremiseOnly)
|
||||
dump_rows(cola_dev_data, cola_dev_fout, DataFormat.PremiseOnly)
|
||||
dump_rows(cola_test_data, cola_test_fout, DataFormat.PremiseOnly)
|
||||
logger.info('done with cola')
|
||||
|
||||
stsb_train_fout = os.path.join(canonical_data_root, 'stsb_train.tsv')
|
||||
stsb_dev_fout = os.path.join(canonical_data_root, 'stsb_dev.tsv')
|
||||
stsb_test_fout = os.path.join(canonical_data_root, 'stsb_test.tsv')
|
||||
dump_rows(stsb_train_data, stsb_train_fout, DataFormat.PremiseAndOneHypothesis)
|
||||
dump_rows(stsb_dev_data, stsb_dev_fout, DataFormat.PremiseAndOneHypothesis)
|
||||
dump_rows(stsb_test_data, stsb_test_fout, DataFormat.PremiseAndOneHypothesis)
|
||||
logger.info('done with stsb')
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
main(args)
|
|
@ -0,0 +1,403 @@
|
|||
# Copyright (c) Microsoft. All rights reserved.
|
||||
from random import shuffle
|
||||
|
||||
from mtdnn.common.metrics import calc_metrics
|
||||
|
||||
|
||||
def load_scitail(file):
|
||||
"""Loading data of scitail
|
||||
"""
|
||||
rows = []
|
||||
cnt = 0
|
||||
with open(file, encoding="utf8") as f:
|
||||
for line in f:
|
||||
blocks = line.strip().split("\t")
|
||||
assert len(blocks) > 2
|
||||
if blocks[0] == "-":
|
||||
continue
|
||||
sample = {
|
||||
"uid": str(cnt),
|
||||
"premise": blocks[0],
|
||||
"hypothesis": blocks[1],
|
||||
"label": blocks[2],
|
||||
}
|
||||
rows.append(sample)
|
||||
cnt += 1
|
||||
return rows
|
||||
|
||||
|
||||
def load_snli(file, header=True):
|
||||
rows = []
|
||||
cnt = 0
|
||||
with open(file, encoding="utf8") as f:
|
||||
for line in f:
|
||||
if header:
|
||||
header = False
|
||||
continue
|
||||
blocks = line.strip().split("\t")
|
||||
assert len(blocks) > 10
|
||||
if blocks[-1] == "-":
|
||||
continue
|
||||
lab = blocks[-1]
|
||||
if lab is None:
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
sample = {
|
||||
"uid": blocks[0],
|
||||
"premise": blocks[7],
|
||||
"hypothesis": blocks[8],
|
||||
"label": lab,
|
||||
}
|
||||
rows.append(sample)
|
||||
cnt += 1
|
||||
return rows
|
||||
|
||||
|
||||
def load_mnli(file, header=True, multi_snli=False, is_train=True):
|
||||
rows = []
|
||||
cnt = 0
|
||||
with open(file, encoding="utf8") as f:
|
||||
for line in f:
|
||||
if header:
|
||||
header = False
|
||||
continue
|
||||
blocks = line.strip().split("\t")
|
||||
assert len(blocks) > 9
|
||||
if blocks[-1] == "-":
|
||||
continue
|
||||
lab = "contradiction"
|
||||
if is_train:
|
||||
lab = blocks[-1]
|
||||
if lab is None:
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
sample = {
|
||||
"uid": blocks[0],
|
||||
"premise": blocks[8],
|
||||
"hypothesis": blocks[9],
|
||||
"label": lab,
|
||||
}
|
||||
rows.append(sample)
|
||||
cnt += 1
|
||||
return rows
|
||||
|
||||
|
||||
def load_mrpc(file, header=True, is_train=True):
|
||||
rows = []
|
||||
cnt = 0
|
||||
with open(file, encoding="utf8") as f:
|
||||
for line in f:
|
||||
if header:
|
||||
header = False
|
||||
continue
|
||||
blocks = line.strip().split("\t")
|
||||
assert len(blocks) > 4
|
||||
lab = 0
|
||||
if is_train:
|
||||
lab = int(blocks[0])
|
||||
sample = {
|
||||
"uid": cnt,
|
||||
"premise": blocks[-2],
|
||||
"hypothesis": blocks[-1],
|
||||
"label": lab,
|
||||
}
|
||||
rows.append(sample)
|
||||
cnt += 1
|
||||
return rows
|
||||
|
||||
|
||||
def load_qnli(file, header=True, is_train=True):
|
||||
"""QNLI for classification"""
|
||||
rows = []
|
||||
cnt = 0
|
||||
with open(file, encoding="utf8") as f:
|
||||
for line in f:
|
||||
if header:
|
||||
header = False
|
||||
continue
|
||||
blocks = line.strip().split("\t")
|
||||
assert len(blocks) > 2
|
||||
lab = "not_entailment"
|
||||
if is_train:
|
||||
lab = blocks[-1]
|
||||
if lab is None:
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
sample = {
|
||||
"uid": blocks[0],
|
||||
"premise": blocks[1],
|
||||
"hypothesis": blocks[2],
|
||||
"label": lab,
|
||||
}
|
||||
rows.append(sample)
|
||||
cnt += 1
|
||||
return rows
|
||||
|
||||
|
||||
def load_qqp(file, header=True, is_train=True):
|
||||
rows = []
|
||||
cnt = 0
|
||||
skipped = 0
|
||||
with open(file, encoding="utf8") as f:
|
||||
for line in f:
|
||||
if header:
|
||||
header = False
|
||||
continue
|
||||
blocks = line.strip().split("\t")
|
||||
if is_train and len(blocks) < 6:
|
||||
skipped += 1
|
||||
continue
|
||||
if not is_train:
|
||||
assert len(blocks) == 3
|
||||
lab = 0
|
||||
if is_train:
|
||||
lab = int(blocks[-1])
|
||||
sample = {
|
||||
"uid": cnt,
|
||||
"premise": blocks[-3],
|
||||
"hypothesis": blocks[-2],
|
||||
"label": lab,
|
||||
}
|
||||
else:
|
||||
sample = {
|
||||
"uid": int(blocks[0]),
|
||||
"premise": blocks[-2],
|
||||
"hypothesis": blocks[-1],
|
||||
"label": lab,
|
||||
}
|
||||
rows.append(sample)
|
||||
cnt += 1
|
||||
return rows
|
||||
|
||||
|
||||
def load_rte(file, header=True, is_train=True):
|
||||
rows = []
|
||||
cnt = 0
|
||||
with open(file, encoding="utf8") as f:
|
||||
for line in f:
|
||||
if header:
|
||||
header = False
|
||||
continue
|
||||
blocks = line.strip().split("\t")
|
||||
if is_train and len(blocks) < 4:
|
||||
continue
|
||||
if not is_train:
|
||||
assert len(blocks) == 3
|
||||
lab = "not_entailment"
|
||||
if is_train:
|
||||
lab = blocks[-1]
|
||||
sample = {
|
||||
"uid": int(blocks[0]),
|
||||
"premise": blocks[-3],
|
||||
"hypothesis": blocks[-2],
|
||||
"label": lab,
|
||||
}
|
||||
else:
|
||||
sample = {
|
||||
"uid": int(blocks[0]),
|
||||
"premise": blocks[-2],
|
||||
"hypothesis": blocks[-1],
|
||||
"label": lab,
|
||||
}
|
||||
rows.append(sample)
|
||||
cnt += 1
|
||||
return rows
|
||||
|
||||
|
||||
def load_wnli(file, header=True, is_train=True):
|
||||
rows = []
|
||||
cnt = 0
|
||||
with open(file, encoding="utf8") as f:
|
||||
for line in f:
|
||||
if header:
|
||||
header = False
|
||||
continue
|
||||
blocks = line.strip().split("\t")
|
||||
if is_train and len(blocks) < 4:
|
||||
continue
|
||||
if not is_train:
|
||||
assert len(blocks) == 3
|
||||
lab = 0
|
||||
if is_train:
|
||||
lab = int(blocks[-1])
|
||||
sample = {
|
||||
"uid": cnt,
|
||||
"premise": blocks[-3],
|
||||
"hypothesis": blocks[-2],
|
||||
"label": lab,
|
||||
}
|
||||
else:
|
||||
sample = {
|
||||
"uid": cnt,
|
||||
"premise": blocks[-2],
|
||||
"hypothesis": blocks[-1],
|
||||
"label": lab,
|
||||
}
|
||||
rows.append(sample)
|
||||
cnt += 1
|
||||
return rows
|
||||
|
||||
|
||||
def load_diag(file, header=True):
|
||||
rows = []
|
||||
cnt = 0
|
||||
with open(file, encoding="utf8") as f:
|
||||
for line in f:
|
||||
if header:
|
||||
header = False
|
||||
continue
|
||||
blocks = line.strip().split("\t")
|
||||
assert len(blocks) > 3
|
||||
sample = {
|
||||
"uid": cnt,
|
||||
"premise": blocks[-3],
|
||||
"hypothesis": blocks[-2],
|
||||
"label": blocks[-1],
|
||||
}
|
||||
rows.append(sample)
|
||||
cnt += 1
|
||||
return rows
|
||||
|
||||
|
||||
def load_sst(file, header=True, is_train=True):
|
||||
rows = []
|
||||
cnt = 0
|
||||
with open(file, encoding="utf8") as f:
|
||||
for line in f:
|
||||
if header:
|
||||
header = False
|
||||
continue
|
||||
blocks = line.strip().split("\t")
|
||||
if is_train and len(blocks) < 2:
|
||||
continue
|
||||
lab = 0
|
||||
if is_train:
|
||||
lab = int(blocks[-1])
|
||||
sample = {"uid": cnt, "premise": blocks[0], "label": lab}
|
||||
else:
|
||||
sample = {"uid": int(blocks[0]), "premise": blocks[1], "label": lab}
|
||||
|
||||
cnt += 1
|
||||
rows.append(sample)
|
||||
return rows
|
||||
|
||||
|
||||
def load_cola(file, header=True, is_train=True):
|
||||
rows = []
|
||||
cnt = 0
|
||||
with open(file, encoding="utf8") as f:
|
||||
for line in f:
|
||||
if header:
|
||||
header = False
|
||||
continue
|
||||
blocks = line.strip().split("\t")
|
||||
if is_train and len(blocks) < 2:
|
||||
continue
|
||||
lab = 0
|
||||
if is_train:
|
||||
lab = int(blocks[1])
|
||||
sample = {"uid": cnt, "premise": blocks[-1], "label": lab}
|
||||
else:
|
||||
sample = {"uid": cnt, "premise": blocks[-1], "label": lab}
|
||||
rows.append(sample)
|
||||
cnt += 1
|
||||
return rows
|
||||
|
||||
|
||||
def load_sts(file, header=True, is_train=True):
|
||||
rows = []
|
||||
cnt = 0
|
||||
with open(file, encoding="utf8") as f:
|
||||
for line in f:
|
||||
if header:
|
||||
header = False
|
||||
continue
|
||||
blocks = line.strip().split("\t")
|
||||
assert len(blocks) > 8
|
||||
score = "0.0"
|
||||
if is_train:
|
||||
score = blocks[-1]
|
||||
sample = {
|
||||
"uid": cnt,
|
||||
"premise": blocks[-3],
|
||||
"hypothesis": blocks[-2],
|
||||
"label": score,
|
||||
}
|
||||
else:
|
||||
sample = {
|
||||
"uid": cnt,
|
||||
"premise": blocks[-2],
|
||||
"hypothesis": blocks[-1],
|
||||
"label": score,
|
||||
}
|
||||
rows.append(sample)
|
||||
cnt += 1
|
||||
return rows
|
||||
|
||||
|
||||
def load_qnnli(file, header=True, is_train=True):
|
||||
"""QNLI for ranking"""
|
||||
rows = []
|
||||
mis_matched_cnt = 0
|
||||
cnt = 0
|
||||
with open(file, encoding="utf8") as f:
|
||||
lines = f.readlines()
|
||||
if header:
|
||||
lines = lines[1:]
|
||||
|
||||
assert len(lines) % 2 == 0
|
||||
for idx in range(0, len(lines), 2):
|
||||
block1 = lines[idx].strip().split("\t")
|
||||
block2 = lines[idx + 1].strip().split("\t")
|
||||
# train shuffle
|
||||
assert len(block1) > 2 and len(block2) > 2
|
||||
if is_train and block1[1] != block2[1]:
|
||||
mis_matched_cnt += 1
|
||||
continue
|
||||
assert block1[1] == block2[1]
|
||||
lab1, lab2 = "entailment", "entailment"
|
||||
if is_train:
|
||||
blocks = [block1, block2]
|
||||
shuffle(blocks)
|
||||
block1 = blocks[0]
|
||||
block2 = blocks[1]
|
||||
lab1 = block1[-1]
|
||||
lab2 = block2[-1]
|
||||
if lab1 == lab2:
|
||||
mis_matched_cnt += 1
|
||||
continue
|
||||
assert "," not in lab1
|
||||
assert "," not in lab2
|
||||
assert "," not in block1[0]
|
||||
assert "," not in block2[0]
|
||||
sample = {
|
||||
"uid": cnt,
|
||||
"ruid": "%s,%s" % (block1[0], block2[0]),
|
||||
"premise": block1[1],
|
||||
"hypothesis": [block1[2], block2[2]],
|
||||
"label": "%s,%s" % (lab1, lab2),
|
||||
}
|
||||
cnt += 1
|
||||
rows.append(sample)
|
||||
return rows
|
||||
|
||||
|
||||
def submit(path, data, label_dict=None):
|
||||
header = "index\tprediction"
|
||||
with open(path, "w") as writer:
|
||||
predictions, uids = data["predictions"], data["uids"]
|
||||
writer.write("{}\n".format(header))
|
||||
assert len(predictions) == len(uids)
|
||||
# sort label
|
||||
paired = [(int(uid), predictions[idx]) for idx, uid in enumerate(uids)]
|
||||
paired = sorted(paired, key=lambda item: item[0])
|
||||
for uid, pred in paired:
|
||||
if label_dict is None:
|
||||
writer.write("{}\t{}\n".format(uid, pred))
|
||||
else:
|
||||
assert type(pred) is int
|
||||
writer.write("{}\t{}\n".format(uid, label_dict[pred]))
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
class LinearPooler(nn.Module):
|
||||
def __init__(self, hidden_size):
|
||||
super(LinearPooler, self).__init__()
|
||||
self.dense = nn.Linear(hidden_size, hidden_size)
|
||||
self.activation = nn.Tanh()
|
||||
|
||||
def forward(self, hidden_states):
|
||||
first_token_tensor = hidden_states[:, 0]
|
||||
pooled_output = self.dense(first_token_tensor)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
return pooled_output
|
|
@ -0,0 +1,119 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import torch
|
||||
from torch.nn.modules.loss import _Loss
|
||||
import torch.nn.functional as F
|
||||
from enum import IntEnum
|
||||
|
||||
class Criterion(_Loss):
|
||||
def __init__(self, alpha=1.0, name='criterion'):
|
||||
super().__init__()
|
||||
"""Alpha is used to weight each loss term
|
||||
"""
|
||||
self.alpha = alpha
|
||||
self.name = name
|
||||
|
||||
def forward(self, input, target, weight=None, ignore_index=-1):
|
||||
"""weight: sample weight
|
||||
"""
|
||||
return
|
||||
|
||||
class CeCriterion(Criterion):
|
||||
def __init__(self, alpha=1.0, name='Cross Entropy Criterion'):
|
||||
super().__init__()
|
||||
self.alpha = alpha
|
||||
self.name = name
|
||||
|
||||
def forward(self, input, target, weight=None, ignore_index=-1):
|
||||
"""weight: sample weight
|
||||
"""
|
||||
if weight:
|
||||
loss = torch.mean(F.cross_entropy(input, target, reduce=False, ignore_index=ignore_index) * weight)
|
||||
else:
|
||||
loss = F.cross_entropy(input, target, ignore_index=ignore_index)
|
||||
loss = loss * self.alpha
|
||||
return loss
|
||||
|
||||
class SeqCeCriterion(CeCriterion):
|
||||
def __init__(self, alpha=1.0, name='Seq Cross Entropy Criterion'):
|
||||
super().__init__(alpha, name)
|
||||
|
||||
def forward(self, input, target, weight=None, ignore_index=-1):
|
||||
target = target.view(-1)
|
||||
if weight:
|
||||
loss = torch.mean(F.cross_entropy(input, target, reduce=False, ignore_index=ignore_index) * weight)
|
||||
else:
|
||||
loss = F.cross_entropy(input, target, ignore_index=ignore_index)
|
||||
loss = loss * self.alpha
|
||||
return loss
|
||||
|
||||
class MseCriterion(Criterion):
|
||||
def __init__(self, alpha=1.0, name='MSE Regression Criterion'):
|
||||
super().__init__()
|
||||
self.alpha = alpha
|
||||
self.name = name
|
||||
|
||||
def forward(self, input, target, weight=None, ignore_index=-1):
|
||||
"""weight: sample weight
|
||||
"""
|
||||
if weight:
|
||||
loss = torch.mean(F.mse_loss(input.squeeze(), target, reduce=False) * weight)
|
||||
else:
|
||||
loss = F.mse_loss(input.squeeze(), target)
|
||||
loss = loss * self.alpha
|
||||
return loss
|
||||
|
||||
class RankCeCriterion(Criterion):
|
||||
def __init__(self, alpha=1.0, name='Cross Entropy Criterion'):
|
||||
super().__init__()
|
||||
self.alpha = alpha
|
||||
self.name = name
|
||||
|
||||
def forward(self, input, target, weight=None, ignore_index=-1, pairwise_size=1):
|
||||
input = input.view(-1, pairwise_size)
|
||||
target = target.contiguous().view(-1, pairwise_size)[:, 0]
|
||||
if weight:
|
||||
loss = torch.mean(F.cross_entropy(input, target, reduce=False, ignore_index=ignore_index) * weight)
|
||||
else:
|
||||
loss = F.cross_entropy(input, target, ignore_index=ignore_index)
|
||||
loss = loss * self.alpha
|
||||
return loss
|
||||
|
||||
class SpanCeCriterion(Criterion):
|
||||
def __init__(self, alpha=1.0, name='Span Cross Entropy Criterion'):
|
||||
super().__init__()
|
||||
"""This is for extractive MRC, e.g., SQuAD, ReCoRD ... etc
|
||||
"""
|
||||
self.alpha = alpha
|
||||
self.name = name
|
||||
|
||||
def forward(self, input, target, weight=None, ignore_index=-1):
|
||||
"""weight: sample weight
|
||||
"""
|
||||
assert len(input) == 2
|
||||
start_input, end_input = input
|
||||
start_target, end_target = target
|
||||
if weight:
|
||||
b = torch.mean(F.cross_entropy(start_input, start_target, reduce=False, ignore_index=ignore_index) * weight)
|
||||
e = torch.mean(F.cross_entropy(end_input, end_target, reduce=False, ignore_index=ignore_index) * weight)
|
||||
else:
|
||||
b = F.cross_entropy(start_input, start_target, ignore_index=ignore_index)
|
||||
e = F.cross_entropy(end_input, end_target, ignore_index=ignore_index)
|
||||
loss = 0.5 * (b + e) * self.alpha
|
||||
return loss
|
||||
|
||||
class LossCriterion(IntEnum):
|
||||
CeCriterion = 0
|
||||
MseCriterion = 1
|
||||
RankCeCriterion = 2
|
||||
SpanCeCriterion = 3
|
||||
SeqCeCriterion = 4
|
||||
|
||||
LOSS_REGISTRY = {
|
||||
LossCriterion.CeCriterion: CeCriterion,
|
||||
LossCriterion.MseCriterion: MseCriterion,
|
||||
LossCriterion.RankCeCriterion: RankCeCriterion,
|
||||
LossCriterion.SpanCeCriterion: SpanCeCriterion,
|
||||
LossCriterion.SeqCeCriterion: SeqCeCriterion,
|
||||
}
|
|
@ -0,0 +1,111 @@
|
|||
# Copyright (c) Microsoft. All rights reserved.
|
||||
from enum import Enum
|
||||
|
||||
from sklearn.metrics import matthews_corrcoef
|
||||
from sklearn.metrics import accuracy_score, f1_score
|
||||
from sklearn.metrics import roc_auc_score
|
||||
from scipy.stats import pearsonr, spearmanr
|
||||
from seqeval.metrics import classification_report
|
||||
from mtdnn.common.squad_eval import evaluate_func
|
||||
|
||||
|
||||
def compute_acc(predicts, labels):
|
||||
return 100.0 * accuracy_score(labels, predicts)
|
||||
|
||||
|
||||
def compute_f1(predicts, labels):
|
||||
return 100.0 * f1_score(labels, predicts)
|
||||
|
||||
|
||||
def compute_mcc(predicts, labels):
|
||||
return 100.0 * matthews_corrcoef(labels, predicts)
|
||||
|
||||
|
||||
def compute_pearson(predicts, labels):
|
||||
pcof = pearsonr(labels, predicts)[0]
|
||||
return 100.0 * pcof
|
||||
|
||||
|
||||
def compute_spearman(predicts, labels):
|
||||
scof = spearmanr(labels, predicts)[0]
|
||||
return 100.0 * scof
|
||||
|
||||
|
||||
def compute_auc(predicts, labels):
|
||||
auc = roc_auc_score(labels, predicts)
|
||||
return 100.0 * auc
|
||||
|
||||
|
||||
def compute_seqacc(predicts, labels, label_mapper):
|
||||
y_true, y_pred = [], []
|
||||
|
||||
def trim(predict, label):
|
||||
temp_1 = []
|
||||
temp_2 = []
|
||||
for j, m in enumerate(predict):
|
||||
if j == 0:
|
||||
continue
|
||||
if label_mapper[label[j]] != "X":
|
||||
temp_1.append(label_mapper[label[j]])
|
||||
temp_2.append(label_mapper[m])
|
||||
temp_1.pop()
|
||||
temp_2.pop()
|
||||
y_true.append(temp_1)
|
||||
y_pred.append(temp_2)
|
||||
|
||||
for predict, label in zip(predicts, labels):
|
||||
trim(predict, label)
|
||||
report = classification_report(y_true, y_pred, digits=4)
|
||||
return report
|
||||
|
||||
|
||||
def compute_emf1(predicts, labels):
|
||||
return evaluate_func(labels, predicts)
|
||||
|
||||
|
||||
class Metric(Enum):
|
||||
ACC = 0
|
||||
F1 = 1
|
||||
MCC = 2
|
||||
Pearson = 3
|
||||
Spearman = 4
|
||||
AUC = 5
|
||||
SeqEval = 7
|
||||
EmF1 = 8
|
||||
|
||||
|
||||
METRIC_FUNC = {
|
||||
Metric.ACC: compute_acc,
|
||||
Metric.F1: compute_f1,
|
||||
Metric.MCC: compute_mcc,
|
||||
Metric.Pearson: compute_pearson,
|
||||
Metric.Spearman: compute_spearman,
|
||||
Metric.AUC: compute_auc,
|
||||
Metric.SeqEval: compute_seqacc,
|
||||
Metric.EmF1: compute_emf1,
|
||||
}
|
||||
|
||||
|
||||
def calc_metrics(metric_meta, golds, predictions, scores, label_mapper=None):
|
||||
"""Label Mapper is used for NER/POS etc.
|
||||
TODO: a better refactor, by xiaodl
|
||||
"""
|
||||
metrics = {}
|
||||
for mm in metric_meta:
|
||||
metric_name = mm.name
|
||||
metric_func = METRIC_FUNC[mm]
|
||||
if mm in (Metric.ACC, Metric.F1, Metric.MCC):
|
||||
metric = metric_func(predictions, golds)
|
||||
elif mm == Metric.SeqEval:
|
||||
metric = metric_func(predictions, golds, label_mapper)
|
||||
elif mm == Metric.EmF1:
|
||||
metric = metric_func(predictions, golds)
|
||||
else:
|
||||
if mm == Metric.AUC:
|
||||
assert len(scores) == 2 * len(
|
||||
golds
|
||||
), "AUC is only valid for binary classification problem"
|
||||
scores = scores[1::2]
|
||||
metric = metric_func(scores, golds)
|
||||
metrics[metric_name] = metric
|
||||
return metrics
|
|
@ -0,0 +1,113 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
from copy import deepcopy
|
||||
from functools import wraps
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
|
||||
class EMA:
|
||||
def __init__(self, gamma, model):
|
||||
super(EMA, self).__init__()
|
||||
self.gamma = gamma
|
||||
self.shadow = {}
|
||||
self.model = model
|
||||
self.setup()
|
||||
|
||||
def setup(self):
|
||||
for name, para in self.model.named_parameters():
|
||||
if para.requires_grad:
|
||||
self.shadow[name] = para.clone()
|
||||
|
||||
def cuda(self):
|
||||
for k, v in self.shadow.items():
|
||||
self.shadow[k] = v.cuda()
|
||||
|
||||
def update(self):
|
||||
for name, para in self.model.named_parameters():
|
||||
if para.requires_grad:
|
||||
self.shadow[name] = (1.0 - self.gamma) * para + self.gamma * self.shadow[name]
|
||||
|
||||
def swap_parameters(self):
|
||||
for name, para in self.model.named_parameters():
|
||||
if para.requires_grad:
|
||||
temp_data = para.data
|
||||
para.data = self.shadow[name].data
|
||||
self.shadow[name].data = temp_data
|
||||
|
||||
def state_dict(self):
|
||||
return self.shadow
|
||||
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/weight_norm.py
|
||||
# and https://github.com/salesforce/awd-lstm-lm/blob/master/weight_drop.py
|
||||
def _norm(p, dim):
|
||||
"""Computes the norm over all dimensions except dim"""
|
||||
if dim is None:
|
||||
return p.norm()
|
||||
elif dim == 0:
|
||||
output_size = (p.size(0),) + (1,) * (p.dim() - 1)
|
||||
return p.contiguous().view(p.size(0), -1).norm(dim=1).view(*output_size)
|
||||
elif dim == p.dim() - 1:
|
||||
output_size = (1,) * (p.dim() - 1) + (p.size(-1),)
|
||||
return p.contiguous().view(-1, p.size(-1)).norm(dim=0).view(*output_size)
|
||||
else:
|
||||
return _norm(p.transpose(0, dim), 0).transpose(0, dim)
|
||||
|
||||
|
||||
def _dummy(*args, **kwargs):
|
||||
# We need to replace flatten_parameters with a nothing function
|
||||
return
|
||||
|
||||
|
||||
class WeightNorm(torch.nn.Module):
|
||||
def __init__(self, weights, dim):
|
||||
super(WeightNorm, self).__init__()
|
||||
self.weights = weights
|
||||
self.dim = dim
|
||||
|
||||
def compute_weight(self, module, name):
|
||||
g = getattr(module, name + "_g")
|
||||
v = getattr(module, name + "_v")
|
||||
return v * (g / _norm(v, self.dim))
|
||||
|
||||
@staticmethod
|
||||
def apply(module, weights, dim):
|
||||
# Terrible temporary solution to an issue regarding compacting weights
|
||||
# re: CUDNN RNN
|
||||
if issubclass(type(module), torch.nn.RNNBase):
|
||||
module.flatten_parameters = _dummy
|
||||
if weights is None: # do for all weight params
|
||||
weights = [w for w in module._parameters.keys() if "weight" in w]
|
||||
fn = WeightNorm(weights, dim)
|
||||
for name in weights:
|
||||
if hasattr(module, name):
|
||||
print("Applying weight norm to {} - {}".format(str(module), name))
|
||||
weight = getattr(module, name)
|
||||
del module._parameters[name]
|
||||
module.register_parameter(name + "_g", Parameter(_norm(weight, dim).data))
|
||||
module.register_parameter(name + "_v", Parameter(weight.data))
|
||||
setattr(module, name, fn.compute_weight(module, name))
|
||||
|
||||
module.register_forward_pre_hook(fn)
|
||||
|
||||
return fn
|
||||
|
||||
def remove(self, module):
|
||||
for name in self.weights:
|
||||
weight = self.compute_weight(module)
|
||||
delattr(module, name)
|
||||
del module._parameters[name + "_g"]
|
||||
del module._parameters[name + "_v"]
|
||||
module.register_parameter(name, Parameter(weight.data))
|
||||
|
||||
def __call__(self, module, inputs):
|
||||
for name in self.weights:
|
||||
setattr(module, name, self.compute_weight(module, name))
|
||||
|
||||
|
||||
def weight_norm(module, weights=None, dim=0):
|
||||
WeightNorm.apply(module, weights, dim)
|
||||
return module
|
|
@ -0,0 +1,290 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
import random
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from fairseq.models.roberta import RobertaModel as FairseqRobertModel
|
||||
from pytorch_pretrained_bert.modeling import BertConfig, BertLayerNorm, BertModel
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
from mtdnn.common.dropout_wrapper import DropoutWrapper
|
||||
from mtdnn.common.optimizer import weight_norm as WN
|
||||
from mtdnn.common.similarity import FlatSimilarityWrapper, SelfAttnWrapper
|
||||
from mtdnn.common.types import EncoderModelType, TaskType
|
||||
from mtdnn.configuration_mtdnn import MTDNNConfig
|
||||
|
||||
SMALL_POS_NUM = 1.0e-30
|
||||
|
||||
|
||||
class Classifier(nn.Module):
|
||||
def __init__(self, x_size, y_size, opt, prefix="decoder", dropout=None):
|
||||
super(Classifier, self).__init__()
|
||||
self.opt = opt
|
||||
if dropout is None:
|
||||
self.dropout = DropoutWrapper(opt.get("{}_dropout_p".format(prefix), 0))
|
||||
else:
|
||||
self.dropout = dropout
|
||||
self.merge_opt = opt.get("{}_merge_opt".format(prefix), 0)
|
||||
self.weight_norm_on = opt.get("{}_weight_norm_on".format(prefix), False)
|
||||
|
||||
if self.merge_opt == 1:
|
||||
self.proj = nn.Linear(x_size * 4, y_size)
|
||||
else:
|
||||
self.proj = nn.Linear(x_size * 2, y_size)
|
||||
|
||||
if self.weight_norm_on:
|
||||
self.proj = weight_norm(self.proj)
|
||||
|
||||
def forward(self, x1, x2, mask=None):
|
||||
if self.merge_opt == 1:
|
||||
x = torch.cat([x1, x2, (x1 - x2).abs(), x1 * x2], 1)
|
||||
else:
|
||||
x = torch.cat([x1, x2], 1)
|
||||
x = self.dropout(x)
|
||||
scores = self.proj(x)
|
||||
return scores
|
||||
|
||||
|
||||
class SANClassifier(nn.Module):
|
||||
"""Implementation of Stochastic Answer Networks for Natural Language Inference, Xiaodong Liu, Kevin Duh and Jianfeng Gao
|
||||
https://arxiv.org/abs/1804.07888
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, x_size, h_size, label_size, opt={}, prefix="decoder", dropout=None
|
||||
):
|
||||
super(SANClassifier, self).__init__()
|
||||
if dropout is None:
|
||||
self.dropout = DropoutWrapper(
|
||||
opt.get("{}_dropout_p".format(self.prefix), 0)
|
||||
)
|
||||
else:
|
||||
self.dropout = dropout
|
||||
self.prefix = prefix
|
||||
self.query_wsum = SelfAttnWrapper(
|
||||
x_size, prefix="mem_cum", opt=opt, dropout=self.dropout
|
||||
)
|
||||
self.attn = FlatSimilarityWrapper(x_size, h_size, prefix, opt, self.dropout)
|
||||
self.rnn_type = "{}{}".format(
|
||||
opt.get("{}_rnn_type".format(prefix), "gru").upper(), "Cell"
|
||||
)
|
||||
self.rnn = getattr(nn, self.rnn_type)(x_size, h_size)
|
||||
self.num_turn = opt.get("{}_num_turn".format(prefix), 5)
|
||||
self.opt = opt
|
||||
self.mem_random_drop = opt.get("{}_mem_drop_p".format(prefix), 0)
|
||||
self.mem_type = opt.get("{}_mem_type".format(prefix), 0)
|
||||
self.weight_norm_on = opt.get("{}_weight_norm_on".format(prefix), False)
|
||||
self.label_size = label_size
|
||||
self.dump_state = opt.get("dump_state_on", False)
|
||||
self.alpha = Parameter(torch.zeros(1, 1), requires_grad=False)
|
||||
if self.weight_norm_on:
|
||||
self.rnn = WN(self.rnn)
|
||||
|
||||
self.classifier = Classifier(
|
||||
x_size, self.label_size, opt, prefix=prefix, dropout=self.dropout
|
||||
)
|
||||
|
||||
def _generate_mask(self, new_data, dropout_p=0.0, is_training=False):
|
||||
if not is_training:
|
||||
dropout_p = 0.0
|
||||
new_data = (1 - dropout_p) * (new_data.zero_() + 1)
|
||||
for i in range(new_data.size(0)):
|
||||
one = random.randint(0, new_data.size(1) - 1)
|
||||
new_data[i][one] = 1
|
||||
mask = 1.0 / (1 - dropout_p) * torch.bernoulli(new_data)
|
||||
mask.requires_grad = False
|
||||
return mask
|
||||
|
||||
def forward(self, x, h0, x_mask=None, h_mask=None):
|
||||
h0 = self.query_wsum(h0, h_mask)
|
||||
if type(self.rnn) is nn.LSTMCell:
|
||||
c0 = h0.new(h0.size()).zero_()
|
||||
scores_list = []
|
||||
for turn in range(self.num_turn):
|
||||
att_scores = self.attn(x, h0, x_mask)
|
||||
x_sum = torch.bmm(F.softmax(att_scores, 1).unsqueeze(1), x).squeeze(1)
|
||||
scores = self.classifier(x_sum, h0)
|
||||
scores_list.append(scores)
|
||||
# next turn
|
||||
if self.rnn is not None:
|
||||
h0 = self.dropout(h0)
|
||||
if type(self.rnn) is nn.LSTMCell:
|
||||
h0, c0 = self.rnn(x_sum, (h0, c0))
|
||||
else:
|
||||
h0 = self.rnn(x_sum, h0)
|
||||
if self.mem_type == 1:
|
||||
mask = self._generate_mask(
|
||||
self.alpha.data.new(x.size(0), self.num_turn),
|
||||
self.mem_random_drop,
|
||||
self.training,
|
||||
)
|
||||
mask = [m.contiguous() for m in torch.unbind(mask, 1)]
|
||||
tmp_scores_list = [
|
||||
mask[idx].view(x.size(0), 1).expand_as(inp) * F.softmax(inp, 1)
|
||||
for idx, inp in enumerate(scores_list)
|
||||
]
|
||||
scores = torch.stack(tmp_scores_list, 2)
|
||||
scores = torch.mean(scores, 2)
|
||||
scores = torch.log(scores)
|
||||
else:
|
||||
scores = scores_list[-1]
|
||||
if self.dump_state:
|
||||
return scores, scores_list
|
||||
else:
|
||||
return scores
|
||||
|
||||
|
||||
class SANBERTNetwork(nn.Module):
|
||||
"""Implementation of Stochastic Answer Networks for Natural Language Inference, Xiaodong Liu, Kevin Duh and Jianfeng Gao
|
||||
https://arxiv.org/abs/1804.07888
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
init_checkpoint_model: Union[BertModel, FairseqRobertModel],
|
||||
pooler,
|
||||
config: MTDNNConfig,
|
||||
):
|
||||
super(SANBERTNetwork, self).__init__()
|
||||
self.config = config
|
||||
self.bert = init_checkpoint_model
|
||||
self.pooler = pooler
|
||||
self.dropout_list = nn.ModuleList()
|
||||
self.encoder_type = config.encoder_type
|
||||
self.hidden_size = self.config.hidden_size
|
||||
|
||||
# Dump other features if value is set to true
|
||||
if config.dump_feature:
|
||||
return
|
||||
|
||||
# Update bert parameters
|
||||
if config.update_bert_opt > 0:
|
||||
for param in self.bert.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# Set decoder and scoring list parameters
|
||||
self.decoder_opts = config.decoder_opts
|
||||
self.scoring_list = nn.ModuleList()
|
||||
|
||||
# Set task specific paramaters
|
||||
self.task_types = config.task_types
|
||||
self.task_dropout_p = config.tasks_dropout_p
|
||||
self.tasks_nclass_list = config.tasks_nclass_list
|
||||
|
||||
# TODO - Move to training
|
||||
# Generate tasks decoding and scoring lists
|
||||
self._generate_tasks_decoding_scoring_options()
|
||||
|
||||
# Initialize weights
|
||||
|
||||
# self._my_init()
|
||||
|
||||
def _my_init(self):
|
||||
def init_weights(module):
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=0.02 * self.config.init_ratio)
|
||||
elif isinstance(module, BertLayerNorm):
|
||||
# Slightly different from the BERT pytorch version, which should be a bug.
|
||||
# Note that it only affects on training from scratch. For detailed discussions, please contact xiaodl@.
|
||||
# Layer normalization (https://arxiv.org/abs/1607.06450)
|
||||
# support both old/latest version
|
||||
if "beta" in dir(module) and "gamma" in dir(module):
|
||||
module.beta.data.zero_()
|
||||
module.gamma.data.fill_(1.0)
|
||||
else:
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear):
|
||||
module.bias.data.zero_()
|
||||
|
||||
self.apply(init_weights)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
attention_mask,
|
||||
premise_mask=None,
|
||||
hyp_mask=None,
|
||||
task_id=0,
|
||||
):
|
||||
if self.encoder_type == EncoderModelType.ROBERTA:
|
||||
sequence_output = self.bert.extract_features(input_ids)
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
else:
|
||||
all_encoder_layers, pooled_output = self.bert(
|
||||
input_ids=input_ids,
|
||||
token_type_ids=token_type_ids,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
sequence_output = all_encoder_layers[-1]
|
||||
|
||||
decoder_opt = self.decoder_opts[task_id]
|
||||
task_type = self.task_types[task_id]
|
||||
if task_type == TaskType.Span:
|
||||
assert decoder_opt != 1
|
||||
sequence_output = self.dropout_list[task_id](sequence_output)
|
||||
logits = self.scoring_list[task_id](sequence_output)
|
||||
start_scores, end_scores = logits.split(1, dim=-1)
|
||||
start_scores = start_scores.squeeze(-1)
|
||||
end_scores = end_scores.squeeze(-1)
|
||||
return start_scores, end_scores
|
||||
elif task_type == TaskType.SequenceLabeling:
|
||||
pooled_output = all_encoder_layers[-1]
|
||||
pooled_output = self.dropout_list[task_id](pooled_output)
|
||||
pooled_output = pooled_output.contiguous().view(-1, pooled_output.size(2))
|
||||
logits = self.scoring_list[task_id](pooled_output)
|
||||
return logits
|
||||
else:
|
||||
if decoder_opt == 1:
|
||||
max_query = hyp_mask.size(1)
|
||||
assert max_query > 0
|
||||
assert premise_mask is not None
|
||||
assert hyp_mask is not None
|
||||
hyp_mem = sequence_output[:, :max_query, :]
|
||||
logits = self.scoring_list[task_id](
|
||||
sequence_output, hyp_mem, premise_mask, hyp_mask
|
||||
)
|
||||
else:
|
||||
pooled_output = self.dropout_list[task_id](pooled_output)
|
||||
logits = self.scoring_list[task_id](pooled_output)
|
||||
return logits
|
||||
|
||||
# TODO - Move to training step
|
||||
def _generate_tasks_decoding_scoring_options(self):
|
||||
""" Enumerate over tasks and setup decoding and scoring list for training """
|
||||
assert (
|
||||
len(self.tasks_nclass_list) > 0
|
||||
), "Number of classes to train for cannot be 0"
|
||||
for idx, task_num_labels in enumerate(self.tasks_nclass_list):
|
||||
print(f"idx: {idx}, number of task labels: {task_num_labels}")
|
||||
decoder_opt = self.decoder_opts[idx]
|
||||
task_type = self.task_types[idx]
|
||||
dropout = DropoutWrapper(
|
||||
self.task_dropout_p[idx], self.config.enable_variational_dropout
|
||||
)
|
||||
self.dropout_list.append(dropout)
|
||||
if task_type == TaskType.Span:
|
||||
assert decoder_opt != 1
|
||||
out_proj = nn.Linear(self.hidden_size, 2)
|
||||
elif task_type == TaskType.SequenceLabeling:
|
||||
out_proj = nn.Linear(self.hidden_size, task_num_labels)
|
||||
else:
|
||||
if decoder_opt == 1:
|
||||
out_proj = SANClassifier(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
task_num_labels,
|
||||
self.config.to_dict(),
|
||||
prefix="answer",
|
||||
dropout=dropout,
|
||||
)
|
||||
else:
|
||||
out_proj = nn.Linear(self.hidden_size, task_num_labels)
|
||||
self.scoring_list.append(out_proj)
|
|
@ -0,0 +1,700 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
import numpy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
from mtdnn.common.activation_functions import activation, init_wrapper
|
||||
from mtdnn.common.dropout_wrapper import DropoutWrapper
|
||||
|
||||
|
||||
class DotProduct(nn.Module):
|
||||
def __init__(self, x1_dim, x2_dim, prefix="sim", opt={}, dropout=None):
|
||||
super(DotProduct, self).__init__()
|
||||
assert x1_dim == x2_dim
|
||||
self.opt = opt
|
||||
self.prefix = prefix
|
||||
self.scale_on = opt.get("{}_scale".format(self.prefix), False)
|
||||
self.scalor = 1.0 / numpy.power(x2_dim, 0.5)
|
||||
|
||||
def forward(self, x1, x2):
|
||||
assert x1.size(2) == x2.size(2)
|
||||
scores = x1.bmm(x2.transpose(1, 2))
|
||||
if self.scale_on:
|
||||
scores *= self.scalor
|
||||
return scores
|
||||
|
||||
|
||||
class DotProductProject(nn.Module):
|
||||
def __init__(self, x1_dim, x2_dim, prefix="sim", opt={}, dropout=None):
|
||||
super(DotProductProject, self).__init__()
|
||||
self.prefix = prefix
|
||||
self.opt = opt
|
||||
self.hidden_size = opt.get("{}_hidden_size".format(self.prefix), 64)
|
||||
self.residual_on = opt.get("{}_residual_on".format(self.prefix), False)
|
||||
self.layer_norm_on = opt.get("{}_norm_on".format(self.prefix), False)
|
||||
self.share = opt.get("{}_share".format(self.prefix), False)
|
||||
self.f = activation(opt.get("{}_activation".format(self.prefix), "relu"))
|
||||
self.scale_on = opt.get("{}_scale_on".format(self.prefix), False)
|
||||
self.dropout = dropout
|
||||
x1_in_dim = x1_dim
|
||||
x2_in_dim = x2_dim
|
||||
out_dim = self.hidden_size
|
||||
self.proj_1 = nn.Linear(x1_in_dim, out_dim, bias=False)
|
||||
if self.layer_norm_on:
|
||||
self.proj_1 = weight_norm(self.proj_1)
|
||||
if self.share and x1_in_dim == x2_in_dim:
|
||||
self.proj_2 = self.proj_1
|
||||
else:
|
||||
self.proj_2 = nn.Linear(x2_in_dim, out_dim)
|
||||
if self.layer_norm_on:
|
||||
self.proj_2 = weight_norm(self.proj_2)
|
||||
|
||||
if self.scale_on:
|
||||
self.scalar = Parameter(
|
||||
torch.ones(1, 1, 1) / (self.hidden_size ** 0.5), requires_grad=False
|
||||
)
|
||||
else:
|
||||
self.sclalar = Parameter(
|
||||
torch.ones(1, 1, self.hidden_size), requires_grad=True
|
||||
)
|
||||
|
||||
def forward(self, x1, x2):
|
||||
assert x1.size(2) == x2.size(2)
|
||||
if self.dropout:
|
||||
x1 = self.dropout(x1)
|
||||
x2 = self.dropout(x2)
|
||||
x1_flat = x1.contiguous().view(-1, x1.size(2))
|
||||
x2_flat = x2.contiguous().view(-1, x2.size(2))
|
||||
x1_o = self.f(self.proj_1(x1_flat)).view(x1.size(0), x1.size(1), -1)
|
||||
# x2_o = self.f(self.proj_1(x2_flat)).view(x2.size(0), x2.size(1), -1)
|
||||
x2_o = self.f(self.proj_2(x2_flat)).view(x2.size(0), x2.size(1), -1)
|
||||
if self.scale_on:
|
||||
scalar = self.scalar.expand_as(x2_o)
|
||||
x2_o = scalar * x2_o
|
||||
scores = x1_o.bmm(x2_o.transpose(1, 2))
|
||||
return scores
|
||||
|
||||
|
||||
class Bilinear(nn.Module):
|
||||
def __init__(self, x1_dim, x2_dim, prefix="sim", opt={}, dropout=None):
|
||||
super(Bilinear, self).__init__()
|
||||
self.opt = opt
|
||||
self.layer_norm_on = opt.get("{}_norm_on".format(self.prefix), False)
|
||||
self.transform_on = opt.get("{}_proj_on".format(self.prefix), False)
|
||||
# self.init = init_wrapper(opt.get('{}_init'.format(self.prefix), ''))
|
||||
self.dropout = dropout
|
||||
if self.transform_on:
|
||||
self.proj = nn.Linear(x1_dim, x2_dim)
|
||||
# self.init(self.proj.weight)
|
||||
if self.layer_norm_on:
|
||||
self.proj = weight_norm(self.proj)
|
||||
|
||||
def forward(self, x, y):
|
||||
"""
|
||||
x = batch * len * h1
|
||||
y = batch * h2
|
||||
x_mask = batch * len
|
||||
"""
|
||||
if self.dropout:
|
||||
x = self.dropout(x)
|
||||
y = self.dropout(y)
|
||||
|
||||
proj = self.proj(y) if self.transform_on else y
|
||||
if self.dropout:
|
||||
proj = self.dropout(proj)
|
||||
scores = x.bmm(proj.unsqueeze(2)).squeeze(2)
|
||||
return scores
|
||||
|
||||
|
||||
class BilinearSum(nn.Module):
|
||||
def __init__(self, x1_dim, x2_dim, prefix="sim", opt={}, dropout=None):
|
||||
super(BilinearSum, self).__init__()
|
||||
self.x_linear = nn.Linear(x1_dim, 1, bias=False)
|
||||
self.y_linear = nn.Linear(x2_dim, 1, bias=False)
|
||||
self.layer_norm_on = opt.get("{}_norm_on".format(self.prefix), False)
|
||||
self.init = init_wrapper(opt.get("{}_init".format(self.prefix), False))
|
||||
if self.layer_norm_on:
|
||||
self.x_linear = weight_norm(self.x_linear)
|
||||
self.y_linear = weight_norm(self.y_linear)
|
||||
|
||||
self.init(self.x_linear.weight)
|
||||
self.init(self.y_linear.weight)
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, x1, x2):
|
||||
"""
|
||||
x1: batch * len1 * input_size
|
||||
x2: batch * len2 * input_size
|
||||
score: batch * len1 * len2
|
||||
"""
|
||||
if self.dropout:
|
||||
x1 = self.dropout(x1)
|
||||
x2 = self.dropout(x2)
|
||||
|
||||
x1_logits = self.x_linear(x1.contiguous().view(-1, x1.size(-1))).view(
|
||||
x1.size(0), -1, 1
|
||||
)
|
||||
x2_logits = self.y_linear(x2.contiguous().view(-1, x2.size(-1))).view(
|
||||
x2.size(0), 1, -1
|
||||
)
|
||||
|
||||
shape = (x1.size(0), x1.size(1), x2.size())
|
||||
scores = x1_logits.expand_as(shape) + x2_logits.expand_as(shape)
|
||||
return scores
|
||||
|
||||
|
||||
class Trilinear(nn.Module):
|
||||
"""Function used in BiDAF"""
|
||||
|
||||
def __init__(self, x1_dim, x2_dim, prefix="sim", opt={}, dropout=None):
|
||||
super(Trilinear, self).__init__()
|
||||
self.prefix = prefix
|
||||
self.x_linear = nn.Linear(x1_dim, 1, bias=False)
|
||||
self.x_dot_linear = nn.Linear(x1_dim, 1, bias=False)
|
||||
self.y_linear = nn.Linear(x2_dim, 1, bias=False)
|
||||
self.layer_norm_on = opt.get("{}_norm_on".format(self.prefix), False)
|
||||
self.init = init_wrapper(
|
||||
opt.get("{}_init".format(self.prefix), "xavier_uniform")
|
||||
)
|
||||
if self.layer_norm_on:
|
||||
self.x_linear = weight_norm(self.x_linear)
|
||||
self.x_dot_linear = weight_norm(self.x_dot_linear)
|
||||
self.y_linear = weight_norm(self.y_linear)
|
||||
|
||||
self.init(self.x_linear.weight)
|
||||
self.init(self.x_dot_linear.weight)
|
||||
self.init(self.y_linear.weight)
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, x1, x2):
|
||||
"""
|
||||
x1: batch * len1 * input_size
|
||||
x2: batch * len2 * input_size
|
||||
score: batch * len1 * len2
|
||||
"""
|
||||
if self.dropout:
|
||||
x1 = self.dropout(x1)
|
||||
x2 = self.dropout(x2)
|
||||
|
||||
x1_logits = self.x_linear(x1.contiguous().view(-1, x1.size(-1))).view(
|
||||
x1.size(0), -1, 1
|
||||
)
|
||||
x2_logits = self.y_linear(x2.contiguous().view(-1, x2.size(-1))).view(
|
||||
x2.size(0), 1, -1
|
||||
)
|
||||
x1_dot = (
|
||||
self.x_dot_linear(x1.contiguous().view(-1, x1.size(-1)))
|
||||
.view(x1.size(0), -1, 1)
|
||||
.expand_as(x1)
|
||||
)
|
||||
x1_dot = x1 * x1_dot
|
||||
|
||||
scores = x1_dot.bmm(x2.transpose(1, 2))
|
||||
scores += x1_logits.expand_as(scores) + x2_logits.expand_as(scores)
|
||||
return scores
|
||||
|
||||
|
||||
class SimilarityWrapper(nn.Module):
|
||||
def __init__(self, x1_dim, x2_dim, prefix="attention", opt={}, dropout=None):
|
||||
super(SimilarityWrapper, self).__init__()
|
||||
self.score_func_str = opt.get(
|
||||
"{}_sim_func".format(prefix), "dotproductproject"
|
||||
).lower()
|
||||
self.score_func = None
|
||||
if self.score_func_str == "dotproduct":
|
||||
self.score_func = DotProduct(
|
||||
x1_dim, x2_dim, prefix=prefix, opt=opt, dropout=dropout
|
||||
)
|
||||
elif self.score_func_str == "dotproductproject":
|
||||
self.score_func = DotProductProject(
|
||||
x1_dim, x2_dim, prefix=prefix, opt=opt, dropout=dropout
|
||||
)
|
||||
elif self.score_func_str == "bilinear":
|
||||
self.score_func = Bilinear(
|
||||
x1_dim, x2_dim, prefix=prefix, opt=opt, dropout=dropout
|
||||
)
|
||||
elif self.score_func_str == "bilinearsum":
|
||||
self.score_func = BilinearSum(
|
||||
x1_dim, x2_dim, prefix=prefix, opt=opt, dropout=dropout
|
||||
)
|
||||
elif self.score_func_str == "trilinear":
|
||||
self.score_func = Trilinear(
|
||||
x1_dim, x2_dim, prefix=prefix, opt=opt, dropout=dropout
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, x1, x2):
|
||||
scores = self.score_func(x1, x2)
|
||||
return scores
|
||||
|
||||
|
||||
class AttentionWrapper(nn.Module):
|
||||
def __init__(
|
||||
self, x1_dim, x2_dim, x3_dim=None, prefix="attention", opt={}, dropout=None
|
||||
):
|
||||
super(AttentionWrapper, self).__init__()
|
||||
self.prefix = prefix
|
||||
self.att_dropout = opt.get("{}_att_dropout".format(self.prefix), 0)
|
||||
self.score_func = SimilarityWrapper(
|
||||
x1_dim, x2_dim, prefix=prefix, opt=opt, dropout=dropout
|
||||
)
|
||||
self.drop_diagonal = opt.get("{}_drop_diagonal".format(self.prefix), False)
|
||||
self.output_size = x2_dim if x3_dim is None else x3_dim
|
||||
|
||||
def forward(self, query, key, value, key_padding_mask=None, return_scores=False):
|
||||
logits = self.score_func(query, key)
|
||||
key_mask = key_padding_mask.unsqueeze(1).expand_as(logits)
|
||||
logits.data.masked_fill_(key_mask.data, -float("inf"))
|
||||
if self.drop_diagonal:
|
||||
assert logits.size(1) == logits.size(2)
|
||||
diag_mask = (
|
||||
torch.diag(logits.data.new(logits.size(1)).zero_() + 1)
|
||||
.byte()
|
||||
.unsqueeze(0)
|
||||
.expand_as(logits)
|
||||
)
|
||||
logits.data.masked_fill_(diag_mask, -float("inf"))
|
||||
|
||||
prob = F.softmax(logits.view(-1, key.size(1)), 1)
|
||||
prob = prob.view(-1, query.size(1), key.size(1))
|
||||
if self.att_dropout > 0:
|
||||
prob = self.dropout(prob)
|
||||
|
||||
if value is None:
|
||||
value = key
|
||||
attn = prob.bmm(value)
|
||||
if return_scores:
|
||||
return attn, prob, logits
|
||||
else:
|
||||
return attn
|
||||
|
||||
|
||||
class LinearSelfAttn(nn.Module):
|
||||
"""Self attention over a sequence:
|
||||
* o_i = softmax(Wx_i) for x_i in X.
|
||||
"""
|
||||
|
||||
def __init__(self, input_size, dropout=None):
|
||||
super(LinearSelfAttn, self).__init__()
|
||||
self.linear = nn.Linear(input_size, 1)
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
x = self.dropout(x)
|
||||
x_flat = x.contiguous().view(-1, x.size(-1))
|
||||
scores = self.linear(x_flat).view(x.size(0), x.size(1))
|
||||
scores.data.masked_fill_(x_mask.data, -float("inf"))
|
||||
alpha = F.softmax(scores, 1)
|
||||
return alpha.unsqueeze(1).bmm(x).squeeze(1)
|
||||
|
||||
|
||||
class MLPSelfAttn(nn.Module):
|
||||
def __init__(self, input_size, opt={}, prefix="attn_sum", dropout=None):
|
||||
super(MLPSelfAttn, self).__init__()
|
||||
self.prefix = prefix
|
||||
self.FC = nn.Linear(input_size, input_size)
|
||||
self.linear = nn.Linear(input_size, 1)
|
||||
self.layer_norm_on = opt.get("{}_norm_on".format(self.prefix), False)
|
||||
self.f = activation(opt.get("{}_activation".format(self.prefix), "relu"))
|
||||
if dropout is None:
|
||||
self.dropout = DropoutWrapper(
|
||||
opt.get("{}_dropout_p".format(self.prefix), 0)
|
||||
)
|
||||
else:
|
||||
self.dropout = dropout
|
||||
if self.layer_norm_on:
|
||||
self.FC = weight_norm(self.FC)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
x = self.dropout(x)
|
||||
x_flat = x.contiguous().view(-1, x.size(-1))
|
||||
scores = self.linear(self.f(self.FC(x_flat))).view(x.size(0), x.size(1))
|
||||
scores.data.masked_fill_(x_mask.data, -float("inf"))
|
||||
alpha = F.softmax(scores)
|
||||
return alpha.unsqueeze(1).bmm(x).squeeze(1)
|
||||
|
||||
|
||||
class SelfAttnWrapper(nn.Module):
|
||||
def __init__(self, input_size, prefix="attn_sum", opt={}, dropout=None):
|
||||
super(SelfAttnWrapper, self).__init__()
|
||||
"""
|
||||
Self att wrapper, support linear and MLP
|
||||
"""
|
||||
attn_type = opt.get("{}_type".format(prefix), "linear")
|
||||
if attn_type == "mlp":
|
||||
self.att = MLPSelfAttn(input_size, prefix, opt, dropout)
|
||||
else:
|
||||
self.att = LinearSelfAttn(input_size, dropout)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
return self.att(x, x_mask)
|
||||
|
||||
|
||||
class DeepAttentionWrapper(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
x1_dim,
|
||||
x2_dim,
|
||||
x3_dims,
|
||||
att_cnt,
|
||||
prefix="deep_att",
|
||||
opt=None,
|
||||
dropout=None,
|
||||
):
|
||||
super(DeepAttentionWrapper, self).__init__()
|
||||
self.opt = {} if opt is None else opt
|
||||
self.prefix = prefix
|
||||
self.x1_dim = x1_dim
|
||||
self.x2_dim = x2_dim
|
||||
self.x3_dims = x3_dims
|
||||
|
||||
if dropout is None:
|
||||
self.dropout = DropoutWrapper(
|
||||
opt.get("{}_dropout_p".format(self.prefix), 0)
|
||||
)
|
||||
else:
|
||||
self.dropout = dropout
|
||||
|
||||
self.attn_list = nn.ModuleList()
|
||||
for i in range(0, att_cnt):
|
||||
if opt["multihead_on"]:
|
||||
attention = MultiheadAttentionWrapper(
|
||||
self.x1_dim,
|
||||
self.x2_dim,
|
||||
self.x3_dims[i],
|
||||
prefix,
|
||||
opt,
|
||||
dropout=dropout,
|
||||
)
|
||||
else:
|
||||
attention = AttentionWrapper(
|
||||
self.x1_dim, self.x2_dim, self.x3_dims[i], prefix, opt, self.dropout
|
||||
)
|
||||
self.attn_list.append(attention)
|
||||
|
||||
def forward(self, x1, x2, x3, x2_mask):
|
||||
rvl = []
|
||||
for i in range(0, len(x3)):
|
||||
hiddens = self.attn_list[i](x1, x2, x3[i], x2_mask)
|
||||
rvl.append(hiddens)
|
||||
|
||||
return torch.cat(rvl, 2)
|
||||
|
||||
|
||||
class BilinearFlatSim(nn.Module):
|
||||
"""A bilinear attention layer over a sequence X w.r.t y:
|
||||
* o_i = x_i'Wy for x_i in X.
|
||||
"""
|
||||
|
||||
def __init__(self, x_size, y_size, opt={}, prefix="seqatt", dropout=None):
|
||||
super(BilinearFlatSim, self).__init__()
|
||||
self.opt = opt
|
||||
self.weight_norm_on = opt.get("{}_weight_norm_on".format(prefix), False)
|
||||
self.linear = nn.Linear(y_size, x_size)
|
||||
if self.weight_norm_on:
|
||||
self.linear = weight_norm(self.linear)
|
||||
if dropout is None:
|
||||
self.dropout = DropoutWrapper(
|
||||
opt.get("{}_dropout_p".format(self.prefix), 0)
|
||||
)
|
||||
else:
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, x, y, x_mask):
|
||||
"""
|
||||
x = batch * len * h1
|
||||
y = batch * h2
|
||||
x_mask = batch * len
|
||||
"""
|
||||
x = self.dropout(x)
|
||||
y = self.dropout(y)
|
||||
|
||||
Wy = self.linear(y)
|
||||
xWy = x.bmm(Wy.unsqueeze(2)).squeeze(2)
|
||||
xWy.data.masked_fill_(x_mask.data, -float("inf"))
|
||||
return xWy
|
||||
|
||||
|
||||
class SimpleFlatSim(nn.Module):
|
||||
def __init__(self, x_size, y_size, opt={}, prefix="seqatt", dropout=None):
|
||||
super(SimpleFlatSim, self).__init__()
|
||||
self.opt = opt
|
||||
self.weight_norm_on = opt.get("{}_norm_on".format(prefix), False)
|
||||
self.linear = nn.Linear(y_size + x_size, 1)
|
||||
if self.weight_norm_on:
|
||||
self.linear = weight_norm(self.linear)
|
||||
if dropout is None:
|
||||
self.dropout = DropoutWrapper(
|
||||
opt.get("{}_dropout_p".format(self.prefix), 0)
|
||||
)
|
||||
else:
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, x, y, x_mask):
|
||||
"""
|
||||
x = batch * len * h1
|
||||
y = batch * h2
|
||||
x_mask = batch * len
|
||||
"""
|
||||
x = self.dropout(x)
|
||||
y = self.dropout(y)
|
||||
y = y.unsqueeze(1).expand_as(x)
|
||||
flat_x = torch.cat([x, y], 2).contiguous().view(x.size(0) * x.size(1), -1)
|
||||
flat_scores = self.linear(flat_x)
|
||||
scores = flat_scores.contiguous().view(x.size(0), -1)
|
||||
scores.data.masked_fill_(x_mask.data, -float("inf"))
|
||||
return scores
|
||||
|
||||
|
||||
class FlatSim(nn.Module):
|
||||
def __init__(self, x_size, y_size, opt={}, prefix="seqatt", dropout=None):
|
||||
super(FlatSim, self).__init__()
|
||||
assert x_size == y_size
|
||||
self.opt = opt
|
||||
self.weight_norm_on = opt.get("{}_weight_norm_on".format(prefix), False)
|
||||
self.linear = nn.Linear(x_size * 3, 1)
|
||||
if self.weight_norm_on:
|
||||
self.linear = weight_norm(self.linear)
|
||||
if dropout is None:
|
||||
self.dropout = DropoutWrapper(
|
||||
opt.get("{}_dropout_p".format(self.prefix), 0)
|
||||
)
|
||||
else:
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, x, y, x_mask):
|
||||
"""
|
||||
x = batch * len * h1
|
||||
y = batch * h2
|
||||
x_mask = batch * len
|
||||
"""
|
||||
x = self.dropout(x)
|
||||
y = self.dropout(y)
|
||||
y = y.unsqueeze(1).expand_as(x)
|
||||
|
||||
flat_x = (
|
||||
torch.cat([x, y, x * y], 2).contiguous().view(x.size(0) * x.size(1), -1)
|
||||
)
|
||||
flat_scores = self.linear(flat_x)
|
||||
scores = flat_scores.contiguous().view(x.size(0), -1)
|
||||
scores.data.masked_fill_(x_mask.data, -float("inf"))
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
class FlatSimV2(nn.Module):
|
||||
def __init__(self, x_size, y_size, opt={}, prefix="seqatt", dropout=None):
|
||||
super(FlatSimV2, self).__init__()
|
||||
assert x_size == y_size
|
||||
self.opt = opt
|
||||
self.weight_norm_on = opt.get("{}_weight_norm_on".format(prefix), False)
|
||||
self.linear = nn.Linear(x_size * 4, 1)
|
||||
if self.weight_norm_on:
|
||||
self.linear = weight_norm(self.linear)
|
||||
if dropout is None:
|
||||
self.dropout = DropoutWrapper(
|
||||
opt.get("{}_dropout_p".format(self.prefix), 0)
|
||||
)
|
||||
else:
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, x, y, x_mask):
|
||||
"""
|
||||
x = batch * len * h1
|
||||
y = batch * h2
|
||||
x_mask = batch * len
|
||||
"""
|
||||
x = self.dropout(x)
|
||||
y = self.dropout(y)
|
||||
y = y.unsqueeze(1).expand_as(x)
|
||||
|
||||
flat_x = (
|
||||
torch.cat([x, y, x * y, torch.abs(x - y)], 2)
|
||||
.contiguous()
|
||||
.view(x.size(0) * x.size(1), -1)
|
||||
)
|
||||
flat_scores = self.linear(flat_x)
|
||||
scores = flat_scores.contiguous().view(x.size(0), -1)
|
||||
scores.data.masked_fill_(x_mask.data, -float("inf"))
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
class FlatSimilarityWrapper(nn.Module):
|
||||
def __init__(self, x1_dim, x2_dim, prefix="attention", opt={}, dropout=None):
|
||||
super(FlatSimilarityWrapper, self).__init__()
|
||||
self.score_func_str = opt.get("{}_att_type".format(prefix), "none").lower()
|
||||
self.att_dropout = DropoutWrapper(opt.get("{}_att_dropout".format(prefix), 0))
|
||||
self.score_func = None
|
||||
if self.score_func_str == "bilinear":
|
||||
self.score_func = BilinearFlatSim(
|
||||
x1_dim, x2_dim, prefix=prefix, opt=opt, dropout=dropout
|
||||
)
|
||||
elif self.score_func_str == "simple":
|
||||
self.score_func = SimpleFlatSim(
|
||||
x1_dim, x2_dim, prefix=prefix, opt=opt, dropout=dropout
|
||||
)
|
||||
elif self.score_func_str == "flatsim":
|
||||
self.score_func = FlatSim(
|
||||
x1_dim, x2_dim, prefix=prefix, opt=opt, dropout=dropout
|
||||
)
|
||||
else:
|
||||
self.score_func = FlatSimV2(
|
||||
x1_dim, x2_dim, prefix=prefix, opt=opt, dropout=dropout
|
||||
)
|
||||
|
||||
def forward(self, x1, x2, mask):
|
||||
scores = self.score_func(x1, x2, mask)
|
||||
return scores
|
||||
|
||||
|
||||
class MultiheadAttentionWrapper(nn.Module):
|
||||
"""Multi-headed attention.
|
||||
See "Attention Is All You Need" for more details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, query_dim, key_dim, value_dim, prefix="attention", opt={}, dropout=None
|
||||
):
|
||||
super().__init__()
|
||||
self.prefix = prefix
|
||||
|
||||
self.num_heads = opt.get("{}_head".format(self.prefix), 1)
|
||||
self.dropout = (
|
||||
DropoutWrapper(opt.get("{}_dropout".format(self.prefix), 0))
|
||||
if dropout is None
|
||||
else dropout
|
||||
)
|
||||
|
||||
self.qkv_dim = [query_dim, key_dim, value_dim]
|
||||
assert query_dim == key_dim, "query dim must equal with key dim"
|
||||
|
||||
self.hidden_size = opt.get("{}_hidden_size".format(self.prefix), 64)
|
||||
|
||||
self.proj_on = opt.get("{}_proj_on".format(prefix), False)
|
||||
self.share = opt.get("{}_share".format(self.prefix), False)
|
||||
self.layer_norm_on = opt.get("{}_norm_on".format(self.prefix), False)
|
||||
self.scale_on = opt.get("{}_scale_on".format(self.prefix), False)
|
||||
|
||||
if self.proj_on:
|
||||
self.proj_modules = nn.ModuleList(
|
||||
[nn.Linear(dim, self.hidden_size) for dim in self.qkv_dim[0:2]]
|
||||
)
|
||||
if self.layer_norm_on:
|
||||
for proj in self.proj_modules:
|
||||
proj = weight_norm(proj)
|
||||
if self.share and self.qkv_dim[0] == self.qkv_dim[1]:
|
||||
self.proj_modules[1] = self.proj_modules[0]
|
||||
self.f = activation(opt.get("{}_activation".format(self.prefix), "relu"))
|
||||
|
||||
self.qkv_head_dim = [self.hidden_size // self.num_heads] * 3
|
||||
self.qkv_head_dim[2] = value_dim // self.num_heads
|
||||
assert (
|
||||
self.qkv_head_dim[0] * self.num_heads == self.hidden_size
|
||||
), "hidden size must be divisible by num_heads"
|
||||
assert (
|
||||
self.qkv_head_dim[2] * self.num_heads == value_dim
|
||||
), "value size must be divisible by num_heads"
|
||||
|
||||
else:
|
||||
self.qkv_head_dim = [emb // self.num_heads for emb in self.qkv_dim]
|
||||
# import pdb; pdb.set_trace()
|
||||
assert (
|
||||
self.qkv_head_dim[0] * self.num_heads == self.qkv_dim[0]
|
||||
), "query size must be divisible by num_heads"
|
||||
assert (
|
||||
self.qkv_head_dim[1] * self.num_heads == self.qkv_dim[1]
|
||||
), "key size must be divisible by num_heads"
|
||||
assert (
|
||||
self.qkv_head_dim[2] * self.num_heads == self.qkv_dim[2]
|
||||
), "value size must be divisible by num_heads"
|
||||
|
||||
if self.scale_on:
|
||||
self.scaling = self.qkv_head_dim[0] ** -0.5
|
||||
self.drop_diagonal = opt.get("{}_drop_diagonal".format(self.prefix), False)
|
||||
self.output_size = self.qkv_dim[2]
|
||||
|
||||
def forward(self, query, key, value, key_padding_mask=None):
|
||||
query = query.transpose(0, 1)
|
||||
key = key.transpose(0, 1)
|
||||
value = value.transpose(0, 1)
|
||||
|
||||
tgt_len, bsz, embed_dim = query.size()
|
||||
assert embed_dim == self.qkv_dim[0]
|
||||
|
||||
q, k, v = query, key, value
|
||||
if self.proj_on:
|
||||
if self.dropout:
|
||||
q, k = self.dropout(q), self.dropout(k)
|
||||
q, k = [
|
||||
self.f(proj(input))
|
||||
for input, proj in zip([query, key], self.proj_modules)
|
||||
]
|
||||
|
||||
src_len = k.size(0)
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.size(0) == bsz
|
||||
assert key_padding_mask.size(1) == src_len
|
||||
|
||||
if self.scale_on:
|
||||
q *= self.scaling
|
||||
|
||||
q = (
|
||||
q.contiguous()
|
||||
.view(tgt_len, bsz * self.num_heads, self.qkv_head_dim[0])
|
||||
.transpose(0, 1)
|
||||
)
|
||||
k = (
|
||||
k.contiguous()
|
||||
.view(src_len, bsz * self.num_heads, self.qkv_head_dim[1])
|
||||
.transpose(0, 1)
|
||||
)
|
||||
v = (
|
||||
v.contiguous()
|
||||
.view(src_len, bsz * self.num_heads, self.qkv_head_dim[2])
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
||||
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
||||
|
||||
if key_padding_mask is not None:
|
||||
# don't attend to padding symbols
|
||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_weights = (
|
||||
attn_weights.float()
|
||||
.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"),)
|
||||
.type_as(attn_weights)
|
||||
) # FP16 support: cast to float and back
|
||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
if self.drop_diagonal:
|
||||
assert attn_weights.size(1) == attn_weights.size(2)
|
||||
diag_mask = (
|
||||
torch.diag(attn_weights.data.new(attn_weights.size(1)).zero_() + 1)
|
||||
.byte()
|
||||
.unsqueeze(0)
|
||||
.expand_as(attn_weights)
|
||||
)
|
||||
attn_weights.data.masked_fill_(diag_mask, -float("inf"))
|
||||
|
||||
attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights)
|
||||
attn_weights = self.dropout(attn_weights)
|
||||
|
||||
attn = torch.bmm(attn_weights, v)
|
||||
assert list(attn.size()) == [
|
||||
bsz * self.num_heads,
|
||||
tgt_len,
|
||||
self.qkv_head_dim[2],
|
||||
]
|
||||
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, -1)
|
||||
|
||||
# output_shape: Batch * Time * Channel
|
||||
attn = attn.transpose(0, 1)
|
||||
|
||||
return attn
|
|
@ -0,0 +1,112 @@
|
|||
""" Official evaluation script for v1.1 of the SQuAD dataset.
|
||||
Credit from: https://worksheets.codalab.org/rest/bundles/0xbcd57bee090b421c982906709c8c27e1/contents/blob/
|
||||
"""
|
||||
from __future__ import print_function
|
||||
from collections import Counter
|
||||
import string
|
||||
import re
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
|
||||
def normalize_answer(s):
|
||||
"""Lower text and remove punctuation, articles and extra whitespace."""
|
||||
def remove_articles(text):
|
||||
return re.sub(r'\b(a|an|the)\b', ' ', text)
|
||||
|
||||
def white_space_fix(text):
|
||||
return ' '.join(text.split())
|
||||
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return ''.join(ch for ch in text if ch not in exclude)
|
||||
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
|
||||
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||
|
||||
|
||||
def f1_score(prediction, ground_truth):
|
||||
prediction_tokens = normalize_answer(prediction).split()
|
||||
ground_truth_tokens = normalize_answer(ground_truth).split()
|
||||
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
|
||||
num_same = sum(common.values())
|
||||
if num_same == 0:
|
||||
return 0
|
||||
precision = 1.0 * num_same / len(prediction_tokens)
|
||||
recall = 1.0 * num_same / len(ground_truth_tokens)
|
||||
f1 = (2 * precision * recall) / (precision + recall)
|
||||
return f1
|
||||
|
||||
|
||||
def exact_match_score(prediction, ground_truth):
|
||||
return (normalize_answer(prediction) == normalize_answer(ground_truth))
|
||||
|
||||
|
||||
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
|
||||
scores_for_ground_truths = []
|
||||
for ground_truth in ground_truths:
|
||||
score = metric_fn(prediction, ground_truth)
|
||||
scores_for_ground_truths.append(score)
|
||||
return max(scores_for_ground_truths)
|
||||
|
||||
def evaluate(dataset, predictions):
|
||||
f1 = exact_match = total = 0
|
||||
for article in dataset:
|
||||
for paragraph in article['paragraphs']:
|
||||
for qa in paragraph['qas']:
|
||||
total += 1
|
||||
if qa['id'] not in predictions:
|
||||
message = 'Unanswered question ' + qa['id'] + \
|
||||
' will receive score 0.'
|
||||
print(message, file=sys.stderr)
|
||||
continue
|
||||
ground_truths = list(map(lambda x: x['text'], qa['answers']))
|
||||
prediction = predictions[qa['id']]
|
||||
exact_match += metric_max_over_ground_truths(
|
||||
exact_match_score, prediction, ground_truths)
|
||||
f1 += metric_max_over_ground_truths(
|
||||
f1_score, prediction, ground_truths)
|
||||
|
||||
exact_match = 100.0 * exact_match / total
|
||||
f1 = 100.0 * f1 / total
|
||||
return {'exact_match': exact_match, 'f1': f1}
|
||||
|
||||
def evaluate_func(human, predictions):
|
||||
f1 = exact_match = total = 0
|
||||
for uid, ground_truths in human.items():
|
||||
total += 1
|
||||
if uid not in predictions:
|
||||
message = 'Unanswered question ' + uid + \
|
||||
' will receive score 0.'
|
||||
print(message, file=sys.stderr)
|
||||
continue
|
||||
prediction = predictions[uid]
|
||||
exact_match += metric_max_over_ground_truths(
|
||||
exact_match_score, prediction, ground_truths)
|
||||
f1 += metric_max_over_ground_truths(
|
||||
f1_score, prediction, ground_truths)
|
||||
|
||||
exact_match = 100.0 * exact_match / total
|
||||
f1 = 100.0 * f1 / total
|
||||
return str({'exact_match': exact_match, 'f1': f1})
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
expected_version = '1.1'
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Evaluation for SQuAD ' + expected_version)
|
||||
parser.add_argument('dataset_file', help='Dataset file')
|
||||
parser.add_argument('prediction_file', help='Prediction File')
|
||||
args = parser.parse_args()
|
||||
with open(args.dataset_file) as dataset_file:
|
||||
dataset_json = json.load(dataset_file)
|
||||
if (dataset_json['version'] != expected_version):
|
||||
print('Evaluation expects v-' + expected_version +
|
||||
', but got dataset with v-' + dataset_json['version'],
|
||||
file=sys.stderr)
|
||||
dataset = dataset_json['data']
|
||||
with open(args.prediction_file) as prediction_file:
|
||||
predictions = json.load(prediction_file)
|
||||
print(json.dumps(evaluate(dataset, predictions)))
|
|
@ -0,0 +1,104 @@
|
|||
import os
|
||||
import argparse
|
||||
from sys import path
|
||||
import json
|
||||
path.append(os.getcwd())
|
||||
from data_utils.log_wrapper import create_logger
|
||||
from experiments.common_utils import dump_rows
|
||||
from data_utils import DataFormat
|
||||
|
||||
logger = create_logger(__name__, to_disk=True, log_file='squad_prepro.log')
|
||||
|
||||
def normalize_qa_field(s: str, replacement_list):
|
||||
for replacement in replacement_list:
|
||||
s = s.replace(replacement, " " * len(replacement)) # ensure answer_start and answer_end still valid
|
||||
return s
|
||||
|
||||
#END = 'EOSEOS'
|
||||
def load_data(path, is_train=True, v2_on=False):
|
||||
rows = []
|
||||
with open(path, encoding="utf8") as f:
|
||||
data = json.load(f)['data']
|
||||
for article in data:
|
||||
for paragraph in article['paragraphs']:
|
||||
context = paragraph['context']
|
||||
if v2_on:
|
||||
context = '{} {}'.format(context, END)
|
||||
for qa in paragraph['qas']:
|
||||
uid, question = qa['id'], qa['question']
|
||||
answers = qa.get('answers', [])
|
||||
# used for v2.0
|
||||
is_impossible = qa.get('is_impossible', False)
|
||||
label = 1 if is_impossible else 0
|
||||
if (v2_on and label < 1 and len(answers) < 1) or ((not v2_on) and len(answers) < 1):
|
||||
# detect inconsistent data
|
||||
# * for v2, the row is possible but has no answer
|
||||
# * for v1, all questions should have answer
|
||||
continue
|
||||
if len(answers) > 0:
|
||||
answer = answers[0]['text']
|
||||
answer_start = answers[0]['answer_start']
|
||||
answer_end = answer_start + len(answer)
|
||||
else:
|
||||
# for questions without answers, give a fake answer
|
||||
#answer = END
|
||||
#answer_start = len(context) - len(END)
|
||||
#answer_end = len(context)
|
||||
answer = ''
|
||||
answer_start = -1
|
||||
answer_end = -1
|
||||
answer = normalize_qa_field(answer, ["\n", "\t", ":::"])
|
||||
context = normalize_qa_field(context, ["\n", "\t"])
|
||||
question = normalize_qa_field(question, ["\n", "\t"])
|
||||
sample = {'uid': uid, 'premise': context, 'hypothesis': question,
|
||||
'label': "%s:::%s:::%s:::%s" % (answer_start, answer_end, label, answer)}
|
||||
rows.append(sample)
|
||||
return rows
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description='Preprocessing SQUAD data.')
|
||||
parser.add_argument('--root_dir', type=str, default='data')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def main(args):
|
||||
root = args.root_dir
|
||||
assert os.path.exists(root)
|
||||
|
||||
squad_train_path = os.path.join(root, 'squad/train.json')
|
||||
squad_dev_path = os.path.join(root, 'squad/dev.json')
|
||||
squad_v2_train_path = os.path.join(root, 'squad_v2/train.json')
|
||||
squad_v2_dev_path = os.path.join(root, 'squad_v2/dev.json')
|
||||
|
||||
squad_train_data = load_data(squad_train_path)
|
||||
squad_dev_data = load_data(squad_dev_path, is_train=False)
|
||||
logger.info('Loaded {} squad train samples'.format(len(squad_train_data)))
|
||||
logger.info('Loaded {} squad dev samples'.format(len(squad_dev_data)))
|
||||
|
||||
squad_v2_train_data = load_data(squad_v2_train_path, v2_on=True)
|
||||
squad_v2_dev_data = load_data(squad_v2_dev_path, is_train=False, v2_on=True)
|
||||
logger.info('Loaded {} squad_v2 train samples'.format(len(squad_v2_train_data)))
|
||||
logger.info('Loaded {} squad_v2 dev samples'.format(len(squad_v2_dev_data)))
|
||||
|
||||
canonical_data_suffix = "canonical_data"
|
||||
canonical_data_root = os.path.join(root, canonical_data_suffix)
|
||||
if not os.path.isdir(canonical_data_root):
|
||||
os.mkdir(canonical_data_root)
|
||||
|
||||
squad_train_fout = os.path.join(canonical_data_root, 'squad_train.tsv')
|
||||
squad_dev_fout = os.path.join(canonical_data_root, 'squad_dev.tsv')
|
||||
dump_rows(squad_train_data, squad_train_fout, DataFormat.PremiseAndOneHypothesis)
|
||||
dump_rows(squad_dev_data, squad_dev_fout, DataFormat.PremiseAndOneHypothesis)
|
||||
logger.info('done with squad')
|
||||
|
||||
squad_v2_train_fout = os.path.join(canonical_data_root, 'squad-v2_train.tsv')
|
||||
squad_v2_dev_fout = os.path.join(canonical_data_root, 'squad-v2_dev.tsv')
|
||||
dump_rows(squad_v2_train_data, squad_v2_train_fout, DataFormat.PremiseAndOneHypothesis)
|
||||
dump_rows(squad_v2_dev_data, squad_v2_dev_fout, DataFormat.PremiseAndOneHypothesis)
|
||||
logger.info('done with squad_v2')
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
main(args)
|
|
@ -0,0 +1,620 @@
|
|||
import collections
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import string
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||
|
||||
from mtdnn.common.types import EncoderModelType
|
||||
|
||||
LARGE_NEG_NUM = -1.0e5
|
||||
tokenizer = None
|
||||
|
||||
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return "".join(ch for ch in text if ch not in exclude)
|
||||
|
||||
|
||||
def calc_tokenized_span_range(
|
||||
context,
|
||||
question,
|
||||
answer,
|
||||
answer_start,
|
||||
answer_end,
|
||||
tokenizer,
|
||||
encoderModelType,
|
||||
verbose=False,
|
||||
):
|
||||
"""
|
||||
:param context:
|
||||
:param question:
|
||||
:param answer:
|
||||
:param answer_start:
|
||||
:param answer_end:
|
||||
:param tokenizer:
|
||||
:param encoderModelType:
|
||||
:param verbose:
|
||||
:return: span_start, span_end
|
||||
"""
|
||||
assert encoderModelType == EncoderModelType.BERT
|
||||
prefix = context[:answer_start]
|
||||
prefix_tokens = tokenizer.tokenize(prefix)
|
||||
full = context[:answer_end]
|
||||
full_tokens = tokenizer.tokenize(full)
|
||||
span_start = len(prefix_tokens)
|
||||
span_end = len(full_tokens)
|
||||
span_tokens = full_tokens[span_start:span_end]
|
||||
recovered_answer = " ".join(span_tokens).replace(" ##", "")
|
||||
cleaned_answer = " ".join(tokenizer.basic_tokenizer.tokenize(answer))
|
||||
if verbose:
|
||||
try:
|
||||
assert recovered_answer == cleaned_answer, (
|
||||
"answer: %s, recovered_answer: %s, question: %s, select:%s ext_select:%s context: %s"
|
||||
% (
|
||||
cleaned_answer,
|
||||
recovered_answer,
|
||||
question,
|
||||
context[answer_start:answer_end],
|
||||
context[answer_start - 5 : answer_end + 5],
|
||||
context,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
print(e)
|
||||
return span_start, span_end
|
||||
|
||||
|
||||
def is_valid_sample(context, answer_start, answer_end, answer):
|
||||
valid = True
|
||||
constructed = context[answer_start:answer_end]
|
||||
if constructed.lower() != answer.lower():
|
||||
valid = False
|
||||
return valid
|
||||
# check if it is inside of a token
|
||||
if answer_start > 0 and answer_end < len(context) - 1:
|
||||
prefix = context[answer_start - 1 : answer_start]
|
||||
suffix = context[answer_end : answer_end + 1]
|
||||
if len(remove_punc(prefix)) > 0 or len(remove_punc(suffix)):
|
||||
valid = False
|
||||
return valid
|
||||
|
||||
|
||||
def is_whitespace(c):
|
||||
if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def parse_squad_label(label):
|
||||
"""
|
||||
:param label:
|
||||
:return: answer_start, answer_end, answer, is_impossible
|
||||
"""
|
||||
answer_start, answer_end, is_impossible, answer = label.split(":::")
|
||||
answer_start = int(answer_start)
|
||||
answer_end = int(answer_end)
|
||||
is_impossible = int(is_impossible)
|
||||
return answer_start, answer_end, answer, is_impossible
|
||||
|
||||
|
||||
def _improve_answer_span(
|
||||
doc_tokens, input_start, input_end, tokenizer, orig_answer_text
|
||||
):
|
||||
"""Returns tokenized answer spans that better match the annotated answer."""
|
||||
# It is copyed from: https://github.com/google-research/bert/blob/master/run_squad.py
|
||||
# The SQuAD annotations are character based. We first project them to
|
||||
# whitespace-tokenized words. But then after WordPiece tokenization, we can
|
||||
# often find a "better match". For example:
|
||||
#
|
||||
# Question: What year was John Smith born?
|
||||
# Context: The leader was John Smith (1895-1943).
|
||||
# Answer: 1895
|
||||
#
|
||||
# The original whitespace-tokenized answer will be "(1895-1943).". However
|
||||
# after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match
|
||||
# the exact answer, 1895.
|
||||
#
|
||||
# However, this is not always possible. Consider the following:
|
||||
#
|
||||
# Question: What country is the top exporter of electornics?
|
||||
# Context: The Japanese electronics industry is the lagest in the world.
|
||||
# Answer: Japan
|
||||
#
|
||||
# In this case, the annotator chose "Japan" as a character sub-span of
|
||||
# the word "Japanese". Since our WordPiece tokenizer does not split
|
||||
# "Japanese", we just use "Japanese" as the annotation. This is fairly rare
|
||||
# in SQuAD, but does happen.
|
||||
tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))
|
||||
|
||||
for new_start in range(input_start, input_end + 1):
|
||||
for new_end in range(input_end, new_start - 1, -1):
|
||||
text_span = " ".join(doc_tokens[new_start : (new_end + 1)])
|
||||
if text_span == tok_answer_text:
|
||||
return (new_start, new_end)
|
||||
|
||||
return (input_start, input_end)
|
||||
|
||||
|
||||
def _check_is_max_context(doc_spans, cur_span_index, position):
|
||||
"""Check if this is the 'max context' doc span for the token."""
|
||||
# It is copyed from: https://github.com/google-research/bert/blob/master/run_squad.py
|
||||
# Because of the sliding window approach taken to scoring documents, a single
|
||||
# token can appear in multiple documents. E.g.
|
||||
# Doc: the man went to the store and bought a gallon of milk
|
||||
# Span A: the man went to the
|
||||
# Span B: to the store and bought
|
||||
# Span C: and bought a gallon of
|
||||
# ...
|
||||
#
|
||||
# Now the word 'bought' will have two scores from spans B and C. We only
|
||||
# want to consider the score with "maximum context", which we define as
|
||||
# the *minimum* of its left and right context (the *sum* of left and
|
||||
# right context will always be the same, of course).
|
||||
#
|
||||
# In the example the maximum context for 'bought' would be span C since
|
||||
# it has 1 left context and 3 right context, while span B has 4 left context
|
||||
# and 0 right context.
|
||||
best_score = None
|
||||
best_span_index = None
|
||||
for (span_index, doc_span) in enumerate(doc_spans):
|
||||
end = doc_span.start + doc_span.length - 1
|
||||
if position < doc_span.start:
|
||||
continue
|
||||
if position > end:
|
||||
continue
|
||||
num_left_context = position - doc_span.start
|
||||
num_right_context = end - position
|
||||
score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
|
||||
if best_score is None or score > best_score:
|
||||
best_score = score
|
||||
best_span_index = span_index
|
||||
|
||||
return cur_span_index == best_span_index
|
||||
|
||||
|
||||
def doc_split(doc_subwords, doc_stride=180, max_tokens_for_doc=384):
|
||||
_DocSpan = collections.namedtuple("DocSpan", ["start", "length"])
|
||||
doc_spans = []
|
||||
start_offset = 0
|
||||
while start_offset < len(doc_subwords):
|
||||
length = len(doc_subwords) - start_offset
|
||||
if length > max_tokens_for_doc:
|
||||
length = max_tokens_for_doc
|
||||
doc_spans.append(_DocSpan(start=start_offset, length=length))
|
||||
if start_offset + length == len(doc_subwords):
|
||||
break
|
||||
start_offset += min(length, doc_stride)
|
||||
return doc_spans
|
||||
|
||||
|
||||
def recompute_span(answer, answer_offset, char_to_word_offset):
|
||||
answer_length = len(answer)
|
||||
start_position = char_to_word_offset[answer_offset]
|
||||
end_position = char_to_word_offset[answer_offset + answer_length - 1]
|
||||
return start_position, end_position
|
||||
|
||||
|
||||
def is_valid_answer(context, answer_start, answer_end, answer):
|
||||
valid = True
|
||||
constructed = " ".join(context[answer_start : answer_end + 1]).lower()
|
||||
cleaned_answer_text = " ".join(answer.split()).lower()
|
||||
if constructed.find(cleaned_answer_text) == -1:
|
||||
valid = False
|
||||
return valid
|
||||
|
||||
|
||||
def token_doc(paragraph_text):
|
||||
doc_tokens = []
|
||||
char_to_word_offset = []
|
||||
prev_is_whitespace = True
|
||||
for c in paragraph_text:
|
||||
if is_whitespace(c):
|
||||
prev_is_whitespace = True
|
||||
else:
|
||||
if prev_is_whitespace:
|
||||
doc_tokens.append(c)
|
||||
else:
|
||||
doc_tokens[-1] += c
|
||||
prev_is_whitespace = False
|
||||
char_to_word_offset.append(len(doc_tokens) - 1)
|
||||
return doc_tokens, char_to_word_offset
|
||||
|
||||
|
||||
class InputFeatures(object):
|
||||
def __init__(
|
||||
self,
|
||||
unique_id,
|
||||
example_index,
|
||||
doc_span_index,
|
||||
tokens,
|
||||
token_to_orig_map,
|
||||
token_is_max_context,
|
||||
input_ids,
|
||||
input_mask,
|
||||
segment_ids,
|
||||
start_position=None,
|
||||
end_position=None,
|
||||
is_impossible=None,
|
||||
doc_offset=0,
|
||||
):
|
||||
self.unique_id = unique_id
|
||||
self.example_index = example_index
|
||||
self.doc_span_index = doc_span_index
|
||||
self.tokens = tokens
|
||||
self.token_to_orig_map = token_to_orig_map
|
||||
self.token_is_max_context = token_is_max_context
|
||||
self.input_ids = input_ids
|
||||
self.input_mask = input_mask
|
||||
self.segment_ids = segment_ids
|
||||
self.start_position = start_position
|
||||
self.end_position = end_position
|
||||
self.is_impossible = is_impossible
|
||||
self.doc_offset = doc_offset
|
||||
|
||||
def __str__(self):
|
||||
return json.dumps(
|
||||
{
|
||||
"unique_id": self.unique_id,
|
||||
"example_index": self.example_index,
|
||||
"doc_span_index": self.doc_span_index,
|
||||
"tokens": self.tokens,
|
||||
"token_to_orig_map": self.token_to_orig_map,
|
||||
"token_is_max_context": self.token_is_max_context,
|
||||
"input_ids": self.input_ids,
|
||||
"input_mask": self.input_mask,
|
||||
"segment_ids": self.segment_ids,
|
||||
"start_position": self.start_position,
|
||||
"end_position": self.end_position,
|
||||
"is_impossible": self.is_impossible,
|
||||
"doc_offset": self.doc_offset,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def mrc_feature(
|
||||
tokenizer,
|
||||
unique_id,
|
||||
example_index,
|
||||
query,
|
||||
doc_tokens,
|
||||
answer_start_adjusted,
|
||||
answer_end_adjusted,
|
||||
is_impossible,
|
||||
max_seq_len,
|
||||
max_query_len,
|
||||
doc_stride,
|
||||
answer_text=None,
|
||||
is_training=True,
|
||||
):
|
||||
tok_to_orig_index = []
|
||||
orig_to_tok_index = []
|
||||
all_doc_tokens = []
|
||||
query_ids = tokenizer.tokenize(query)
|
||||
query_ids = (
|
||||
query_ids[0:max_query_len] if len(query_ids) > max_query_len else query_ids
|
||||
)
|
||||
max_tokens_for_doc = max_seq_len - len(query_ids) - 3
|
||||
unique_id_cp = unique_id
|
||||
for (i, token) in enumerate(doc_tokens):
|
||||
orig_to_tok_index.append(len(all_doc_tokens))
|
||||
sub_tokens = tokenizer.tokenize(token)
|
||||
for sub_token in sub_tokens:
|
||||
tok_to_orig_index.append(i)
|
||||
all_doc_tokens.append(sub_token)
|
||||
tok_start_position = None
|
||||
tok_end_position = None
|
||||
if is_training and is_impossible:
|
||||
tok_start_position = -1
|
||||
tok_end_position = -1
|
||||
if is_training and not is_impossible:
|
||||
tok_start_position = orig_to_tok_index[answer_start_adjusted]
|
||||
if answer_end_adjusted < len(doc_tokens) - 1:
|
||||
tok_end_position = orig_to_tok_index[answer_end_adjusted + 1] - 1
|
||||
else:
|
||||
tok_end_position = len(all_doc_tokens) - 1
|
||||
(tok_start_position, tok_end_position) = _improve_answer_span(
|
||||
all_doc_tokens, tok_start_position, tok_end_position, tokenizer, answer_text
|
||||
)
|
||||
|
||||
doc_spans = doc_split(
|
||||
all_doc_tokens, doc_stride=doc_stride, max_tokens_for_doc=max_tokens_for_doc
|
||||
)
|
||||
feature_list = []
|
||||
for (doc_span_index, doc_span) in enumerate(doc_spans):
|
||||
tokens = ["[CLS]"] + query_ids + ["[SEP]"]
|
||||
token_to_orig_map = {}
|
||||
token_is_max_context = {}
|
||||
segment_ids = [0 for i in range(len(tokens))]
|
||||
|
||||
for i in range(doc_span.length):
|
||||
split_token_index = doc_span.start + i
|
||||
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
|
||||
|
||||
is_max_context = _check_is_max_context(
|
||||
doc_spans, doc_span_index, split_token_index
|
||||
)
|
||||
token_is_max_context[len(tokens)] = is_max_context
|
||||
tokens.append(all_doc_tokens[split_token_index])
|
||||
segment_ids.append(1)
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(1)
|
||||
|
||||
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
|
||||
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
||||
# tokens are attended to.
|
||||
input_mask = [1] * len(input_ids)
|
||||
doc_offset = len(query_ids) + 2
|
||||
|
||||
start_position = None
|
||||
end_position = None
|
||||
if is_training and not is_impossible:
|
||||
# For training, if our document chunk does not contain an annotation
|
||||
# we throw it out, since there is nothing to predict.
|
||||
doc_start = doc_span.start
|
||||
doc_end = doc_span.start + doc_span.length - 1
|
||||
out_of_span = False
|
||||
if not (tok_start_position >= doc_start and tok_end_position <= doc_end):
|
||||
out_of_span = True
|
||||
if out_of_span:
|
||||
start_position = 0
|
||||
end_position = 0
|
||||
else:
|
||||
# doc_offset = len(query_ids) + 2
|
||||
start_position = tok_start_position - doc_start + doc_offset
|
||||
end_position = tok_end_position - doc_start + doc_offset
|
||||
|
||||
if is_training and is_impossible:
|
||||
start_position = 0
|
||||
end_position = 0
|
||||
is_impossible = True if is_impossible else False
|
||||
feature = InputFeatures(
|
||||
unique_id=unique_id_cp,
|
||||
example_index=example_index,
|
||||
doc_span_index=doc_span_index,
|
||||
tokens=tokens,
|
||||
token_to_orig_map=token_to_orig_map,
|
||||
token_is_max_context=token_is_max_context,
|
||||
input_ids=input_ids,
|
||||
input_mask=input_mask,
|
||||
segment_ids=segment_ids,
|
||||
start_position=start_position,
|
||||
end_position=end_position,
|
||||
is_impossible=is_impossible,
|
||||
doc_offset=doc_offset,
|
||||
)
|
||||
feature_list.append(feature)
|
||||
unique_id_cp += 1
|
||||
return feature_list
|
||||
|
||||
|
||||
def gen_gold_name(dir, path, version, suffix="json"):
|
||||
fname = "{}-{}.{}".format(path, version, suffix)
|
||||
return os.path.join(dir, fname)
|
||||
|
||||
|
||||
def load_squad_label(path):
|
||||
rows = {}
|
||||
with open(path, encoding="utf8") as f:
|
||||
data = json.load(f)["data"]
|
||||
for article in tqdm.tqdm(data, total=len(data)):
|
||||
for paragraph in article["paragraphs"]:
|
||||
for qa in paragraph["qas"]:
|
||||
uid, question = qa["id"], qa["question"]
|
||||
is_impossible = qa.get("is_impossible", False)
|
||||
label = 1 if is_impossible else 0
|
||||
rows[uid] = label
|
||||
return rows
|
||||
|
||||
|
||||
def position_encoding(m, threshold=4):
|
||||
encoding = np.ones((m, m), dtype=np.float32)
|
||||
for i in range(m):
|
||||
for j in range(i, m):
|
||||
if j - i > threshold:
|
||||
encoding[i][j] = float(1.0 / math.log(j - i + 1))
|
||||
return torch.from_numpy(encoding)
|
||||
|
||||
|
||||
def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
|
||||
"""Project the tokenized prediction back to the original text."""
|
||||
|
||||
# When we created the data, we kept track of the alignment between original
|
||||
# (whitespace tokenized) tokens and our WordPiece tokenized tokens. So
|
||||
# now `orig_text` contains the span of our original text corresponding to the
|
||||
# span that we predicted.
|
||||
#
|
||||
# However, `orig_text` may contain extra characters that we don't want in
|
||||
# our prediction.
|
||||
#
|
||||
# For example, let's say:
|
||||
# pred_text = steve smith
|
||||
# orig_text = Steve Smith's
|
||||
#
|
||||
# We don't want to return `orig_text` because it contains the extra "'s".
|
||||
#
|
||||
# We don't want to return `pred_text` because it's already been normalized
|
||||
# (the SQuAD eval script also does punctuation stripping/lower casing but
|
||||
# our tokenizer does additional normalization like stripping accent
|
||||
# characters).
|
||||
#
|
||||
# What we really want to return is "Steve Smith".
|
||||
#
|
||||
# Therefore, we have to apply a semi-complicated alignment heruistic between
|
||||
# `pred_text` and `orig_text` to get a character-to-charcter alignment. This
|
||||
# can fail in certain cases in which case we just return `orig_text`.
|
||||
|
||||
def _strip_spaces(text):
|
||||
ns_chars = []
|
||||
ns_to_s_map = collections.OrderedDict()
|
||||
for (i, c) in enumerate(text):
|
||||
if c == " ":
|
||||
continue
|
||||
ns_to_s_map[len(ns_chars)] = i
|
||||
ns_chars.append(c)
|
||||
ns_text = "".join(ns_chars)
|
||||
return (ns_text, ns_to_s_map)
|
||||
|
||||
# We first tokenize `orig_text`, strip whitespace from the result
|
||||
# and `pred_text`, and check if they are the same length. If they are
|
||||
# NOT the same length, the heuristic has failed. If they are the same
|
||||
# length, we assume the characters are one-to-one aligned.
|
||||
# tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case)
|
||||
global tokenizer
|
||||
if tokenizer is None:
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
tok_text = " ".join(tokenizer.tokenize(orig_text))
|
||||
|
||||
start_position = tok_text.find(pred_text)
|
||||
if start_position == -1:
|
||||
return orig_text
|
||||
end_position = start_position + len(pred_text) - 1
|
||||
|
||||
(orig_ns_text, orig_ns_to_s_map) = _strip_spaces(orig_text)
|
||||
(tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
|
||||
|
||||
if len(orig_ns_text) != len(tok_ns_text):
|
||||
return orig_text
|
||||
|
||||
# We then project the characters in `pred_text` back to `orig_text` using
|
||||
# the character-to-character alignment.
|
||||
tok_s_to_ns_map = {}
|
||||
for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
|
||||
tok_s_to_ns_map[tok_index] = i
|
||||
|
||||
orig_start_position = None
|
||||
if start_position in tok_s_to_ns_map:
|
||||
ns_start_position = tok_s_to_ns_map[start_position]
|
||||
if ns_start_position in orig_ns_to_s_map:
|
||||
orig_start_position = orig_ns_to_s_map[ns_start_position]
|
||||
|
||||
if orig_start_position is None:
|
||||
return orig_text
|
||||
|
||||
orig_end_position = None
|
||||
if end_position in tok_s_to_ns_map:
|
||||
ns_end_position = tok_s_to_ns_map[end_position]
|
||||
if ns_end_position in orig_ns_to_s_map:
|
||||
orig_end_position = orig_ns_to_s_map[ns_end_position]
|
||||
|
||||
if orig_end_position is None:
|
||||
return orig_text
|
||||
|
||||
output_text = orig_text[orig_start_position : (orig_end_position + 1)]
|
||||
return output_text
|
||||
|
||||
|
||||
def masking_score(mask, batch_meta, start, end, keep_first_token=False):
|
||||
"""For MRC, e.g., SQuAD
|
||||
"""
|
||||
start = start.data.cpu()
|
||||
end = end.data.cpu()
|
||||
score_mask = start.new(mask.size()).zero_()
|
||||
score_mask = score_mask.data.cpu()
|
||||
token_is_max_contexts = batch_meta["token_is_max_context"]
|
||||
doc_offsets = batch_meta["doc_offset"]
|
||||
word_maps = batch_meta["token_to_orig_map"]
|
||||
batch_size = score_mask.size(0)
|
||||
doc_len = score_mask.size(1)
|
||||
for i in range(batch_size):
|
||||
doc_offset = doc_offsets[i]
|
||||
if keep_first_token:
|
||||
score_mask[i][1:doc_offset] = 1.0
|
||||
else:
|
||||
score_mask[i][:doc_offset] = 1.0
|
||||
for j in range(doc_len):
|
||||
sj = str(j)
|
||||
if mask[i][j] == 0:
|
||||
score_mask[i][j] == 1.0
|
||||
if sj in token_is_max_contexts[i] and (not token_is_max_contexts[i][sj]):
|
||||
score_mask[i][j] == 1.0
|
||||
score_mask = score_mask * LARGE_NEG_NUM
|
||||
start = start + score_mask
|
||||
end = end + score_mask
|
||||
start = F.softmax(start, 1)
|
||||
end = F.softmax(end, 1)
|
||||
return start, end
|
||||
|
||||
|
||||
def extract_answer(
|
||||
batch_meta, batch_data, start, end, keep_first_token=False, max_len=5
|
||||
):
|
||||
doc_len = start.size(1)
|
||||
pos_enc = position_encoding(doc_len, max_len)
|
||||
token_is_max_contexts = batch_meta["token_is_max_context"]
|
||||
doc_offsets = batch_meta["doc_offset"]
|
||||
word_maps = batch_meta["token_to_orig_map"]
|
||||
tokens = batch_meta["tokens"]
|
||||
contexts = batch_meta["doc"]
|
||||
uids = batch_meta["uids"]
|
||||
mask = batch_data[batch_meta["mask"]].data.cpu()
|
||||
# need to fill mask
|
||||
start, end = masking_score(mask, batch_meta, start, end)
|
||||
#####
|
||||
predictions = []
|
||||
answer_scores = []
|
||||
|
||||
for i in range(start.size(0)):
|
||||
uid = uids[i]
|
||||
scores = torch.ger(start[i], end[i])
|
||||
scores = scores * pos_enc
|
||||
scores.triu_()
|
||||
scores = scores.numpy()
|
||||
best_idx = np.argpartition(scores, -1, axis=None)[-1]
|
||||
best_score = np.partition(scores, -1, axis=None)[-1]
|
||||
s_idx, e_idx = np.unravel_index(best_idx, scores.shape)
|
||||
s_idx, e_idx = int(s_idx), int(e_idx)
|
||||
###
|
||||
tok_tokens = tokens[i][s_idx : (e_idx + 1)]
|
||||
tok_text = " ".join(tok_tokens)
|
||||
# De-tokenize WordPieces that have been split off.
|
||||
tok_text = tok_text.replace(" ##", "")
|
||||
tok_text = tok_text.replace("##", "")
|
||||
# Clean whitespace
|
||||
tok_text = tok_text.strip()
|
||||
tok_text = " ".join(tok_text.split())
|
||||
###
|
||||
context = contexts[i].split()
|
||||
rs = word_maps[i][str(s_idx)]
|
||||
re = word_maps[i][str(e_idx)]
|
||||
raw_answer = " ".join(context[rs : re + 1])
|
||||
# extract final answer
|
||||
answer = get_final_text(tok_text, raw_answer, True, False)
|
||||
predictions.append(answer)
|
||||
answer_scores.append(float(best_score))
|
||||
return predictions, answer_scores
|
||||
|
||||
|
||||
def select_answers(ids, predictions, scores):
|
||||
assert len(ids) == len(predictions)
|
||||
predictions_list = {}
|
||||
for idx, uid in enumerate(ids):
|
||||
score = scores[idx]
|
||||
ans = predictions[idx]
|
||||
lst = predictions_list.get(uid, [])
|
||||
lst.append((ans, score))
|
||||
predictions_list[uid] = lst
|
||||
final = {}
|
||||
scores = {}
|
||||
for key, val in predictions_list.items():
|
||||
idx = np.argmax([v[1] for v in val])
|
||||
final[key] = val[idx][1]
|
||||
scores[key] = val[idx][0]
|
||||
return final, scores
|
||||
|
||||
|
||||
def merge_answers(ids, golds):
|
||||
gold_list = {}
|
||||
for idx, uid in enumerate(ids):
|
||||
gold = golds[idx]
|
||||
if not uid in gold_list:
|
||||
gold_list[uid] = gold
|
||||
return gold_list
|
|
@ -0,0 +1,27 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
#ref: https://github.com/pytorch/pytorch/issues/1959
|
||||
# :https://arxiv.org/pdf/1607.06450.pdf
|
||||
def __init__(self, hidden_size, eps=1e-4):
|
||||
super(LayerNorm, self).__init__()
|
||||
self.alpha = Parameter(torch.ones(1,1,hidden_size)) # gain g
|
||||
self.beta = Parameter(torch.zeros(1,1,hidden_size)) # bias b
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
:param x: batch * len * input_size
|
||||
|
||||
Returns:
|
||||
normalized x
|
||||
"""
|
||||
mu = torch.mean(x, 2, keepdim=True).expand_as(x)
|
||||
sigma = torch.std(x, 2, keepdim=True).expand_as(x)
|
||||
return (x - mu) / (sigma + self.eps) * self.alpha.expand_as(x) + self.beta.expand_as(x)
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from enum import IntEnum
|
||||
|
||||
|
||||
class TaskType(IntEnum):
|
||||
Classification = 1
|
||||
Regression = 2
|
||||
Ranking = 3
|
||||
Span = 4
|
||||
SequenceLabeling = 5
|
||||
|
||||
|
||||
class DataFormat(IntEnum):
|
||||
PremiseOnly = 1
|
||||
PremiseAndOneHypothesis = 2
|
||||
PremiseAndMultiHypothesis = 3
|
||||
Sequence = 4
|
||||
|
||||
|
||||
class EncoderModelType(IntEnum):
|
||||
BERT = 1
|
||||
ROBERTA = 2
|
|
@ -0,0 +1,146 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
# Some code referenced from https://github.com/microsoft/nlp-recipes
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import tarfile
|
||||
import zipfile
|
||||
from contextlib import contextmanager
|
||||
from logging import Logger
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class MTDNNCommonUtils:
|
||||
@staticmethod
|
||||
def set_environment(seed, set_cuda=False):
|
||||
random.seed(seed)
|
||||
numpy.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available() and set_cuda:
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
@staticmethod
|
||||
def patch_var(v, cuda=True):
|
||||
if cuda:
|
||||
v = v.cuda(non_blocking=True)
|
||||
return v
|
||||
|
||||
@staticmethod
|
||||
def get_gpu_memory_map():
|
||||
result = subprocess.check_output(
|
||||
["nvidia-smi", "--query-gpu=memory.used", "--format=csv,nounits,noheader"],
|
||||
encoding="utf-8",
|
||||
)
|
||||
gpu_memory = [int(x) for x in result.strip().split("\n")]
|
||||
gpu_memory_map = dict(zip(range(len(gpu_memory)), gpu_memory))
|
||||
return gpu_memory_map
|
||||
|
||||
@staticmethod
|
||||
def get_pip_env():
|
||||
result = subprocess.call(["pip", "freeze"])
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def load_pytorch_model(local_model_path: str = ""):
|
||||
state_dict = None
|
||||
assert os.path.exists(local_model_path), "Model File path doesn't exist"
|
||||
state_dict = torch.load(local_model_path)
|
||||
return state_dict
|
||||
|
||||
@staticmethod
|
||||
def dump(path, data):
|
||||
with open(path, "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
@staticmethod
|
||||
def generate_decoder_opt(enable_san, max_opt):
|
||||
opt_v = 0
|
||||
if enable_san and max_opt < 3:
|
||||
opt_v = max_opt
|
||||
return opt_v
|
||||
|
||||
@staticmethod
|
||||
def setup_logging(filename="run.log", mode="w") -> Logger:
|
||||
logger = logging.getLogger(__name__)
|
||||
log_file_handler = logging.FileHandler(
|
||||
filename=filename, mode=mode, encoding="utf-8"
|
||||
)
|
||||
log_formatter = logging.Formatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
log_file_handler.setFormatter(log_formatter)
|
||||
do_add_handler = True
|
||||
for handler in logger.handlers:
|
||||
if isinstance(handler, logging.FileHandler):
|
||||
do_add_handler = False
|
||||
if do_add_handler:
|
||||
logger.addHandler(log_file_handler)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
return logger
|
||||
|
||||
@staticmethod
|
||||
def create_directory_if_not_exists(dir_path: str):
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def download_path(path=None):
|
||||
tmp_dir = TemporaryDirectory()
|
||||
if not path:
|
||||
path = tmp_dir.name
|
||||
else:
|
||||
path = os.path.realpath(path)
|
||||
|
||||
try:
|
||||
yield path
|
||||
finally:
|
||||
tmp_dir.cleanup()
|
||||
|
||||
@staticmethod
|
||||
def maybe_download(url, filename=None, work_directory=".", expected_bytes=None):
|
||||
"""Download a file if it is not already downloaded.
|
||||
|
||||
Args:
|
||||
filename (str): File name.
|
||||
work_directory (str): Working directory.
|
||||
url (str): URL of the file to download.
|
||||
expected_bytes (int): Expected file size in bytes.
|
||||
Returns:
|
||||
str: File path of the file downloaded.
|
||||
"""
|
||||
if filename is None:
|
||||
filename = url.split("/")[-1]
|
||||
os.makedirs(work_directory, exist_ok=True)
|
||||
filepath = os.path.join(work_directory, filename)
|
||||
if not os.path.exists(filepath):
|
||||
if not os.path.isdir(work_directory):
|
||||
os.makedirs(work_directory)
|
||||
r = requests.get(url, stream=True)
|
||||
total_size = int(r.headers.get("content-length", 0))
|
||||
block_size = 1024
|
||||
num_iterables = math.ceil(total_size / block_size)
|
||||
|
||||
with open(filepath, "wb") as file:
|
||||
for data in tqdm(
|
||||
r.iter_content(block_size),
|
||||
total=num_iterables,
|
||||
unit="KB",
|
||||
unit_scale=True,
|
||||
):
|
||||
file.write(data)
|
||||
else:
|
||||
log.debug("File {} already downloaded".format(filepath))
|
||||
if expected_bytes is not None:
|
||||
statinfo = os.stat(filepath)
|
||||
if statinfo.st_size != expected_bytes:
|
||||
os.remove(filepath)
|
||||
raise IOError("Failed to verify {}".format(filepath))
|
||||
|
||||
return filepath
|
|
@ -0,0 +1,15 @@
|
|||
from pytorch_pretrained_bert import BertTokenizer
|
||||
from data_utils.task_def import EncoderModelType
|
||||
from experiments.squad.squad_utils import calc_tokenized_span_range, parse_squad_label
|
||||
|
||||
model = "bert-base-uncased"
|
||||
do_lower_case = True
|
||||
tokenizer = BertTokenizer.from_pretrained(model, do_lower_case=do_lower_case)
|
||||
|
||||
for no, line in enumerate(open(r"data\canonical_data\squad_v2_train.tsv", encoding="utf-8")):
|
||||
if no % 1000 == 0:
|
||||
print(no)
|
||||
uid, label, context, question = line.strip().split("\t")
|
||||
answer_start, answer_end, answer, is_impossible = parse_squad_label(label)
|
||||
calc_tokenized_span_range(context, question, answer, answer_start, answer_end, tokenizer, EncoderModelType.BERT,
|
||||
verbose=True)
|
|
@ -0,0 +1,80 @@
|
|||
# Copyright (c) Microsoft. All rights reserved.
|
||||
import tqdm
|
||||
import unicodedata
|
||||
|
||||
PAD = 'PADPAD'
|
||||
UNK = 'UNKUNK'
|
||||
STA= 'BOSBOS'
|
||||
END = 'EOSEOS'
|
||||
|
||||
PAD_ID = 0
|
||||
UNK_ID = 1
|
||||
STA_ID = 2
|
||||
END_ID = 3
|
||||
|
||||
class Vocabulary(object):
|
||||
INIT_LEN = 4
|
||||
def __init__(self, neat=False):
|
||||
self.neat = neat
|
||||
if not neat:
|
||||
self.tok2ind = {PAD: PAD_ID, UNK: UNK_ID, STA: STA_ID, END: END_ID}
|
||||
self.ind2tok = {PAD_ID: PAD, UNK_ID: UNK, STA_ID: STA, END_ID:END}
|
||||
else:
|
||||
self.tok2ind = {}
|
||||
self.ind2tok = {}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.tok2ind)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.tok2ind)
|
||||
|
||||
def __contains__(self, key):
|
||||
if type(key) == int:
|
||||
return key in self.ind2tok
|
||||
elif type(key) == str:
|
||||
return key in self.tok2ind
|
||||
|
||||
def __getitem__(self, key):
|
||||
if type(key) == int:
|
||||
return self.ind2tok.get(key, -1) if self.neat else self.ind2tok.get(key, UNK)
|
||||
if type(key) == str:
|
||||
return self.tok2ind.get(key, None) if self.neat else self.tok2ind.get(key,self.tok2ind.get(UNK))
|
||||
|
||||
def __setitem__(self, key, item):
|
||||
if type(key) == int and type(item) == str:
|
||||
self.ind2tok[key] = item
|
||||
elif type(key) == str and type(item) == int:
|
||||
self.tok2ind[key] = item
|
||||
else:
|
||||
raise RuntimeError('Invalid (key, item) types.')
|
||||
|
||||
def add(self, token):
|
||||
if token not in self.tok2ind:
|
||||
index = len(self.tok2ind)
|
||||
self.tok2ind[token] = index
|
||||
self.ind2tok[index] = token
|
||||
|
||||
def get_vocab_list(self, with_order=True):
|
||||
if with_order:
|
||||
words = [self[k] for k in range(0, len(self))]
|
||||
else:
|
||||
words = [k for k in self.tok2ind.keys()
|
||||
if k not in {PAD, UNK, STA, END}]
|
||||
return words
|
||||
|
||||
def toidx(self, tokens):
|
||||
return [self[tok] for tok in tokens]
|
||||
|
||||
def copy(self):
|
||||
"""Deep copy
|
||||
"""
|
||||
new_vocab = Vocabulary(self.neat)
|
||||
for w in self:
|
||||
new_vocab.add(w)
|
||||
return new_vocab
|
||||
|
||||
def build(words, neat=False):
|
||||
vocab = Vocabulary(neat)
|
||||
for w in words: vocab.add(w)
|
||||
return vocab
|
|
@ -0,0 +1,199 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
# This script reuses some code from
|
||||
# https://github.com/huggingface/transformers
|
||||
|
||||
import torch
|
||||
from transformers import BertConfig, PretrainedConfig
|
||||
|
||||
from mtdnn.common.types import EncoderModelType
|
||||
from mtdnn.common.archive_maps import PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
"""MTDNN model configuration"""
|
||||
|
||||
encoder_checkpoint_map = {1: "bert", 2: "roberta"}
|
||||
|
||||
|
||||
class MTDNNConfig(PretrainedConfig):
|
||||
r"""
|
||||
:class:`~MTDNNConfig` is the configuration class to store the configuration of a
|
||||
`MTDNNModel`.
|
||||
|
||||
|
||||
Arguments:
|
||||
vocab_size: Vocabulary size of `inputs_ids` in `MTDNNModel`.
|
||||
hidden_size: Size of the encoder layers and the pooler layer.
|
||||
num_hidden_layers: Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads: Number of attention heads for each attention layer in
|
||||
the Transformer encoder.
|
||||
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
|
||||
layer in the Transformer encoder.
|
||||
hidden_act: The non-linear activation function (function or string) in the
|
||||
encoder and pooler. If string, "gelu", "relu", "swish" and "gelu_new" are supported.
|
||||
hidden_dropout_prob: The dropout probabilitiy for all fully connected
|
||||
layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob: The dropout ratio for the attention
|
||||
probabilities.
|
||||
max_position_embeddings: The maximum sequence length that this model might
|
||||
ever be used with. Typically set this to something large just in case
|
||||
(e.g., 512 or 1024 or 2048).
|
||||
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
|
||||
`MTDNNModel`.
|
||||
initializer_range: The sttdev of the truncated_normal_initializer for
|
||||
initializing all weight matrices.
|
||||
layer_norm_eps: The epsilon used by LayerNorm.
|
||||
"""
|
||||
|
||||
# TODO - Not needed
|
||||
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
use_pretrained_model=False,
|
||||
encoder_type=EncoderModelType.BERT,
|
||||
vocab_size=30522,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-12,
|
||||
dump_feature=False,
|
||||
update_bert_opt=0,
|
||||
answer_opt=0, # answer --> decoder opts flag
|
||||
enable_variational_dropout=True,
|
||||
init_ratio=1.0,
|
||||
init_checkpoint="bert-base-uncased",
|
||||
mtl_opt=0,
|
||||
ratio=0.0,
|
||||
mix_opt=0,
|
||||
max_seq_len=512,
|
||||
# Training config
|
||||
cuda=torch.cuda.is_available(),
|
||||
cuda_device=0,
|
||||
multi_gpu_on=False,
|
||||
log_per_updates=500,
|
||||
save_per_updates=10000,
|
||||
save_per_updates_on=False,
|
||||
epochs=5,
|
||||
use_tensor_board=False,
|
||||
tensorboard_logdir="tensorboard_logdir",
|
||||
batch_size=8,
|
||||
batch_size_eval=8,
|
||||
optimizer="adamax",
|
||||
grad_clipping=0.0,
|
||||
global_grad_clipping=1.0,
|
||||
weight_decay=0.0,
|
||||
learning_rate=5e-5,
|
||||
momentum=0.0,
|
||||
warmup=0.1,
|
||||
warmup_schedule="warmup_linear",
|
||||
adam_eps=1e-6,
|
||||
pooler=None,
|
||||
bert_dropout_p=0.1,
|
||||
dropout_p=0.1,
|
||||
dropout_w=0.0,
|
||||
vb_dropout=True,
|
||||
use_glue_format=False,
|
||||
# loading config
|
||||
model_ckpt="checkpoints/model_0.pt",
|
||||
resume=False,
|
||||
# Scheduler config
|
||||
have_lr_scheduler=True,
|
||||
multi_step_lr="10,20,30",
|
||||
freeze_layers=1,
|
||||
embedding_opt=0,
|
||||
lr_gamma=0.5,
|
||||
bert_l2norm=0.0,
|
||||
scheduler_type="ms",
|
||||
seed=2018,
|
||||
grad_accumulation_step=1,
|
||||
# fp16
|
||||
fp16=False,
|
||||
fp16_opt_level="01",
|
||||
mkd_opt=0,
|
||||
weighted_on=False,
|
||||
**kwargs,
|
||||
):
|
||||
# basic Configuration validation
|
||||
# assert inital checkpoint and encoder type are same
|
||||
assert init_checkpoint.startswith(
|
||||
encoder_checkpoint_map[encoder_type]
|
||||
), """Encoder type and initial checkpoint mismatch.
|
||||
1 - Bert models
|
||||
2 - Roberta models
|
||||
"""
|
||||
super(MTDNNConfig, self).__init__(**kwargs)
|
||||
self.use_pretrained_model = use_pretrained_model
|
||||
self.encoder_type = encoder_type
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.dump_feature = dump_feature
|
||||
self.update_bert_opt = update_bert_opt
|
||||
self.answer_opt = answer_opt
|
||||
self.enable_variational_dropout = enable_variational_dropout
|
||||
self.init_ratio = init_ratio
|
||||
self.init_checkpoint = init_checkpoint
|
||||
self.mtl_opt = mtl_opt
|
||||
self.ratio = ratio
|
||||
self.mix_opt = mix_opt
|
||||
self.max_seq_len = max_seq_len
|
||||
self.cuda = cuda
|
||||
self.cuda_device = cuda_device
|
||||
self.multi_gpu_on = multi_gpu_on
|
||||
self.log_per_updates = log_per_updates
|
||||
self.save_per_updates = save_per_updates
|
||||
self.save_per_updates_on = save_per_updates_on
|
||||
self.epochs = epochs
|
||||
self.use_tensor_board = use_tensor_board
|
||||
self.batch_size = batch_size
|
||||
self.batch_size_eval = batch_size_eval
|
||||
self.optimizer = optimizer
|
||||
self.grad_clipping = grad_clipping
|
||||
self.global_grad_clipping = global_grad_clipping
|
||||
self.weight_decay = weight_decay
|
||||
self.learning_rate = learning_rate
|
||||
self.momentum = momentum
|
||||
self.warmup = warmup
|
||||
self.warmup_schedule = warmup_schedule
|
||||
self.pooler = pooler
|
||||
self.adam_eps = adam_eps
|
||||
self.bert_dropout_p = bert_dropout_p
|
||||
self.dropout_p = dropout_p
|
||||
self.dropout_w = dropout_w
|
||||
self.vb_dropout = vb_dropout
|
||||
self.use_glue_format = use_glue_format
|
||||
self.model_ckpt = model_ckpt
|
||||
self.resume = resume
|
||||
self.have_lr_scheduler = have_lr_scheduler
|
||||
self.multi_step_lr = multi_step_lr
|
||||
self.freeze_layers = freeze_layers
|
||||
self.embedding_opt = embedding_opt
|
||||
self.lr_gamma = lr_gamma
|
||||
self.bert_l2norm = bert_l2norm
|
||||
self.scheduler_type = scheduler_type
|
||||
self.seed = seed
|
||||
self.grad_accumulation_step = grad_accumulation_step
|
||||
self.fp16 = fp16
|
||||
self.fp16_opt_level = fp16_opt_level
|
||||
self.mkd_opt = mkd_opt
|
||||
self.weighted_on = weighted_on
|
||||
self.kwargs = kwargs
|
||||
|
|
@ -0,0 +1,360 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
import json
|
||||
import random
|
||||
import sys
|
||||
from shutil import copyfile
|
||||
|
||||
import torch
|
||||
from torch.utils.data import BatchSampler, DataLoader, Dataset
|
||||
|
||||
from mtdnn.common.types import DataFormat, EncoderModelType, TaskType
|
||||
|
||||
UNK_ID = 100
|
||||
BOS_ID = 101
|
||||
|
||||
|
||||
class MTDNNMultiTaskBatchSampler(BatchSampler):
|
||||
def __init__(self, datasets, batch_size, mix_opt, extra_task_ratio):
|
||||
self._datasets = datasets
|
||||
self._batch_size = batch_size
|
||||
self._mix_opt = mix_opt
|
||||
self._extra_task_ratio = extra_task_ratio
|
||||
train_data_list = []
|
||||
for dataset in datasets:
|
||||
train_data_list.append(
|
||||
self._get_shuffled_index_batches(len(dataset), batch_size)
|
||||
)
|
||||
self._train_data_list = train_data_list
|
||||
|
||||
@staticmethod
|
||||
def _get_shuffled_index_batches(dataset_len, batch_size):
|
||||
index_batches = [
|
||||
list(range(i, min(i + batch_size, dataset_len)))
|
||||
for i in range(0, dataset_len, batch_size)
|
||||
]
|
||||
random.shuffle(index_batches)
|
||||
return index_batches
|
||||
|
||||
def __len__(self):
|
||||
return sum(len(train_data) for train_data in self._train_data_list)
|
||||
|
||||
def __iter__(self):
|
||||
all_iters = [iter(item) for item in self._train_data_list]
|
||||
all_indices = self._gen_task_indices(
|
||||
self._train_data_list, self._mix_opt, self._extra_task_ratio
|
||||
)
|
||||
for local_task_idx in all_indices:
|
||||
task_id = self._datasets[local_task_idx].get_task_id()
|
||||
batch = next(all_iters[local_task_idx])
|
||||
yield [(task_id, sample_id) for sample_id in batch]
|
||||
|
||||
@staticmethod
|
||||
def _gen_task_indices(train_data_list, mix_opt, extra_task_ratio):
|
||||
all_indices = []
|
||||
if len(train_data_list) > 1 and extra_task_ratio > 0:
|
||||
main_indices = [0] * len(train_data_list[0])
|
||||
extra_indices = []
|
||||
for i in range(1, len(train_data_list)):
|
||||
extra_indices += [i] * len(train_data_list[i])
|
||||
random_picks = int(
|
||||
min(len(train_data_list[0]) * extra_task_ratio, len(extra_indices))
|
||||
)
|
||||
extra_indices = np.random.choice(extra_indices, random_picks, replace=False)
|
||||
if mix_opt > 0:
|
||||
extra_indices = extra_indices.tolist()
|
||||
random.shuffle(extra_indices)
|
||||
all_indices = extra_indices + main_indices
|
||||
else:
|
||||
all_indices = main_indices + extra_indices.tolist()
|
||||
|
||||
else:
|
||||
for i in range(1, len(train_data_list)):
|
||||
all_indices += [i] * len(train_data_list[i])
|
||||
if mix_opt > 0:
|
||||
random.shuffle(all_indices)
|
||||
all_indices += [0] * len(train_data_list[0])
|
||||
if mix_opt < 1:
|
||||
random.shuffle(all_indices)
|
||||
return all_indices
|
||||
|
||||
|
||||
class MTDNNMultiTaskDataset(Dataset):
|
||||
def __init__(self, datasets):
|
||||
self._datasets = datasets
|
||||
task_id_2_data_set_dic = {}
|
||||
for dataset in datasets:
|
||||
task_id = dataset.get_task_id()
|
||||
assert task_id not in task_id_2_data_set_dic, (
|
||||
"Duplicate task_id %s" % task_id
|
||||
)
|
||||
task_id_2_data_set_dic[task_id] = dataset
|
||||
|
||||
self._task_id_2_data_set_dic = task_id_2_data_set_dic
|
||||
|
||||
def __len__(self):
|
||||
return sum(len(dataset) for dataset in self._datasets)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
task_id, sample_id = idx
|
||||
return self._task_id_2_data_set_dic[task_id][sample_id]
|
||||
|
||||
|
||||
class MTDNNSingleTaskDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
path,
|
||||
is_train=True,
|
||||
maxlen=128,
|
||||
factor=1.0,
|
||||
task_id=0,
|
||||
task_type=TaskType.Classification,
|
||||
data_type=DataFormat.PremiseOnly,
|
||||
):
|
||||
self._data = self.load(path, is_train, maxlen, factor, task_type)
|
||||
self._task_id = task_id
|
||||
self._task_type = task_type
|
||||
self._data_type = data_type
|
||||
|
||||
def get_task_id(self):
|
||||
return self._task_id
|
||||
|
||||
@staticmethod
|
||||
def load(path, is_train=True, maxlen=128, factor=1.0, task_type=None):
|
||||
assert task_type is not None
|
||||
with open(path, "r", encoding="utf-8") as reader:
|
||||
data = []
|
||||
cnt = 0
|
||||
for line in reader:
|
||||
sample = json.loads(line)
|
||||
sample["factor"] = factor
|
||||
cnt += 1
|
||||
if is_train:
|
||||
if (task_type == TaskType.Ranking) and (
|
||||
len(sample["token_id"][0]) > maxlen
|
||||
or len(sample["token_id"][1]) > maxlen
|
||||
):
|
||||
continue
|
||||
if (task_type != TaskType.Ranking) and (
|
||||
len(sample["token_id"]) > maxlen
|
||||
):
|
||||
continue
|
||||
data.append(sample)
|
||||
print("Loaded {} samples out of {}".format(len(data), cnt))
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return {
|
||||
"task": {
|
||||
"task_id": self._task_id,
|
||||
"task_type": self._task_type,
|
||||
"data_type": self._data_type,
|
||||
},
|
||||
"sample": self._data[idx],
|
||||
}
|
||||
|
||||
|
||||
class MTDNNCollater:
|
||||
def __init__(
|
||||
self,
|
||||
is_train=True,
|
||||
dropout_w=0.005,
|
||||
soft_label=False,
|
||||
encoder_type=EncoderModelType.BERT,
|
||||
):
|
||||
self.is_train = is_train
|
||||
self.dropout_w = dropout_w
|
||||
self.soft_label_on = soft_label
|
||||
self.encoder_type = encoder_type
|
||||
self.pairwise_size = 1
|
||||
|
||||
def __random_select__(self, arr):
|
||||
if self.dropout_w > 0:
|
||||
return [UNK_ID if random.uniform(0, 1) < self.dropout_w else e for e in arr]
|
||||
else:
|
||||
return arr
|
||||
|
||||
@staticmethod
|
||||
def patch_data(gpu, batch_info, batch_data):
|
||||
if gpu:
|
||||
for i, part in enumerate(batch_data):
|
||||
if isinstance(part, torch.Tensor):
|
||||
batch_data[i] = part.pin_memory().cuda(non_blocking=True)
|
||||
elif isinstance(part, tuple):
|
||||
batch_data[i] = tuple(
|
||||
sub_part.pin_memory().cuda(non_blocking=True)
|
||||
for sub_part in part
|
||||
)
|
||||
elif isinstance(part, list):
|
||||
batch_data[i] = [
|
||||
sub_part.pin_memory().cuda(non_blocking=True)
|
||||
for sub_part in part
|
||||
]
|
||||
else:
|
||||
raise TypeError("unknown batch data type at %s: %s" % (i, part))
|
||||
|
||||
if "soft_label" in batch_info:
|
||||
batch_info["soft_label"] = (
|
||||
batch_info["soft_label"].pin_memory().cuda(non_blocking=True)
|
||||
)
|
||||
|
||||
return batch_info, batch_data
|
||||
|
||||
def rebatch(self, batch):
|
||||
newbatch = []
|
||||
for sample in batch:
|
||||
size = len(sample["token_id"])
|
||||
self.pairwise_size = size
|
||||
assert size == len(sample["type_id"])
|
||||
for idx in range(0, size):
|
||||
token_id = sample["token_id"][idx]
|
||||
type_id = sample["type_id"][idx]
|
||||
uid = sample["ruid"][idx]
|
||||
olab = sample["olabel"][idx]
|
||||
newbatch.append(
|
||||
{
|
||||
"uid": uid,
|
||||
"token_id": token_id,
|
||||
"type_id": type_id,
|
||||
"label": sample["label"],
|
||||
"true_label": olab,
|
||||
}
|
||||
)
|
||||
return newbatch
|
||||
|
||||
def __if_pair__(self, data_type):
|
||||
return data_type in [
|
||||
DataFormat.PremiseAndOneHypothesis,
|
||||
DataFormat.PremiseAndMultiHypothesis,
|
||||
]
|
||||
|
||||
def collate_fn(self, batch):
|
||||
task_id = batch[0]["task"]["task_id"]
|
||||
task_type = batch[0]["task"]["task_type"]
|
||||
data_type = batch[0]["task"]["data_type"]
|
||||
new_batch = []
|
||||
for sample in batch:
|
||||
assert sample["task"]["task_id"] == task_id
|
||||
assert sample["task"]["task_type"] == task_type
|
||||
assert sample["task"]["data_type"] == data_type
|
||||
new_batch.append(sample["sample"])
|
||||
batch = new_batch
|
||||
|
||||
if task_type == TaskType.Ranking:
|
||||
batch = self.rebatch(batch)
|
||||
|
||||
# prepare model input
|
||||
batch_info, batch_data = self._prepare_model_input(batch, data_type)
|
||||
batch_info["task_id"] = task_id # used for select correct decoding head
|
||||
batch_info["input_len"] = len(batch_data) # used to select model inputs
|
||||
# select different loss function and other difference in training and testing
|
||||
batch_info["task_type"] = task_type
|
||||
batch_info["pairwise_size"] = self.pairwise_size # need for ranking task
|
||||
|
||||
# add label
|
||||
labels = [sample["label"] for sample in batch]
|
||||
if self.is_train:
|
||||
# in training model, label is used by Pytorch, so would be tensor
|
||||
if task_type == TaskType.Regression:
|
||||
batch_data.append(torch.FloatTensor(labels))
|
||||
batch_info["label"] = len(batch_data) - 1
|
||||
elif task_type in (TaskType.Classification, TaskType.Ranking):
|
||||
batch_data.append(torch.LongTensor(labels))
|
||||
batch_info["label"] = len(batch_data) - 1
|
||||
elif task_type == TaskType.Span:
|
||||
start = [sample["start_position"] for sample in batch]
|
||||
end = [sample["end_position"] for sample in batch]
|
||||
batch_data.append((torch.LongTensor(start), torch.LongTensor(end)))
|
||||
# unify to one type of label
|
||||
batch_info["label"] = len(batch_data) - 1
|
||||
# batch_data.extend([torch.LongTensor(start), torch.LongTensor(end)])
|
||||
elif task_type == TaskType.SeqenceLabeling:
|
||||
batch_size = self._get_batch_size(batch)
|
||||
tok_len = self._get_max_len(batch, key="token_id")
|
||||
tlab = torch.LongTensor(batch_size, tok_len).fill_(-1)
|
||||
for i, label in enumerate(labels):
|
||||
ll = len(label)
|
||||
tlab[i, :ll] = torch.LongTensor(label)
|
||||
batch_data.append(tlab)
|
||||
batch_info["label"] = len(batch_data) - 1
|
||||
|
||||
# soft label generated by ensemble models for knowledge distillation
|
||||
if self.soft_label_on and (batch[0].get("softlabel", None) is not None):
|
||||
assert (
|
||||
task_type != TaskType.Span
|
||||
) # Span task doesn't support soft label yet.
|
||||
sortlabels = [sample["softlabel"] for sample in batch]
|
||||
sortlabels = torch.FloatTensor(sortlabels)
|
||||
batch_info["soft_label"] = sortlabels
|
||||
else:
|
||||
# in test model, label would be used for evaluation
|
||||
batch_info["label"] = labels
|
||||
if task_type == TaskType.Ranking:
|
||||
batch_info["true_label"] = [sample["true_label"] for sample in batch]
|
||||
if task_type == TaskType.Span:
|
||||
batch_info["token_to_orig_map"] = [
|
||||
sample["token_to_orig_map"] for sample in batch
|
||||
]
|
||||
batch_info["token_is_max_context"] = [
|
||||
sample["token_is_max_context"] for sample in batch
|
||||
]
|
||||
batch_info["doc_offset"] = [sample["doc_offset"] for sample in batch]
|
||||
batch_info["doc"] = [sample["doc"] for sample in batch]
|
||||
batch_info["tokens"] = [sample["tokens"] for sample in batch]
|
||||
batch_info["answer"] = [sample["answer"] for sample in batch]
|
||||
|
||||
batch_info["uids"] = [sample["uid"] for sample in batch] # used in scoring
|
||||
return batch_info, batch_data
|
||||
|
||||
def _get_max_len(self, batch, key="token_id"):
|
||||
tok_len = max(len(x[key]) for x in batch)
|
||||
return tok_len
|
||||
|
||||
def _get_batch_size(self, batch):
|
||||
return len(batch)
|
||||
|
||||
def _prepare_model_input(self, batch, data_type):
|
||||
batch_size = self._get_batch_size(batch)
|
||||
tok_len = self._get_max_len(batch, key="token_id")
|
||||
# tok_len = max(len(x['token_id']) for x in batch)
|
||||
hypothesis_len = max(len(x["type_id"]) - sum(x["type_id"]) for x in batch)
|
||||
if self.encoder_type == EncoderModelType.ROBERTA:
|
||||
token_ids = torch.LongTensor(batch_size, tok_len).fill_(1)
|
||||
type_ids = torch.LongTensor(batch_size, tok_len).fill_(0)
|
||||
masks = torch.LongTensor(batch_size, tok_len).fill_(0)
|
||||
else:
|
||||
token_ids = torch.LongTensor(batch_size, tok_len).fill_(0)
|
||||
type_ids = torch.LongTensor(batch_size, tok_len).fill_(0)
|
||||
masks = torch.LongTensor(batch_size, tok_len).fill_(0)
|
||||
if self.__if_pair__(data_type):
|
||||
premise_masks = torch.ByteTensor(batch_size, tok_len).fill_(1)
|
||||
hypothesis_masks = torch.ByteTensor(batch_size, hypothesis_len).fill_(1)
|
||||
for i, sample in enumerate(batch):
|
||||
select_len = min(len(sample["token_id"]), tok_len)
|
||||
tok = sample["token_id"]
|
||||
if self.is_train:
|
||||
tok = self.__random_select__(tok)
|
||||
token_ids[i, :select_len] = torch.LongTensor(tok[:select_len])
|
||||
type_ids[i, :select_len] = torch.LongTensor(sample["type_id"][:select_len])
|
||||
masks[i, :select_len] = torch.LongTensor([1] * select_len)
|
||||
if self.__if_pair__(data_type):
|
||||
hlen = len(sample["type_id"]) - sum(sample["type_id"])
|
||||
hypothesis_masks[i, :hlen] = torch.LongTensor([0] * hlen)
|
||||
for j in range(hlen, select_len):
|
||||
premise_masks[i, j] = 0
|
||||
if self.__if_pair__(data_type):
|
||||
batch_info = {
|
||||
"token_id": 0,
|
||||
"segment_id": 1,
|
||||
"mask": 2,
|
||||
"premise_mask": 3,
|
||||
"hypothesis_mask": 4,
|
||||
}
|
||||
batch_data = [token_ids, type_ids, masks, premise_masks, hypothesis_masks]
|
||||
else:
|
||||
batch_info = {"token_id": 0, "segment_id": 1, "mask": 2}
|
||||
batch_data = [token_ids, type_ids, masks]
|
||||
return batch_info, batch_data
|
|
@ -0,0 +1,766 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
# This script reuses some code from
|
||||
# https://github.com/huggingface/transformers
|
||||
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from fairseq.models.roberta import RobertaModel as FairseqRobertModel
|
||||
from tensorboardX import SummaryWriter
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import *
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import (
|
||||
BertConfig,
|
||||
BertModel,
|
||||
BertPreTrainedModel,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
RobertaModel,
|
||||
)
|
||||
|
||||
from mtdnn.common.archive_maps import PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from mtdnn.common.average_meter import AverageMeter
|
||||
from mtdnn.common.bert_optim import Adamax, RAdam
|
||||
from mtdnn.common.linear_pooler import LinearPooler
|
||||
from mtdnn.common.loss import LOSS_REGISTRY
|
||||
from mtdnn.common.metrics import calc_metrics
|
||||
from mtdnn.common.san import SANBERTNetwork, SANClassifier
|
||||
from mtdnn.common.squad_utils import extract_answer, merge_answers, select_answers
|
||||
from mtdnn.common.types import DataFormat, EncoderModelType, TaskType
|
||||
from mtdnn.common.utils import MTDNNCommonUtils
|
||||
from mtdnn.configuration_mtdnn import MTDNNConfig
|
||||
from mtdnn.dataset_mtdnn import MTDNNCollater
|
||||
from mtdnn.tasks.config import MTDNNTaskDefs
|
||||
|
||||
logger = MTDNNCommonUtils.setup_logging()
|
||||
|
||||
|
||||
class MTDNNPretrainedModel(nn.Module):
|
||||
config_class = MTDNNConfig
|
||||
pretrained_model_archive_map = PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_tf_weights = lambda model, config, path: None
|
||||
base_model_prefix = "mtdnn"
|
||||
|
||||
def __init__(self, config):
|
||||
super(MTDNNPretrainedModel, self).__init__()
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
raise ValueError(
|
||||
"Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
|
||||
"To create a model from a pretrained model use "
|
||||
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
||||
self.__class__.__name__, self.__class__.__name__
|
||||
)
|
||||
)
|
||||
# Save config in model
|
||||
self.config = config
|
||||
|
||||
|
||||
class MTDNNModel(MTDNNPretrainedModel):
|
||||
"""Instance of an MTDNN Model
|
||||
|
||||
Arguments:
|
||||
MTDNNPretrainedModel {BertPretrainedModel} -- Inherited from Bert Pretrained
|
||||
config {MTDNNConfig} -- MTDNN Configuration Object
|
||||
pretrained_model_name {str} -- Name of the pretrained model to initial checkpoint
|
||||
num_train_step {int} -- Number of steps to take each training
|
||||
|
||||
Raises:
|
||||
RuntimeError: [description]
|
||||
ImportError: [description]
|
||||
|
||||
Returns:
|
||||
MTDNNModel -- An Instance of an MTDNN Model
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MTDNNConfig,
|
||||
task_defs: MTDNNTaskDefs,
|
||||
pretrained_model_name: str = "mtdnn-base-uncased",
|
||||
num_train_step: int = -1,
|
||||
decoder_opts: list = None,
|
||||
task_types: list = None,
|
||||
dropout_list: list = None,
|
||||
loss_types: list = None,
|
||||
kd_loss_types: list = None,
|
||||
tasks_nclass_list: list = None,
|
||||
multitask_train_dataloader: DataLoader = None,
|
||||
dev_dataloaders_list: list = None, # list of dataloaders
|
||||
test_dataloaders_list: list = None, # list of dataloaders
|
||||
test_datasets_list: list = ["mnli_mismatched", "mnli_matched"],
|
||||
output_dir: str = "checkpoint",
|
||||
log_dir: str = "tensorboard_logdir",
|
||||
):
|
||||
|
||||
# Input validation
|
||||
assert (
|
||||
config.init_checkpoint in self.supported_init_checkpoints()
|
||||
), f"Initial checkpoint must be in {self.supported_init_checkpoints()}"
|
||||
|
||||
assert decoder_opts, "Decoder options list is required!"
|
||||
assert task_types, "Task types list is required!"
|
||||
assert dropout_list, "Task dropout list is required!"
|
||||
assert loss_types, "Loss types list is required!"
|
||||
assert kd_loss_types, "KD Loss types list is required!"
|
||||
assert tasks_nclass_list, "Tasks nclass list is required!"
|
||||
assert (
|
||||
multitask_train_dataloader
|
||||
), "DataLoader for multiple tasks cannot be None"
|
||||
assert test_datasets_list, "Pass a list of test dataset prefixes"
|
||||
|
||||
super(MTDNNModel, self).__init__(config)
|
||||
|
||||
# Initialize model config and update with training options
|
||||
self.config = config
|
||||
self.update_config_with_training_opts(
|
||||
decoder_opts,
|
||||
task_types,
|
||||
dropout_list,
|
||||
loss_types,
|
||||
kd_loss_types,
|
||||
tasks_nclass_list,
|
||||
)
|
||||
self.task_defs = task_defs
|
||||
self.multitask_train_dataloader = multitask_train_dataloader
|
||||
self.dev_dataloaders_list = dev_dataloaders_list
|
||||
self.test_dataloaders_list = test_dataloaders_list
|
||||
self.test_datasets_list = test_datasets_list
|
||||
self.output_dir = output_dir
|
||||
self.log_dir = log_dir
|
||||
|
||||
# Create the output_dir if it's doesn't exist
|
||||
MTDNNCommonUtils.create_directory_if_not_exists(self.output_dir)
|
||||
self.tensor_board = SummaryWriter(log_dir=self.log_dir)
|
||||
|
||||
self.pooler = None
|
||||
|
||||
# Resume from model checkpoint
|
||||
if self.config.resume and self.config.model_ckpt:
|
||||
assert os.path.exists(
|
||||
self.config.model_ckpt
|
||||
), "Model checkpoint does not exist"
|
||||
logger.info(f"loading model from {self.config.model_ckpt}")
|
||||
self = self.load(self.config.model_ckpt)
|
||||
return
|
||||
|
||||
# Setup the baseline network
|
||||
# - Define the encoder based on config options
|
||||
# - Set state dictionary based on configuration setting
|
||||
# - Download pretrained model if flag is set
|
||||
# TODO - Use Model.pretrained_model() after configuration file is hosted.
|
||||
if self.config.use_pretrained_model:
|
||||
with MTDNNCommonUtils.download_path() as file_path:
|
||||
path = pathlib.Path(file_path)
|
||||
self.local_model_path = MTDNNCommonUtils.maybe_download(
|
||||
url=self.pretrained_model_archive_map[pretrained_model_name]
|
||||
)
|
||||
self.bert_model = MTDNNCommonUtils.load_pytorch_model(self.local_model_path)
|
||||
self.state_dict = self.bert_model["state"]
|
||||
else:
|
||||
# Set the config base on encoder type set for initial checkpoint
|
||||
if config.encoder_type == EncoderModelType.BERT:
|
||||
self.bert_config = BertConfig.from_dict(self.config.to_dict())
|
||||
self.bert_model = BertModel.from_pretrained(self.config.init_checkpoint)
|
||||
self.state_dict = self.bert_model.state_dict()
|
||||
self.config.hidden_size = self.bert_config.hidden_size
|
||||
if config.encoder_type == EncoderModelType.ROBERTA:
|
||||
# Download and extract from PyTorch hub if not downloaded before
|
||||
self.bert_model = torch.hub.load(
|
||||
"pytorch/fairseq", config.init_checkpoint
|
||||
)
|
||||
self.config.hidden_size = self.bert_model.args.encoder_embed_dim
|
||||
self.pooler = LinearPooler(self.config.hidden_size)
|
||||
new_state_dict = {}
|
||||
for key, val in self.bert_model.state_dict().items():
|
||||
if key.startswith(
|
||||
"model.decoder.sentence_encoder"
|
||||
) or key.startswith("model.classification_heads"):
|
||||
key = f"bert.{key}"
|
||||
new_state_dict[key] = val
|
||||
# backward compatibility PyTorch <= 1.0.0
|
||||
if key.startswith("classification_heads"):
|
||||
key = f"bert.model.{key}"
|
||||
new_state_dict[key] = val
|
||||
self.state_dict = new_state_dict
|
||||
|
||||
self.updates = (
|
||||
self.state_dict["updates"]
|
||||
if self.state_dict and "updates" in self.state_dict
|
||||
else 0
|
||||
)
|
||||
self.local_updates = 0
|
||||
self.train_loss = AverageMeter()
|
||||
self.network = SANBERTNetwork(
|
||||
init_checkpoint_model=self.bert_model,
|
||||
pooler=self.pooler,
|
||||
config=self.config,
|
||||
)
|
||||
if self.state_dict:
|
||||
self.network.load_state_dict(self.state_dict, strict=False)
|
||||
self.mnetwork = (
|
||||
nn.DataParallel(self.network) if self.config.multi_gpu_on else self.network
|
||||
)
|
||||
self.total_param = sum(
|
||||
[p.nelement() for p in self.network.parameters() if p.requires_grad]
|
||||
)
|
||||
|
||||
# Move network to GPU if device available and flag set
|
||||
if self.config.cuda:
|
||||
self.network.cuda(device=self.config.cuda_device)
|
||||
self.optimizer_parameters = self._get_param_groups()
|
||||
self._setup_optim(self.optimizer_parameters, self.state_dict, num_train_step)
|
||||
self.para_swapped = False
|
||||
self.optimizer.zero_grad()
|
||||
self._setup_lossmap()
|
||||
|
||||
def _get_param_groups(self):
|
||||
no_decay = ["bias", "gamma", "beta", "LayerNorm.bias", "LayerNorm.weight"]
|
||||
optimizer_parameters = [
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in self.network.named_parameters()
|
||||
if not any(nd in n for nd in no_decay)
|
||||
],
|
||||
"weight_decay": 0.01,
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p
|
||||
for n, p in self.network.named_parameters()
|
||||
if any(nd in n for nd in no_decay)
|
||||
],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
return optimizer_parameters
|
||||
|
||||
def _setup_optim(
|
||||
self, optimizer_parameters, state_dict: dict = None, num_train_step: int = -1
|
||||
):
|
||||
|
||||
# Setup optimizer parameters
|
||||
if self.config.optimizer == "sgd":
|
||||
self.optimizer = optim.SGD(
|
||||
optimizer_parameters,
|
||||
self.config.learning_rate,
|
||||
weight_decay=self.config.weight_decay,
|
||||
)
|
||||
elif self.config.optimizer == "adamax":
|
||||
self.optimizer = Adamax(
|
||||
optimizer_parameters,
|
||||
self.config.learning_rate,
|
||||
warmup=self.config.warmup,
|
||||
t_total=num_train_step,
|
||||
max_grad_norm=self.config.grad_clipping,
|
||||
schedule=self.config.warmup_schedule,
|
||||
weight_decay=self.config.weight_decay,
|
||||
)
|
||||
|
||||
elif self.config.optimizer == "radam":
|
||||
self.optimizer = RAdam(
|
||||
optimizer_parameters,
|
||||
self.config.learning_rate,
|
||||
warmup=self.config.warmup,
|
||||
t_total=num_train_step,
|
||||
max_grad_norm=self.config.grad_clipping,
|
||||
schedule=self.config.warmup_schedule,
|
||||
eps=self.config.adam_eps,
|
||||
weight_decay=self.config.weight_decay,
|
||||
)
|
||||
|
||||
# The current radam does not support FP16.
|
||||
self.config.fp16 = False
|
||||
elif self.config.optimizer == "adam":
|
||||
self.optimizer = Adam(
|
||||
optimizer_parameters,
|
||||
lr=self.config.learning_rate,
|
||||
warmup=self.config.warmup,
|
||||
t_total=num_train_step,
|
||||
max_grad_norm=self.config.grad_clipping,
|
||||
schedule=self.config.warmup_schedule,
|
||||
weight_decay=self.config.weight_decay,
|
||||
)
|
||||
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported optimizer: {self.config.optimizer}")
|
||||
|
||||
# Clear scheduler for certain optimizer choices
|
||||
if self.config.optimizer in ["adam", "adamax", "radam"]:
|
||||
if self.config.have_lr_scheduler:
|
||||
self.config.have_lr_scheduler = False
|
||||
|
||||
if state_dict and "optimizer" in state_dict:
|
||||
self.optimizer.load_state_dict(state_dict["optimizer"])
|
||||
|
||||
if self.config.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
|
||||
global amp
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
|
||||
)
|
||||
model, optimizer = amp.initialize(
|
||||
self.network, self.optimizer, opt_level=self.config.fp16_opt_level
|
||||
)
|
||||
self.network = model
|
||||
self.optimizer = optimizer
|
||||
|
||||
if self.config.have_lr_scheduler:
|
||||
if self.config.scheduler_type == "rop":
|
||||
self.scheduler = ReduceLROnPlateau(
|
||||
self.optimizer, mode="max", factor=self.config.lr_gamma, patience=3
|
||||
)
|
||||
elif self.config.scheduler_type == "exp":
|
||||
self.scheduler = ExponentialLR(
|
||||
self.optimizer, gamma=self.config.lr_gamma or 0.95
|
||||
)
|
||||
else:
|
||||
milestones = [
|
||||
int(step)
|
||||
for step in (self.config.multi_step_lr or "10,20,30").split(",")
|
||||
]
|
||||
self.scheduler = MultiStepLR(
|
||||
self.optimizer, milestones=milestones, gamma=self.config.lr_gamma
|
||||
)
|
||||
else:
|
||||
self.scheduler = None
|
||||
|
||||
def _setup_lossmap(self):
|
||||
self.task_loss_criterion = []
|
||||
for idx, cs in enumerate(self.config.loss_types):
|
||||
assert cs is not None, "Loss type must be defined."
|
||||
lc = LOSS_REGISTRY[cs](name=f"Loss func of task {idx}: {cs}")
|
||||
self.task_loss_criterion.append(lc)
|
||||
|
||||
def _setup_kd_lossmap(self):
|
||||
loss_types = self.config.kd_loss_types
|
||||
self.kd_task_loss_criterion = []
|
||||
if config.mkd_opt > 0:
|
||||
for idx, cs in enumerate(loss_types):
|
||||
assert cs, "Loss type must be defined."
|
||||
lc = LOSS_REGISTRY[cs](name="Loss func of task {}: {}".format(idx, cs))
|
||||
self.kd_task_loss_criterion.append(lc)
|
||||
|
||||
def _to_cuda(self, tensor):
|
||||
# Set tensor to gpu (non-blocking) if a PyTorch tensor
|
||||
if tensor is None:
|
||||
return tensor
|
||||
|
||||
if isinstance(tensor, list) or isinstance(tensor, tuple):
|
||||
y = [
|
||||
e.cuda(device=self.config.cuda_device, non_blocking=True)
|
||||
for e in tensor
|
||||
]
|
||||
for t in y:
|
||||
t.requires_grad = False
|
||||
else:
|
||||
y = tensor.cuda(device=self.config.cuda_device, non_blocking=True)
|
||||
y.requires_grad = False
|
||||
return y
|
||||
|
||||
def train(self):
|
||||
if self.para_swapped:
|
||||
self.para_swapped = False
|
||||
|
||||
def update(self, batch_meta, batch_data):
|
||||
self.network.train()
|
||||
target = batch_data[batch_meta["label"]]
|
||||
soft_labels = None
|
||||
|
||||
task_type = batch_meta["task_type"]
|
||||
target = self._to_cuda(target) if self.config.cuda else target
|
||||
|
||||
task_id = batch_meta["task_id"]
|
||||
inputs = batch_data[: batch_meta["input_len"]]
|
||||
if len(inputs) == 3:
|
||||
inputs.append(None)
|
||||
inputs.append(None)
|
||||
inputs.append(task_id)
|
||||
weight = None
|
||||
if self.config.weighted_on:
|
||||
if self.config.cuda:
|
||||
weight = batch_data[batch_meta["factor"]].cuda(
|
||||
device=self.config.cuda_device, non_blocking=True
|
||||
)
|
||||
else:
|
||||
weight = batch_data[batch_meta["factor"]]
|
||||
logits = self.mnetwork(*inputs)
|
||||
|
||||
# compute loss
|
||||
loss = 0
|
||||
if self.task_loss_criterion[task_id] and (target is not None):
|
||||
loss = self.task_loss_criterion[task_id](
|
||||
logits, target, weight, ignore_index=-1
|
||||
)
|
||||
|
||||
# compute kd loss
|
||||
if self.config.mkd_opt > 0 and ("soft_label" in batch_meta):
|
||||
soft_labels = batch_meta["soft_label"]
|
||||
soft_labels = (
|
||||
self._to_cuda(soft_labels) if self.config.cuda else soft_labels
|
||||
)
|
||||
kd_lc = self.kd_task_loss_criterion[task_id]
|
||||
kd_loss = (
|
||||
kd_lc(logits, soft_labels, weight, ignore_index=-1) if kd_lc else 0
|
||||
)
|
||||
loss = loss + kd_loss
|
||||
|
||||
self.train_loss.update(loss.item(), batch_data[batch_meta["token_id"]].size(0))
|
||||
# scale loss
|
||||
loss = loss / (self.config.grad_accumulation_step or 1)
|
||||
if self.config.fp16:
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
self.local_updates += 1
|
||||
if self.local_updates % self.config.grad_accumulation_step == 0:
|
||||
if self.config.global_grad_clipping > 0:
|
||||
if self.config.fp16:
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
amp.master_params(self.optimizer),
|
||||
self.config.global_grad_clipping,
|
||||
)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.network.parameters(), self.config.global_grad_clipping
|
||||
)
|
||||
self.updates += 1
|
||||
# reset number of the grad accumulation
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
def eval_mode(
|
||||
self,
|
||||
data: DataLoader,
|
||||
metric_meta,
|
||||
use_cuda=True,
|
||||
with_label=True,
|
||||
label_mapper=None,
|
||||
task_type=TaskType.Classification,
|
||||
):
|
||||
if use_cuda:
|
||||
self.cuda()
|
||||
predictions = []
|
||||
golds = []
|
||||
scores = []
|
||||
ids = []
|
||||
metrics = {}
|
||||
for idx, (batch_info, batch_data) in enumerate(data):
|
||||
if idx % 100 == 0:
|
||||
logger.info(f"predicting {idx}")
|
||||
batch_info, batch_data = MTDNNCollater.patch_data(
|
||||
use_cuda, batch_info, batch_data
|
||||
)
|
||||
score, pred, gold = self._predict_batch(batch_info, batch_data)
|
||||
predictions.extend(pred)
|
||||
golds.extend(gold)
|
||||
scores.extend(score)
|
||||
ids.extend(batch_info["uids"])
|
||||
|
||||
if task_type == TaskType.Span:
|
||||
golds = merge_answers(ids, golds)
|
||||
predictions, scores = select_answers(ids, predictions, scores)
|
||||
if with_label:
|
||||
metrics = calc_metrics(
|
||||
metric_meta, golds, predictions, scores, label_mapper
|
||||
)
|
||||
return metrics, predictions, scores, golds, ids
|
||||
|
||||
def _predict_batch(self, batch_meta, batch_data):
|
||||
self.network.eval()
|
||||
task_id = batch_meta["task_id"]
|
||||
task_type = batch_meta["task_type"]
|
||||
inputs = batch_data[: batch_meta["input_len"]]
|
||||
if len(inputs) == 3:
|
||||
inputs.append(None)
|
||||
inputs.append(None)
|
||||
inputs.append(task_id)
|
||||
score = self.mnetwork(*inputs)
|
||||
if task_type == TaskType.Ranking:
|
||||
score = score.contiguous().view(-1, batch_meta["pairwise_size"])
|
||||
assert task_type == TaskType.Ranking
|
||||
score = F.softmax(score, dim=1)
|
||||
score = score.data.cpu()
|
||||
score = score.numpy()
|
||||
predict = np.zeros(score.shape, dtype=int)
|
||||
positive = np.argmax(score, axis=1)
|
||||
for idx, pos in enumerate(positive):
|
||||
predict[idx, pos] = 1
|
||||
predict = predict.reshape(-1).tolist()
|
||||
score = score.reshape(-1).tolist()
|
||||
return score, predict, batch_meta["true_label"]
|
||||
elif task_type == TaskType.SequenceLabeling:
|
||||
mask = batch_data[batch_meta["mask"]]
|
||||
score = score.contiguous()
|
||||
score = score.data.cpu()
|
||||
score = score.numpy()
|
||||
predict = np.argmax(score, axis=1).reshape(mask.size()).tolist()
|
||||
valied_lenght = mask.sum(1).tolist()
|
||||
final_predict = []
|
||||
for idx, p in enumerate(predict):
|
||||
final_predict.append(p[: valied_lenght[idx]])
|
||||
score = score.reshape(-1).tolist()
|
||||
return score, final_predict, batch_meta["label"]
|
||||
elif task_type == TaskType.Span:
|
||||
start, end = score
|
||||
predictions = []
|
||||
if self.config.encoder_type == EncoderModelType.BERT:
|
||||
scores, predictions = extract_answer(
|
||||
batch_meta,
|
||||
batch_data,
|
||||
start,
|
||||
end,
|
||||
self.config.get("max_answer_len", 5),
|
||||
)
|
||||
return scores, predictions, batch_meta["answer"]
|
||||
else:
|
||||
if task_type == TaskType.Classification:
|
||||
score = F.softmax(score, dim=1)
|
||||
score = score.data.cpu()
|
||||
score = score.numpy()
|
||||
predict = np.argmax(score, axis=1).tolist()
|
||||
score = score.reshape(-1).tolist()
|
||||
return score, predict, batch_meta["label"]
|
||||
|
||||
def fit(self, epochs=0):
|
||||
""" Fit model to training datasets """
|
||||
epochs = epochs or self.config.epochs
|
||||
logger.info(f"Total number of params: {self.total_param}")
|
||||
for epoch in range(epochs):
|
||||
logger.info(f"At epoch {epoch}")
|
||||
logger.info(
|
||||
f"Amount of data to go over: {len(self.multitask_train_dataloader)}"
|
||||
)
|
||||
|
||||
start = datetime.now()
|
||||
# Create batches and train
|
||||
for idx, (batch_meta, batch_data) in enumerate(
|
||||
self.multitask_train_dataloader
|
||||
):
|
||||
batch_meta, batch_data = MTDNNCollater.patch_data(
|
||||
self.config.cuda, batch_meta, batch_data
|
||||
)
|
||||
|
||||
task_id = batch_meta["task_id"]
|
||||
self.update(batch_meta, batch_data)
|
||||
if (
|
||||
self.local_updates == 1
|
||||
or (self.local_updates)
|
||||
% (self.config.log_per_updates * self.config.grad_accumulation_step)
|
||||
== 0
|
||||
):
|
||||
|
||||
time_left = str(
|
||||
(datetime.now() - start)
|
||||
/ (idx + 1)
|
||||
* (len(self.multitask_train_dataloader) - idx - 1)
|
||||
).split(".")[0]
|
||||
logger.info(
|
||||
"Task - [{0:2}] Updates - [{1:6}] Training Loss - [{2:.5f}] Time Remaining - [{3}]".format(
|
||||
task_id, self.updates, self.train_loss.avg, time_left,
|
||||
)
|
||||
)
|
||||
if self.config.use_tensor_board:
|
||||
self.tensor_board.add_scalar(
|
||||
"train/loss", self.train_loss.avg, global_step=self.updates,
|
||||
)
|
||||
|
||||
if self.config.save_per_updates_on and (
|
||||
(self.local_updates)
|
||||
% (
|
||||
self.config.save_per_updates
|
||||
* self.config.grad_accumulation_step
|
||||
)
|
||||
== 0
|
||||
):
|
||||
model_file = os.path.join(
|
||||
self.output_dir, "model_{}_{}.pt".format(epoch, self.updates),
|
||||
)
|
||||
logger.info(f"Saving mt-dnn model to {model_file}")
|
||||
self.save(model_file)
|
||||
|
||||
# TODO: Alternatively, we need to refactor save function
|
||||
# and move into prediction
|
||||
# Saving each checkpoint after model training
|
||||
model_file = os.path.join(self.output_dir, "model_{}.pt".format(epoch))
|
||||
logger.info(f"Saving mt-dnn model to {model_file}")
|
||||
self.save(model_file)
|
||||
|
||||
def predict(self, trained_model_chckpt: str = None, saved_epoch_idx: int = 0):
|
||||
"""
|
||||
Inference of model on test datasets
|
||||
"""
|
||||
|
||||
# Load a trained checkpoint if a valid model checkpoint
|
||||
if trained_model_chckpt and os.path.exists(trained_model_chckpt):
|
||||
logger.info(f"Running predictions using: {trained_model_chckpt}")
|
||||
self.load(trained_model_chckpt)
|
||||
|
||||
# Create batches and train
|
||||
start = datetime.now()
|
||||
for idx, dataset in enumerate(self.test_datasets_list):
|
||||
prefix = dataset.split("_")[0]
|
||||
label_dict = self.task_defs.global_map.get(prefix, None)
|
||||
dev_data: DataLoader = self.dev_dataloaders_list[idx]
|
||||
if dev_data is not None:
|
||||
with torch.no_grad():
|
||||
(
|
||||
dev_metrics,
|
||||
dev_predictions,
|
||||
scores,
|
||||
golds,
|
||||
dev_ids,
|
||||
) = self.eval_mode(
|
||||
dev_data,
|
||||
metric_meta=self.task_defs.metric_meta_map[prefix],
|
||||
use_cuda=self.config.cuda,
|
||||
label_mapper=label_dict,
|
||||
task_type=self.task_defs.task_type_map[prefix],
|
||||
)
|
||||
for key, val in dev_metrics.items():
|
||||
if self.config.use_tensor_board:
|
||||
self.tensor_board.add_scalar(
|
||||
f"dev/{dataset}/{key}", val, global_step=saved_epoch_idx
|
||||
)
|
||||
if isinstance(val, str):
|
||||
logger.info(
|
||||
f"Task {dataset} -- epoch {saved_epoch_idx} -- Dev {key}:\n {val}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Task {dataset} -- epoch {saved_epoch_idx} -- Dev {key}: {val:.3f}"
|
||||
)
|
||||
score_file = os.path.join(
|
||||
self.output_dir, f"{dataset}_dev_scores_{saved_epoch_idx}.json"
|
||||
)
|
||||
results = {
|
||||
"metrics": dev_metrics,
|
||||
"predictions": dev_predictions,
|
||||
"uids": dev_ids,
|
||||
"scores": scores,
|
||||
}
|
||||
|
||||
# Save results to file
|
||||
MTDNNCommonUtils.dump(score_file, results)
|
||||
if self.config.use_glue_format:
|
||||
official_score_file = os.path.join(
|
||||
self.output_dir,
|
||||
"{}_dev_scores_{}.tsv".format(dataset, saved_epoch_idx),
|
||||
)
|
||||
submit(official_score_file, results, label_dict)
|
||||
|
||||
# test eval
|
||||
test_data: DataLoader = self.test_dataloaders_list[idx]
|
||||
if test_data is not None:
|
||||
with torch.no_grad():
|
||||
(
|
||||
test_metrics,
|
||||
test_predictions,
|
||||
scores,
|
||||
golds,
|
||||
test_ids,
|
||||
) = self.eval_mode(
|
||||
test_data,
|
||||
metric_meta=self.task_defs.metric_meta_map[prefix],
|
||||
use_cuda=self.config.cuda,
|
||||
with_label=False,
|
||||
label_mapper=label_dict,
|
||||
task_type=self.task_defs.task_type_map[prefix],
|
||||
)
|
||||
score_file = os.path.join(
|
||||
self.output_dir, f"{dataset}_test_scores_{saved_epoch_idx}.json"
|
||||
)
|
||||
results = {
|
||||
"metrics": test_metrics,
|
||||
"predictions": test_predictions,
|
||||
"uids": test_ids,
|
||||
"scores": scores,
|
||||
}
|
||||
MTDNNCommonUtils.dump(score_file, results)
|
||||
if self.config.use_glue_format:
|
||||
official_score_file = os.path.join(
|
||||
self.output_dir, f"{dataset}_test_scores_{saved_epoch_idx}.tsv"
|
||||
)
|
||||
submit(official_score_file, results, label_dict)
|
||||
logger.info("[new test scores saved.]")
|
||||
|
||||
# Close tensorboard connection if opened
|
||||
self.close_connections()
|
||||
|
||||
def close_connections(self):
|
||||
# Close tensor board connection
|
||||
if self.config.use_tensor_board:
|
||||
self.tensor_board.close()
|
||||
|
||||
def extract(self, batch_meta, batch_data):
|
||||
self.network.eval()
|
||||
# 'token_id': 0; 'segment_id': 1; 'mask': 2
|
||||
inputs = batch_data[:3]
|
||||
all_encoder_layers, pooled_output = self.mnetwork.bert(*inputs)
|
||||
return all_encoder_layers, pooled_output
|
||||
|
||||
def save(self, filename):
|
||||
network_state = dict(
|
||||
[(k, v.cpu()) for k, v in self.network.state_dict().items()]
|
||||
)
|
||||
params = {
|
||||
"state": network_state,
|
||||
"optimizer": self.optimizer.state_dict(),
|
||||
"config": self.config,
|
||||
}
|
||||
torch.save(params, filename)
|
||||
logger.info("model saved to {}".format(filename))
|
||||
|
||||
def load(self, checkpoint):
|
||||
model_state_dict = torch.load(checkpoint)
|
||||
self.network.load_state_dict(model_state_dict["state"], strict=False)
|
||||
self.optimizer.load_state_dict(model_state_dict["optimizer"])
|
||||
self.config = model_state_dict["config"]
|
||||
|
||||
def cuda(self):
|
||||
self.network.cuda(device=self.config.cuda_device)
|
||||
|
||||
def supported_init_checkpoints(self):
|
||||
"""List of allowed check points
|
||||
"""
|
||||
return [
|
||||
"bert-base-uncased",
|
||||
"bert-base-cased",
|
||||
"bert-large-uncased",
|
||||
"mtdnn-base-uncased",
|
||||
"mtdnn-large-uncased",
|
||||
"roberta.base",
|
||||
"roberta.large",
|
||||
]
|
||||
|
||||
def update_config_with_training_opts(
|
||||
self,
|
||||
decoder_opts,
|
||||
task_types,
|
||||
dropout_list,
|
||||
loss_types,
|
||||
kd_loss_types,
|
||||
tasks_nclass_list,
|
||||
):
|
||||
# Update configurations with options obtained from preprocessing training data
|
||||
setattr(self.config, "decoder_opts", decoder_opts)
|
||||
setattr(self.config, "task_types", task_types)
|
||||
setattr(self.config, "tasks_dropout_p", dropout_list)
|
||||
setattr(self.config, "loss_types", loss_types)
|
||||
setattr(self.config, "kd_loss_types", kd_loss_types)
|
||||
setattr(self.config, "tasks_nclass_list", tasks_nclass_list)
|
|
@ -0,0 +1,275 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
import torch
|
||||
from tensorboardX import SummaryWriter
|
||||
from torch.utils.data import BatchSampler, DataLoader, Dataset
|
||||
|
||||
from mtdnn.common.glue.glue_utils import submit
|
||||
from mtdnn.common.types import TaskType
|
||||
from mtdnn.common.utils import MTDNNCommonUtils
|
||||
from mtdnn.configuration_mtdnn import MTDNNConfig
|
||||
from mtdnn.dataset_mtdnn import (
|
||||
MTDNNCollater,
|
||||
MTDNNMultiTaskBatchSampler,
|
||||
MTDNNMultiTaskDataset,
|
||||
MTDNNSingleTaskDataset,
|
||||
)
|
||||
from mtdnn.modeling_mtdnn import MTDNNModel
|
||||
from mtdnn.tasks.config import MTDNNTaskDefs
|
||||
|
||||
logger = MTDNNCommonUtils.setup_logging(mode="w")
|
||||
|
||||
|
||||
class MTDNNDataProcess:
|
||||
def __init__(
|
||||
self,
|
||||
config: MTDNNConfig,
|
||||
task_defs: MTDNNTaskDefs,
|
||||
data_dir: str,
|
||||
train_datasets_list: list = ["mnli"],
|
||||
test_datasets_list: list = ["mnli_mismatched,mnli_matched"],
|
||||
glue_format: bool = False,
|
||||
data_sort: bool = False,
|
||||
):
|
||||
assert len(train_datasets_list) >= 1, "Train dataset list cannot be empty"
|
||||
assert len(test_datasets_list) >= 1, "Test dataset list cannot be empty"
|
||||
|
||||
# Initialize class members
|
||||
self.config = config
|
||||
self.task_defs = task_defs
|
||||
self.train_datasets = train_datasets_list
|
||||
self.test_datasets = test_datasets_list
|
||||
self.data_dir = data_dir
|
||||
self.glue_format = glue_format
|
||||
self.data_sort = data_sort
|
||||
self.tasks = {}
|
||||
self.tasks_class = {}
|
||||
self.nclass_list = []
|
||||
self.decoder_opts = []
|
||||
self.task_types = []
|
||||
self.dropout_list = []
|
||||
self.loss_types = []
|
||||
self.kd_loss_types = []
|
||||
self._multitask_train_dataloader = self._process_train_datasets()
|
||||
(
|
||||
self._dev_dataloaders_list,
|
||||
self._test_dataloaders_list,
|
||||
) = self._process_dev_test_datasets()
|
||||
self._num_all_batches = (
|
||||
self.config.epochs
|
||||
* len(self._multitask_train_dataloader)
|
||||
// self.config.grad_accumulation_step
|
||||
)
|
||||
|
||||
def _process_train_datasets(self):
|
||||
"""Preprocess the training sets and generate decoding and task specific training options needed to update config object
|
||||
|
||||
Returns:
|
||||
[DataLoader] -- Multiple tasks train data ready for training
|
||||
"""
|
||||
logger.info("Starting to process the training data sets")
|
||||
|
||||
train_datasets = []
|
||||
for dataset in self.train_datasets:
|
||||
prefix = dataset.split("_")[0]
|
||||
if prefix in self.tasks:
|
||||
continue
|
||||
assert (
|
||||
prefix in self.task_defs.n_class_map
|
||||
), f"[ERROR] - {prefix} does not exist in {self.task_defs.n_class_map}"
|
||||
assert (
|
||||
prefix in self.task_defs.data_type_map
|
||||
), f"[ERROR] - {prefix} does not exist in {self.task_defs.data_type_map}"
|
||||
data_type = self.task_defs.data_type_map[prefix]
|
||||
nclass = self.task_defs.n_class_map[prefix]
|
||||
task_id = len(self.tasks)
|
||||
if self.config.mtl_opt > 0:
|
||||
task_id = (
|
||||
self.tasks_class[nclass]
|
||||
if nclass in self.tasks_class
|
||||
else len(self.tasks_class)
|
||||
)
|
||||
|
||||
task_type = self.task_defs.task_type_map[prefix]
|
||||
|
||||
dopt = self.generate_decoder_opt(
|
||||
self.task_defs.enable_san_map[prefix], self.config.answer_opt
|
||||
)
|
||||
if task_id < len(self.decoder_opts):
|
||||
self.decoder_opts[task_id] = min(self.decoder_opts[task_id], dopt)
|
||||
else:
|
||||
self.decoder_opts.append(dopt)
|
||||
self.task_types.append(task_type)
|
||||
self.loss_types.append(self.task_defs.loss_map[prefix])
|
||||
self.kd_loss_types.append(self.task_defs.kd_loss_map[prefix])
|
||||
|
||||
if prefix not in self.tasks:
|
||||
self.tasks[prefix] = len(self.tasks)
|
||||
if self.config.mtl_opt < 1:
|
||||
self.nclass_list.append(nclass)
|
||||
|
||||
if nclass not in self.tasks_class:
|
||||
self.tasks_class[nclass] = len(self.tasks_class)
|
||||
if self.config.mtl_opt > 0:
|
||||
self.nclass_list.append(nclass)
|
||||
|
||||
dropout_p = self.task_defs.dropout_p_map.get(prefix, self.config.dropout_p)
|
||||
self.dropout_list.append(dropout_p)
|
||||
|
||||
train_path = os.path.join(self.data_dir, f"{dataset}_train.json")
|
||||
assert os.path.exists(
|
||||
train_path
|
||||
), f"[ERROR] - Training dataset does not exist"
|
||||
logger.info(f"Loading {train_path} as task {task_id}")
|
||||
train_data_set = MTDNNSingleTaskDataset(
|
||||
train_path,
|
||||
True,
|
||||
maxlen=self.config.max_seq_len,
|
||||
task_id=task_id,
|
||||
task_type=task_type,
|
||||
data_type=data_type,
|
||||
)
|
||||
train_datasets.append(train_data_set)
|
||||
train_collater = MTDNNCollater(
|
||||
dropout_w=self.config.dropout_w, encoder_type=self.config.encoder_type
|
||||
)
|
||||
multitask_train_dataset = MTDNNMultiTaskDataset(train_datasets)
|
||||
multitask_batch_sampler = MTDNNMultiTaskBatchSampler(
|
||||
train_datasets,
|
||||
self.config.batch_size,
|
||||
self.config.mix_opt,
|
||||
self.config.ratio,
|
||||
)
|
||||
multitask_train_data = DataLoader(
|
||||
multitask_train_dataset,
|
||||
batch_sampler=multitask_batch_sampler,
|
||||
collate_fn=train_collater.collate_fn,
|
||||
pin_memory=self.config.cuda,
|
||||
)
|
||||
return multitask_train_data
|
||||
|
||||
def _process_dev_test_datasets(self):
|
||||
"""Preprocess the test sets
|
||||
|
||||
Returns:
|
||||
[List] -- Multiple tasks test data ready for inference
|
||||
"""
|
||||
logger.info("Starting to process the testing data sets")
|
||||
dev_dataloaders_list = []
|
||||
test_dataloaders_list = []
|
||||
test_collater = MTDNNCollater(
|
||||
is_train=False, encoder_type=self.config.encoder_type
|
||||
)
|
||||
for dataset in self.test_datasets:
|
||||
prefix = dataset.split("_")[0]
|
||||
task_id = (
|
||||
self.tasks_class[self.task_defs.n_class_map[prefix]]
|
||||
if self.config.mtl_opt > 0
|
||||
else self.tasks[prefix]
|
||||
)
|
||||
task_type = self.task_defs.task_type_map[prefix]
|
||||
|
||||
pw_task = False
|
||||
if task_type == TaskType.Ranking:
|
||||
pw_task = True
|
||||
|
||||
assert prefix in self.task_defs.data_type_map
|
||||
data_type = self.task_defs.data_type_map[prefix]
|
||||
|
||||
dev_path = os.path.join(self.data_dir, f"{dataset}_dev.json")
|
||||
assert os.path.exists(
|
||||
dev_path
|
||||
), f"[ERROR] - Dev dataset does not exist: {dev_path}"
|
||||
dev_data = None
|
||||
if os.path.exists(dev_path):
|
||||
dev_data_set = MTDNNSingleTaskDataset(
|
||||
dev_path,
|
||||
False,
|
||||
maxlen=self.config.max_seq_len,
|
||||
task_id=task_id,
|
||||
task_type=task_type,
|
||||
data_type=data_type,
|
||||
)
|
||||
dev_data = DataLoader(
|
||||
dev_data_set,
|
||||
batch_size=self.config.batch_size_eval,
|
||||
collate_fn=test_collater.collate_fn,
|
||||
pin_memory=self.config.cuda,
|
||||
)
|
||||
dev_dataloaders_list.append(dev_data)
|
||||
|
||||
test_path = os.path.join(self.data_dir, f"{dataset}_test.json")
|
||||
test_data = None
|
||||
if os.path.exists(test_path):
|
||||
test_data_set = MTDNNSingleTaskDataset(
|
||||
test_path,
|
||||
False,
|
||||
maxlen=self.config.max_seq_len,
|
||||
task_id=task_id,
|
||||
task_type=task_type,
|
||||
data_type=data_type,
|
||||
)
|
||||
test_data = DataLoader(
|
||||
test_data_set,
|
||||
batch_size=self.config.batch_size_eval,
|
||||
collate_fn=test_collater.collate_fn,
|
||||
pin_memory=self.config.cuda,
|
||||
)
|
||||
test_dataloaders_list.append(test_data)
|
||||
|
||||
# Return tuple of dev and test dataloaders
|
||||
return dev_dataloaders_list, test_dataloaders_list
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
"""Returns a dataloader for mutliple tasks
|
||||
|
||||
Returns:
|
||||
DataLoader -- Multiple tasks batch dataloader
|
||||
"""
|
||||
return self._multitask_train_dataloader
|
||||
|
||||
def get_dev_dataloaders(self) -> list:
|
||||
"""Returns a list of dev dataloaders for multiple tasks
|
||||
|
||||
Returns:
|
||||
list -- List of dev dataloaders
|
||||
"""
|
||||
return self._dev_dataloaders_list
|
||||
|
||||
def get_test_dataloaders(self) -> list:
|
||||
"""Returns a list of test dataloaders for multiple tasks
|
||||
|
||||
Returns:
|
||||
list -- List of test dataloaders
|
||||
"""
|
||||
return self._test_dataloaders_list
|
||||
|
||||
def generate_decoder_opt(self, enable_san, max_opt):
|
||||
return max_opt if enable_san and max_opt < 3 else 0
|
||||
|
||||
# Getters for Model training configuration
|
||||
def get_decoder_options_list(self) -> list:
|
||||
return self.decoder_opts
|
||||
|
||||
def get_task_types_list(self) -> list:
|
||||
return self.task_types
|
||||
|
||||
def get_tasks_dropout_prob_list(self) -> list:
|
||||
return self.dropout_list
|
||||
|
||||
def get_loss_types_list(self) -> list:
|
||||
return self.loss_types
|
||||
|
||||
def get_kd_loss_types_list(self) -> list:
|
||||
return self.kd_loss_types
|
||||
|
||||
def get_task_nclass_list(self) -> list:
|
||||
return self.nclass_list
|
||||
|
||||
def get_num_all_batches(self) -> int:
|
||||
return self._num_all_batches
|
|
@ -0,0 +1,740 @@
|
|||
# coding=utf-8
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
# This script reuses some code from https://github.com/huggingface/transformers
|
||||
|
||||
|
||||
""" Model configuration """
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
import yaml
|
||||
|
||||
from mtdnn.common.loss import LossCriterion
|
||||
from mtdnn.common.metrics import Metric
|
||||
from mtdnn.common.types import DataFormat, EncoderModelType, TaskType
|
||||
from mtdnn.common.vocab import Vocabulary
|
||||
from mtdnn.common.utils import MTDNNCommonUtils
|
||||
|
||||
|
||||
logger = MTDNNCommonUtils.setup_logging()
|
||||
|
||||
|
||||
class TaskConfig(object):
|
||||
"""Base Class for Task Configurations
|
||||
|
||||
Handles parameters that are common to all task configurations
|
||||
|
||||
Arguments:
|
||||
object {[type]} -- [description]
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: dict):
|
||||
""" Define a generic task configuration """
|
||||
logger.info("Mapping Task attributes")
|
||||
|
||||
# Mapping attributes
|
||||
for key, value in kwargs.items():
|
||||
try:
|
||||
setattr(self, key, value)
|
||||
except AttributeError as err:
|
||||
logger.error(
|
||||
f"[ERROR] - Unable to set {key} with value {value} for {self}"
|
||||
)
|
||||
raise err
|
||||
|
||||
def to_dict(self):
|
||||
"""Serializes this instance to a Python dictionary."""
|
||||
return copy.deepcopy(self.__dict__)
|
||||
|
||||
|
||||
class COLATaskConfig(TaskConfig):
|
||||
def __init__(self, kwargs: dict = {}):
|
||||
if not kwargs:
|
||||
kwargs = {
|
||||
"task_name": "cola",
|
||||
"data_format": "PremiseOnly",
|
||||
"encoder_type": "BERT",
|
||||
"dropout_p": 0.05,
|
||||
"enable_san": False,
|
||||
"metric_meta": ["ACC", "MCC"],
|
||||
"loss": "CeCriterion",
|
||||
"kd_loss": "MseCriterion",
|
||||
"n_class": 2,
|
||||
"task_type": "Classification",
|
||||
}
|
||||
super(COLATaskConfig, self).__init__(**kwargs)
|
||||
self.dropout_p = kwargs.pop("dropout_p", 0.05)
|
||||
|
||||
|
||||
class MNLITaskConfig(TaskConfig):
|
||||
def __init__(self, kwargs: dict = {}):
|
||||
if not kwargs:
|
||||
kwargs = {
|
||||
"task_name": "mnli",
|
||||
"data_format": "PremiseAndOneHypothesis",
|
||||
"encoder_type": "BERT",
|
||||
"dropout_p": 0.3,
|
||||
"enable_san": True,
|
||||
"labels": ["contradiction", "neutral", "entailment"],
|
||||
"metric_meta": ["ACC"],
|
||||
"loss": "CeCriterion",
|
||||
"kd_loss": "MseCriterion",
|
||||
"n_class": 3,
|
||||
"split_names": [
|
||||
"train",
|
||||
"matched_dev",
|
||||
"mismatched_dev",
|
||||
"matched_test",
|
||||
"mismatched_test",
|
||||
],
|
||||
"task_type": "Classification",
|
||||
}
|
||||
super(MNLITaskConfig, self).__init__(**kwargs)
|
||||
self.dropout_p = kwargs.pop("dropout_p", 0.3)
|
||||
self.split_names = kwargs.pop(
|
||||
"split_names",
|
||||
[
|
||||
"train",
|
||||
"matched_dev",
|
||||
"mismatched_dev",
|
||||
"matched_test",
|
||||
"mismatched_test",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class MRPCTaskConfig(TaskConfig):
|
||||
def __init__(self, kwargs: dict = {}):
|
||||
if not kwargs:
|
||||
kwargs = {
|
||||
"task_name": "mrpc",
|
||||
"data_format": "PremiseAndOneHypothesis",
|
||||
"encoder_type": "BERT",
|
||||
"enable_san": True,
|
||||
"metric_meta": ["ACC", "F1"],
|
||||
"loss": "CeCriterion",
|
||||
"kd_loss": "MseCriterion",
|
||||
"n_class": 2,
|
||||
"task_type": "Classification",
|
||||
}
|
||||
super(MRPCTaskConfig, self).__init__(**kwargs)
|
||||
|
||||
|
||||
class QNLITaskConfig(TaskConfig):
|
||||
def __init__(self, kwargs: dict = {}):
|
||||
if not kwargs:
|
||||
kwargs = {
|
||||
"task_name": "qnli",
|
||||
"data_format": "PremiseAndOneHypothesis",
|
||||
"encoder_type": "BERT",
|
||||
"enable_san": True,
|
||||
"labels": ["not_entailment", "entailment"],
|
||||
"metric_meta": ["ACC"],
|
||||
"loss": "CeCriterion",
|
||||
"kd_loss": "MseCriterion",
|
||||
"n_class": 2,
|
||||
"task_type": "Classification",
|
||||
}
|
||||
super(QNLITaskConfig, self).__init__(**kwargs)
|
||||
self.labels = kwargs.pop("labels", ["not_entailment", "entailment"])
|
||||
|
||||
|
||||
class QQPTaskConfig(TaskConfig):
|
||||
def __init__(self, kwargs: dict = {}):
|
||||
if not kwargs:
|
||||
kwargs = {
|
||||
"task_name": "qqp",
|
||||
"data_format": "PremiseAndOneHypothesis",
|
||||
"encoder_type": "BERT",
|
||||
"enable_san": True,
|
||||
"metric_meta": ["ACC", "F1"],
|
||||
"loss": "CeCriterion",
|
||||
"kd_loss": "MseCriterion",
|
||||
"n_class": 2,
|
||||
"task_type": "Classification",
|
||||
}
|
||||
super(QQPTaskConfig, self).__init__(**kwargs)
|
||||
|
||||
|
||||
class RTETaskConfig(TaskConfig):
|
||||
def __init__(self, kwargs: dict = {}):
|
||||
if not kwargs:
|
||||
kwargs = {
|
||||
"task_name": "rte",
|
||||
"data_format": "PremiseAndOneHypothesis",
|
||||
"encoder_type": "BERT",
|
||||
"enable_san": True,
|
||||
"labels": ["not_entailment", "entailment"],
|
||||
"metric_meta": ["ACC"],
|
||||
"loss": "CeCriterion",
|
||||
"kd_loss": "MseCriterion",
|
||||
"n_class": 2,
|
||||
"task_type": "Classification",
|
||||
}
|
||||
super(RTETaskConfig, self).__init__(**kwargs)
|
||||
self.labels = kwargs.pop("labels", ["not_entailment", "entailment"])
|
||||
|
||||
|
||||
class SCITAILTaskConfig(TaskConfig):
|
||||
def __init__(self, kwargs: dict = {}):
|
||||
if not kwargs:
|
||||
kwargs = {
|
||||
"task_name": "scitail",
|
||||
"encoder_type": "BERT",
|
||||
"data_format": "PremiseAndOneHypothesis",
|
||||
"enable_san": True,
|
||||
"labels": ["neutral", "entails"],
|
||||
"metric_meta": ["ACC"],
|
||||
"loss": "CeCriterion",
|
||||
"kd_loss": "MseCriterion",
|
||||
"n_class": 2,
|
||||
"task_type": "Classification",
|
||||
}
|
||||
super(SCITAILTaskConfig, self).__init__(**kwargs)
|
||||
self.labels = kwargs.pop("labels", ["neutral", "entails"])
|
||||
|
||||
|
||||
class SNLITaskConfig(TaskConfig):
|
||||
def __init__(self, kwargs: dict = {}):
|
||||
if not kwargs:
|
||||
kwargs = {
|
||||
"task_name": "snli",
|
||||
"data_format": "PremiseAndOneHypothesis",
|
||||
"encoder_type": "BERT",
|
||||
"enable_san": True,
|
||||
"labels": ["contradiction", "neutral", "entailment"],
|
||||
"metric_meta": ["ACC"],
|
||||
"loss": "CeCriterion",
|
||||
"kd_loss": "MseCriterion",
|
||||
"n_class": 3,
|
||||
"task_type": "Classification",
|
||||
}
|
||||
super(SNLITaskConfig, self).__init__(**kwargs)
|
||||
self.labels = kwargs.pop("labels", ["contradiction", "neutral", "entailment"])
|
||||
|
||||
|
||||
class SSTTaskConfig(TaskConfig):
|
||||
def __init__(self, kwargs: dict = {}):
|
||||
if not kwargs:
|
||||
kwargs = {
|
||||
"task_name": "sst",
|
||||
"data_format": "PremiseOnly",
|
||||
"encoder_type": "BERT",
|
||||
"enable_san": False,
|
||||
"metric_meta": ["ACC"],
|
||||
"loss": "CeCriterion",
|
||||
"kd_loss": "MseCriterion",
|
||||
"n_class": 2,
|
||||
"task_type": "Classification",
|
||||
}
|
||||
super(SSTTaskConfig, self).__init__(**kwargs)
|
||||
|
||||
|
||||
class STSBTaskConfig(TaskConfig):
|
||||
def __init__(self, kwargs: dict = {}):
|
||||
if not kwargs:
|
||||
kwargs = {
|
||||
"task_name": "stsb",
|
||||
"data_format": "PremiseAndOneHypothesis",
|
||||
"encoder_type": "BERT",
|
||||
"enable_san": false,
|
||||
"metric_meta": ["Pearson", "Spearman"],
|
||||
"n_class": 1,
|
||||
"loss": "MseCriterion",
|
||||
"kd_loss": "MseCriterion",
|
||||
"task_type": "Regression",
|
||||
}
|
||||
super(STSBTaskConfig, self).__init__(**kwargs)
|
||||
|
||||
|
||||
class WNLITaskConfig(TaskConfig):
|
||||
def __init__(self, kwargs: dict = {}):
|
||||
if not kwargs:
|
||||
kwargs = {
|
||||
"task_name": "wnli",
|
||||
"data_format": "PremiseAndOneHypothesis",
|
||||
"encoder_type": "BERT",
|
||||
"enable_san": True,
|
||||
"metric_meta": ["ACC"],
|
||||
"loss": "CeCriterion",
|
||||
"kd_loss": "MseCriterion",
|
||||
"n_class": 2,
|
||||
"task_type": "Classification",
|
||||
}
|
||||
super(WNLITaskConfig, self).__init__(**kwargs)
|
||||
|
||||
|
||||
class NERTaskConfig(TaskConfig):
|
||||
def __init__(self, kwargs: dict = {}):
|
||||
if not kwargs:
|
||||
kwargs = {
|
||||
"task_name": "ner",
|
||||
"data_format": "Seqence",
|
||||
"encoder_type": "BERT",
|
||||
"dropout_p": 0.3,
|
||||
"enable_san": False,
|
||||
"labels": [
|
||||
"O",
|
||||
"B-MISC",
|
||||
"I-MISC",
|
||||
"B-PER",
|
||||
"I-PER",
|
||||
"B-ORG",
|
||||
"I-ORG",
|
||||
"B-LOC",
|
||||
"I-LOC",
|
||||
"X",
|
||||
"CLS",
|
||||
"SEP",
|
||||
],
|
||||
"metric_meta": ["SeqEval"],
|
||||
"n_class": 12,
|
||||
"loss": "SeqCeCriterion",
|
||||
"split_names": ["train", "dev", "test"],
|
||||
"task_type": "SequenceLabeling",
|
||||
}
|
||||
super(NERTaskConfig, self).__init__(**kwargs)
|
||||
self.labels = kwargs.pop(
|
||||
"labels",
|
||||
[
|
||||
"O",
|
||||
"B-MISC",
|
||||
"I-MISC",
|
||||
"B-PER",
|
||||
"I-PER",
|
||||
"B-ORG",
|
||||
"I-ORG",
|
||||
"B-LOC",
|
||||
"I-LOC",
|
||||
"X",
|
||||
"CLS",
|
||||
"SEP",
|
||||
],
|
||||
)
|
||||
self.split_names = kwargs.pop("split_names", ["train", "dev", "test"])
|
||||
|
||||
|
||||
class POSTaskConfig(TaskConfig):
|
||||
def __init__(self, kwargs: dict = {}):
|
||||
if not kwargs:
|
||||
kwargs = {
|
||||
"task_name": "pos",
|
||||
"data_format": "Seqence",
|
||||
"encoder_type": "BERT",
|
||||
"dropout_p": 0.1,
|
||||
"enable_san": False,
|
||||
"labels": [
|
||||
",",
|
||||
"\\",
|
||||
":",
|
||||
".",
|
||||
"''",
|
||||
'"',
|
||||
"(",
|
||||
")",
|
||||
"$",
|
||||
"CC",
|
||||
"CD",
|
||||
"DT",
|
||||
"EX",
|
||||
"FW",
|
||||
"IN",
|
||||
"JJ",
|
||||
"JJR",
|
||||
"JJS",
|
||||
"LS",
|
||||
"MD",
|
||||
"NN",
|
||||
"NNP",
|
||||
"NNPS",
|
||||
"NNS",
|
||||
"NN|SYM",
|
||||
"PDT",
|
||||
"POS",
|
||||
"PRP",
|
||||
"PRP$",
|
||||
"RB",
|
||||
"RBR",
|
||||
"RBS",
|
||||
"RP",
|
||||
"SYM",
|
||||
"TO",
|
||||
"UH",
|
||||
"VB",
|
||||
"VBD",
|
||||
"VBG",
|
||||
"VBN",
|
||||
"VBP",
|
||||
"VBZ",
|
||||
"WDT",
|
||||
"WP",
|
||||
"WP$",
|
||||
"WRB",
|
||||
"X",
|
||||
"CLS",
|
||||
"SEP",
|
||||
],
|
||||
"metric_meta": ["SeqEval"],
|
||||
"n_class": 49,
|
||||
"loss": "SeqCeCriterion",
|
||||
"split_names": ["train", "dev", "test"],
|
||||
"task_type": "SequenceLabeling",
|
||||
}
|
||||
super(POSTaskConfig, self).__init__(**kwargs)
|
||||
self.labels = kwargs.pop(
|
||||
"labels",
|
||||
[
|
||||
",",
|
||||
"\\",
|
||||
":",
|
||||
".",
|
||||
"''",
|
||||
'"',
|
||||
"(",
|
||||
")",
|
||||
"$",
|
||||
"CC",
|
||||
"CD",
|
||||
"DT",
|
||||
"EX",
|
||||
"FW",
|
||||
"IN",
|
||||
"JJ",
|
||||
"JJR",
|
||||
"JJS",
|
||||
"LS",
|
||||
"MD",
|
||||
"NN",
|
||||
"NNP",
|
||||
"NNPS",
|
||||
"NNS",
|
||||
"NN|SYM",
|
||||
"PDT",
|
||||
"POS",
|
||||
"PRP",
|
||||
"PRP$",
|
||||
"RB",
|
||||
"RBR",
|
||||
"RBS",
|
||||
"RP",
|
||||
"SYM",
|
||||
"TO",
|
||||
"UH",
|
||||
"VB",
|
||||
"VBD",
|
||||
"VBG",
|
||||
"VBN",
|
||||
"VBP",
|
||||
"VBZ",
|
||||
"WDT",
|
||||
"WP",
|
||||
"WP$",
|
||||
"WRB",
|
||||
"X",
|
||||
"CLS",
|
||||
"SEP",
|
||||
],
|
||||
)
|
||||
self.split_names = kwargs.pop("split_names", ["train", "dev", "test"])
|
||||
|
||||
|
||||
class CHUNKTaskConfig(TaskConfig):
|
||||
def __init__(self, kwargs: dict = {}):
|
||||
if not kwargs:
|
||||
kwargs = {
|
||||
"task_name": "chunk",
|
||||
"data_format": "Seqence",
|
||||
"encoder_type": "BERT",
|
||||
"dropout_p": 0.1,
|
||||
"enable_san": False,
|
||||
"labels": [
|
||||
"B-ADJP",
|
||||
"B-ADVP",
|
||||
"B-CONJP",
|
||||
"B-INTJ",
|
||||
"B-LST",
|
||||
"B-NP",
|
||||
"B-PP",
|
||||
"B-PRT",
|
||||
"B-SBAR",
|
||||
"B-VP",
|
||||
"I-ADJP",
|
||||
"I-ADVP",
|
||||
"I-CONJP",
|
||||
"I-INTJ",
|
||||
"I-LST",
|
||||
"I-NP",
|
||||
"I-PP",
|
||||
"I-SBAR",
|
||||
"I-VP",
|
||||
"O",
|
||||
"X",
|
||||
"CLS",
|
||||
"SEP",
|
||||
],
|
||||
"metric_meta": ["SeqEval"],
|
||||
"n_class": 23,
|
||||
"loss": "SeqCeCriterion",
|
||||
"split_names": ["train", "dev", "test"],
|
||||
"task_type": "SequenceLabeling",
|
||||
}
|
||||
super(CHUNKTaskConfig, self).__init__(**kwargs)
|
||||
self.labels = kwargs.pop(
|
||||
"labels",
|
||||
[
|
||||
"B-ADJP",
|
||||
"B-ADVP",
|
||||
"B-CONJP",
|
||||
"B-INTJ",
|
||||
"B-LST",
|
||||
"B-NP",
|
||||
"B-PP",
|
||||
"B-PRT",
|
||||
"B-SBAR",
|
||||
"B-VP",
|
||||
"I-ADJP",
|
||||
"I-ADVP",
|
||||
"I-CONJP",
|
||||
"I-INTJ",
|
||||
"I-LST",
|
||||
"I-NP",
|
||||
"I-PP",
|
||||
"I-SBAR",
|
||||
"I-VP",
|
||||
"O",
|
||||
"X",
|
||||
"CLS",
|
||||
"SEP",
|
||||
],
|
||||
)
|
||||
self.split_names = kwargs.pop("split_names", ["train", "dev", "test"])
|
||||
|
||||
|
||||
class SQUADTaskConfig(TaskConfig):
|
||||
def __init__(self, kwargs: dict = {}):
|
||||
if not kwargs:
|
||||
kwargs = {
|
||||
"task_name": "squad",
|
||||
"data_format": "MRC",
|
||||
"encoder_type": "BERT",
|
||||
"dropout_p": 0.1,
|
||||
"enable_san": False,
|
||||
"metric_meta": ["EmF1"],
|
||||
"n_class": 2,
|
||||
"task_type": "Span",
|
||||
"loss": "SpanCeCriterion",
|
||||
"split_names": ["train", "dev"],
|
||||
}
|
||||
super(SQUADTaskConfig, self).__init__(**kwargs)
|
||||
self.split_names = kwargs.pop("split_names", ["train", "dev"])
|
||||
self.dropout_p = kwargs.pop("dropout_p", 0.1)
|
||||
|
||||
|
||||
# Map of supported tasks
|
||||
SUPPORTED_TASKS_MAP = {
|
||||
"cola": COLATaskConfig,
|
||||
"mnli": MNLITaskConfig,
|
||||
"mrpc": MRPCTaskConfig,
|
||||
"qnli": QNLITaskConfig,
|
||||
"qqp": QQPTaskConfig,
|
||||
"rte": RTETaskConfig,
|
||||
"scitail": SCITAILTaskConfig,
|
||||
"snli": SNLITaskConfig,
|
||||
"sst": SSTTaskConfig,
|
||||
"stsb": STSBTaskConfig,
|
||||
"wnli": WNLITaskConfig,
|
||||
"ner": NERTaskConfig,
|
||||
"pos": POSTaskConfig,
|
||||
"chunk": CHUNKTaskConfig,
|
||||
"squad": SQUADTaskConfig,
|
||||
"squad-v2": SQUADTaskConfig,
|
||||
}
|
||||
|
||||
|
||||
class MTDNNTaskConfig:
|
||||
supported_tasks_map = SUPPORTED_TASKS_MAP
|
||||
|
||||
def from_dict(self, task_name: str, opts: dict = {}):
|
||||
""" Create Task configuration from dictionary of configuration """
|
||||
assert opts, "Configuration dictionary cannot be empty"
|
||||
task = self.supported_tasks_map[task_name]
|
||||
opts.update({"task_name": f"{task_name}"})
|
||||
return task(kwargs=opts)
|
||||
|
||||
def get_supported_tasks(self) -> list:
|
||||
"""Return list of supported tasks
|
||||
|
||||
Returns:
|
||||
list -- Supported list of tasks
|
||||
"""
|
||||
return self.supported_tasks_map.keys()
|
||||
|
||||
|
||||
class MTDNNTaskDefs:
|
||||
"""Definition of single or multiple tasks to train. Can take a single task name or a definition yaml or json file
|
||||
|
||||
Arguments:
|
||||
task_dict_or_def_file {str or dict} -- Task dictionary or definition file (yaml or json)
|
||||
Example:
|
||||
|
||||
JSON:
|
||||
{
|
||||
"cola": {
|
||||
"data_format": "PremiseOnly",
|
||||
"encoder_type": "BERT",
|
||||
"dropout_p": 0.05,
|
||||
"enable_san": false,
|
||||
"metric_meta": [
|
||||
"ACC",
|
||||
"MCC"
|
||||
],
|
||||
"loss": "CeCriterion",
|
||||
"kd_loss": "MseCriterion",
|
||||
"n_class": 2,
|
||||
"task_type": "Classification"
|
||||
}
|
||||
...
|
||||
}
|
||||
or
|
||||
|
||||
Python dict:
|
||||
{
|
||||
"cola": {
|
||||
"data_format": "PremiseOnly",
|
||||
"encoder_type": "BERT",
|
||||
"dropout_p": 0.05,
|
||||
"enable_san": False,
|
||||
"metric_meta": [
|
||||
"ACC",
|
||||
"MCC"
|
||||
],
|
||||
"loss": "CeCriterion",
|
||||
"kd_loss": "MseCriterion",
|
||||
"n_class": 2,
|
||||
"task_type": "Classification"
|
||||
}
|
||||
...
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, task_dict_or_file: Union[str, dict]):
|
||||
|
||||
assert (
|
||||
task_dict_or_file
|
||||
), "Please pass in a task dict or definition file in yaml or json"
|
||||
self._task_def_dic = {}
|
||||
self._configured_tasks = [] # list of configured tasks
|
||||
if isinstance(task_dict_or_file, dict):
|
||||
self._task_def_dic = task_dict_or_file
|
||||
elif isinstance(task_dict_or_file, str):
|
||||
assert os.path.exists(
|
||||
task_dict_or_file
|
||||
), "Task definition file does not exist"
|
||||
assert os.path.isfile(task_dict_or_file), "Task definition must be a file"
|
||||
|
||||
task_def_filepath, ext = os.path.splitext(task_dict_or_file)
|
||||
ext = ext[1:].lower()
|
||||
assert ext in [
|
||||
"json",
|
||||
"yml",
|
||||
"yaml",
|
||||
], "Definition file must be in JSON or YAML format"
|
||||
|
||||
self._task_def_dic = (
|
||||
yaml.safe_load(open(task_dict_or_file))
|
||||
if ext in ["yaml", "yml"]
|
||||
else json.load(open(task_dict_or_file))
|
||||
)
|
||||
|
||||
global_map = {}
|
||||
n_class_map = {}
|
||||
data_type_map = {}
|
||||
task_type_map = {}
|
||||
metric_meta_map = {}
|
||||
enable_san_map = {}
|
||||
dropout_p_map = {}
|
||||
encoderType_map = {}
|
||||
loss_map = {}
|
||||
kd_loss_map = {}
|
||||
|
||||
# Create an instance of task creator singleton
|
||||
task_creator = MTDNNTaskConfig()
|
||||
|
||||
uniq_encoderType = set()
|
||||
for name, params in self._task_def_dic.items():
|
||||
assert (
|
||||
"_" not in name
|
||||
), f"task name should not contain '_', current task name: {name}"
|
||||
|
||||
# Create a singleton to create tasks
|
||||
task = task_creator.from_dict(task_name=name, opts=params)
|
||||
|
||||
n_class_map[name] = task.n_class
|
||||
data_type_map[name] = DataFormat[task.data_format]
|
||||
task_type_map[name] = TaskType[task.task_type]
|
||||
metric_meta_map[name] = tuple(
|
||||
Metric[metric_name] for metric_name in task.metric_meta
|
||||
)
|
||||
enable_san_map[name] = task.enable_san
|
||||
uniq_encoderType.add(EncoderModelType[task.encoder_type])
|
||||
|
||||
if hasattr(task, "labels"):
|
||||
labels = task.labels
|
||||
label_mapper = Vocabulary(True)
|
||||
for label in labels:
|
||||
label_mapper.add(label)
|
||||
global_map[name] = label_mapper
|
||||
|
||||
# dropout
|
||||
if hasattr(task, "dropout_p"):
|
||||
dropout_p_map[name] = task.dropout_p
|
||||
|
||||
# loss map
|
||||
if hasattr(task, "loss"):
|
||||
t_loss = task.loss
|
||||
loss_crt = LossCriterion[t_loss]
|
||||
loss_map[name] = loss_crt
|
||||
else:
|
||||
loss_map[name] = None
|
||||
|
||||
if hasattr(task, "kd_loss"):
|
||||
t_loss = task.kd_loss
|
||||
loss_crt = LossCriterion[t_loss]
|
||||
kd_loss_map[name] = loss_crt
|
||||
else:
|
||||
kd_loss_map[name] = None
|
||||
|
||||
# Track configured tasks for downstream
|
||||
self._configured_tasks.append(task.to_dict())
|
||||
|
||||
logger.info(
|
||||
f"Configured task definitions - {[obj['task_name'] for obj in self.get_configured_tasks()]}"
|
||||
)
|
||||
|
||||
assert len(uniq_encoderType) == 1, "The shared encoder has to be the same."
|
||||
self.global_map = global_map
|
||||
self.n_class_map = n_class_map
|
||||
self.data_type_map = data_type_map
|
||||
self.task_type_map = task_type_map
|
||||
self.metric_meta_map = metric_meta_map
|
||||
self.enable_san_map = enable_san_map
|
||||
self.dropout_p_map = dropout_p_map
|
||||
self.encoderType = uniq_encoderType.pop()
|
||||
self.loss_map = loss_map
|
||||
self.kd_loss_map = kd_loss_map
|
||||
|
||||
def get_configured_tasks(self) -> list:
|
||||
"""Returns a list of configured tasks by TaskDefs class from the input configuration file
|
||||
|
||||
Returns:
|
||||
list -- List of configured task classes
|
||||
"""
|
||||
return self._configured_tasks
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:42ae01c7f32f50ddcef3d1c09ee589a384596ef9634e2c3778cb59204bd84fec
|
||||
size 4002207
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4361fb059fbcdb3727398ba99754f27e6f445f27a02344c86456c23a3b55de61
|
||||
size 4021462
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a1430dfbfb26254a5bc35fadf15839b48a099a827b0e2fbba0b0bcf2d92dcb0a
|
||||
size 4156909
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8fcaa5a7ad081253619ae9f628c3e9e773d5bfba6178017c04092d015737a963
|
||||
size 4144939
|
|
@ -0,0 +1,3 @@
|
|||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:94ae34f18390c9d0f0a92e00d0312d09015c39ba4bfd6ff1c6ff6ec51b7728c6
|
||||
size 163466473
|
|
@ -0,0 +1,103 @@
|
|||
#!/usr/bin/env bash
|
||||
##############################################################
|
||||
# This script is used to download resources for MT-DNN experiments
|
||||
##############################################################
|
||||
|
||||
BERT_DIR=$(pwd)/../mt_dnn_models
|
||||
if [ ! -d ${BERT_DIR} ]; then
|
||||
echo "Create a folder BERT_DIR"
|
||||
mkdir ${BERT_DIR}
|
||||
fi
|
||||
|
||||
## Download bert models
|
||||
wget https://mrc.blob.core.windows.net/mt-dnn-model/bert_model_base_v2.pt -O "${BERT_DIR}/bert_model_base_uncased.pt"
|
||||
wget https://mrc.blob.core.windows.net/mt-dnn-model/bert_model_large_v2.pt -O "${BERT_DIR}/bert_model_large_uncased.pt"
|
||||
wget https://mrc.blob.core.windows.net/mt-dnn-model/bert_base_chinese.pt -O "${BERT_DIR}/bert_model_base_chinese.pt"
|
||||
|
||||
## Download MT-DNN models
|
||||
wget https://mrc.blob.core.windows.net/mt-dnn-model/mt_dnn_base.pt -O "${BERT_DIR}/mt_dnn_base_uncased.pt"
|
||||
wget https://mrc.blob.core.windows.net/mt-dnn-model/mt_dnn_large.pt -O "${BERT_DIR}/mt_dnn_large_uncased.pt"
|
||||
|
||||
## MT-DNN-KD
|
||||
wget https://mrc.blob.core.windows.net/mt-dnn-model/mt_dnn_kd_large_cased.pt -O "${BERT_DIR}/mt_dnn_kd_large_cased.pt"
|
||||
|
||||
## Download XLNet model
|
||||
wget https://storage.googleapis.com/xlnet/released_models/cased_L-24_H-1024_A-16.zip -O "xlnet_cased_large.zip"
|
||||
unzip xlnet_cased_large.zip
|
||||
mv xlnet_cased_L-24_H-1024_A-16/spiece.model "${BERT_DIR}/xlnet_large_cased_spiece.model"
|
||||
rm -rf *.zip xlnet_cased_L-24_H-1024_A-16
|
||||
## download converted xlnet pytorch model
|
||||
wget https://mrc.blob.core.windows.net/mt-dnn-model/xlnet_model_large_cased.pt -O "${BERT_DIR}/xlnet_model_large_cased.pt"
|
||||
|
||||
|
||||
## download ROBERTA
|
||||
wget https://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz -O "roberta.base.tar.gz"
|
||||
wget https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz -O "roberta.large.tar.gz"
|
||||
tar xvf roberta.base.tar.gz
|
||||
mv "roberta.base" "${BERT_DIR}/"
|
||||
tar xvf roberta.large.tar.gz
|
||||
mv "roberta.large" "${BERT_DIR}/"
|
||||
rm "roberta.base.tar.gz"
|
||||
rm "roberta.large.tar.gz"
|
||||
|
||||
mkdir "${BERT_DIR}/roberta"
|
||||
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json' -O "${BERT_DIR}/roberta/encoder.json"
|
||||
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe' -O "${BERT_DIR}/roberta/vocab.bpe"
|
||||
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt' -O "${BERT_DIR}/roberta/ict.txt"
|
||||
|
||||
if [ "$1" == "model_only" ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
DATA_DIR=$(pwd)/data
|
||||
if [ ! -d ${DATA_DIR} ]; then
|
||||
echo "Create a folder $DATA_DIR"
|
||||
mkdir ${DATA_DIR}
|
||||
fi
|
||||
|
||||
## DOWNLOAD GLUE DATA
|
||||
## Please refer glue-baseline install requirments or other issues.
|
||||
git clone https://github.com/jsalt18-sentence-repl/jiant.git
|
||||
cd jiant
|
||||
python scripts/download_glue_data.py --data_dir $DATA_DIR --tasks all
|
||||
|
||||
cd ..
|
||||
rm -rf jiant
|
||||
#########################
|
||||
|
||||
## DOWNLOAD SciTail
|
||||
cd $DATA_DIR
|
||||
wget http://data.allenai.org.s3.amazonaws.com/downloads/SciTailV1.1.zip
|
||||
unzip SciTailV1.1.zip
|
||||
mv SciTailV1.1 SciTail
|
||||
# remove zip files
|
||||
rm *.zip
|
||||
|
||||
## Download preprocessed SciTail/SNLI data for domain adaptation
|
||||
cd $DATA_DIR
|
||||
DOMAIN_ADP="domain_adaptation"
|
||||
echo "Create a folder $DATA_DIR"
|
||||
mkdir ${DOMAIN_ADP}
|
||||
|
||||
wget https://mrc.blob.core.windows.net/mt-dnn-model/data.zip
|
||||
unzip data.zip
|
||||
mv data/* ${DOMAIN_ADP}
|
||||
rm -rf data.zip
|
||||
rm -rf data
|
||||
|
||||
## Download SQuAD & SQuAD v2.0 data
|
||||
cd $DATA_DIR
|
||||
mkdir "squad"
|
||||
wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json -O $DATA_DIR/squad/train.json
|
||||
wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json -O $DATA_DIR/squad/dev.json
|
||||
|
||||
mkdir "squad_v2"
|
||||
wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json -O $DATA_DIR/squad_v2/train.json
|
||||
wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json -O $DATA_DIR/squad_v2/dev.json
|
||||
|
||||
# NER
|
||||
cd $DATA_DIR
|
||||
mkdir "ner"
|
||||
wget https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.train -O "ner/train.txt"
|
||||
wget https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.testa -O "ner/valid.txt"
|
||||
wget https://raw.githubusercontent.com/synalp/NER/master/corpus/CoNLL-2003/eng.testb -O "ner/test.txt"
|
|
@ -0,0 +1,186 @@
|
|||
#!/usr/bin/python
|
||||
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
# Code adapted from https://github.com/microsoft/nlp-recipes/blob/master/tools/generate_conda_file.py
|
||||
|
||||
# This script creates yaml files to build conda environments
|
||||
# For generating a conda file for running only python code:
|
||||
# $ python generate_conda_file.py
|
||||
#
|
||||
# For generating a conda file for running python gpu:
|
||||
# $ python generate_conda_file.py --gpu
|
||||
|
||||
|
||||
import argparse
|
||||
import textwrap
|
||||
from sys import platform
|
||||
|
||||
|
||||
HELP_MSG = """
|
||||
To create the conda environment:
|
||||
$ conda env create -f {conda_env}.yaml
|
||||
|
||||
To update the conda environment:
|
||||
$ conda env update -f {conda_env}.yaml
|
||||
|
||||
To register the conda environment in Jupyter:
|
||||
$ conda activate {conda_env}
|
||||
$ python -m ipykernel install --user --name {conda_env} \
|
||||
--display-name "Python ({conda_env})"
|
||||
"""
|
||||
|
||||
|
||||
CHANNELS = ["defaults", "conda-forge", "pytorch"]
|
||||
|
||||
CONDA_BASE = {
|
||||
"python": "python==3.6.8",
|
||||
"pip": "pip>=19.1.1",
|
||||
"ipykernel": "ipykernel>=4.6.1",
|
||||
"jupyter": "jupyter>=1.0.0",
|
||||
"matplotlib": "matplotlib>=2.2.2",
|
||||
"numpy": "numpy>=1.13.3",
|
||||
"pandas": "pandas>=0.24.2",
|
||||
"pytest": "pytest>=3.6.4",
|
||||
"pytorch": "pytorch-cpu>=1.0.0",
|
||||
"scipy": "scipy>=1.0.0",
|
||||
"h5py": "h5py>=2.8.0",
|
||||
"tensorflow": "tensorflow==1.15.0",
|
||||
"tensorflow-hub": "tensorflow-hub==0.7.0",
|
||||
"dask": "dask[dataframe]==1.2.2",
|
||||
"papermill": "papermill>=1.0.1",
|
||||
}
|
||||
|
||||
CONDA_GPU = {
|
||||
"numba": "numba>=0.38.1",
|
||||
"cudatoolkit": "cudatoolkit==10.2.89",
|
||||
}
|
||||
|
||||
PIP_BASE = {
|
||||
"allennlp": "allennlp==0.8.4",
|
||||
"black": "black>=18.6b4",
|
||||
"cached-property": "cached-property==1.5.1",
|
||||
"jsonlines": "jsonlines>=1.2.0",
|
||||
"nteract-scrapbook": "nteract-scrapbook>=0.2.1",
|
||||
"pytorch-pretrained-bert": "pytorch-pretrained-bert>=0.6",
|
||||
"tqdm": "tqdm==4.32.2",
|
||||
"pyemd": "pyemd==0.5.1",
|
||||
"ipywebrtc": "ipywebrtc==0.4.3",
|
||||
"pre-commit": "pre-commit>=1.14.4",
|
||||
"scikit-learn": "scikit-learn>=0.19.0,<=0.20.3",
|
||||
"seaborn": "seaborn>=0.9.0",
|
||||
"sklearn-crfsuite": "sklearn-crfsuite>=0.3.6",
|
||||
"spacy": "spacy==2.1.8",
|
||||
"spacy-models": (
|
||||
"https://github.com/explosion/spacy-models/releases/download/"
|
||||
"en_core_web_sm-2.1.0/en_core_web_sm-2.1.0.tar.gz"
|
||||
),
|
||||
"transformers": "transformers>=2.1.1",
|
||||
"gensim": "gensim>=3.7.0",
|
||||
"nltk": "nltk>=3.4",
|
||||
"seqeval": "seqeval>=0.0.12",
|
||||
"bertsum": "git+https://github.com/daden-ms/BertSum.git@030c139c97bc57d0c31f6515b8bf9649f999a443#egg=BertSum",
|
||||
"pyrouge": "pyrouge>=0.1.3",
|
||||
"py-rouge": "py-rouge>=1.1",
|
||||
"torchtext": "torchtext>=0.4.0",
|
||||
"multiprocess": "multiprocess==0.70.9",
|
||||
"tensorboardX": "tensorboardX==1.8",
|
||||
}
|
||||
|
||||
PIP_GPU = {
|
||||
"torch": "torch==1.4.0",
|
||||
}
|
||||
|
||||
PIP_DARWIN = {}
|
||||
PIP_DARWIN_GPU = {}
|
||||
|
||||
PIP_LINUX = {}
|
||||
PIP_LINUX_GPU = {}
|
||||
|
||||
PIP_WIN32 = {}
|
||||
PIP_WIN32_GPU = {}
|
||||
|
||||
CONDA_DARWIN = {}
|
||||
CONDA_DARWIN_GPU = {}
|
||||
|
||||
CONDA_LINUX = {}
|
||||
CONDA_LINUX_GPU = {}
|
||||
|
||||
CONDA_WIN32 = {}
|
||||
CONDA_WIN32_GPU = {"pytorch": "pytorch==1.0.0", "cudatoolkit": "cuda90"}
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description=textwrap.dedent(
|
||||
"""
|
||||
This script generates a conda file for different environments.
|
||||
Plain python is the default,
|
||||
but flags can be used to support GPU functionality."""
|
||||
),
|
||||
epilog=HELP_MSG,
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
parser.add_argument("--name", help="specify name of conda environment")
|
||||
parser.add_argument(
|
||||
"--gpu", action="store_true", help="include packages for GPU support"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# set name for environment and output yaml file
|
||||
conda_env = "mtdnn_cpu"
|
||||
if args.gpu:
|
||||
conda_env = "mtdnn_gpu"
|
||||
|
||||
# overwrite environment name with user input
|
||||
if args.name is not None:
|
||||
conda_env = args.name
|
||||
|
||||
# add conda and pip base packages
|
||||
conda_packages = CONDA_BASE
|
||||
pip_packages = PIP_BASE
|
||||
|
||||
# update conda and pip packages based on flags provided
|
||||
if args.gpu:
|
||||
conda_packages.update(CONDA_GPU)
|
||||
pip_packages.update(PIP_GPU)
|
||||
|
||||
# update conda and pip packages based on os platform support
|
||||
if platform == "darwin":
|
||||
conda_packages.update(CONDA_DARWIN)
|
||||
pip_packages.update(PIP_DARWIN)
|
||||
if args.gpu:
|
||||
conda_packages.update(CONDA_DARWIN_GPU)
|
||||
pip_packages.update(PIP_DARWIN_GPU)
|
||||
elif platform.startswith("linux"):
|
||||
conda_packages.update(CONDA_LINUX)
|
||||
pip_packages.update(PIP_LINUX)
|
||||
if args.gpu:
|
||||
conda_packages.update(CONDA_LINUX_GPU)
|
||||
pip_packages.update(PIP_LINUX_GPU)
|
||||
elif platform == "win32":
|
||||
conda_packages.update(CONDA_WIN32)
|
||||
pip_packages.update(PIP_WIN32)
|
||||
if args.gpu:
|
||||
conda_packages.update(CONDA_WIN32_GPU)
|
||||
pip_packages.update(PIP_WIN32_GPU)
|
||||
else:
|
||||
raise Exception("Unsupported platform. Must be Windows, Linux, or macOS")
|
||||
|
||||
# write out yaml file
|
||||
conda_file = "{}.yaml".format(conda_env)
|
||||
with open(conda_file, "w") as f:
|
||||
for line in HELP_MSG.format(conda_env=conda_env).split("\n"):
|
||||
f.write("# {}\n".format(line))
|
||||
f.write("name: {}\n".format(conda_env))
|
||||
f.write("channels:\n")
|
||||
for channel in CHANNELS:
|
||||
f.write("- {}\n".format(channel))
|
||||
f.write("dependencies:\n")
|
||||
for conda_package in conda_packages.values():
|
||||
f.write("- {}\n".format(conda_package))
|
||||
f.write("- pip:\n")
|
||||
for pip_package in pip_packages.values():
|
||||
f.write(" - {}\n".format(pip_package))
|
||||
|
||||
print("Generated conda file: {}".format(conda_file))
|
||||
print(HELP_MSG.format(conda_env=conda_env))
|
|
@ -0,0 +1,101 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
from __future__ import absolute_import, print_function
|
||||
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
from os.path import dirname, join
|
||||
|
||||
from setuptools import setup
|
||||
|
||||
from mtdnn import AUTHOR, LICENSE, TITLE, VERSION
|
||||
|
||||
|
||||
def read(*names, **kwargs):
|
||||
with io.open(
|
||||
join(dirname(__file__), *names), encoding=kwargs.get("encoding", "utf8")
|
||||
) as fh:
|
||||
return fh.read()
|
||||
|
||||
setup(
|
||||
name="mtdnn",
|
||||
version=VERSION,
|
||||
license=LICENSE,
|
||||
description="Multi-Task Deep Neural Networks for Natural Language Understanding. Developed by Microsoft Research AI",
|
||||
long_description="%s\n%s"
|
||||
% (
|
||||
re.compile("^.. start-badges.*^.. end-badges", re.M | re.S).sub(
|
||||
"", read("README.md")
|
||||
),
|
||||
re.sub(":[a-z]+:`~?(.*?)`", r"``\1``", read("CONTRIBUTION.md")),
|
||||
),
|
||||
author=AUTHOR,
|
||||
author_email="xiaodl@microsoft.com",
|
||||
url="https://github.com/microsoft/mt-dnn",
|
||||
packages=["mtdnn"],
|
||||
include_package_data=True,
|
||||
zip_safe=True,
|
||||
classifiers=[
|
||||
# complete classifier list: http://pypi.python.org/pypi?%3Aaction=list_classifiers
|
||||
"Development Status :: 5 - Production/Stable",
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: Unix",
|
||||
"Operating System :: POSIX",
|
||||
"Operating System :: Microsoft :: Windows",
|
||||
"Programming Language :: Python :: 3.6",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: Implementation :: CPython",
|
||||
"Programming Language :: Python :: Implementation :: PyPy",
|
||||
"Topic :: Text Processing :: Linguistic",
|
||||
"Topic :: Utilities",
|
||||
"Intended Audience :: Science/Research",
|
||||
"Intended Audience :: Developers",
|
||||
"Intended Audience :: Education",
|
||||
"Intended Audience :: Financial and Insurance Industry",
|
||||
"Intended Audience :: Healthcare Industry",
|
||||
"Intended Audience :: Information Technology",
|
||||
"Intended Audience :: Telecommunications Industry",
|
||||
],
|
||||
project_urls={
|
||||
"Documentation": "https://github.com/microsoft/mt-dnn/",
|
||||
"Issue Tracker": "https://github.com/microsoft/mt-dnn/issues",
|
||||
},
|
||||
keywords=[
|
||||
"Microsoft NLP",
|
||||
"Microsoft MT-DNN",
|
||||
"Mutli-Task Deep Neural Network for Natual Language Understanding",
|
||||
"Natural Language Processing",
|
||||
"Text Processing",
|
||||
"Word Embedding",
|
||||
"Multi-Task DNN",
|
||||
],
|
||||
python_requires=">=3.6",
|
||||
install_requires=[
|
||||
"numpy",
|
||||
"torch==1.4.0",
|
||||
"tqdm",
|
||||
"colorlog",
|
||||
"boto3",
|
||||
"pytorch-pretrained-bert==0.6.0",
|
||||
"regex",
|
||||
"scikit-learn",
|
||||
"pyyaml",
|
||||
"pytest",
|
||||
"sentencepiece",
|
||||
"tensorboardX",
|
||||
"tensorboard",
|
||||
"future",
|
||||
"fairseq==0.8.0",
|
||||
"seqeval==0.0.12",
|
||||
"transformers==2.3.0",
|
||||
],
|
||||
dependency_links=[],
|
||||
extras_require={},
|
||||
use_scm_version=False,
|
||||
setup_requires=[],
|
||||
)
|
Загрузка…
Ссылка в новой задаче