зеркало из https://github.com/microsoft/LoRA.git
initial commit
This commit is contained in:
Коммит
2eb74c2dc8
|
@ -0,0 +1,135 @@
|
|||
# all files with SCRATCH prefix and large model / metric files
|
||||
**/SCRATCH*
|
||||
/pretrained_checkpoints
|
||||
/trained_models
|
||||
eval/e2e
|
||||
eval/GenerationEval
|
||||
venv
|
||||
trained_models
|
||||
pretrained_checkpoints
|
||||
.*
|
||||
*~*
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# celery beat schedule file
|
||||
celerybeat-schedule
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
.vscode/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
.idea/
|
||||
toy*.py
|
||||
.DS_Store
|
||||
toy*.py
|
||||
post/
|
||||
*.user
|
||||
*.nupkg
|
||||
/Packages
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# Microsoft Open Source Code of Conduct
|
||||
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||
|
||||
Resources:
|
||||
|
||||
- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
|
||||
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
|
||||
- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
|
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) Microsoft Corporation.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE
|
|
@ -0,0 +1,182 @@
|
|||
# Adapting GPT-2 using LoRA
|
||||
|
||||
This folder contains the implementation of LoRA in GPT-2 using the Python package `lora` and steps to replicate the results in our recent paper
|
||||
|
||||
**LoRA: Low-Rank Adaptation of Large Language Models** <br>
|
||||
*Edward J. Hu\*, Yelong Shen\*, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Weizhu Chen* <br>
|
||||
Paper: https://arxiv.org/abs/2106.09685 <br>
|
||||
|
||||
<p>
|
||||
<img src="figures/LoRA_GPT2.PNG" width="800" >
|
||||
</p>
|
||||
|
||||
This repo reproduces our experiments on GPT-2.
|
||||
|
||||
## Repository Overview
|
||||
|
||||
Our implementation is based on the fine-tuning code for GPT-2 in [Hugging Face](https://huggingface.co/).
|
||||
There are several directories in this repo:
|
||||
* [src/](src) contains the source code used for data processing, training, and decoding.
|
||||
* [eval/](eval) contains the code for task-specific evaluation scripts.
|
||||
* [data/](data) contains the raw data we used in our experiments.
|
||||
* [vocab/](vocab) contains the GPT-2 vocabulary files.
|
||||
|
||||
## Getting Started
|
||||
|
||||
1. You can start with the following docker image: `nvcr.io/nvidia/pytorch:20.03-py3` on a GPU-capable machine, but any generic PyTorch image should work.
|
||||
```
|
||||
docker run -it nvcr.io/nvidia/pytorch:20.03-py3
|
||||
```
|
||||
|
||||
2. Clone the repo and install dependencies in a virtual environment (remove sudo if running in docker container):
|
||||
```
|
||||
sudo apt-get update
|
||||
sudo apt-get -y install git jq virtualenv
|
||||
git clone https://github.com/microsoft/LoRA.git; cd LoRA
|
||||
virtualenv -p `which python3` ./venv
|
||||
. ./venv/bin/activate
|
||||
pip install -r requirement.txt
|
||||
bash download_pretrained_checkpoints.sh
|
||||
bash create_datasets.sh
|
||||
cd ./eval
|
||||
bash download_evalscript.sh
|
||||
cd ..
|
||||
```
|
||||
|
||||
#### Now we are ready to replicate the results in our paper.
|
||||
|
||||
## Replicating Our Result on E2E
|
||||
|
||||
1. Train GPT-2 Medium with LoRA (see our paper for hyperparameters for GPT-2 Medium)
|
||||
```
|
||||
python -m torch.distributed.launch --nproc_per_node=1 src/gpt2_ft.py \
|
||||
--train_data ./data/e2e/train.jsonl \
|
||||
--valid_data ./data/e2e/valid.jsonl \
|
||||
--train_batch_size 8 \
|
||||
--grad_acc 1 \
|
||||
--valid_batch_size 4 \
|
||||
--seq_len 512 \
|
||||
--model_card gpt2.md \
|
||||
--init_checkpoint ./pretrained_checkpoints/gpt2-medium-pytorch_model.bin \
|
||||
--platform local \
|
||||
--clip 0.0 \
|
||||
--lr 0.0002 \
|
||||
--weight_decay 0.01 \
|
||||
--correct_bias \
|
||||
--adam_beta2 0.999 \
|
||||
--scheduler linear \
|
||||
--warmup_step 500 \
|
||||
--max_epoch 5 \
|
||||
--save_interval 1000 \
|
||||
--lora_dim 4 \
|
||||
--lora_alpha 32 \
|
||||
--lora_dropout 0.1 \
|
||||
--label_smooth 0.1 \
|
||||
--work_dir ./trained_models/GPT2_M/e2e \
|
||||
--random_seed 110
|
||||
```
|
||||
|
||||
2. Generate outputs from the trained model using beam search:
|
||||
```
|
||||
python -m torch.distributed.launch --nproc_per_node=1 src/gpt2_beam.py \
|
||||
--data ./data/e2e/test.jsonl \
|
||||
--batch_size 1 \
|
||||
--seq_len 512 \
|
||||
--eval_len 64 \
|
||||
--model_card gpt2.md \
|
||||
--init_checkpoint ./trained_models/GPT2_M/e2e/model.26289.pt \
|
||||
--platform local \
|
||||
--lora_dim 4 \
|
||||
--lora_alpha 32 \
|
||||
--beam 10 \
|
||||
--length_penalty 0.8 \
|
||||
--no_repeat_ngram_size 4 \
|
||||
--repetition_penalty 1.0 \
|
||||
--eos_token_id 628 \
|
||||
--work_dir ./trained_models/GPT2_M/e2e \
|
||||
--output_file predict.26289.b10p08r4.jsonl
|
||||
```
|
||||
|
||||
3. Decode outputs from step (2)
|
||||
```
|
||||
python src/gpt2_decode.py \
|
||||
--vocab ./vocab \
|
||||
--sample_file ./trained_models/GPT2_M/e2e/predict.26289.b10p08r4.jsonl \
|
||||
--input_file ./data/e2e/test_formatted.jsonl \
|
||||
--output_ref_file e2e_ref.txt \
|
||||
--output_pred_file e2e_pred.txt
|
||||
```
|
||||
|
||||
4. Run evaluation on E2E test set
|
||||
|
||||
```
|
||||
python eval/e2e/measure_scores.py e2e_ref.txt e2e_pred.txt -p
|
||||
```
|
||||
|
||||
## Replicating Our Result on WebNLG
|
||||
|
||||
1. Follow steps 1 and 2 from E2E pipeline by replacing references to E2E with webnlg (see our paper for hyperparameters)
|
||||
|
||||
2. Decode outputs from beam search (step 2 above)
|
||||
```
|
||||
python src/gpt2_decode.py \
|
||||
--vocab ./vocab \
|
||||
--sample_file ./trained_models/GPT2_M/webnlg/predict.20000.b10p08.jsonl \
|
||||
--input_file ./data/webnlg_challenge_2017/test_formatted.jsonl \
|
||||
--ref_type webnlg \
|
||||
--ref_num 6 \
|
||||
--output_ref_file eval/GenerationEval/data/references_webnlg \
|
||||
--output_pred_file eval/GenerationEval/data/hypothesis_webnlg \
|
||||
--tokenize --lower
|
||||
```
|
||||
|
||||
3. Run evaluation on WebNLG test set
|
||||
```
|
||||
cd ./eval/GenerationEval/
|
||||
python eval.py \
|
||||
-R data/references_webnlg/reference \
|
||||
-H data/hypothesis_webnlg \
|
||||
-nr 6 \
|
||||
-m bleu,meteor,ter
|
||||
cd ../..
|
||||
```
|
||||
|
||||
## Replicating Our Result on DART
|
||||
|
||||
1. Follow steps 1 and 2 from E2E pipeline by replacing references to E2E with dart (see our paper for hyperparameters)
|
||||
|
||||
2. Decode outputs from beam search (step 2 above)
|
||||
```
|
||||
python src/gpt2_decode.py \
|
||||
--vocab ./vocab \
|
||||
--sample_file ./trained_models/GPT2_M/dart/predict.20000.b10p08.jsonl \
|
||||
--input_file ./data/dart/test_formatted.jsonl \
|
||||
--ref_type dart \
|
||||
--ref_num 6 \
|
||||
--output_ref_file eval/GenerationEval/data/references_dart \
|
||||
--output_pred_file eval/GenerationEval/data/hypothesis_dart \
|
||||
--tokenize --lower
|
||||
```
|
||||
|
||||
3. Run evaluation on Dart test set
|
||||
```
|
||||
cd ./eval/GenerationEval/
|
||||
python eval.py \
|
||||
-R data/references_dart/reference \
|
||||
-H data/hypothesis_dart \
|
||||
-nr 6 \
|
||||
-m bleu,meteor,ter
|
||||
cd ../..
|
||||
```
|
||||
|
||||
## Citation
|
||||
```
|
||||
@misc{hu2021lora,
|
||||
title={LoRA: Low-Rank Adaptation of Large Language Models},
|
||||
author={Hu, Edward and Shen, Yelong and Wallis, Phil and Allen-Zhu, Zeyuan and Li, Yuanzhi and Chen, Weizhu},
|
||||
year={2021},
|
||||
eprint={2106.09685},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CL}
|
||||
}
|
||||
```
|
|
@ -0,0 +1,41 @@
|
|||
<!-- BEGIN MICROSOFT SECURITY.MD V0.0.5 BLOCK -->
|
||||
|
||||
## Security
|
||||
|
||||
Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
|
||||
|
||||
If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below.
|
||||
|
||||
## Reporting Security Issues
|
||||
|
||||
**Please do not report security vulnerabilities through public GitHub issues.**
|
||||
|
||||
Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report).
|
||||
|
||||
If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc).
|
||||
|
||||
You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
|
||||
|
||||
Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
|
||||
|
||||
* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
|
||||
* Full paths of source file(s) related to the manifestation of the issue
|
||||
* The location of the affected source code (tag/branch/commit or direct URL)
|
||||
* Any special configuration required to reproduce the issue
|
||||
* Step-by-step instructions to reproduce the issue
|
||||
* Proof-of-concept or exploit code (if possible)
|
||||
* Impact of the issue, including how an attacker might exploit the issue
|
||||
|
||||
This information will help us triage your report more quickly.
|
||||
|
||||
If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs.
|
||||
|
||||
## Preferred Languages
|
||||
|
||||
We prefer all communications to be in English.
|
||||
|
||||
## Policy
|
||||
|
||||
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).
|
||||
|
||||
<!-- END MICROSOFT SECURITY.MD BLOCK -->
|
|
@ -0,0 +1,44 @@
|
|||
#!/bin/bash
|
||||
|
||||
echo "creating e2e datasets..."
|
||||
path=data/e2e
|
||||
echo "train..."
|
||||
python src/format_converting_e2e.py $path/train.txt $path/train_formatted.jsonl
|
||||
python src/gpt2_encode.py --vocab vocab --input $path/train_formatted.jsonl --output $path/train.jsonl --add_bos --add_eos
|
||||
echo "test..."
|
||||
python src/format_converting_e2e.py $path/test.txt $path/test_formatted.jsonl
|
||||
python src/gpt2_encode.py --vocab vocab --input $path/test_formatted.jsonl --output $path/test.jsonl --add_bos --add_eos
|
||||
|
||||
echo "valid..."
|
||||
python src/format_converting_e2e.py $path/valid.txt $path/valid_formatted.jsonl
|
||||
python src/gpt2_encode.py --vocab vocab --input $path/valid_formatted.jsonl --output $path/valid.jsonl --add_bos --add_eos
|
||||
|
||||
echo "creating webnlg datasets..."
|
||||
path=data/webnlg_challenge_2017
|
||||
echo "train..."
|
||||
python src/format_converting_webnlg.py $path/train.json $path/train_formatted.jsonl
|
||||
python src/gpt2_encode.py --vocab vocab --input $path/train_formatted.jsonl --output $path/train.jsonl --add_bos --add_eos
|
||||
|
||||
echo "test..."
|
||||
python src/format_converting_webnlg.py $path/test.json $path/test_formatted.jsonl
|
||||
python src/gpt2_encode.py --vocab vocab --input $path/test_formatted.jsonl --output $path/test.jsonl --add_bos --add_eos
|
||||
|
||||
echo "valid..."
|
||||
python src/format_converting_webnlg.py $path/dev.json $path/valid_formatted.jsonl
|
||||
python src/gpt2_encode.py --vocab vocab --input $path/valid_formatted.jsonl --output $path/valid.jsonl --add_bos --add_eos
|
||||
|
||||
echo "creating dart datasets..."
|
||||
path=data/dart
|
||||
echo "train..."
|
||||
python src/format_converting_dart.py data/dart/dart-v1.1.1-full-train.json data/dart/train_formatted.jsonl
|
||||
python src/gpt2_encode.py --vocab vocab --input $path/train_formatted.jsonl --output $path/train.jsonl --add_bos --add_eos
|
||||
|
||||
echo "test..."
|
||||
python src/format_converting_dart.py data/dart/dart-v1.1.1-full-test.json data/dart/test_formatted.jsonl
|
||||
python src/gpt2_encode.py --vocab vocab --input $path/test_formatted.jsonl --output $path/test.jsonl --add_bos --add_eos
|
||||
|
||||
echo "valid..."
|
||||
python src/format_converting_dart.py data/dart/dart-v1.1.1-full-dev.json data/dart/valid_formatted.jsonl
|
||||
python src/gpt2_encode.py --vocab vocab --input $path/valid_formatted.jsonl --output $path/valid.jsonl --add_bos --add_eos
|
||||
|
||||
echo "script complete!"
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,11 @@
|
|||
#!/bin/bash
|
||||
|
||||
echo "downloading pretrained model checkpoints..."
|
||||
mkdir pretrained_checkpoints
|
||||
cd pretrained_checkpoints
|
||||
wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin
|
||||
wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin
|
||||
wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-pytorch_model.bin
|
||||
cd ..
|
||||
|
||||
echo "script complete!"
|
|
@ -0,0 +1,7 @@
|
|||
### Evaluation code for E2E, WebNLG and Dart
|
||||
|
||||
* Code for evaluating E2E https://github.com/tuetschek/e2e-metrics
|
||||
* Code for evaluating WebNLG and Dart https://github.com/WebNLG/GenerationEval.git
|
||||
|
||||
Before running evaluation for the first time you must run
|
||||
`bash download_evalscript.sh`
|
|
@ -0,0 +1,18 @@
|
|||
#!/bin/bash
|
||||
|
||||
cd eval
|
||||
echo "installing evaluation dependencies"
|
||||
echo "downloading e2e-metrics..."
|
||||
git clone https://github.com/tuetschek/e2e-metrics e2e
|
||||
pip install -r e2e/requirements.txt
|
||||
|
||||
echo "downloading GenerationEval for webnlg and dart..."
|
||||
git clone https://github.com/WebNLG/GenerationEval.git
|
||||
cd GenerationEval
|
||||
./install_dependencies.sh
|
||||
rm -r data/en
|
||||
rm -r data/ru
|
||||
cd ..
|
||||
mv eval.py GenerationEval/
|
||||
|
||||
echo "script complete!"
|
|
@ -0,0 +1,364 @@
|
|||
__author__='thiagocastroferreira'
|
||||
|
||||
"""
|
||||
Author: Organizers of the 2nd WebNLG Challenge
|
||||
Date: 23/04/2020
|
||||
Description:
|
||||
This script aims to evaluate the output of data-to-text NLG models by
|
||||
computing popular automatic metrics such as BLEU (two implementations),
|
||||
METEOR, chrF++, TER and BERT-Score.
|
||||
|
||||
ARGS:
|
||||
usage: eval.py [-h] -R REFERENCE -H HYPOTHESIS [-lng LANGUAGE] [-nr NUM_REFS]
|
||||
[-m METRICS] [-nc NCORDER] [-nw NWORDER] [-b BETA]
|
||||
|
||||
optional arguments:
|
||||
-h, --help show this help message and exit
|
||||
-R REFERENCE, --reference REFERENCE
|
||||
reference translation
|
||||
-H HYPOTHESIS, --hypothesis HYPOTHESIS
|
||||
hypothesis translation
|
||||
-lng LANGUAGE, --language LANGUAGE
|
||||
evaluated language
|
||||
-nr NUM_REFS, --num_refs NUM_REFS
|
||||
number of references
|
||||
-m METRICS, --metrics METRICS
|
||||
evaluation metrics to be computed
|
||||
-nc NCORDER, --ncorder NCORDER
|
||||
chrF metric: character n-gram order (default=6)
|
||||
-nw NWORDER, --nworder NWORDER
|
||||
chrF metric: word n-gram order (default=2)
|
||||
-b BETA, --beta BETA chrF metric: beta parameter (default=2)
|
||||
|
||||
EXAMPLE:
|
||||
ENGLISH:
|
||||
python3 eval.py -R data/en/references/reference -H data/en/hypothesis -nr 4 -m bleu,meteor,chrf++,ter,bert,bleurt
|
||||
RUSSIAN:
|
||||
python3 eval.py -R data/ru/reference -H data/ru/hypothesis -lng ru -nr 1 -m bleu,meteor,chrf++,ter,bert
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import codecs
|
||||
import copy
|
||||
import os
|
||||
import pyter
|
||||
import logging
|
||||
import nltk
|
||||
import subprocess
|
||||
import re
|
||||
|
||||
from bert_score import score
|
||||
from metrics.chrF import computeChrF
|
||||
from metrics.bleurt.bleurt import score as bleurt_score
|
||||
|
||||
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
|
||||
from razdel import tokenize
|
||||
from tabulate import tabulate
|
||||
|
||||
BLEU_PATH = 'metrics/multi-bleu-detok.perl'
|
||||
METEOR_PATH = 'metrics/meteor-1.5/meteor-1.5.jar'
|
||||
|
||||
|
||||
def parse(refs_path, hyps_path, num_refs, lng='en'):
|
||||
logging.info('STARTING TO PARSE INPUTS...')
|
||||
print('STARTING TO PARSE INPUTS...')
|
||||
# references
|
||||
references = []
|
||||
for i in range(num_refs):
|
||||
fname = refs_path + str(i) if num_refs > 1 else refs_path
|
||||
with codecs.open(fname, 'r', 'utf-8') as f:
|
||||
texts = f.read().split('\n')
|
||||
for j, text in enumerate(texts):
|
||||
if len(references) <= j:
|
||||
references.append([text])
|
||||
else:
|
||||
references[j].append(text)
|
||||
|
||||
# references tokenized
|
||||
references_tok = copy.copy(references)
|
||||
for i, refs in enumerate(references_tok):
|
||||
if lng == 'ru':
|
||||
references_tok[i] = [' '.join([_.text for _ in tokenize(ref)]) for ref in refs]
|
||||
else:
|
||||
references_tok[i] = [' '.join(nltk.word_tokenize(ref)) for ref in refs]
|
||||
|
||||
# hypothesis
|
||||
with codecs.open(hyps_path, 'r', 'utf-8') as f:
|
||||
hypothesis = f.read().split('\n')
|
||||
|
||||
# hypothesis tokenized
|
||||
hypothesis_tok = copy.copy(hypothesis)
|
||||
if lng == 'ru':
|
||||
hypothesis_tok = [' '.join([_.text for _ in tokenize(hyp)]) for hyp in hypothesis_tok]
|
||||
else:
|
||||
hypothesis_tok = [' '.join(nltk.word_tokenize(hyp)) for hyp in hypothesis_tok]
|
||||
|
||||
|
||||
logging.info('FINISHING TO PARSE INPUTS...')
|
||||
print('FINISHING TO PARSE INPUTS...')
|
||||
return references, references_tok, hypothesis, hypothesis_tok
|
||||
|
||||
def bleu_score(refs_path, hyps_path, num_refs):
|
||||
logging.info('STARTING TO COMPUTE BLEU...')
|
||||
print('STARTING TO COMPUTE BLEU...')
|
||||
ref_files = []
|
||||
for i in range(num_refs):
|
||||
if num_refs == 1:
|
||||
ref_files.append(refs_path)
|
||||
else:
|
||||
ref_files.append(refs_path + str(i))
|
||||
|
||||
command = 'perl {0} {1} < {2}'.format(BLEU_PATH, ' '.join(ref_files), hyps_path)
|
||||
result = subprocess.check_output(command, shell=True)
|
||||
try:
|
||||
bleu = float(re.findall('BLEU = (.+?),', str(result))[0])
|
||||
except:
|
||||
logging.error('ERROR ON COMPUTING METEOR. MAKE SURE YOU HAVE PERL INSTALLED GLOBALLY ON YOUR MACHINE.')
|
||||
print('ERROR ON COMPUTING METEOR. MAKE SURE YOU HAVE PERL INSTALLED GLOBALLY ON YOUR MACHINE.')
|
||||
bleu = -1
|
||||
logging.info('FINISHING TO COMPUTE BLEU...')
|
||||
print('FINISHING TO COMPUTE BLEU...')
|
||||
return bleu
|
||||
|
||||
|
||||
def bleu_nltk(references, hypothesis):
|
||||
# check for empty lists
|
||||
references_, hypothesis_ = [], []
|
||||
for i, refs in enumerate(references):
|
||||
refs_ = [ref for ref in refs if ref.strip() != '']
|
||||
if len(refs_) > 0:
|
||||
references_.append([ref.split() for ref in refs_])
|
||||
hypothesis_.append(hypothesis[i].split())
|
||||
|
||||
chencherry = SmoothingFunction()
|
||||
return corpus_bleu(references_, hypothesis_, smoothing_function=chencherry.method3)
|
||||
|
||||
|
||||
def meteor_score(references, hypothesis, num_refs, lng='en'):
|
||||
logging.info('STARTING TO COMPUTE METEOR...')
|
||||
print('STARTING TO COMPUTE METEOR...')
|
||||
hyps_tmp, refs_tmp = 'hypothesis_meteor', 'reference_meteor'
|
||||
|
||||
with codecs.open(hyps_tmp, 'w', 'utf-8') as f:
|
||||
f.write('\n'.join(hypothesis))
|
||||
|
||||
linear_references = []
|
||||
for refs in references:
|
||||
for i in range(num_refs):
|
||||
linear_references.append(refs[i])
|
||||
|
||||
with codecs.open(refs_tmp, 'w', 'utf-8') as f:
|
||||
f.write('\n'.join(linear_references))
|
||||
|
||||
try:
|
||||
command = 'java -Xmx2G -jar {0} '.format(METEOR_PATH)
|
||||
command += '{0} {1} -l {2} -norm -r {3}'.format(hyps_tmp, refs_tmp, lng, num_refs)
|
||||
result = subprocess.check_output(command, shell=True)
|
||||
meteor = result.split(b'\n')[-2].split()[-1]
|
||||
except:
|
||||
logging.error('ERROR ON COMPUTING METEOR. MAKE SURE YOU HAVE JAVA INSTALLED GLOBALLY ON YOUR MACHINE.')
|
||||
print('ERROR ON COMPUTING METEOR. MAKE SURE YOU HAVE JAVA INSTALLED GLOBALLY ON YOUR MACHINE.')
|
||||
meteor = -1
|
||||
|
||||
try:
|
||||
os.remove(hyps_tmp)
|
||||
os.remove(refs_tmp)
|
||||
except:
|
||||
pass
|
||||
logging.info('FINISHING TO COMPUTE METEOR...')
|
||||
print('FINISHING TO COMPUTE METEOR...')
|
||||
return float(meteor)
|
||||
|
||||
|
||||
def chrF_score(references, hypothesis, num_refs, nworder, ncorder, beta):
|
||||
logging.info('STARTING TO COMPUTE CHRF++...')
|
||||
print('STARTING TO COMPUTE CHRF++...')
|
||||
hyps_tmp, refs_tmp = 'hypothesis_chrF', 'reference_chrF'
|
||||
|
||||
# check for empty lists
|
||||
references_, hypothesis_ = [], []
|
||||
for i, refs in enumerate(references):
|
||||
refs_ = [ref for ref in refs if ref.strip() != '']
|
||||
if len(refs_) > 0:
|
||||
references_.append(refs_)
|
||||
hypothesis_.append(hypothesis[i])
|
||||
|
||||
with codecs.open(hyps_tmp, 'w', 'utf-8') as f:
|
||||
f.write('\n'.join(hypothesis_))
|
||||
|
||||
linear_references = []
|
||||
for refs in references_:
|
||||
linear_references.append('*#'.join(refs[:num_refs]))
|
||||
|
||||
with codecs.open(refs_tmp, 'w', 'utf-8') as f:
|
||||
f.write('\n'.join(linear_references))
|
||||
|
||||
rtxt = codecs.open(refs_tmp, 'r', 'utf-8')
|
||||
htxt = codecs.open(hyps_tmp, 'r', 'utf-8')
|
||||
|
||||
try:
|
||||
totalF, averageTotalF, totalPrec, totalRec = computeChrF(rtxt, htxt, nworder, ncorder, beta, None)
|
||||
except:
|
||||
logging.error('ERROR ON COMPUTING CHRF++.')
|
||||
print('ERROR ON COMPUTING CHRF++.')
|
||||
totalF, averageTotalF, totalPrec, totalRec = -1, -1, -1, -1
|
||||
try:
|
||||
os.remove(hyps_tmp)
|
||||
os.remove(refs_tmp)
|
||||
except:
|
||||
pass
|
||||
logging.info('FINISHING TO COMPUTE CHRF++...')
|
||||
print('FINISHING TO COMPUTE CHRF++...')
|
||||
return totalF, averageTotalF, totalPrec, totalRec
|
||||
|
||||
|
||||
def ter_score(references, hypothesis, num_refs):
|
||||
logging.info('STARTING TO COMPUTE TER...')
|
||||
print('STARTING TO COMPUTE TER...')
|
||||
ter_scores = []
|
||||
for hyp, refs in zip(hypothesis, references):
|
||||
candidates = []
|
||||
for ref in refs[:num_refs]:
|
||||
if len(ref) == 0:
|
||||
ter_score = 1
|
||||
else:
|
||||
try:
|
||||
ter_score = pyter.ter(hyp.split(), ref.split())
|
||||
except:
|
||||
ter_score = 1
|
||||
candidates.append(ter_score)
|
||||
|
||||
ter_scores.append(min(candidates))
|
||||
|
||||
logging.info('FINISHING TO COMPUTE TER...')
|
||||
print('FINISHING TO COMPUTE TER...')
|
||||
return sum(ter_scores) / len(ter_scores)
|
||||
|
||||
|
||||
def bert_score_(references, hypothesis, lng='en'):
|
||||
logging.info('STARTING TO COMPUTE BERT SCORE...')
|
||||
print('STARTING TO COMPUTE BERT SCORE...')
|
||||
for i, refs in enumerate(references):
|
||||
references[i] = [ref for ref in refs if ref.strip() != '']
|
||||
|
||||
try:
|
||||
P, R, F1 = score(hypothesis, references, lang=lng)
|
||||
logging.info('FINISHING TO COMPUTE BERT SCORE...')
|
||||
# print('FINISHING TO COMPUTE BERT SCORE...')
|
||||
P, R, F1 = list(P), list(R), list(F1)
|
||||
F1 = float(sum(F1) / len(F1))
|
||||
P = float(sum(P) / len(P))
|
||||
R = float(sum(R) / len(R))
|
||||
except:
|
||||
P, R, F1 = 0, 0, 0
|
||||
return P, R, F1
|
||||
|
||||
def bleurt(references, hypothesis, num_refs, checkpoint = "metrics/bleurt/bleurt-base-128"):
|
||||
refs, cands = [], []
|
||||
for i, hyp in enumerate(hypothesis):
|
||||
for ref in references[i][:num_refs]:
|
||||
cands.append(hyp)
|
||||
refs.append(ref)
|
||||
|
||||
scorer = bleurt_score.BleurtScorer(checkpoint)
|
||||
scores = scorer.score(refs, cands)
|
||||
scores = [max(scores[i:i+num_refs]) for i in range(0, len(scores), num_refs)]
|
||||
return round(sum(scores) / len(scores), 2)
|
||||
|
||||
|
||||
def run(refs_path, hyps_path, num_refs, lng='en', metrics='bleu,meteor,chrf++,ter,bert,bleurt',ncorder=6, nworder=2, beta=2):
|
||||
metrics = metrics.lower().split(',')
|
||||
references, references_tok, hypothesis, hypothesis_tok = parse(refs_path, hyps_path, num_refs, lng)
|
||||
|
||||
result = {}
|
||||
|
||||
logging.info('STARTING EVALUATION...')
|
||||
if 'bleu' in metrics:
|
||||
bleu = bleu_score(refs_path, hyps_path, num_refs)
|
||||
result['bleu'] = bleu
|
||||
|
||||
b = bleu_nltk(references_tok, hypothesis_tok)
|
||||
result['bleu_nltk'] = b
|
||||
if 'meteor' in metrics:
|
||||
meteor = meteor_score(references_tok, hypothesis_tok, num_refs, lng=lng)
|
||||
result['meteor'] = meteor
|
||||
if 'chrf++' in metrics:
|
||||
chrf, _, _, _ = chrF_score(references, hypothesis, num_refs, nworder, ncorder, beta)
|
||||
result['chrf++'] = chrf
|
||||
if 'ter' in metrics:
|
||||
ter = ter_score(references_tok, hypothesis_tok, num_refs)
|
||||
result['ter'] = ter
|
||||
if 'bert' in metrics:
|
||||
P, R, F1 = bert_score_(references, hypothesis, lng=lng)
|
||||
result['bert_precision'] = P
|
||||
result['bert_recall'] = R
|
||||
result['bert_f1'] = F1
|
||||
if 'bleurt' in metrics and lng == 'en':
|
||||
s = bleurt(references, hypothesis, num_refs)
|
||||
result['bleurt'] = s
|
||||
logging.info('FINISHING EVALUATION...')
|
||||
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
FORMAT = '%(levelname)s: %(asctime)-15s - %(message)s'
|
||||
logging.basicConfig(filename='eval.log', level=logging.INFO, format=FORMAT)
|
||||
|
||||
argParser = argparse.ArgumentParser()
|
||||
argParser.add_argument("-R", "--reference", help="reference translation", required=True)
|
||||
argParser.add_argument("-H", "--hypothesis", help="hypothesis translation", required=True)
|
||||
argParser.add_argument("-lng", "--language", help="evaluated language", default='en')
|
||||
argParser.add_argument("-nr", "--num_refs", help="number of references", type=int, default=4)
|
||||
argParser.add_argument("-m", "--metrics", help="evaluation metrics to be computed", default='bleu,meteor,ter,chrf++,bert,bleurt')
|
||||
argParser.add_argument("-nc", "--ncorder", help="chrF metric: character n-gram order (default=6)", type=int, default=6)
|
||||
argParser.add_argument("-nw", "--nworder", help="chrF metric: word n-gram order (default=2)", type=int, default=2)
|
||||
argParser.add_argument("-b", "--beta", help="chrF metric: beta parameter (default=2)", type=float, default=2.0)
|
||||
|
||||
args = argParser.parse_args()
|
||||
|
||||
logging.info('READING INPUTS...')
|
||||
refs_path = args.reference
|
||||
hyps_path = args.hypothesis
|
||||
lng = args.language
|
||||
num_refs = args.num_refs
|
||||
metrics = args.metrics#.lower().split(',')
|
||||
|
||||
nworder = args.nworder
|
||||
ncorder = args.ncorder
|
||||
beta = args.beta
|
||||
logging.info('FINISHING TO READ INPUTS...')
|
||||
|
||||
result = run(refs_path=refs_path, hyps_path=hyps_path, num_refs=num_refs, lng=lng, metrics=metrics, ncorder=ncorder, nworder=nworder, beta=beta)
|
||||
|
||||
metrics = metrics.lower().split(',')
|
||||
headers, values = [], []
|
||||
if 'bleu' in metrics:
|
||||
headers.append('BLEU')
|
||||
values.append(result['bleu'])
|
||||
|
||||
headers.append('BLEU NLTK')
|
||||
values.append(round(result['bleu_nltk'], 2))
|
||||
if 'meteor' in metrics:
|
||||
headers.append('METEOR')
|
||||
values.append(round(result['meteor'], 2))
|
||||
if 'chrf++' in metrics:
|
||||
headers.append('chrF++')
|
||||
values.append(round(result['chrf++'], 2))
|
||||
if 'ter' in metrics:
|
||||
headers.append('TER')
|
||||
values.append(round(result['ter'], 2))
|
||||
if 'bert' in metrics:
|
||||
headers.append('BERT-SCORE P')
|
||||
values.append(round(result['bert_precision'], 2))
|
||||
headers.append('BERT-SCORE R')
|
||||
values.append(round(result['bert_recall'], 2))
|
||||
headers.append('BERT-SCORE F1')
|
||||
values.append(round(result['bert_f1'], 2))
|
||||
if 'bleurt' in metrics and lng == 'en':
|
||||
headers.append('BLEURT')
|
||||
values.append(round(result['bleurt'], 2))
|
||||
|
||||
logging.info('PRINTING RESULTS...')
|
||||
print(tabulate([values], headers=headers))
|
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 266 KiB |
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 98 KiB |
|
@ -0,0 +1,7 @@
|
|||
--find-links https://download.pytorch.org/whl/torch_stable.html
|
||||
torch==1.7.1+cu101
|
||||
transformers==3.3.1
|
||||
spacy
|
||||
tqdm
|
||||
tensorboard
|
||||
progress
|
|
@ -0,0 +1,269 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import os, sys
|
||||
import glob
|
||||
import random
|
||||
from collections import Counter, OrderedDict
|
||||
import numpy as np
|
||||
import torch
|
||||
import json
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
class LMOrderedIterator(object):
|
||||
def __init__(self, data, bsz, bptt, eval_len=None, device='cpu', world_size=1, rank=0):
|
||||
"""
|
||||
data -- LongTensor -- the LongTensor is strictly ordered
|
||||
"""
|
||||
self.data = data
|
||||
self.bsz = bsz
|
||||
self.world_size = world_size
|
||||
self.rank = rank
|
||||
self.bptt = bptt # tgt_len
|
||||
# existing len.
|
||||
self.eval_len = bptt if eval_len is None else eval_len
|
||||
|
||||
self.device = device
|
||||
|
||||
self.global_bsz = bsz * world_size
|
||||
# Work out how cleanly we can divide the dataset into bsz parts.
|
||||
self.n_step = len(data) // self.global_bsz # bsz
|
||||
|
||||
self.split_data = torch.tensor(
|
||||
data[rank * self.n_step * bsz : (rank + 1) * self.n_step * bsz],
|
||||
dtype=torch.long, device=self.device
|
||||
) # data.view(-1)
|
||||
|
||||
self.split_data = self.split_data.view(bsz, -1)
|
||||
|
||||
def __iter__(self):
|
||||
return self.get_fixlen_iter()
|
||||
|
||||
def get_batch(self, i, bptt, eval_len):
|
||||
beg_idx = i
|
||||
end_idx = i + bptt # seq_len
|
||||
|
||||
# batch_size, lengh;
|
||||
_input = self.split_data[:, beg_idx : end_idx].contiguous()
|
||||
_target = self.split_data[:, beg_idx+1 : end_idx+1].contiguous()
|
||||
|
||||
_msk = torch.cat(
|
||||
[
|
||||
torch.zeros(bptt-eval_len, dtype=torch.float, device=self.device),
|
||||
torch.ones(eval_len, dtype=torch.float, device=self.device)
|
||||
]
|
||||
)
|
||||
_msk = _msk.unsqueeze(0).expand_as(_input) # .unsqueeze(-1) # length, 1;
|
||||
return _input, _target, _msk
|
||||
|
||||
def get_fixlen_iter(self, start=0):
|
||||
self.data_len = self.split_data.size(1)
|
||||
_eval_cursor = 0
|
||||
for i in range(start, self.data_len - 1, self.eval_len):
|
||||
bptt = min(self.bptt, self.data_len - i - 1)
|
||||
_end_idx = i + bptt
|
||||
yield self.get_batch(i, bptt, _end_idx - _eval_cursor)
|
||||
_eval_cursor = _end_idx
|
||||
|
||||
|
||||
class Corpus(object):
|
||||
def __init__(self, path):
|
||||
self.path = path
|
||||
self.num_words = 0
|
||||
self.tokens = []
|
||||
with open(self.path, "r") as reader:
|
||||
for line in reader:
|
||||
items = json.loads(line.strip())
|
||||
book = items['book']
|
||||
tokens = items['tokens']
|
||||
num_words = items['num_words']
|
||||
|
||||
self.num_words += num_words
|
||||
self.tokens.extend(tokens)
|
||||
|
||||
|
||||
class BinLMOrderedIterator(object):
|
||||
def __init__(self, corpus, bsz, bptt, eval_len=None, device='cpu', world_size=1, rank=0):
|
||||
"""
|
||||
data -- LongTensor -- the LongTensor is strictly ordered
|
||||
"""
|
||||
self.corpus = corpus
|
||||
self.bsz = bsz
|
||||
self.world_size = world_size
|
||||
self.rank = rank
|
||||
self.bptt = bptt # tgt_len
|
||||
# existing len.
|
||||
self.eval_len = bptt if eval_len is None else eval_len
|
||||
self.device = device
|
||||
self.global_bsz = bsz * world_size
|
||||
# Work out how cleanly we can divide the dataset into bsz parts.
|
||||
self.n_step = corpus.length // self.global_bsz # bsz
|
||||
|
||||
self.offset = [(rank * bsz + _b) * self.n_step for _b in range(bsz)]
|
||||
|
||||
def __iter__(self):
|
||||
return self.get_fixlen_iter()
|
||||
|
||||
def get_batch(self, i, bptt, eval_len):
|
||||
# batch_size, lengh;
|
||||
_inputs = []
|
||||
_targets = []
|
||||
for _b in range(0, self.bsz):
|
||||
_input = self.corpus.get_tokens(self.offset[_b] + i, bptt)
|
||||
_target = self.corpus.get_tokens(self.offset[_b] + i + 1, bptt)
|
||||
|
||||
_inputs.append(_input)
|
||||
_targets.append(_target)
|
||||
|
||||
_input = torch.tensor(_inputs, dtype=torch.int64, device=self.device).contiguous()
|
||||
_target = torch.tensor(_targets, dtype=torch.int64, device=self.device).contiguous()
|
||||
|
||||
_msk = torch.cat(
|
||||
[
|
||||
torch.zeros(bptt-eval_len, dtype=torch.float, device=self.device),
|
||||
torch.ones(eval_len, dtype=torch.float, device=self.device)
|
||||
]
|
||||
)
|
||||
_msk = _msk.unsqueeze(0).expand_as(_input) # .unsqueeze(-1) # length, 1;
|
||||
return _input, _target, _msk
|
||||
|
||||
def get_fixlen_iter(self, start=0):
|
||||
#self.data_len = self.split_data.size(1)
|
||||
_eval_cursor = 0
|
||||
for i in range(start, self.n_step - 1, self.eval_len):
|
||||
bptt = min(self.bptt, self.n_step - i - 1)
|
||||
_end_idx = i + bptt
|
||||
yield self.get_batch(i, bptt, _end_idx - _eval_cursor)
|
||||
_eval_cursor = _end_idx
|
||||
|
||||
|
||||
class BinCorpus(object):
|
||||
def __init__(self, path):
|
||||
self.path = path
|
||||
|
||||
self.book_token_span = []
|
||||
self.book_token_span.append(0)
|
||||
tokens_sum = 0
|
||||
self.num_words = 0
|
||||
|
||||
with open(path+'.info', 'r') as info_reader:
|
||||
for line in info_reader:
|
||||
items = json.loads(line.strip())
|
||||
book = items['book']
|
||||
num_tokens = items['num_subtokens']
|
||||
num_words = items['num_words']
|
||||
|
||||
tokens_sum += num_tokens
|
||||
self.book_token_span.append(tokens_sum)
|
||||
self.num_words += num_words
|
||||
|
||||
self.length = self.book_token_span[-1]
|
||||
self.bin_reader = open(path+'.bin', 'rb')
|
||||
|
||||
def get_tokens(self, offset, count):
|
||||
INT64_SIZE = 8
|
||||
self.bin_reader.seek(offset * INT64_SIZE)
|
||||
x = np.fromfile(self.bin_reader, count=count, dtype=np.int)
|
||||
return x
|
||||
|
||||
|
||||
def get_lm_corpus(data):
|
||||
print('Producing dataset {}...'.format(data))
|
||||
corpus = Corpus(data)
|
||||
return corpus
|
||||
|
||||
|
||||
def padding_tokens(tokens, max_seq_length, pad_token, direct, max_context_length=0):
|
||||
|
||||
if max_context_length == 0:
|
||||
max_context_length = max_seq_length
|
||||
|
||||
if len(tokens) > max_context_length:
|
||||
if direct > 0:
|
||||
pad_tokens = tokens[:max_context_length]
|
||||
else:
|
||||
pad_tokens = tokens[-max_context_length:]
|
||||
else:
|
||||
pad_tokens = tokens
|
||||
token_len = len(pad_tokens)
|
||||
pad_tokens = pad_tokens + [pad_token for _ in range(max_seq_length - token_len)]
|
||||
return pad_tokens, token_len
|
||||
|
||||
|
||||
class FT_Dataset(Dataset):
|
||||
def __init__(self, ft_file, batch_size, max_seq_length,
|
||||
max_eval_length=0, joint_lm=False, prefix_len=0, infix_len=0,
|
||||
prefix_cursor=1000000, infix_cursor=2000000):
|
||||
self.ft_file = ft_file
|
||||
self.ft_samples = self.read_ft_file(ft_file)
|
||||
self.batch_size = batch_size
|
||||
self.num_examples = len(self.ft_samples)
|
||||
self.max_seq_length = max_seq_length
|
||||
self.max_eval_length = max_eval_length
|
||||
self.rng = random.Random(911)
|
||||
self.joint_lm = joint_lm
|
||||
|
||||
self.num_batches = int((self.num_examples + self.batch_size - 1) / self.batch_size)
|
||||
|
||||
self.prefix_len = prefix_len
|
||||
self.infix_len = infix_len
|
||||
self.prefix_cursor = prefix_cursor
|
||||
self.infix_cursor = infix_cursor
|
||||
|
||||
def __len__(self):
|
||||
return self.num_batches * self.batch_size
|
||||
|
||||
def __getitem__(self, item):
|
||||
if(item >= self.num_examples):
|
||||
item = self.rng.randint(0, self.num_examples - 1)
|
||||
|
||||
example = self.ft_samples[item]
|
||||
context = example[0]
|
||||
completion = example[1]
|
||||
|
||||
pretokens = [i + self.prefix_cursor for i in range(0, self.prefix_len)]
|
||||
intokens = [i + self.infix_cursor for i in range(0, self.infix_len)]
|
||||
|
||||
conditions = pretokens + context + intokens
|
||||
_input, _input_len = padding_tokens(conditions + completion, self.max_seq_length, 0, 1)
|
||||
|
||||
pad_targets = [0 for i in range(0, self.prefix_len)] + context + [0 for i in range(0, self.infix_len)] + completion
|
||||
_target, _ = padding_tokens(pad_targets[1:], self.max_seq_length, 0, 1)
|
||||
|
||||
if not self.joint_lm:
|
||||
_msk = [0.0] * (len(conditions) - 1) + [1.0] * (_input_len - len(conditions))
|
||||
else:
|
||||
_msk = [1.0] * (_input_len - 1)
|
||||
|
||||
_msk, _ = padding_tokens(_msk, self.max_seq_length, 0.0, 1)
|
||||
|
||||
output = {}
|
||||
output["id"] = torch.tensor(item, dtype=torch.long)
|
||||
|
||||
_query, _query_len = padding_tokens(
|
||||
conditions, self.max_seq_length, 0, -1,
|
||||
max_context_length = self.max_seq_length - self.max_eval_length
|
||||
)
|
||||
output["query"] = torch.tensor(_query, dtype=torch.long)
|
||||
output["query_len"] = torch.tensor(_query_len, dtype=torch.long)
|
||||
|
||||
output["input"] = torch.tensor(_input, dtype=torch.long)
|
||||
output["target"] = torch.tensor(_target, dtype=torch.long)
|
||||
|
||||
output["mask"] = torch.tensor(_msk, dtype=torch.float)
|
||||
return output
|
||||
|
||||
def read_ft_file(self, ft_file):
|
||||
ft_samples = []
|
||||
with open(ft_file, 'r') as reader:
|
||||
for line in reader:
|
||||
items = json.loads(line.strip())
|
||||
context = items['context']
|
||||
completion = items['completion']
|
||||
ft_samples.append([context, completion])
|
||||
return ft_samples
|
|
@ -0,0 +1,132 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import os
|
||||
import json
|
||||
import regex as re
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8+n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
|
||||
def get_pairs(word):
|
||||
"""Return set of symbol pairs in a word.
|
||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||
"""
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
|
||||
class Encoder:
|
||||
|
||||
def __init__(self, encoder, bpe_merges, errors='replace'):
|
||||
self.encoder = encoder
|
||||
self.decoder = {v:k for k,v in self.encoder.items()}
|
||||
self.errors = errors # how to handle errors in decoding
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
|
||||
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
||||
self.cache = {}
|
||||
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
|
||||
try:
|
||||
import regex as re
|
||||
self.re = re
|
||||
except ImportError:
|
||||
raise ImportError('Please install regex with: pip install regex')
|
||||
|
||||
|
||||
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token)
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
||||
new_word.append(first+second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = ' '.join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def encode(self, text):
|
||||
bpe_tokens = []
|
||||
tokens = []
|
||||
for token in re.findall(self.pat, text):
|
||||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
||||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
||||
if token:
|
||||
tokens.append(token)
|
||||
return bpe_tokens, tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
text = ''.join([self.decoder[token] for token in tokens])
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
||||
return text
|
||||
|
||||
|
||||
def get_encoder(models_dir):
|
||||
with open(os.path.join(models_dir, 'encoder.json'), 'r') as f:
|
||||
encoder = json.load(f)
|
||||
with open(os.path.join(models_dir, 'vocab.bpe'), 'r', encoding="utf-8") as f:
|
||||
bpe_data = f.read()
|
||||
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
|
||||
return Encoder(
|
||||
encoder=encoder,
|
||||
bpe_merges=bpe_merges,
|
||||
)
|
|
@ -0,0 +1,46 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import functools
|
||||
import os, shutil
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def logging(s, log_path, print_=True, log_=True):
|
||||
if print_:
|
||||
print(s)
|
||||
if log_:
|
||||
with open(log_path, 'a+') as f_log:
|
||||
f_log.write(s + '\n')
|
||||
|
||||
|
||||
def get_logger(log_path, **kwargs):
|
||||
return functools.partial(logging, log_path=log_path, **kwargs)
|
||||
|
||||
|
||||
def create_exp_dir(dir_path, scripts_to_save=None, debug=False):
|
||||
if debug:
|
||||
print('Debug Mode : no experiment dir created')
|
||||
return functools.partial(logging, log_path=None, log_=False)
|
||||
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
|
||||
print('Experiment dir : {}'.format(dir_path))
|
||||
if scripts_to_save is not None:
|
||||
script_path = os.path.join(dir_path, 'scripts')
|
||||
if not os.path.exists(script_path):
|
||||
os.makedirs(script_path)
|
||||
for script in scripts_to_save:
|
||||
dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script))
|
||||
shutil.copyfile(script, dst_file)
|
||||
|
||||
return get_logger(log_path=os.path.join(dir_path, 'log.txt'))
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, path, epoch):
|
||||
torch.save(model, os.path.join(path, 'model_{}.pt'.format(epoch)))
|
||||
torch.save(optimizer.state_dict(), os.path.join(path, 'optimizer_{}.pt'.format(epoch)))
|
|
@ -0,0 +1,43 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import sys
|
||||
import io
|
||||
import json
|
||||
|
||||
|
||||
with open(sys.argv[1], 'r', encoding='utf8') as reader, \
|
||||
open(sys.argv[2], 'w', encoding='utf8') as writer :
|
||||
lines_dict = json.load(reader)
|
||||
|
||||
full_rela_lst = []
|
||||
full_src_lst = []
|
||||
full_tgt_lst = []
|
||||
unique_src = 0
|
||||
|
||||
for example in lines_dict:
|
||||
rela_lst = []
|
||||
temp_triples = ''
|
||||
for i, tripleset in enumerate(example['tripleset']):
|
||||
subj, rela, obj = tripleset
|
||||
rela = rela.lower()
|
||||
rela_lst.append(rela)
|
||||
if i > 0:
|
||||
temp_triples += ' | '
|
||||
temp_triples += '{} : {} : {}'.format(subj, rela, obj)
|
||||
|
||||
unique_src += 1
|
||||
|
||||
for sent in example['annotations']:
|
||||
full_tgt_lst.append(sent['text'])
|
||||
full_src_lst.append(temp_triples)
|
||||
full_rela_lst.append(rela_lst)
|
||||
|
||||
print('unique source is', unique_src)
|
||||
|
||||
for src, tgt in zip(full_src_lst, full_tgt_lst):
|
||||
x = {}
|
||||
x['context'] = src # context #+ '||'
|
||||
x['completion'] = tgt #completion
|
||||
writer.write(json.dumps(x)+'\n')
|
|
@ -0,0 +1,20 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import sys
|
||||
import io
|
||||
import json
|
||||
|
||||
|
||||
with open(sys.argv[1], 'r', encoding='utf8') as reader, \
|
||||
open(sys.argv[2], 'w', encoding='utf8') as writer :
|
||||
for line in reader:
|
||||
items = line.strip().split('||')
|
||||
context = items[0]
|
||||
completion = items[1].strip('\n')
|
||||
x = {}
|
||||
x['context'] = context #+ '||'
|
||||
x['completion'] = completion
|
||||
writer.write(json.dumps(x)+'\n')
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import sys
|
||||
import io
|
||||
import json
|
||||
|
||||
|
||||
with open(sys.argv[1], 'r', encoding='utf8') as reader, \
|
||||
open(sys.argv[2], 'w', encoding='utf8') as writer :
|
||||
lines_dict = json.load(reader)
|
||||
|
||||
full_rela_lst = []
|
||||
full_src_lst = []
|
||||
full_tgt_lst = []
|
||||
full_cate_lst = []
|
||||
|
||||
seen = [
|
||||
'Airport',
|
||||
'Astronaut',
|
||||
'Building',
|
||||
'City',
|
||||
'ComicsCharacter',
|
||||
'Food',
|
||||
'Monument',
|
||||
'SportsTeam',
|
||||
'University',
|
||||
'WrittenWork'
|
||||
]
|
||||
|
||||
cate_dict = {}
|
||||
for i, example in enumerate(lines_dict['entries']):
|
||||
sents = example[str(i+1)]['lexicalisations']
|
||||
triples = example[str(i + 1)]['modifiedtripleset']
|
||||
cate = example[str(i + 1)]['category']
|
||||
|
||||
if not cate in cate_dict:
|
||||
cate_dict[cate] = 0
|
||||
cate_dict[cate] += 1
|
||||
|
||||
rela_lst = []
|
||||
temp_triples = ''
|
||||
for i, tripleset in enumerate(triples):
|
||||
subj, rela, obj = tripleset['subject'], tripleset['property'], tripleset['object']
|
||||
rela_lst.append(rela)
|
||||
if i > 0:
|
||||
temp_triples += ' | '
|
||||
temp_triples += '{} : {} : {}'.format(subj, rela, obj)
|
||||
|
||||
for sent in sents:
|
||||
if sent["comment"] == 'good':
|
||||
full_tgt_lst.append(sent['lex'])
|
||||
full_src_lst.append(temp_triples)
|
||||
full_rela_lst.append(rela_lst)
|
||||
full_cate_lst.append(cate)
|
||||
|
||||
for cate in cate_dict:
|
||||
print('cate', cate, cate_dict[cate])
|
||||
|
||||
#edited_sents = []
|
||||
for src, tgt, cate in zip(full_src_lst, full_tgt_lst, full_cate_lst):
|
||||
x = {}
|
||||
x['context'] = src # context #+ '||'
|
||||
x['completion'] = tgt #completion
|
||||
x['cate'] = cate in seen
|
||||
writer.write(json.dumps(x)+'\n')
|
||||
|
|
@ -0,0 +1,400 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import argparse
|
||||
import time
|
||||
import math
|
||||
import os, sys
|
||||
import json
|
||||
import itertools
|
||||
from typing import Callable, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, device, dtype, nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
import torch.nn.functional as F
|
||||
torch.set_printoptions(threshold=100000)
|
||||
|
||||
import numpy as np
|
||||
|
||||
from gpu import (
|
||||
add_gpu_params,
|
||||
parse_gpu,
|
||||
distributed_opt,
|
||||
distributed_gather,
|
||||
distributed_sync,
|
||||
cleanup
|
||||
)
|
||||
|
||||
from exp_utils import create_exp_dir
|
||||
|
||||
from data_utils import FT_Dataset
|
||||
from model import GPT2Config, GPT2LMModel
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch GPT2 beam decoding')
|
||||
|
||||
add_gpu_params(parser)
|
||||
|
||||
parser.add_argument('--data', type=str, default='../data/wikitext-103',
|
||||
help='location of the data corpus')
|
||||
|
||||
parser.add_argument('--batch_size', type=int, default=10,
|
||||
help='batch size')
|
||||
|
||||
parser.add_argument('--seq_len', type=int, default=512,
|
||||
help='number of tokens to predict')
|
||||
|
||||
parser.add_argument('--eval_len', type=int, default=256,
|
||||
help='evaluation length')
|
||||
|
||||
parser.add_argument('--min_length', type=int, default=0,
|
||||
help='minimum generation length')
|
||||
|
||||
parser.add_argument('--model_card', default='gpt2.sm', choices=['gpt2.sm', 'gpt2.md', 'gpt2.lg'],
|
||||
help='model names')
|
||||
|
||||
parser.add_argument('--init_checkpoint', default=None, type=str, help='initial checkpoint')
|
||||
|
||||
parser.add_argument('--lora_dim', type=int, default=0, help='lora attn dimension')
|
||||
|
||||
parser.add_argument('--lora_alpha', type=int, default=128, help='lora attn alpha')
|
||||
|
||||
parser.add_argument('--work_dir', type=str, default=os.getenv('PT_OUTPUT_DIR', 'gpt2_model'),
|
||||
help='working folder')
|
||||
|
||||
parser.add_argument('--beam', type=int, default=1, help='beam search size')
|
||||
|
||||
parser.add_argument('--length_penalty', type=float, default=1.0, help='length penalty')
|
||||
|
||||
parser.add_argument('--no_repeat_ngram_size', type=int, default=4, help='no_repeat_ngram_size')
|
||||
|
||||
parser.add_argument('--repetition_penalty', type=float, default=1.0, help='repetition_penalty')
|
||||
|
||||
parser.add_argument('--eos_token_id', action='append', type=int, default=[50256],
|
||||
help='eos token id')
|
||||
|
||||
parser.add_argument('--output_file', type=str, default='beam_prediction.jsonl',
|
||||
help='output file name')
|
||||
|
||||
parser.add_argument('--prefix_len', default=0, type=int, help='prefix length')
|
||||
|
||||
parser.add_argument('--infix_len', default=0, type=int, help='infix length')
|
||||
|
||||
|
||||
def print_args(args):
|
||||
if args.rank == 0:
|
||||
print('=' * 100)
|
||||
for k, v in args.__dict__.items():
|
||||
print(' - {} : {}'.format(k, v))
|
||||
print('=' * 100)
|
||||
|
||||
|
||||
def _reorder_cache(past: Tuple, beam_idx: Tensor) -> Tuple[Tensor]:
|
||||
return tuple(layer_past.index_select(1, beam_idx).contiguous().detach() for layer_past in past)
|
||||
|
||||
|
||||
def _calc_banned_ngram_tokens(
|
||||
prev_input_ids: Tensor,
|
||||
num_hypos: int,
|
||||
no_repeat_ngram_size: int,
|
||||
cur_len: int
|
||||
) -> None:
|
||||
"""Copied from fairseq for no_repeat_ngram in beam_search"""
|
||||
if cur_len + 1 < no_repeat_ngram_size:
|
||||
# return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
|
||||
return [[] for _ in range(num_hypos)]
|
||||
|
||||
generated_ngrams = [{} for _ in range(num_hypos)]
|
||||
for idx in range(num_hypos):
|
||||
gen_tokens = prev_input_ids[idx].tolist()
|
||||
generated_ngram = generated_ngrams[idx]
|
||||
for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
|
||||
prev_ngram_tuple = tuple(ngram[:-1])
|
||||
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
|
||||
|
||||
def _get_generated_ngrams(hypo_idx):
|
||||
# Before decoding the next token, prevent decoding of ngrams that have already appeared
|
||||
start_idx = cur_len + 1 - no_repeat_ngram_size
|
||||
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
|
||||
return generated_ngrams[hypo_idx].get(ngram_idx, [])
|
||||
|
||||
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
|
||||
return banned_tokens
|
||||
|
||||
|
||||
def _enforce_repetition_penalty_(
|
||||
lprobs,
|
||||
batch_size,
|
||||
num_beams,
|
||||
prev_output_tokens,
|
||||
repetition_penalty
|
||||
):
|
||||
"""repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
|
||||
|
||||
for i in range(batch_size * num_beams):
|
||||
print('prev_output_tokens.shape', prev_output_tokens.shape)
|
||||
print('prev_output_tokens[i].shape', prev_output_tokens[i].shape)
|
||||
|
||||
for previous_token in set(prev_output_tokens[i].tolist()):
|
||||
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
|
||||
if lprobs[i, previous_token] < 0:
|
||||
lprobs[i, previous_token] *= repetition_penalty
|
||||
else:
|
||||
lprobs[i, previous_token] /= repetition_penalty
|
||||
|
||||
def _postprocess_next_token_scores(
|
||||
scores,
|
||||
history,
|
||||
cur_len,
|
||||
batch_size,
|
||||
num_beams,
|
||||
repetition_penalty=1.0,
|
||||
no_repeat_ngram_size=4,
|
||||
bad_words_ids=None,
|
||||
min_length=0,
|
||||
max_length=100,
|
||||
eos_token_id=None,
|
||||
):
|
||||
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
|
||||
if repetition_penalty != 1.0 and history is not None:
|
||||
_enforce_repetition_penalty_(scores, batch_size, num_beams, history, repetition_penalty)
|
||||
|
||||
# score: batch_size * beam, vocab
|
||||
# set eos token prob to zero if min_length is not reached
|
||||
if eos_token_id is not None and cur_len < min_length:
|
||||
for eos in eos_token_id:
|
||||
scores[:, eos] = -float("inf")
|
||||
|
||||
if no_repeat_ngram_size > 0 and history is not None:
|
||||
# calculate a list of banned tokens to prevent repetitively generating the same ngrams
|
||||
num_batch_hypotheses = batch_size * num_beams
|
||||
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
||||
banned_batch_tokens = _calc_banned_ngram_tokens(
|
||||
history, num_batch_hypotheses, no_repeat_ngram_size, cur_len
|
||||
)
|
||||
|
||||
for i, banned_tokens in enumerate(banned_batch_tokens):
|
||||
scores[i, banned_tokens] = -float("inf")
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
def _add_beam_candidate(
|
||||
best_score,
|
||||
best_sequence,
|
||||
batch_size,
|
||||
num_beams,
|
||||
beam_scores,
|
||||
history,
|
||||
eos_token_id=None
|
||||
):
|
||||
last_tokens = history[:, -1]
|
||||
for _i in range(batch_size * num_beams):
|
||||
if eos_token_id is None or last_tokens[_i] in eos_token_id:
|
||||
cur_len = history.shape[-1]
|
||||
_score = beam_scores.view(-1)[_i] / cur_len ** args.length_penalty
|
||||
|
||||
batch_id = _i // num_beams
|
||||
|
||||
if not batch_id in best_score or best_score[batch_id] < _score:
|
||||
best_score[batch_id] = _score
|
||||
best_sequence[batch_id][:cur_len] = history[_i]
|
||||
|
||||
beam_scores.view(-1)[_i] = -float("inf")
|
||||
|
||||
|
||||
def beam(model, data_iter, args):
|
||||
model.eval()
|
||||
total_loss = 0.
|
||||
start_time = time.time()
|
||||
|
||||
all_predictions = {}
|
||||
with torch.no_grad():
|
||||
for idx, data in enumerate(data_iter):
|
||||
data = {key: value for key, value in data.items()}
|
||||
|
||||
_id = data['id'].to(args.device)
|
||||
_query = data['query'].to(args.device)
|
||||
_query_len = data['query_len'].to(args.device)
|
||||
|
||||
## local adaptation start.
|
||||
|
||||
## local adaptation end.
|
||||
|
||||
|
||||
output = None
|
||||
score = None
|
||||
|
||||
batch_size = _id.size(0)
|
||||
num_beams = args.beam
|
||||
length_penalty = args.length_penalty
|
||||
|
||||
_batch = torch.arange(0, _id.size(0), device=args.device, dtype=torch.long)
|
||||
|
||||
past = None
|
||||
len_past = None
|
||||
|
||||
_query = _query.repeat(1, num_beams).view(batch_size * num_beams, -1)
|
||||
_query_len = _query_len.unsqueeze(-1).repeat(1, num_beams).view(-1)
|
||||
|
||||
_bbatch = _batch.unsqueeze(-1).repeat(1, num_beams).view(-1)
|
||||
|
||||
# scores for each sentence in the beam
|
||||
beam_scores = torch.zeros(
|
||||
(batch_size, num_beams), dtype=torch.float, device=_query.device
|
||||
)
|
||||
|
||||
best_sequence = torch.zeros(
|
||||
(batch_size, args.eval_len), dtype=torch.long, device=_query.device
|
||||
)
|
||||
best_score = {}
|
||||
|
||||
history = None
|
||||
with torch.no_grad():
|
||||
for i in range(0, args.eval_len):
|
||||
if i == 0:
|
||||
logits, past = model(_query)
|
||||
logits = logits[_bbatch, (_query_len-1).long(), :] # batch_size * beam, vocab
|
||||
else:
|
||||
#print('token_id.shape', token_id.shape, token_id)
|
||||
#print('past.shape', past[0].shape)
|
||||
#print('len_past.shape', len_past.shape, len_past)
|
||||
|
||||
logits, past = model(token_id, past=past, len_past=len_past)
|
||||
logits = logits[:, -1, :] # batch_size * beam, vocab
|
||||
|
||||
logits = _postprocess_next_token_scores(
|
||||
logits,
|
||||
history,
|
||||
i,
|
||||
batch_size,
|
||||
num_beams,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
no_repeat_ngram_size=args.no_repeat_ngram_size,
|
||||
min_length=args.min_length,
|
||||
eos_token_id=args.eos_token_id,
|
||||
)
|
||||
|
||||
softmax_probs = F.softmax(logits, dim=-1)
|
||||
##_prob, _w_idx = torch.topk(softmax_probs, num_beams) # batch_size, beam
|
||||
|
||||
vocab_size = softmax_probs.shape[-1]
|
||||
|
||||
|
||||
_logprob = torch.log(softmax_probs) # batch_size * beam, vocab
|
||||
if i == 0:
|
||||
next_scores = _logprob.view(batch_size, num_beams, -1)[:, 0, :] # batch_size, vocab
|
||||
|
||||
else:
|
||||
next_scores = beam_scores.unsqueeze(-1) + _logprob.view(batch_size, num_beams, -1)
|
||||
next_scores = next_scores.view(batch_size, -1) # batch_size, beam * vocab
|
||||
|
||||
next_scores, next_tokens = torch.topk(
|
||||
next_scores, num_beams, dim=1, largest=True, sorted=True
|
||||
) # batch_size, num_beams
|
||||
|
||||
beam_id = (next_tokens // vocab_size).view(-1) # batch_size * num_beams
|
||||
token_id = (next_tokens % vocab_size).view(-1).unsqueeze(-1) # batch_size, num_beams
|
||||
|
||||
beam_idx = beam_id.view(batch_size, num_beams) + (_batch * num_beams).unsqueeze(-1)
|
||||
past = _reorder_cache(past, beam_idx.view(-1))
|
||||
beam_scores = next_scores # batch_size, num_beams
|
||||
len_past = (_query_len + i).long()
|
||||
|
||||
if history is None:
|
||||
history = token_id.detach()
|
||||
else:
|
||||
history = torch.cat((history[beam_idx.view(-1)], token_id.detach()), dim=1).detach()
|
||||
|
||||
_add_beam_candidate(
|
||||
best_score, best_sequence, batch_size, num_beams, beam_scores, history,
|
||||
eos_token_id=args.eos_token_id
|
||||
)
|
||||
|
||||
_add_beam_candidate(
|
||||
best_score, best_sequence, batch_size, num_beams, beam_scores, history
|
||||
)
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
_id = distributed_gather(args, _id)
|
||||
output = distributed_gather(args, best_sequence)
|
||||
#score = distributed_gather(args, score)
|
||||
distributed_sync(args)
|
||||
|
||||
if args.rank == 0:
|
||||
_id = _id.view(-1).cpu()
|
||||
output = output.view(-1, output.shape[-1]).cpu()
|
||||
#score = score.view(-1, score.shape[-1]).cpu()
|
||||
|
||||
for _b in range(0, _id.shape[-1]):
|
||||
_i = int(_id[_b].item())
|
||||
all_predictions[_i] = {}
|
||||
all_predictions[_i]['id'] = _i
|
||||
all_predictions[_i]['predict'] = output[_b].tolist()
|
||||
#all_predictions[_i]['score'] = score[_b].tolist()
|
||||
|
||||
if idx % 10 == 0:
|
||||
print('inference samples', idx)
|
||||
|
||||
if args.rank == 0:
|
||||
pred_file = os.path.join(args.work_dir, args.output_file)
|
||||
print('saving prediction file', pred_file)
|
||||
with open(pred_file, 'w') as writer:
|
||||
for _i in all_predictions:
|
||||
writer.write(json.dumps(all_predictions[_i]) + '\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parser.parse_args()
|
||||
parse_gpu(args)
|
||||
print_args(args)
|
||||
|
||||
if args.rank == 0:
|
||||
args.logging = create_exp_dir(args.work_dir)
|
||||
|
||||
valid_data = FT_Dataset(
|
||||
args.data, args.batch_size, args.seq_len, args.eval_len,
|
||||
prefix_len=args.prefix_len, infix_len=args.infix_len
|
||||
)
|
||||
valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_data)
|
||||
valid_loader = DataLoader(
|
||||
valid_data, batch_size=args.batch_size, num_workers=0, shuffle=False,
|
||||
pin_memory=False, drop_last=False, sampler=valid_sampler
|
||||
)
|
||||
|
||||
if args.model_card == 'gpt2.sm':
|
||||
config = GPT2Config(
|
||||
n_embd=768, n_layer=12, n_head=12,
|
||||
lora_attn_dim=args.lora_dim, lora_attn_alpha=args.lora_alpha,
|
||||
prefix_len=args.prefix_len, infix_len=args.infix_len
|
||||
)
|
||||
elif args.model_card == 'gpt2.md':
|
||||
config = GPT2Config(
|
||||
n_embd=1024, n_layer=24, n_head=16,
|
||||
lora_attn_dim=args.lora_dim, lora_attn_alpha=args.lora_alpha,
|
||||
prefix_len=args.prefix_len, infix_len=args.infix_len
|
||||
)
|
||||
elif args.model_card == 'gpt2.lg':
|
||||
config = GPT2Config(
|
||||
n_embd=1280, n_layer=36, n_head=20,
|
||||
lora_attn_dim=args.lora_dim, lora_attn_alpha=args.lora_alpha,
|
||||
prefix_len=args.prefix_len, infix_len=args.infix_len
|
||||
)
|
||||
|
||||
lm_net = GPT2LMModel(config)
|
||||
if args.init_checkpoint is not None:
|
||||
print('loading model pretrained weight.')
|
||||
cp = torch.load(args.init_checkpoint, map_location=torch.device('cpu'))
|
||||
lm_net.load_weight(cp)
|
||||
lm_net = lm_net.cuda()
|
||||
|
||||
print('model sampling ...')
|
||||
beam(lm_net, valid_loader, args)
|
||||
distributed_sync(args)
|
||||
print('cleanup dist ...')
|
||||
cleanup(args)
|
|
@ -0,0 +1,162 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import json
|
||||
import numpy as np
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import json
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
|
||||
import encoder
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('--vocab', type=str, default=None, help='vocab path')
|
||||
|
||||
parser.add_argument('--sample_file', default=None, type=str, help='ft sample file')
|
||||
parser.add_argument('--input_file', default=None, type=str, help='ft input file')
|
||||
|
||||
parser.add_argument('--output_ref_file', default=None, type=str, help='output reference file')
|
||||
parser.add_argument('--output_pred_file', default=None, type=str, help='output predicion file')
|
||||
|
||||
parser.add_argument('--ref_unique_file', default=None, type=str, help='reference unique id file')
|
||||
|
||||
parser.add_argument('--ref_type', default='e2e', choices=['e2e', 'webnlg', 'dart'],
|
||||
help='e2e style reference type; webnlg style reference type.')
|
||||
parser.add_argument('--ref_num', default=4, type=int, help='number of references.')
|
||||
|
||||
|
||||
parser.add_argument('--tokenize', action='store_true', help='')
|
||||
parser.add_argument('--lower', action='store_true', help='')
|
||||
|
||||
parser.add_argument('--filter', default='all', choices=['all', 'seen', 'unseen'],
|
||||
help='for webnlg only, filter categories that are seen during training, unseen, or all')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def stardard_tokenize(sent):
|
||||
sent = ' '.join(re.split('(\W)', sent))
|
||||
sent = sent.split()
|
||||
sent = ' '.join(sent)
|
||||
return sent
|
||||
|
||||
|
||||
def post_process(sent, is_tokenize, is_lower):
|
||||
if is_lower:
|
||||
sent = sent.lower()
|
||||
if is_tokenize:
|
||||
sent = stardard_tokenize(sent)
|
||||
|
||||
return sent
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
enc = encoder.get_encoder(args.vocab)
|
||||
|
||||
ref_unique = None
|
||||
|
||||
if args.ref_unique_file is not None:
|
||||
print('reading ref_unique_file.')
|
||||
ref_unique = []
|
||||
uniques = {}
|
||||
with open(args.ref_unique_file, 'r') as ref_unique_reader:
|
||||
for line in ref_unique_reader:
|
||||
_id = int(line.strip())
|
||||
ref_unique.append(_id)
|
||||
uniques[_id] = 1
|
||||
print('len refer dict', len(ref_unique), 'unique', len(uniques))
|
||||
|
||||
with open(args.sample_file, 'r') as sample_reader, \
|
||||
open(args.input_file, 'r', encoding='utf8') as input_reader, \
|
||||
open(args.output_pred_file, 'w', encoding='utf8') as pred_writer:
|
||||
|
||||
refer_dict = {}
|
||||
context_list = []
|
||||
line_id = 0
|
||||
for line in input_reader:
|
||||
items = json.loads(line.strip())
|
||||
context = items['context']
|
||||
completion = items['completion']
|
||||
|
||||
context_list.append(context)
|
||||
|
||||
keep = False
|
||||
|
||||
if args.filter == 'all':
|
||||
keep = True
|
||||
if args.filter == 'seen' and items['cate']:
|
||||
keep = True
|
||||
if args.filter == 'unseen' and not items['cate']:
|
||||
keep = True
|
||||
|
||||
if ref_unique is None:
|
||||
_key = context
|
||||
else:
|
||||
_key = ref_unique[line_id]
|
||||
|
||||
if keep:
|
||||
if not _key in refer_dict:
|
||||
refer_dict[_key] = {}
|
||||
refer_dict[_key]['references'] = []
|
||||
refer_dict[_key]['references'].append(completion.split('<|endoftext|>')[0].split('\n\n')[0].strip())
|
||||
|
||||
line_id += 1
|
||||
|
||||
print('unique refer dict', len(refer_dict))
|
||||
|
||||
for line in sample_reader:
|
||||
items = json.loads(line.strip())
|
||||
_id = items['id']
|
||||
_pred_tokens = items['predict']
|
||||
|
||||
if ref_unique is None:
|
||||
_key = context_list[_id]
|
||||
else:
|
||||
_key = ref_unique[_id]
|
||||
|
||||
#assert _key in refer_dict
|
||||
if _key in refer_dict:
|
||||
refer_dict[_key]['sample'] = enc.decode(_pred_tokens).split('<|endoftext|>')[0].split('\n\n')[0].strip()
|
||||
|
||||
references = [refer_dict[s]['references'] for s in refer_dict]
|
||||
hypothesis = [refer_dict[s]['sample'] for s in refer_dict]
|
||||
|
||||
if args.ref_type == 'e2e':
|
||||
with open(args.output_ref_file, 'w', encoding='utf8') as ref_writer:
|
||||
for ref, hyp in zip(references, hypothesis):
|
||||
for r in ref:
|
||||
ref_writer.write(post_process(r, args.tokenize, args.lower) + '\n')
|
||||
ref_writer.write('\n')
|
||||
pred_writer.write(post_process(hyp, args.tokenize, args.lower) + '\n')
|
||||
|
||||
elif args.ref_type in ['webnlg', 'dart']:
|
||||
if not os.path.exists(args.output_ref_file):
|
||||
os.makedirs(args.output_ref_file)
|
||||
|
||||
reference_writers = [
|
||||
open(os.path.join(args.output_ref_file, f'reference{fid}'), 'w', encoding='utf8')
|
||||
for fid in range(0, args.ref_num)
|
||||
]
|
||||
|
||||
for ref, hyp in zip(references, hypothesis):
|
||||
for fid in range(0, args.ref_num):
|
||||
if len(ref) > fid:
|
||||
reference_writers[fid].write(post_process(ref[fid], args.tokenize, args.lower) + '\n')
|
||||
else:
|
||||
reference_writers[fid].write(post_process(ref[0], args.tokenize, args.lower) + '\n')
|
||||
pred_writer.write(post_process(hyp, args.tokenize, args.lower) + '\n')
|
||||
|
||||
for writer in reference_writers:
|
||||
writer.close()
|
|
@ -0,0 +1,70 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
import encoder
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.optim as optim
|
||||
import torch.utils.data
|
||||
|
||||
import numpy
|
||||
import io
|
||||
import sys
|
||||
import threading
|
||||
import math
|
||||
import random
|
||||
|
||||
import json
|
||||
import collections
|
||||
from collections import Counter
|
||||
from collections import OrderedDict
|
||||
from progress.bar import Bar as Bar
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--input', default=None, type=str, help='ft input file')
|
||||
parser.add_argument('--vocab', type=str, default=None, help='vocab path')
|
||||
parser.add_argument('--output', default=None, type=str, help='ft output file')
|
||||
parser.add_argument('--add_bos', action='store_true', help='')
|
||||
parser.add_argument('--add_eos', action='store_true', help='')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
enc = encoder.get_encoder(args.vocab)
|
||||
|
||||
writer = open(args.output, 'w')
|
||||
|
||||
with open(args.input, 'r') as reader:
|
||||
line_idx = 0
|
||||
for line in reader:
|
||||
items = json.loads(line.strip())
|
||||
context = items['context']
|
||||
completion = items['completion']
|
||||
|
||||
bos = 50256
|
||||
eos = 50256
|
||||
context_bpes, _ = enc.encode(context)
|
||||
context_bpes += [bos] if args.add_bos else []
|
||||
|
||||
completion_bpes, _ = enc.encode(' ' + completion)
|
||||
completion_bpes += [eos] if args.add_eos else []
|
||||
|
||||
ft_json = {}
|
||||
ft_json['context'] = context_bpes
|
||||
ft_json['completion'] = completion_bpes
|
||||
writer.write(json.dumps(ft_json)+'\n')
|
||||
|
||||
line_idx += 1
|
||||
|
||||
writer.close()
|
|
@ -0,0 +1,361 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import argparse
|
||||
import time
|
||||
import math
|
||||
import os, sys
|
||||
import numpy as np
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import random
|
||||
from torch.utils.data import DataLoader
|
||||
torch.set_printoptions(threshold=100000)
|
||||
|
||||
from gpu import (
|
||||
add_gpu_params,
|
||||
parse_gpu,
|
||||
distributed_opt,
|
||||
distributed_gather,
|
||||
distributed_sync,
|
||||
cleanup
|
||||
)
|
||||
from optimizer import (
|
||||
create_adam_optimizer,
|
||||
create_optimizer_scheduler,
|
||||
add_optimizer_params,
|
||||
create_adam_optimizer_from_args
|
||||
)
|
||||
|
||||
from data_utils import FT_Dataset
|
||||
from model import GPT2Config, GPT2LMModel
|
||||
from exp_utils import create_exp_dir
|
||||
|
||||
import loralib as lora
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch GPT2 ft script')
|
||||
|
||||
add_gpu_params(parser)
|
||||
add_optimizer_params(parser)
|
||||
|
||||
parser.add_argument('--train_data', required=True, help='location of training data corpus')
|
||||
|
||||
parser.add_argument('--valid_data', required=True, help='location of validation data corpus')
|
||||
|
||||
parser.add_argument('--train_batch_size', type=int, default=8, help='training batch size')
|
||||
|
||||
parser.add_argument('--valid_batch_size', type=int, default=4, help='validation batch size')
|
||||
|
||||
parser.add_argument('--grad_acc', type=int, default=1, help='gradient accumulation steps')
|
||||
|
||||
parser.add_argument('--clip', type=float, default=0.0, help='gradient clip')
|
||||
|
||||
parser.add_argument('--seq_len', type=int, default=512, help='number of tokens to predict.')
|
||||
|
||||
parser.add_argument('--model_card', default='gpt2.md', choices=['gpt2.sm', 'gpt2.md', 'gpt2.lg'],
|
||||
help='model names')
|
||||
|
||||
parser.add_argument('--init_checkpoint', default=None, help='pretrained checkpoint path')
|
||||
|
||||
parser.add_argument('--fp16', action='store_true', help='train model with fp16')
|
||||
|
||||
parser.add_argument('--log_interval', type=int, default=100, help='log interval')
|
||||
|
||||
parser.add_argument('--eval_interval', type=int, default=2000, help='eval interval')
|
||||
|
||||
parser.add_argument('--save_interval', type=int, default=500, help='save interval')
|
||||
|
||||
parser.add_argument('--work_dir', type=str, default=os.getenv('PT_OUTPUT_DIR', 'gpt2_model'),
|
||||
help='working folder.')
|
||||
|
||||
parser.add_argument('--lora_dim', type=int, default=0, help='lora attn dimension')
|
||||
|
||||
parser.add_argument('--lora_alpha', type=int, default=128, help='lora attn alpha')
|
||||
|
||||
parser.add_argument('--obj', default='clm', choices=['jlm', 'clm'],
|
||||
help='language model training objective')
|
||||
|
||||
parser.add_argument('--lora_dropout', default=0.0, type=float,
|
||||
help='dropout probability for lora layers')
|
||||
|
||||
parser.add_argument('--label_smooth', default=0.0, type=float, help='label smoothing')
|
||||
|
||||
parser.add_argument('--roll_interval', type=int, default=-1, help='rolling interval')
|
||||
|
||||
parser.add_argument('--roll_lr', type=float, default=0.00001, help='rolling learning rate')
|
||||
|
||||
parser.add_argument('--roll_step', type=int, default=100, help='rolling step')
|
||||
|
||||
parser.add_argument('--eval_epoch', type=int, default=1, help='eval per number of epochs')
|
||||
|
||||
# influence model, calculate the influence score between two samples.
|
||||
def print_args(args):
|
||||
if args.rank == 0:
|
||||
print('=' * 100)
|
||||
for k, v in args.__dict__.items():
|
||||
print(f' - {k} : {v}')
|
||||
print('=' * 100)
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value
|
||||
Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
|
||||
"""
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
def optimizer_step(_loss, _optimizer, _model, _schedule, args, is_update=True):
|
||||
if args.fp16:
|
||||
with amp.scale_loss(_loss, _optimizer) as _scaled_loss:
|
||||
_scaled_loss.backward()
|
||||
else:
|
||||
_loss.backward()
|
||||
|
||||
if is_update:
|
||||
if args.clip > 0:
|
||||
if args.fp16:
|
||||
torch.nn.utils.clip_grad_norm_(amp.master_params(_optimizer), args.clip)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(_model.parameters(), args.clip)
|
||||
|
||||
_optimizer.step()
|
||||
_optimizer.zero_grad()
|
||||
|
||||
if _schedule is not None:
|
||||
_schedule.step()
|
||||
|
||||
|
||||
def evaluate(model, valid_loader, args):
|
||||
model.eval()
|
||||
total_loss = 0.
|
||||
start_time = time.time()
|
||||
|
||||
avg_lm_loss = AverageMeter()
|
||||
|
||||
with torch.no_grad():
|
||||
for idx, data in enumerate(valid_loader):
|
||||
data = {key: value for key, value in data.items()}
|
||||
|
||||
_input = data['input'].to(args.device)
|
||||
_target = data['target'].to(args.device)
|
||||
_msk = data['mask'].to(args.device)
|
||||
|
||||
_lm_logits, _loss = model(_input, lm_labels=_target, lm_mask=_msk)
|
||||
loss = _loss.mean()
|
||||
|
||||
avg_lm_loss.update(loss.item())
|
||||
|
||||
if idx % 100 == 0:
|
||||
print('eval samples:', idx, 'loss:', loss.float())
|
||||
|
||||
total_time = time.time() - start_time
|
||||
print('average loss', avg_lm_loss.avg)
|
||||
return avg_lm_loss.avg, math.exp(avg_lm_loss.avg)
|
||||
|
||||
|
||||
def train_validate(
|
||||
model,
|
||||
optimizer,
|
||||
scheduler,
|
||||
train_loader,
|
||||
valid_loader,
|
||||
args,
|
||||
train_step=0,
|
||||
epoch=0
|
||||
):
|
||||
model.train()
|
||||
avg_lm_loss = AverageMeter()
|
||||
print('start to train the model................', epoch)
|
||||
log_start_time = time.time()
|
||||
best_val_ppl = None
|
||||
|
||||
train_loader.sampler.set_epoch(epoch)
|
||||
|
||||
for idx, data in enumerate(train_loader):
|
||||
data = {key: value for key, value in data.items()}
|
||||
|
||||
_input = data['input'].to(args.device)
|
||||
_target = data['target'].to(args.device)
|
||||
_msk = data['mask'].to(args.device)
|
||||
|
||||
_lm_logits, _lm_loss = model(
|
||||
_input, lm_labels=_target, lm_mask=_msk, label_smooth=args.label_smooth
|
||||
)
|
||||
|
||||
_lm_loss = _lm_loss.mean()
|
||||
|
||||
train_step += 1
|
||||
is_update = True if train_step % args.grad_acc == 0 else False
|
||||
avg_lm_loss.update(_lm_loss.item())
|
||||
optimizer_step(
|
||||
_lm_loss/(args.grad_acc), optimizer, model, scheduler, args, is_update=is_update
|
||||
)
|
||||
|
||||
if train_step % args.log_interval == 0:
|
||||
elapsed = time.time() - log_start_time
|
||||
lr = optimizer.param_groups[0]['lr']
|
||||
log_str = f'| epoch {epoch:3d} step {train_step:>8d} | { idx + 1:>6d} batches | ' \
|
||||
f'lr {lr:.3g} | ms/batch {elapsed * 1000 / args.log_interval:5.2f} | ' \
|
||||
f'loss {avg_lm_loss.val:5.2f} | avg loss {avg_lm_loss.avg:5.2f} | ' \
|
||||
f'ppl {math.exp(avg_lm_loss.avg):5.2f}'
|
||||
|
||||
if args.rank == 0:
|
||||
print(log_str)
|
||||
log_start_time = time.time()
|
||||
avg_lm_loss.reset()
|
||||
|
||||
if train_step % args.save_interval == 0:
|
||||
if args.rank == 0:
|
||||
model_path = os.path.join(args.work_dir, f'model.{train_step}.pt')
|
||||
print('saving checkpoint', model_path)
|
||||
torch.save({'model_state_dict': lora.lora_state_dict(model)}, model_path)
|
||||
distributed_sync(args)
|
||||
|
||||
# evaluation interval
|
||||
if train_step % args.eval_interval == 0:
|
||||
eval_start_time = time.time()
|
||||
|
||||
valid_loss, valid_ppl = evaluate(model, valid_loader, args)
|
||||
|
||||
if best_val_ppl is None or valid_ppl < best_val_ppl:
|
||||
best_val_ppl = valid_ppl
|
||||
|
||||
log_str = f'| Eval {train_step // args.eval_interval:3d} at step {train_step:>8d} | ' \
|
||||
f'time: {time.time() - eval_start_time:5.2f}s | valid loss {valid_loss:5.2f} | ' \
|
||||
f'valid ppl {valid_ppl:5.2f} | best ppl {best_val_ppl:5.2f} '
|
||||
|
||||
if args.rank == 0:
|
||||
print('-' * 100)
|
||||
print(log_str)
|
||||
print('-' * 100)
|
||||
|
||||
model.train()
|
||||
distributed_sync(args)
|
||||
|
||||
if train_step == args.max_step:
|
||||
break
|
||||
|
||||
if args.rank == 0:
|
||||
model_path = os.path.join(args.work_dir, f'model.{train_step}.pt')
|
||||
print('saving checkpoint', model_path)
|
||||
torch.save({'model_state_dict': model.state_dict()}, model_path)
|
||||
distributed_sync(args)
|
||||
return train_step
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parser.parse_args()
|
||||
parse_gpu(args)
|
||||
print_args(args)
|
||||
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
except Exception as e:
|
||||
warnings.warn('Could not import amp, apex may not be installed')
|
||||
|
||||
torch.manual_seed(args.random_seed)
|
||||
random.seed(args.random_seed)
|
||||
|
||||
if args.rank == 0:
|
||||
args.logging = create_exp_dir(args.work_dir)
|
||||
|
||||
train_data = FT_Dataset(
|
||||
args.train_data, args.train_batch_size, args.seq_len,
|
||||
joint_lm=args.obj=='jlm'
|
||||
)
|
||||
|
||||
valid_data = FT_Dataset(
|
||||
args.valid_data, args.valid_batch_size, args.seq_len,
|
||||
)
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_data, batch_size=args.train_batch_size, num_workers=0,
|
||||
shuffle=False, pin_memory=False, drop_last=True,
|
||||
sampler=torch.utils.data.distributed.DistributedSampler(train_data, seed=args.random_seed)
|
||||
)
|
||||
|
||||
valid_loader = DataLoader(
|
||||
valid_data, batch_size=args.valid_batch_size, num_workers=0,
|
||||
shuffle=False, pin_memory=False, drop_last=False,
|
||||
sampler=torch.utils.data.distributed.DistributedSampler(valid_data, seed=args.random_seed)
|
||||
)
|
||||
|
||||
if args.model_card == 'gpt2.sm':
|
||||
config = GPT2Config(
|
||||
n_embd=768, n_layer=12, n_head=12,
|
||||
lora_attn_dim=args.lora_dim,
|
||||
lora_attn_alpha=args.lora_alpha,
|
||||
lora_dropout=args.lora_dropout,
|
||||
)
|
||||
elif args.model_card == 'gpt2.md':
|
||||
config = GPT2Config(
|
||||
n_embd=1024, n_layer=24, n_head=16,
|
||||
lora_attn_dim=args.lora_dim,
|
||||
lora_attn_alpha=args.lora_alpha,
|
||||
lora_dropout=args.lora_dropout,
|
||||
)
|
||||
elif args.model_card == 'gpt2.lg':
|
||||
config = GPT2Config(
|
||||
n_embd=1280, n_layer=36, n_head=20,
|
||||
lora_attn_dim=args.lora_dim,
|
||||
lora_attn_alpha=args.lora_alpha,
|
||||
lora_dropout=args.lora_dropout,
|
||||
)
|
||||
|
||||
lm_net = GPT2LMModel(config)
|
||||
if args.init_checkpoint is not None:
|
||||
print('loading model pretrained weight.')
|
||||
lm_net.load_weight(torch.load(args.init_checkpoint))
|
||||
|
||||
lm_net = lm_net.cuda()
|
||||
|
||||
if args.lora_dim > 0:
|
||||
lora.mark_only_lora_as_trainable(lm_net)
|
||||
optimizer = create_adam_optimizer_from_args(lm_net, args)
|
||||
|
||||
if args.max_step is None:
|
||||
args.max_step = (args.max_epoch * train_data.num_batches + args.world_size - 1) // args.world_size
|
||||
print('set max_step:', args.max_step)
|
||||
|
||||
scheduler = create_optimizer_scheduler(optimizer, args)
|
||||
if args.fp16:
|
||||
lm_net, optimizer = amp.initialize(lm_net, optimizer, opt_level="O1")
|
||||
lm_net, optimizer = distributed_opt(args, lm_net, optimizer, grad_acc=args.grad_acc)
|
||||
|
||||
try:
|
||||
train_step = 0
|
||||
for epoch in itertools.count(start=1):
|
||||
train_step = train_validate(
|
||||
lm_net, optimizer, scheduler, train_loader, valid_loader, args,
|
||||
train_step=train_step, epoch=epoch
|
||||
)
|
||||
|
||||
if train_step >= args.max_step or (args.max_epoch is not None and epoch >= args.max_epoch):
|
||||
if args.rank == 0:
|
||||
print('-' * 100)
|
||||
print('End of training')
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
if args.rank == 0:
|
||||
print('-' * 100)
|
||||
print('Exiting from training early')
|
||||
|
||||
distributed_sync(args)
|
||||
print('cleanup dist ...')
|
||||
cleanup(args)
|
|
@ -0,0 +1,126 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import argparse
|
||||
import time
|
||||
import math
|
||||
import os, sys
|
||||
import itertools
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def add_gpu_params(parser: argparse.ArgumentParser):
|
||||
parser.add_argument("--platform", default='k8s', type=str, help='platform cloud')
|
||||
parser.add_argument("--local_rank", default=0, type=int, help='local rank')
|
||||
parser.add_argument("--rank", default=0, type=int, help='rank')
|
||||
parser.add_argument("--device", default=0, type=int, help='device')
|
||||
parser.add_argument("--world_size", default=0, type=int, help='world size')
|
||||
parser.add_argument("--random_seed", default=10, type=int, help='random seed')
|
||||
|
||||
|
||||
def distributed_opt(args, model, opt, grad_acc=1):
|
||||
if args.platform == 'azure':
|
||||
args.hvd.broadcast_parameters(model.state_dict(), root_rank=0)
|
||||
opt = args.hvd.DistributedOptimizer(
|
||||
opt, named_parameters=model.named_parameters(), backward_passes_per_step=grad_acc
|
||||
)
|
||||
elif args.platform == 'philly' or args.platform == 'k8s' or args.platform == 'local':
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[args.local_rank], output_device=args.local_rank,
|
||||
find_unused_parameters=False, broadcast_buffers=False
|
||||
)
|
||||
return model, opt
|
||||
|
||||
|
||||
def distributed_gather(args, tensor):
|
||||
g_y = [torch.zeros_like(tensor) for _ in range(args.world_size)]
|
||||
torch.distributed.all_gather(g_y, tensor, async_op=False)
|
||||
return torch.stack(g_y)
|
||||
|
||||
|
||||
def distributed_sync(args):
|
||||
if args.platform == 'azure':
|
||||
args.hvd.allreduce(torch.tensor(0), name='barrier')
|
||||
else:
|
||||
args.dist.barrier()
|
||||
|
||||
|
||||
def parse_gpu(args):
|
||||
torch.manual_seed(args.random_seed)
|
||||
|
||||
if args.platform == 'local':
|
||||
dist.init_process_group(backend='nccl')
|
||||
local_rank = torch.distributed.get_rank()
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device('cuda', local_rank)
|
||||
args.rank = local_rank
|
||||
args.device = device
|
||||
args.world_size = torch.distributed.get_world_size()
|
||||
args.dist = dist
|
||||
|
||||
elif args.platform == 'azure':
|
||||
import horovod.torch as hvd
|
||||
hvd.init()
|
||||
print('azure hvd rank', hvd.rank(), 'local rank', hvd.local_rank())
|
||||
local_rank = hvd.local_rank()
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device('cuda', local_rank)
|
||||
rank = hvd.rank()
|
||||
world_size = hvd.size()
|
||||
|
||||
args.local_rank = local_rank
|
||||
args.rank = rank
|
||||
args.device = device
|
||||
args.world_size = world_size
|
||||
args.hvd = hvd
|
||||
|
||||
elif args.platform == 'philly':
|
||||
local_rank = args.local_rank
|
||||
torch.cuda.set_device(local_rank)
|
||||
dist.init_process_group(backend='nccl')
|
||||
rank = dist.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
device = torch.device('cuda', local_rank)
|
||||
|
||||
args.rank = rank
|
||||
args.device = device
|
||||
args.world_size = world_size
|
||||
args.dist = dist
|
||||
elif args.platform == 'k8s':
|
||||
master_uri = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
|
||||
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
|
||||
args.local_rank = local_rank
|
||||
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
|
||||
world_rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
|
||||
rank = world_rank
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
dist.init_process_group(
|
||||
backend='nccl',
|
||||
init_method=master_uri,
|
||||
world_size=world_size,
|
||||
rank=world_rank,
|
||||
)
|
||||
device = torch.device("cuda", local_rank)
|
||||
args.rank = rank
|
||||
args.device = device
|
||||
args.world_size = world_size
|
||||
args.dist = dist
|
||||
print(
|
||||
'myrank:', args.rank,
|
||||
'local_rank:', args.local_rank,
|
||||
'device_count:', torch.cuda.device_count(),
|
||||
'world_size:', args.world_size
|
||||
)
|
||||
|
||||
|
||||
def cleanup(args):
|
||||
if args.platform == 'k8s' or args.platform == 'philly':
|
||||
args.dist.destroy_process_group()
|
|
@ -0,0 +1,450 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
import copy
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
import loralib as lora
|
||||
|
||||
|
||||
def gelu(x):
|
||||
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
|
||||
|
||||
def gelu_fast(x):
|
||||
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
|
||||
|
||||
|
||||
def gelu_new(x):
|
||||
""" Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
|
||||
Also see https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def _gelu_python(x):
|
||||
""" Original Implementation of the gelu activation function in Google Bert repo when initially created.
|
||||
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
||||
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
This is now written in C in torch.nn.functional
|
||||
Also see https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-12):
|
||||
"""Construct a layernorm module in the TF style (epsilon inside the square root)."""
|
||||
super(LayerNorm, self).__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.bias = nn.Parameter(torch.zeros(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, x):
|
||||
u = x.mean(-1, keepdim=True)
|
||||
s = (x - u).pow(2).mean(-1, keepdim=True)
|
||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
|
||||
return self.weight * x + self.bias
|
||||
|
||||
|
||||
class Conv1D(nn.Module):
|
||||
def __init__(self, nf, nx):
|
||||
super(Conv1D, self).__init__()
|
||||
self.nf = nf
|
||||
w = torch.empty(nx, nf)
|
||||
nn.init.normal_(w, std=0.02)
|
||||
self.weight = Parameter(w)
|
||||
self.bias = Parameter(torch.zeros(nf))
|
||||
|
||||
def forward(self, x):
|
||||
size_out = x.size()[:-1] + (self.nf,)
|
||||
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
|
||||
x = x.view(*size_out)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, nx, n_ctx, config, scale=False):
|
||||
super(Attention, self).__init__()
|
||||
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
||||
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
|
||||
|
||||
assert n_state % config.n_head == 0
|
||||
self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
|
||||
self.n_head = config.n_head
|
||||
self.split_size = n_state
|
||||
self.scale = scale
|
||||
self.c_attn = Conv1D(n_state * 3, nx)
|
||||
self.c_attn = lora.MergedLinear(
|
||||
nx, n_state * 3,
|
||||
r=config.lora_attn_dim,
|
||||
lora_alpha=config.lora_attn_alpha,
|
||||
lora_dropout=config.lora_dropout,
|
||||
enable_lora=[True, False, True],
|
||||
fan_in_fan_out=True,
|
||||
merge_weights=False
|
||||
)
|
||||
self.c_proj = Conv1D(n_state, nx)
|
||||
|
||||
self.config = config
|
||||
|
||||
def _attn(self, q, k, v, len_kv=None):
|
||||
w = torch.matmul(q, k)
|
||||
if self.scale:
|
||||
w = w / math.sqrt(v.size(-1))
|
||||
nd, ns = w.size(-2), w.size(-1)
|
||||
b = self.bias[:, :, ns-nd:ns, :ns]
|
||||
w = w * b - 1e10 * (1 - b)
|
||||
|
||||
# q : (batch, head, q_seq_length, head_features)
|
||||
# k : (batch, head, head_features, kv_seq_length)
|
||||
# w : (batch, head, q_seq_length, kv_seq_length)
|
||||
# v : (batch, head, kv_seq_length, head_features)
|
||||
if len_kv is not None:
|
||||
_len = torch.arange(k.size(-1), device=k.device)
|
||||
_input_msk = _len[None, :] >= (len_kv)[:, None]
|
||||
w = w.masked_fill(_input_msk.unsqueeze(1).unsqueeze(2), -1.0e10)
|
||||
|
||||
w = nn.Softmax(dim=-1)(w)
|
||||
return torch.matmul(w, v)
|
||||
|
||||
def merge_heads(self, x):
|
||||
x = x.permute(0, 2, 1, 3).contiguous()
|
||||
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
|
||||
return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
|
||||
|
||||
def split_heads(self, x, k=False):
|
||||
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
|
||||
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
|
||||
if k:
|
||||
return x.permute(0, 2, 3, 1).contiguous() # (batch, head, head_features, seq_length)
|
||||
else:
|
||||
return x.permute(0, 2, 1, 3).contiguous() # (batch, head, seq_length, head_features)
|
||||
|
||||
def forward(self, x, history=None, layer_past=None, len_past=None):
|
||||
hidden_states = x
|
||||
|
||||
x = self.c_attn(x)
|
||||
query, key, value = x.split(self.split_size, dim=2)
|
||||
|
||||
query = self.split_heads(query)
|
||||
key = self.split_heads(key, k=True)
|
||||
value = self.split_heads(value)
|
||||
|
||||
#_input_msk = None
|
||||
|
||||
len_kv = None
|
||||
|
||||
if layer_past is not None:
|
||||
# key : (batch, head, head_features, seq_length)
|
||||
# value : (batch, head, seq_length, head_features)
|
||||
# layer_past, key : (batch, head, seq_length, head_features)
|
||||
if len_past is None:
|
||||
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
|
||||
key = torch.cat((past_key, key), dim=-1)
|
||||
value = torch.cat((past_value, value), dim=-2)
|
||||
else:
|
||||
key_seq = key.shape[-1]
|
||||
assert key_seq == 1
|
||||
|
||||
_batch = torch.arange(0, key.shape[0], dtype=torch.long, device=key.device)
|
||||
|
||||
past_key, past_value = layer_past[0], layer_past[1]
|
||||
|
||||
past_key[_batch,:,len_past,:] = key.squeeze(-1)
|
||||
past_value[_batch,:,len_past,:] = value.squeeze(-2)
|
||||
|
||||
key = past_key.transpose(-2, -1)
|
||||
value = past_value
|
||||
|
||||
len_kv = len_past + 1
|
||||
|
||||
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
|
||||
a = self._attn(query, key, value, len_kv = len_kv)
|
||||
a = self.merge_heads(a)
|
||||
a = self.c_proj(a)
|
||||
return a, present
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
|
||||
super(MLP, self).__init__()
|
||||
nx = config.n_embd
|
||||
self.c_fc = Conv1D(n_state, nx)
|
||||
self.c_proj = Conv1D(nx, n_state)
|
||||
self.act = gelu
|
||||
|
||||
def forward(self, x):
|
||||
h = self.act(self.c_fc(x))
|
||||
h2 = self.c_proj(h)
|
||||
return h2
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, n_ctx, config, scale=False):
|
||||
super(Block, self).__init__()
|
||||
nx = config.n_embd
|
||||
self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
|
||||
self.attn = Attention(nx, n_ctx, config, scale)
|
||||
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
|
||||
self.mlp = MLP(4 * nx, config)
|
||||
|
||||
def forward(self, x, layer_past=None, len_past=None):
|
||||
a, present = self.attn(self.ln_1(x), layer_past=layer_past, len_past=len_past)
|
||||
x = x + a
|
||||
m = self.mlp(self.ln_2(x))
|
||||
x = x + m
|
||||
return x, present
|
||||
|
||||
|
||||
class GPT2Model(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(GPT2Model, self).__init__()
|
||||
self.n_layer = config.n_layer
|
||||
self.n_embd = config.n_embd
|
||||
self.n_vocab = config.vocab_size
|
||||
|
||||
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
||||
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
|
||||
block = Block(config.n_ctx, config, scale=True)
|
||||
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
|
||||
self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.config = config
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids=None,
|
||||
token_type_ids=None,
|
||||
past=None,
|
||||
len_past=None
|
||||
):
|
||||
if past is None:
|
||||
past_length = 0
|
||||
past = [None] * len(self.h)
|
||||
elif len_past is None:
|
||||
# equal size for past. []
|
||||
past_length = past[0][0].size(-2)
|
||||
|
||||
if position_ids is None and len_past is None:
|
||||
position_ids = torch.arange(
|
||||
past_length, input_ids.size(-1) + past_length,
|
||||
dtype=torch.long, device=input_ids.device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
|
||||
elif len_past is not None:
|
||||
position_ids = (len_past).unsqueeze(1) #.long()
|
||||
|
||||
input_shape = input_ids.size()
|
||||
input_ids = input_ids.view(-1, input_ids.size(-1))
|
||||
position_ids = position_ids.view(-1, position_ids.size(-1))
|
||||
|
||||
inputs_embeds = self.wte(input_ids)
|
||||
|
||||
position_embeds = self.wpe(position_ids)
|
||||
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
|
||||
token_type_embeds = self.wte(token_type_ids)
|
||||
else:
|
||||
token_type_embeds = 0
|
||||
hidden_states = inputs_embeds + position_embeds + token_type_embeds
|
||||
presents = []
|
||||
for block, layer_past in zip(self.h, past):
|
||||
hidden_states, present = block(hidden_states, layer_past = layer_past, len_past=len_past)
|
||||
presents.append(present)
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
return hidden_states.view(*output_shape), presents
|
||||
|
||||
|
||||
class GPT2LMHead(nn.Module):
|
||||
def __init__(self, model_embeddings_weights, config):
|
||||
super(GPT2LMHead, self).__init__()
|
||||
self.n_embd = config.n_embd
|
||||
self.set_embeddings_weights(model_embeddings_weights)
|
||||
|
||||
def set_embeddings_weights(self, model_embeddings_weights):
|
||||
embed_shape = model_embeddings_weights.shape
|
||||
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
|
||||
self.decoder.weight = model_embeddings_weights # Tied weights
|
||||
|
||||
def forward(self, hidden_state):
|
||||
# Truncated Language modeling logits (we remove the last token)
|
||||
# h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
|
||||
lm_logits = self.decoder(hidden_state)
|
||||
return lm_logits
|
||||
|
||||
|
||||
class GPT2Config(object):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size_or_config_json_file=50257,
|
||||
n_positions=1024,
|
||||
n_ctx=1024,
|
||||
n_embd=768,
|
||||
n_layer=12,
|
||||
n_head=12,
|
||||
layer_norm_epsilon=1e-5,
|
||||
initializer_range=0.02,
|
||||
lora_attn_dim=0,
|
||||
lora_attn_alpha=128,
|
||||
lora_dropout=0.0,
|
||||
lora_r_dropout=0.0,
|
||||
fix_dropout=0.0,
|
||||
):
|
||||
self.vocab_size = vocab_size_or_config_json_file
|
||||
self.n_ctx = n_ctx
|
||||
self.n_positions = n_positions
|
||||
self.n_embd = n_embd
|
||||
self.n_layer = n_layer
|
||||
self.n_head = n_head
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_range = initializer_range
|
||||
self.lora_attn_dim = lora_attn_dim
|
||||
self.lora_attn_alpha = lora_attn_alpha
|
||||
self.lora_dropout = lora_dropout
|
||||
self.lora_r_dropout = lora_r_dropout
|
||||
|
||||
self.fix_dropout = fix_dropout
|
||||
|
||||
|
||||
class GPT2LMModel(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(GPT2LMModel, self).__init__()
|
||||
self.transformer = GPT2Model(config)
|
||||
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def set_tied(self):
|
||||
""" Make sure we are sharing the embeddings"""
|
||||
self.lm_head.set_embeddings_weights(self.transformer.wte.weight)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
lm_labels=None,
|
||||
lm_mask=None,
|
||||
past=None,
|
||||
len_past=None,
|
||||
label_smooth=0.0,
|
||||
is_report_accuracy=False
|
||||
):
|
||||
_batch, _len = input_ids.shape
|
||||
hidden_states, presents = self.transformer(input_ids, past=past, len_past=len_past)
|
||||
|
||||
# batch, seq, vocab
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
|
||||
if lm_labels is not None:
|
||||
|
||||
if is_report_accuracy:
|
||||
_pred_token = torch.argmax(lm_logits, dim=-1)
|
||||
_hit = (_pred_token == lm_labels) * lm_mask
|
||||
|
||||
_t1_acc = torch.zeros(_batch, dtype=torch.float, device=input_ids.device)
|
||||
_all_acc = torch.zeros(_batch, dtype=torch.float, device=input_ids.device)
|
||||
|
||||
for _b in range(0, _batch):
|
||||
for _i in range(0, _len):
|
||||
if lm_mask[_b, _i] >= 1.0:
|
||||
if _hit[_b, _i] > 0:
|
||||
_t1_acc[_b] = 1.0
|
||||
break
|
||||
|
||||
_is_succ = True
|
||||
for _i in range(0, _len):
|
||||
if lm_mask[_b, _i] >= 1.0:
|
||||
if _hit[_b, _i] <= 0:
|
||||
_is_succ = False
|
||||
break
|
||||
|
||||
if _is_succ:
|
||||
_all_acc[_b] = 1.0
|
||||
|
||||
#_t1_acc = _t1_acc * 1.0 / _batch
|
||||
#_all_acc = _all_acc * 1.0 / _batch
|
||||
|
||||
if label_smooth > 0.0001:
|
||||
logprobs = torch.nn.functional.log_softmax(lm_logits.view(-1, lm_logits.size(-1)), dim=-1)
|
||||
nll_loss = -logprobs.gather(dim=-1, index=lm_labels.view(-1).unsqueeze(1))
|
||||
nll_loss = nll_loss.squeeze(1)
|
||||
smooth_loss = -logprobs.mean(dim=-1)
|
||||
loss = (1.0 - label_smooth) * nll_loss + label_smooth * smooth_loss
|
||||
loss = loss.view(_batch, _len)
|
||||
else:
|
||||
loss_fct = nn.CrossEntropyLoss(ignore_index=-1, reduce=False)
|
||||
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), lm_labels.view(-1)).view(_batch, _len)
|
||||
|
||||
if lm_mask is None:
|
||||
lm_mask = torch.ones(loss.shape, dtype=loss.dtype, device=loss.device)
|
||||
loss = loss * lm_mask
|
||||
|
||||
loss = loss.sum() / (lm_mask.sum() + 0.0001)
|
||||
|
||||
if is_report_accuracy:
|
||||
return lm_logits, loss, _t1_acc, _all_acc
|
||||
else:
|
||||
return lm_logits, loss
|
||||
return lm_logits, presents
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
module.weight.data.normal_(mean=0.0, std=0.02)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def load_weight(self, state_dict):
|
||||
if 'model_state_dict' in state_dict:
|
||||
state_dict = state_dict['model_state_dict']
|
||||
|
||||
state_dict_tmp = copy.deepcopy(state_dict)
|
||||
old_keys = []
|
||||
new_keys = []
|
||||
for key in state_dict_tmp:
|
||||
new_key = None
|
||||
if key.endswith(".g"):
|
||||
new_key = key[:-2] + ".weight"
|
||||
elif key.endswith(".b"):
|
||||
new_key = key[:-2] + ".bias"
|
||||
elif key.endswith(".w"):
|
||||
new_key = key[:-2] + ".weight"
|
||||
|
||||
if key.startswith("module.transformer."):
|
||||
new_key = key[len("module.transformer."):]
|
||||
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
for n, p in self.transformer.named_parameters():
|
||||
if n not in state_dict:
|
||||
state_dict[n] = p
|
||||
|
||||
self.transformer.load_state_dict(state_dict, strict=False)
|
||||
self.set_tied()
|
|
@ -0,0 +1,360 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR, _LRScheduler
|
||||
|
||||
|
||||
def add_optimizer_params(parser: argparse.ArgumentParser):
|
||||
parser.add_argument('--lr', default=0.00001, type=float, help='learning rate')
|
||||
parser.add_argument('--weight_decay', default=0.01, type=float, help='weight decay rate')
|
||||
parser.add_argument('--correct_bias', action='store_true', help='correct adam bias term')
|
||||
parser.add_argument('--adam_epislon', default=1e-6, type=float, help='adam epsilon')
|
||||
parser.add_argument('--no_decay_bias', action='store_true', help='no weight decay on bias weigh')
|
||||
parser.add_argument('--adam_beta1', default=0.9, type=float, help='adam beta1 term')
|
||||
parser.add_argument('--adam_beta2', default=0.98, type=float, help='adam beta2 term')
|
||||
|
||||
parser.add_argument('--scheduler', default='linear', type=str,
|
||||
choices=['cosine', 'inv_sqrt', 'dev_perf', 'constant', 'linear', 'cycle'],
|
||||
help='lr scheduler to use.')
|
||||
|
||||
parser.add_argument('--max_step', type=int, default=None, help='upper epoch limit')
|
||||
|
||||
parser.add_argument('--max_epoch', type=int, default=None, help='max epoch of training')
|
||||
|
||||
parser.add_argument('--warmup_step', type=int, default=0, help='upper epoch limit')
|
||||
|
||||
parser.add_argument('--i_steps', type=str, default='0', help='interval_steps')
|
||||
parser.add_argument('--i_lrs', type=str, default='0.00025', help='interval_lrs')
|
||||
|
||||
|
||||
class AdamW(Optimizer):
|
||||
""" Implements Adam algorithm with weight decay fix.
|
||||
Parameters:
|
||||
lr (float): learning rate. Default 1e-3.
|
||||
betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.98)
|
||||
eps (float): Adams epsilon. Default: 1e-6
|
||||
weight_decay (float): Weight decay. Default: 0.0
|
||||
correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True.
|
||||
"""
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.0, correct_bias=True):
|
||||
if lr < 0.0:
|
||||
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
|
||||
def reset_state(self):
|
||||
for group in param_groups:
|
||||
for p in group['params']:
|
||||
state = self.state[p]
|
||||
state['step'] = 0
|
||||
state["exp_avg"] = torch.zeros_like(p.data)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p.data)
|
||||
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state["step"] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state["exp_avg"] = torch.zeros_like(p.data)
|
||||
# Exponential moving average of squared gradient values
|
||||
state["exp_avg_sq"] = torch.zeros_like(p.data)
|
||||
|
||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||
beta1, beta2 = group["betas"]
|
||||
|
||||
state["step"] += 1
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
# In-place operations to update the averages at the same time
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
|
||||
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
||||
|
||||
step_size = group["lr"]
|
||||
if 'correct_bias' in group and group["correct_bias"]: # No bias correction for Bert
|
||||
bias_correction1 = 1.0 - beta1 ** state["step"]
|
||||
bias_correction2 = 1.0 - beta2 ** state["step"]
|
||||
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
|
||||
|
||||
p.data.addcdiv_(-step_size, exp_avg, denom)
|
||||
|
||||
# Just adding the square of the weights to the loss function is *not*
|
||||
# the correct way of using L2 regularization/weight decay with Adam,
|
||||
# since that will interact with the m and v parameters in strange ways.
|
||||
#
|
||||
# Instead we want to decay the weights in a manner that doesn't interact
|
||||
# with the m/v parameters. This is equivalent to adding the square
|
||||
# of the weights to the loss with plain (non-momentum) SGD.
|
||||
# Add weight decay at the end (fixed version)
|
||||
if group["weight_decay"] > 0.0:
|
||||
p.data.add_(p.data, alpha=-group["lr"] * group["weight_decay"])
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class CosineAnnealingWarmupRestarts(_LRScheduler):
|
||||
"""
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
first_cycle_steps (int): First cycle step size.
|
||||
cycle_mult(float): Cycle steps magnification. Default: -1.
|
||||
max_lr(float): First cycle's max learning rate. Default: 0.1.
|
||||
min_lr(float): Min learning rate. Default: 0.001.
|
||||
warmup_steps(int): Linear warmup step size. Default: 0.
|
||||
gamma(float): Decrease rate of max learning rate by cycle. Default: 1.
|
||||
last_epoch (int): The index of last epoch. Default: -1.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
optimizer : torch.optim.Optimizer,
|
||||
max_lr : float = 0.1,
|
||||
min_lr : float = 0.0,
|
||||
warmup_steps : int = 0,
|
||||
max_steps : int = 1,
|
||||
alpha : float = 0.,
|
||||
last_epoch : int = -1
|
||||
):
|
||||
self.max_lr = max_lr # max learning rate in the current cycle
|
||||
self.min_lr = min_lr # min learning rate
|
||||
self.warmup_steps = warmup_steps # warmup step size
|
||||
|
||||
self.alpha = alpha # decrease rate of max learning rate by cycle
|
||||
self.max_steps = max_steps
|
||||
super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch)
|
||||
self.init_lr()
|
||||
|
||||
def init_lr(self):
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group['lr'] = self.min_lr
|
||||
|
||||
def get_lr(self):
|
||||
if self.last_epoch < self.warmup_steps:
|
||||
curr_lr = self.max_lr * self.last_epoch / self.warmup_steps
|
||||
return curr_lr
|
||||
else:
|
||||
_step = min(self.last_epoch, self.max_steps)
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * _step / self.max_steps))
|
||||
decayed = (1 - self.alpha) * cosine_decay + self.alpha
|
||||
return self.max_lr * decayed # learning_rate * decayed
|
||||
|
||||
def step(self, epoch=None):
|
||||
if epoch is None:
|
||||
epoch = self.last_epoch + 1
|
||||
|
||||
self.last_epoch = math.floor(epoch)
|
||||
_lr = self.get_lr()
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group['lr'] = _lr
|
||||
|
||||
|
||||
class CyclicScheduler(_LRScheduler):
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
interval_steps = [],
|
||||
interval_lrs = [],
|
||||
last_epoch = -1,
|
||||
):
|
||||
self.optimizer = optimizer
|
||||
|
||||
self.interval_steps = interval_steps
|
||||
self.interval_lrs = interval_lrs
|
||||
|
||||
self.last_epoch = last_epoch
|
||||
|
||||
super(CyclicScheduler, self).__init__(optimizer, last_epoch)
|
||||
|
||||
self.init_lr()
|
||||
|
||||
def init_lr(self):
|
||||
for param_group in self.optimizer.param_groups:
|
||||
param_group['lr'] = self.interval_lrs[0]
|
||||
|
||||
def get_lr(self):
|
||||
for _i in range(0, len(self.interval_steps)-1):
|
||||
if self.last_epoch >= self.interval_steps[_i] and self.last_epoch < self.interval_steps[_i + 1]:
|
||||
_alpha = (self.last_epoch - self.interval_steps[_i]) / (self.interval_steps[_i + 1] - self.interval_steps[_i] + 1e-6)
|
||||
if _alpha < 0:
|
||||
_alpha = 0
|
||||
if _alpha >= 1:
|
||||
_alpha = 1
|
||||
curr_lr = _alpha * self.interval_lrs[_i + 1] + (1.0 - _alpha) * self.interval_lrs[_i]
|
||||
return curr_lr
|
||||
return self.interval_lrs[-1]
|
||||
|
||||
def step(self, epoch=None):
|
||||
if epoch is None:
|
||||
epoch = self.last_epoch + 1
|
||||
|
||||
#self.max_lr = self.base_max_lr * (self.gamma**self.cycle)
|
||||
self.last_epoch = math.floor(epoch)
|
||||
_lr = self.get_lr()
|
||||
for param_group in self.optimizer.param_groups: #, self.get_lr()):
|
||||
param_group['lr'] = _lr
|
||||
|
||||
|
||||
|
||||
def get_linear_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps,
|
||||
num_training_steps,
|
||||
last_epoch=-1
|
||||
):
|
||||
""" Create a schedule with a learning rate that decreases linearly after
|
||||
linearly increasing during a warmup period.
|
||||
"""
|
||||
def lr_lambda(current_step):
|
||||
if current_step < num_warmup_steps:
|
||||
return float(current_step) / float(max(1, num_warmup_steps))
|
||||
return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
|
||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||
|
||||
|
||||
def get_constant_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps,
|
||||
num_training_steps,
|
||||
last_epoch=-1
|
||||
):
|
||||
""" Create a schedule with a learning rate that decreases linearly after
|
||||
linearly increasing during a warmup period.
|
||||
"""
|
||||
def lr_lambda(current_step):
|
||||
if current_step < num_warmup_steps:
|
||||
return float(current_step) / float(max(1, num_warmup_steps))
|
||||
return 1.0
|
||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||
|
||||
|
||||
def create_grouped_parameters(model, no_decay_bias): # args):
|
||||
if not no_decay_bias:
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters()], # if not any(nd in n for nd in no_decay)],
|
||||
}]
|
||||
else:
|
||||
no_decay = ["bias", "layer_norm.weight"]
|
||||
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
}]
|
||||
return optimizer_grouped_parameters
|
||||
|
||||
|
||||
def create_adam_optimizer(
|
||||
model,
|
||||
lr,
|
||||
weight_decay,
|
||||
optimizer_grouped_parameters=None,
|
||||
beta1=0.9,
|
||||
beta2=0.98,
|
||||
correct_bias=True,
|
||||
adam_epislon=1e-6,
|
||||
no_decay_bias=False
|
||||
):
|
||||
if optimizer_grouped_parameters is None:
|
||||
optimizer_grouped_parameters = create_grouped_parameters(model, no_decay_bias)
|
||||
|
||||
optimizer = AdamW(
|
||||
optimizer_grouped_parameters,
|
||||
lr=lr,
|
||||
betas=(beta1, beta2),
|
||||
eps=adam_epislon,
|
||||
weight_decay=weight_decay,
|
||||
correct_bias=correct_bias
|
||||
)
|
||||
return optimizer
|
||||
|
||||
|
||||
def create_sgd_optimizer(model, lr):
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.0)
|
||||
return optimizer
|
||||
|
||||
|
||||
def create_adam_optimizer_from_args(model, args, grouped_parameters=None):
|
||||
if grouped_parameters is None:
|
||||
grouped_parameters = create_grouped_parameters(model, args.no_decay_bias)
|
||||
|
||||
optimizer = AdamW(
|
||||
grouped_parameters,
|
||||
lr=args.lr,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
eps=args.adam_epislon,
|
||||
weight_decay=args.weight_decay,
|
||||
correct_bias=args.correct_bias
|
||||
)
|
||||
return optimizer
|
||||
|
||||
|
||||
def create_optimizer_scheduler(optimizer, args):
|
||||
if args.scheduler == 'cosine':
|
||||
scheduler = CosineAnnealingWarmupRestarts(
|
||||
optimizer,
|
||||
max_lr=args.lr,
|
||||
min_lr=0.0,
|
||||
warmup_steps=args.warmup_step,
|
||||
max_steps=args.max_step, alpha=0
|
||||
)
|
||||
elif args.scheduler == 'linear':
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer, args.warmup_step, args.max_step, last_epoch=-1
|
||||
)
|
||||
elif args.scheduler == 'cycle':
|
||||
if args.i_steps is not None:
|
||||
args.i_steps = [int(_i) for _i in args.i_steps.split(',')]
|
||||
args.i_lrs = [float(_i) for _i in args.i_lrs.split(',')]
|
||||
args.max_step = args.i_steps[-1]
|
||||
print('max_step is rest to', args.max_step)
|
||||
scheduler = CyclicScheduler(
|
||||
optimizer, interval_steps=args.i_steps, interval_lrs=args.i_lrs
|
||||
)
|
||||
elif args.scheduler == 'constant':
|
||||
scheduler = get_constant_schedule_with_warmup(
|
||||
optimizer, args.warmup_step, args.max_step, last_epoch=-1
|
||||
)
|
||||
else:
|
||||
# constant leanring rate.
|
||||
scheduler = None
|
||||
return scheduler
|
|
@ -0,0 +1,33 @@
|
|||
{
|
||||
"activation_function": "gelu_new",
|
||||
"architectures": [
|
||||
"GPT2LMHeadModel"
|
||||
],
|
||||
"attn_pdrop": 0.1,
|
||||
"bos_token_id": 50256,
|
||||
"embd_pdrop": 0.1,
|
||||
"eos_token_id": 50256,
|
||||
"initializer_range": 0.02,
|
||||
"layer_norm_epsilon": 1e-05,
|
||||
"model_type": "gpt2",
|
||||
"n_ctx": 1024,
|
||||
"n_embd": 1024,
|
||||
"n_head": 16,
|
||||
"n_layer": 24,
|
||||
"n_positions": 1024,
|
||||
"n_special": 0,
|
||||
"predict_special_tokens": true,
|
||||
"resid_pdrop": 0.1,
|
||||
"summary_activation": null,
|
||||
"summary_first_dropout": 0.1,
|
||||
"summary_proj_to_labels": true,
|
||||
"summary_type": "cls_index",
|
||||
"summary_use_proj": true,
|
||||
"task_specific_params": {
|
||||
"text-generation": {
|
||||
"do_sample": true,
|
||||
"max_length": 50
|
||||
}
|
||||
},
|
||||
"vocab_size": 50257
|
||||
}
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -0,0 +1,186 @@
|
|||
# LoRA: Low-Rank Adaptation of Large Language Models
|
||||
|
||||
|
||||
This repo contains the source code of the Python package `loralib` and several examples of how to integrate it with PyTorch models, such as those in HuggingFace.
|
||||
We only support PyTorch for now.
|
||||
See our paper for a detailed description of LoRA.
|
||||
|
||||
**LoRA: Low-Rank Adaptation of Large Language Models** <br>
|
||||
*Edward J. Hu\*, Yelong Shen\*, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Weizhu Chen* <br>
|
||||
Paper: https://arxiv.org/abs/2106.09685 <br>
|
||||
|
||||
LoRA reduces the number of trainable parameters by learning pairs of rank-decompostion matrices while freezing the original weights.
|
||||
This vastly reduces the storage requirement for large language models adapted to specific tasks and enables efficient task-switching during deployment all without introducing inference latency.
|
||||
LoRA also outperforms several other adaptation methods including adapter, prefix-tuning, and fine-tuning.
|
||||
|
||||
We obtain result comparable or superior to full finetuning on the GLUE benchmark using [RoBERTa (Liu et al., 2019)](https://arxiv.org/abs/1907.11692) base and large and [DeBERTa (He et al., 2020)](https://arxiv.org/abs/2006.03654) XXL 1.5B, while only training and storing a fraction of the parameters. Click the numbers below to download the RoBERTa and DeBERTa LoRA checkpoints.
|
||||
|
||||
| | | RoBERTa base <br> Fine-tune | RoBERTa base <br> LoRA | DeBERTa XXL <br> Fine-tune | DeBERTa XXL <br> LoRA |
|
||||
|---|-------------------------|----------------|--------------------------|-----------------|-----------------|
|
||||
| | # of Trainable Params. | 125M | 0.8M | 1.5B | 4.7M |
|
||||
| | MNLI (m-Acc/mm-Acc) | <b>87.6</b> | [<b>87.5</b>±.3/86.9±.3](https://github.com/microsoft/LoRA/releases/download/RoBERTa/roberta_base_lora_mnli.bin) |91.7/<b>91.9</b>| [<b>91.9</b>±.1/<b>91.9</b>±.2](https://github.com/microsoft/LoRA/releases/download/DeBERTa/deberta_v2_xxlarge_lora_mnli.bin) |
|
||||
| | SST2 (Acc) | 94.8 | [<b>95.1</b>±.2](https://github.com/microsoft/LoRA/releases/download/RoBERTa/roberta_base_lora_sst2.bin) | <b>97.2</b> | [96.9±.2](https://github.com/microsoft/LoRA/releases/download/DeBERTa/deberta_v2_xxlarge_lora_sst2.bin) |
|
||||
| | MRPC (Acc) | <b>90.2</b> | [<b>89.7</b>±.7](https://github.com/microsoft/LoRA/releases/download/RoBERTa/roberta_base_lora_mrpc.bin) | 92.0 | [<b>92.6</b>±.6](https://github.com/microsoft/LoRA/releases/download/DeBERTa/deberta_v2_xxlarge_lora_mrpc.bin) |
|
||||
| | CoLA (Matthew's Corr) | <b>63.6</b> | [<b>63.4</b>±1.2](https://github.com/microsoft/LoRA/releases/download/RoBERTa/roberta_base_lora_cola.bin) | <b>72.0</b> | [<b>72.4</b>±1.1](https://github.com/microsoft/LoRA/releases/download/DeBERTa/deberta_v2_xxlarge_lora_cola.bin) |
|
||||
| | QNLI (Acc) | 92.8 | [<b>93.3</b>±.3](https://github.com/microsoft/LoRA/releases/download/RoBERTa/roberta_base_lora_qnli.bin) | <b>96.0</b> | [<b>96.0</b>±.1](https://github.com/microsoft/LoRA/releases/download/DeBERTa/deberta_v2_xxlarge_lora_qnli.bin) |
|
||||
| | QQP (Acc) | <b>91.9</b> | [90.8±.1](https://github.com/microsoft/LoRA/releases/download/RoBERTa/roberta_base_lora_qqp.bin) | 92.7 | [<b>92.9</b>±.1](https://github.com/microsoft/LoRA/releases/download/DeBERTa/deberta_v2_xxlarge_lora_qqp.bin) |
|
||||
| | RTE (Acc) | 78.7 | [<b>86.6</b>±.7](https://github.com/microsoft/LoRA/releases/download/RoBERTa/roberta_base_lora_rte.bin) | 93.9 | [<b>94.9</b>±.4](https://github.com/microsoft/LoRA/releases/download/DeBERTa/deberta_v2_xxlarge_lora_rte.bin) |
|
||||
| | STSB (Pearson/Spearman Corr) | 91.2 | [<b>91.5</b>±.2/<b>91.3</b>±.2](https://github.com/microsoft/LoRA/releases/download/RoBERTa/roberta_base_lora_stsb.bin) |<b>92.9</b>/92.6| [<b>93.0</b>±.2/<b>92.9</b>±.3](https://github.com/microsoft/LoRA/releases/download/DeBERTa/deberta_v2_xxlarge_lora_stsb.bin) |
|
||||
| | Average | 86.40 | <b>87.24</b> | 91.06 | <b>91.32</b> |
|
||||
|
||||
<i>Note: You still need the original pre-trained checkpoint from [HuggingFace](https://huggingface.co/) to use the LoRA checkpoints.</i>
|
||||
|
||||
Fine-tuning numbers are taken from [Liu et al. (2019)](https://arxiv.org/abs/1907.11692) and [He et al. (2020)](https://arxiv.org/abs/2006.03654). We include confidence intervals on results from our experiments. Please follow the instructions in `NLU/` to reproduce our results.
|
||||
|
||||
On GPT-2, LoRA compares favorably to both full finetuning and other efficient tuning methods, such as [adapter (Houlsby et al., 2019)](https://arxiv.org/abs/1902.00751) and [prefix tuning (Li and Liang, 2021)](https://arxiv.org/abs/2101.00190). We evaluated on E2E NLG Challenge, DART, and WebNLG:
|
||||
|
||||
| | Method | # of Trainable Params | E2E (BLEU) | DART (BLEU) | WebNLG (BLEU-U/S/A) |
|
||||
|---|---------------------|-----------------------|--------------|--------------|--------------------------------|
|
||||
| | GPT-2 M (Fine-Tune) | 354.92M | 68.2 | 46.0 | 30.4/<b>63.2</b>/47.6 |
|
||||
| | GPT-2 M (Adapter) | 0.37M | 66.3 | 42.4 | 45.1/54.5/50.2 |
|
||||
| | GPT-2 M (Prefix) | 0.35M | 69.7 | 45.7 | 44.1/63.1/54.4 |
|
||||
| | GPT-2 M (LoRA) | 0.35M |<b>70.4</b>±.1|<b>47.1</b>±.2| <b>46.7</b>±.4/62.1±.2/<b>55.3</b>±.2 |
|
||||
| | GPT-2 L (Fine-Tune) | 774.03M | 68.5 | 46.5 | 41.7/<b>64.6</b>/54.2 |
|
||||
| | GPT-2 L (Adapter) | 0.88M | 69.1±.1 | 45.7±.1 | <b>49.8</b>±.0/61.1±.0/56.0±.0 |
|
||||
| | GPT-2 L (Prefix) | 0.77M | 70.3 | 46.5 | 47.0/64.2/56.4 |
|
||||
| | GPT-2 L (LoRA) | 0.77M |<b>70.4</b>±.1|<b>47.5</b>±.1| 48.4±.3/<b>64.0</b>±.3/<b>57.0</b>±.1 |
|
||||
|
||||
Non-LoRA baselines, except for adapter on GPT-2 large, are taken from [Li and Liang (2021)](https://arxiv.org/abs/2101.00190). We include confidence intervals on results from our experiments.
|
||||
|
||||
Download the GPT-2 LoRA checkpoints:
|
||||
* [GPT-2 Medium E2E](https://github.com/microsoft/LoRA/releases/download/GPT-2/gpt2_md_lora_e2e.pt) (1.5 MB)
|
||||
* [GPT-2 Medium DART](https://github.com/microsoft/LoRA/releases/download/GPT-2/gpt2_md_lora_dart.pt) (1.5 MB)
|
||||
* [GPT-2 Medium WebNLG](https://github.com/microsoft/LoRA/releases/download/GPT-2/gpt2_md_lora_webnlg.pt) (1.5 MB)
|
||||
* [GPT-2 Large E2E](https://github.com/microsoft/LoRA/releases/download/GPT-2/gpt2_lg_lora_e2e.pt) (2.3 MB)
|
||||
* [GPT-2 Large DART](https://github.com/microsoft/LoRA/releases/download/GPT-2/gpt2_lg_lora_dart.pt) (2.3 MB)
|
||||
* [GPT-2 Large WebNLG](https://github.com/microsoft/LoRA/releases/download/GPT-2/gpt2_lg_lora_webnlg.pt) (2.3 MB)
|
||||
|
||||
Please follow the instructions in `NLG/` to reproduce our result.
|
||||
## Repository Overview
|
||||
|
||||
There are several directories in this repo:
|
||||
* [loralib/](loralib) contains the source code for the package `loralib`, which needs to be installed to run the examples we provide;
|
||||
* [NLG/](NLG) contains an example implementation of LoRA in GPT-2 using our package, which can be used to reproduce the result in our paper;
|
||||
* [NLU/](NLU) contains an example implementation of LoRA in RoBERTa and DeBERTa using our package, which produces competitive results on the GLUE benchmark;
|
||||
* See how we use `loralib` in [GPT-2](NLG/src/model.py), [RoBERTa](NLU/src/transformers/models/roberta/modeling_roberta.py), and [DeBERTa v2](NLU/src/transformers/models/deberta_v2/modeling_deberta_v2.py)
|
||||
|
||||
## Quickstart
|
||||
|
||||
1. Installing `loralib` is simply
|
||||
```
|
||||
pip install loralib
|
||||
# Alternatively
|
||||
# pip install git+https://github.com/microsoft/LoRA
|
||||
```
|
||||
|
||||
2. You can choose to adapt some layers by replacing them with counterparts implemented in `loralib`. We only support `nn.Linear` for now. We also support a `MergedLinear` for cases where a single `nn.Linear` represents more than one layers, such as in some implementations of the attention `qkv` projection (see Additional Notes for more).
|
||||
```
|
||||
# ===== Before =====
|
||||
# layer = nn.Linear(in_features, out_features)
|
||||
|
||||
# ===== After ======
|
||||
import loralib as lora
|
||||
# Add a pair of low-rank adaptation matrices with rank r=16
|
||||
layer = lora.Linear(in_features, out_features, r=16)
|
||||
```
|
||||
|
||||
3. Before the training loop begins, mark only LoRA parameters as trainable.
|
||||
```
|
||||
import loralib as lora
|
||||
model = BigModel()
|
||||
# This sets requires_grad to False for all parameters without the string "lora_" in their names
|
||||
lora.mark_only_lora_as_trainable(model)
|
||||
# Training loop
|
||||
for batch in dataloader:
|
||||
...
|
||||
```
|
||||
4. When saving a checkpoint, generate a `state_dict` that only contains LoRA parameters.
|
||||
```
|
||||
# ===== Before =====
|
||||
# torch.save(model.state_dict(), checkpoint_path)
|
||||
# ===== After =====
|
||||
torch.save(lora.lora_state_dict(model), checkpoint_path)
|
||||
```
|
||||
5. When loading a checkpoint using `load_state_dict`, be sure to set `strict=False`.
|
||||
```
|
||||
# Load the pretrained checkpoint first
|
||||
model.load_state_dict(torch.load('ckpt_pretrained.pt'), strict=False)
|
||||
# Then load the LoRA checkpoint
|
||||
model.load_state_dict(torch.load('ckpt_lora.pt'), strict=False)
|
||||
```
|
||||
|
||||
#### Now training can proceed as usual.
|
||||
|
||||
## Additional Notes
|
||||
|
||||
1. While we focus on a simple yet effect setup, namely adapting only the `q` and `v` projection in a Transformer, in our examples, LoRA can be apply to any subsets of pre-trained weights. We encourage you to explore different configurations, such as adapting the embedding layer by replacing `nn.Embedding` with `lora.Embedding` and/or adapting the MLP layers. It's very likely that the optimal configuration varies for different model architectures and tasks.
|
||||
|
||||
2. Some Transformer implementation uses a single `nn.Linear` for the projection matrices for query, key, and value. If one wishes to constrain the rank of the updates to the individual matrices, one has to either break it up into three separate matrices or use `lora.MergedLinear`. Make sure to modify the checkpoint accordingly if you choose to break up the layer.
|
||||
```
|
||||
# ===== Before =====
|
||||
# qkv_proj = nn.Linear(d_model, 3*d_model)
|
||||
# ===== After =====
|
||||
# Break it up (remember to modify the pretrained checkpoint accordingly)
|
||||
q_proj = lora.Linear(d_model, d_model, r=8)
|
||||
k_proj = nn.Linear(d_model, d_model)
|
||||
v_proj = lora.Linear(d_model, d_model, r=8)
|
||||
# Alternatively, use lora.MergedLinear (recommended)
|
||||
qkv_proj = lora.MergedLinear(d_model, 3*d_model, r=8, enable_lora=[True, False, True])
|
||||
```
|
||||
1. Training bias vectors in tandem with LoRA might be a cost-efficient way to squeeze out extra task performance (if you tune the learning rate carefully). While we did not study its effect thoroughly in our paper, we make it easy to try in `lora`. You can mark some biases as trainable by passing "all" or "lora_only" to `bias=` when calling `mark_only_lora_as_trainable`. Remember to pass the corresponding `bias=` argument to `lora_state_dict` when saving a checkpoint.
|
||||
```
|
||||
# ===== Before =====
|
||||
# lora.mark_only_lora_as_trainable(model) # Not training any bias vectors
|
||||
# ===== After =====
|
||||
# Training all bias vectors associated with modules we apply LoRA to
|
||||
lora.mark_only_lora_as_trainable(model, bias='lora_only')
|
||||
# Alternatively, we can train *all* bias vectors in the model, including LayerNorm biases
|
||||
lora.mark_only_lora_as_trainable(model, bias='all')
|
||||
# When saving a checkpoint, use the same bias= ('all' or 'lora_only')
|
||||
torch.save(lora.lora_state_dict(model, bias='all'), checkpoint_path)
|
||||
```
|
||||
4. Calling `model.eval()` will trigger the merging of LoRA parameters with the corresponding pretrained ones, which eliminates additional latency for subsequent forward passes. Calling `model.train()` again will undo the merge. This can be disabled by passing `merge_weights=False` to LoRA layers.
|
||||
|
||||
## Contact
|
||||
Please contact us or post an issue if you have any questions.
|
||||
|
||||
For questions related to the package `loralib`:
|
||||
* Edward Hu (edwardhu@microsoft.com)
|
||||
* Phillip Wallis (phwallis@microsoft.com)
|
||||
* Weizhu Chen (wzchen@microsoft.com)
|
||||
|
||||
The GPT-2 example:
|
||||
* Phillip Wallis (phwallis@microsoft.com)
|
||||
* Yelong Shen (yeshe@microsoft.com)
|
||||
|
||||
The DeBERTa example:
|
||||
* Lu Wang (luw@microsoft.com)
|
||||
|
||||
## Acknowledgements
|
||||
We thank in alphabetical order Jianfeng Gao, Jade Huang, Jiayuan Huang, Lisa Xiang Li, Xiaodong Liu, Yabin Liu, Benjamin Van Durme, Luis Vargas, Haoran Wei, Peter Welinder, and Greg Yang for providing valuable feedback.
|
||||
|
||||
## Citation
|
||||
```
|
||||
@misc{hu2021lora,
|
||||
title={LoRA: Low-Rank Adaptation of Large Language Models},
|
||||
author={Hu, Edward and Shen, Yelong and Wallis, Phil and Allen-Zhu, Zeyuan and Li, Yuanzhi and Chen, Weizhu},
|
||||
year={2021},
|
||||
eprint={2106.09685},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CL}
|
||||
}
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
||||
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
||||
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
|
||||
|
||||
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
|
||||
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
|
||||
provided by the bot. You will only need to do this once across all repos using our CLA.
|
||||
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
||||
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
|
@ -0,0 +1,4 @@
|
|||
name = "lora"
|
||||
|
||||
from .layers import *
|
||||
from .utils import *
|
|
@ -0,0 +1,322 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import math
|
||||
from typing import Optional, List
|
||||
|
||||
class LoRALayer():
|
||||
def __init__(
|
||||
self,
|
||||
r: int,
|
||||
lora_alpha: int,
|
||||
lora_dropout: float,
|
||||
merge_weights: bool,
|
||||
):
|
||||
self.r = r
|
||||
self.lora_alpha = lora_alpha
|
||||
# Optional dropout
|
||||
if lora_dropout > 0.:
|
||||
self.lora_dropout = nn.Dropout(p=lora_dropout)
|
||||
else:
|
||||
self.lora_dropout = lambda x: x
|
||||
# Mark the weight as unmerged
|
||||
self.merged = False
|
||||
self.merge_weights = merge_weights
|
||||
|
||||
|
||||
class Embedding(nn.Embedding, LoRALayer):
|
||||
# LoRA implemented in a dense layer
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
r: int = 0,
|
||||
lora_alpha: int = 1,
|
||||
merge_weights: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
|
||||
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=0,
|
||||
merge_weights=merge_weights)
|
||||
# Actual trainable parameters
|
||||
if r > 0:
|
||||
self.lora_A = nn.Parameter(self.weight.new_zeros((r, num_embeddings)))
|
||||
self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r)))
|
||||
self.scaling = self.lora_alpha / self.r
|
||||
# Freezing the pre-trained weight matrix
|
||||
self.weight.requires_grad = False
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.Embedding.reset_parameters(self)
|
||||
if hasattr(self, 'lora_A'):
|
||||
# initialize A the same way as the default for nn.Linear and B to zero
|
||||
nn.init.zeros_(self.lora_A)
|
||||
nn.init.normal_(self.lora_B)
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
nn.Embedding.train(self, mode)
|
||||
if self.merge_weights and self.merged:
|
||||
# Make sure that the weights are not merged
|
||||
if self.r > 0:
|
||||
self.weight.data -= (self.lora_B @ self.lora_A).T * self.scaling
|
||||
self.merged = False
|
||||
|
||||
def eval(self):
|
||||
nn.Linear.eval(self)
|
||||
if self.merge_weights and not self.merged:
|
||||
# Merge the weights and mark it
|
||||
if self.r > 0:
|
||||
self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
|
||||
self.merged = True
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.r > 0 and not self.merged:
|
||||
result = nn.Embedding.forward(self, x)
|
||||
if self.r > 0:
|
||||
after_A = F.embedding(
|
||||
x, self.lora_A.T, self.padding_idx, self.max_norm,
|
||||
self.norm_type, self.scale_grad_by_freq, self.sparse
|
||||
)
|
||||
result += (after_A @ self.lora_B.T) * self.scaling
|
||||
return result
|
||||
else:
|
||||
return nn.Embedding.forward(self, x)
|
||||
|
||||
|
||||
class Linear(nn.Linear, LoRALayer):
|
||||
# LoRA implemented in a dense layer
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
r: int = 0,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.,
|
||||
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
||||
merge_weights: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
nn.Linear.__init__(self, in_features, out_features, **kwargs)
|
||||
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
|
||||
merge_weights=merge_weights)
|
||||
|
||||
self.fan_in_fan_out = fan_in_fan_out
|
||||
# Actual trainable parameters
|
||||
if r > 0:
|
||||
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
|
||||
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
|
||||
self.scaling = self.lora_alpha / self.r
|
||||
# Freezing the pre-trained weight matrix
|
||||
self.weight.requires_grad = False
|
||||
self.reset_parameters()
|
||||
if fan_in_fan_out:
|
||||
self.weight.data = self.weight.data.T
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.Linear.reset_parameters(self)
|
||||
if hasattr(self, 'lora_A'):
|
||||
# initialize A the same way as the default for nn.Linear and B to zero
|
||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||
nn.init.zeros_(self.lora_B)
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
def T(w):
|
||||
return w.T if self.fan_in_fan_out else w
|
||||
nn.Linear.train(self, mode)
|
||||
if self.merge_weights and self.merged:
|
||||
# Make sure that the weights are not merged
|
||||
if self.r > 0:
|
||||
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
|
||||
self.merged = False
|
||||
|
||||
def eval(self):
|
||||
def T(w):
|
||||
return w.T if self.fan_in_fan_out else w
|
||||
nn.Linear.eval(self)
|
||||
if self.merge_weights and not self.merged:
|
||||
# Merge the weights and mark it
|
||||
if self.r > 0:
|
||||
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
|
||||
self.merged = True
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
def T(w):
|
||||
return w.T if self.fan_in_fan_out else w
|
||||
if self.r > 0 and not self.merged:
|
||||
result = F.linear(x, T(self.weight), bias=self.bias)
|
||||
if self.r > 0:
|
||||
result += (self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T) * self.scaling
|
||||
return result
|
||||
else:
|
||||
return F.linear(x, T(self.weight), bias=self.bias)
|
||||
|
||||
|
||||
class MergedLinear(nn.Linear, LoRALayer):
|
||||
# LoRA implemented in a dense layer
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
r: int = 0,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.,
|
||||
enable_lora: List[bool] = [False],
|
||||
fan_in_fan_out: bool = False,
|
||||
merge_weights: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
nn.Linear.__init__(self, in_features, out_features, **kwargs)
|
||||
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
|
||||
merge_weights=merge_weights)
|
||||
assert out_features % len(enable_lora) == 0, \
|
||||
'The length of enable_lora must divide out_features'
|
||||
self.enable_lora = enable_lora
|
||||
self.fan_in_fan_out = fan_in_fan_out
|
||||
# Actual trainable parameters
|
||||
if r > 0 and any(enable_lora):
|
||||
self.lora_A = nn.Parameter(
|
||||
self.weight.new_zeros((r * sum(enable_lora), in_features)))
|
||||
self.lora_B = nn.Parameter(
|
||||
self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r))
|
||||
) # weights for Conv1D with groups=sum(enable_lora)
|
||||
self.scaling = self.lora_alpha / self.r
|
||||
# Freezing the pre-trained weight matrix
|
||||
self.weight.requires_grad = False
|
||||
# Compute the indices
|
||||
self.lora_ind = self.weight.new_zeros(
|
||||
(out_features, ), dtype=torch.bool
|
||||
).view(len(enable_lora), -1)
|
||||
self.lora_ind[enable_lora, :] = True
|
||||
self.lora_ind = self.lora_ind.view(-1)
|
||||
self.reset_parameters()
|
||||
if fan_in_fan_out:
|
||||
self.weight.data = self.weight.data.T
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.Linear.reset_parameters(self)
|
||||
if hasattr(self, 'lora_A'):
|
||||
# initialize A the same way as the default for nn.Linear and B to zero
|
||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||
nn.init.zeros_(self.lora_B)
|
||||
|
||||
def zero_pad(self, x):
|
||||
result = x.new_zeros((*x.shape[:-1], self.out_features))
|
||||
result = result.view(-1, self.out_features)
|
||||
result[:, self.lora_ind] = x.reshape(
|
||||
-1, self.out_features // len(self.enable_lora) * sum(self.enable_lora)
|
||||
)
|
||||
return result.view((*x.shape[:-1], self.out_features))
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
def T(w):
|
||||
return w.T if self.fan_in_fan_out else w
|
||||
nn.Linear.train(self, mode)
|
||||
if self.merge_weights and self.merged:
|
||||
# Make sure that the weights are not merged
|
||||
if self.r > 0 and any(self.enable_lora):
|
||||
delta_w = F.conv1d(
|
||||
self.lora_A.data.unsqueeze(0),
|
||||
self.lora_B.data.unsqueeze(-1),
|
||||
groups=sum(self.enable_lora)
|
||||
).squeeze(0)
|
||||
self.weight.data -= self.zero_pad(T(delta_w * self.scaling))
|
||||
self.merged = False
|
||||
|
||||
def eval(self):
|
||||
def T(w):
|
||||
return w.T if self.fan_in_fan_out else w
|
||||
nn.Linear.eval(self)
|
||||
if self.merge_weights and not self.merged:
|
||||
# Merge the weights and mark it
|
||||
if self.r > 0 and any(self.enable_lora):
|
||||
delta_w = F.conv1d(
|
||||
self.lora_A.data.unsqueeze(0),
|
||||
self.lora_B.data.unsqueeze(-1),
|
||||
groups=sum(self.enable_lora)
|
||||
).squeeze(0)
|
||||
self.weight.data += self.zero_pad(T(delta_w * self.scaling))
|
||||
self.merged = True
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
def T(w):
|
||||
return w.T if self.fan_in_fan_out else w
|
||||
if self.merged:
|
||||
return F.linear(x, T(self.weight), bias=self.bias)
|
||||
else:
|
||||
result = F.linear(x, T(self.weight), bias=self.bias)
|
||||
if self.r > 0:
|
||||
after_A = F.linear(self.lora_dropout(x), self.lora_A)
|
||||
after_B = F.conv1d(
|
||||
after_A.transpose(-2, -1),
|
||||
self.lora_B.unsqueeze(-1),
|
||||
groups=sum(self.enable_lora)
|
||||
).transpose(-2, -1)
|
||||
result += self.zero_pad(after_B) * self.scaling
|
||||
return result
|
||||
|
||||
|
||||
class Conv2d(nn.Conv2d, LoRALayer):
|
||||
# LoRA implemented in a dense layer
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
r: int = 0,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.,
|
||||
merge_weights: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
|
||||
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
|
||||
merge_weights=merge_weights)
|
||||
assert type(kernel_size) is int
|
||||
# Actual trainable parameters
|
||||
if r > 0:
|
||||
self.lora_A = nn.Parameter(
|
||||
self.weight.new_zeros((r*kernel_size, in_channels*kernel_size))
|
||||
)
|
||||
self.lora_B = nn.Parameter(
|
||||
self.weight.new_zeros((out_channels*kernel_size, r*kernel_size))
|
||||
)
|
||||
self.scaling = self.lora_alpha / self.r
|
||||
# Freezing the pre-trained weight matrix
|
||||
self.weight.requires_grad = False
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.Conv2d.reset_parameters(self)
|
||||
if hasattr(self, 'lora_A'):
|
||||
# initialize A the same way as the default for nn.Linear and B to zero
|
||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||
nn.init.zeros_(self.lora_B)
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
nn.Conv2d.train(self, mode)
|
||||
if self.merge_weights and self.merged:
|
||||
# Make sure that the weights are not merged
|
||||
self.weight.data -= (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
|
||||
self.merged = False
|
||||
|
||||
def eval(self):
|
||||
nn.Conv2d.eval(self)
|
||||
if self.merge_weights and not self.merged:
|
||||
# Merge the weights and mark it
|
||||
self.weight.data += (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling
|
||||
self.merged = True
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.r > 0 and not self.merged:
|
||||
return F.conv2d(
|
||||
x,
|
||||
self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling,
|
||||
self.bias, self.stride, self.padding, self.dilation, self.groups
|
||||
)
|
||||
return nn.Conv2d.forward(self, x)
|
|
@ -0,0 +1,49 @@
|
|||
# ------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
|
||||
# ------------------------------------------------------------------------------------------
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from .layers import LoRALayer
|
||||
|
||||
|
||||
def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None:
|
||||
for n, p in model.named_parameters():
|
||||
if 'lora_' not in n:
|
||||
p.requires_grad = False
|
||||
if bias == 'none':
|
||||
return
|
||||
elif bias == 'all':
|
||||
for n, p in model.named_parameters():
|
||||
if 'bias' in n:
|
||||
p.requires_grad = True
|
||||
elif bias == 'lora_only':
|
||||
for m in model.modules():
|
||||
if isinstance(m, LoRALayer) and \
|
||||
hasattr(m, 'bias') and \
|
||||
m.bias is not None:
|
||||
m.bias.requires_grad = True
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]:
|
||||
my_state_dict = model.state_dict()
|
||||
if bias == 'none':
|
||||
return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k}
|
||||
elif bias == 'all':
|
||||
return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'bias' in k}
|
||||
elif bias == 'lora_only':
|
||||
to_return = {}
|
||||
for k in my_state_dict:
|
||||
if 'lora_' in k:
|
||||
to_return[k] = my_state_dict[k]
|
||||
bias_name = k.split('lora_')[0]+'bias'
|
||||
if bias_name in my_state_dict:
|
||||
to_return[bias_name] = my_state_dict[bias_name]
|
||||
return to_return
|
||||
else:
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,22 @@
|
|||
import setuptools
|
||||
|
||||
with open("README.md", "r", encoding="utf-8") as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
setuptools.setup(
|
||||
name="loralib",
|
||||
version="0.1.0",
|
||||
author="Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen",
|
||||
author_email="edward.hu@microsoft.com",
|
||||
description="PyTorch implementation of low-rank adaptation (LoRA), a parameter-efficient approach to adapt a large pre-trained deep learning model which obtains performance on-par with full fine-tuning.",
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://github.com/microsoft/LoRA",
|
||||
packages=setuptools.find_packages(),
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
],
|
||||
python_requires='>=3.6',
|
||||
)
|
Загрузка…
Ссылка в новой задаче