зеркало из https://github.com/microsoft/muzic.git
upload museformer
This commit is contained in:
Родитель
8ad59b7ecb
Коммит
8ef9e354ae
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 191 KiB |
|
@ -0,0 +1,139 @@
|
|||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
.idea/
|
||||
.vscode/
|
||||
|
||||
data-bin/
|
||||
cached_datasets/
|
||||
checkpoints*/
|
||||
log*/
|
||||
tb_log*/
|
||||
output*/
|
|
@ -0,0 +1,139 @@
|
|||
# Museformer
|
||||
|
||||
[Museformer: Transformer with Fine- and Coarse-Grained Attention for Music Generation](https://arxiv.org/abs/2210.10349), by Botao Yu, Peiling Lu, Rui Wang, Wei Hu, Xu Tan, Wei Ye, Shikun Zhang, Tao Qin, Tie-Yan Liu, NeurIPS 2022, is a Transformer with a novel fine- and coarse-grained attention (FC-Attention) for music generation. Specifically, with the fine-grained attention, a token of a specific bar directly attends to all the tokens of the bars that are most relevant to music structures (e.g., the previous 1st, 2nd, 4th and 8th bars, selected via similarity statistics); with the coarse-grained attention, a token only attends to the summarization of the other bars rather than each token of them so as to reduce the computational cost. The advantages are two-fold. First, it can capture both music structure-related correlations via the fine-grained attention, and other contextual information via the coarse-grained attention. Second, it is efficient and can model over 3X longer music sequences compared to its full-attention counterpart. Both objective and subjective experimental results demonstrate its ability to generate long music sequences with high quality and better structures.
|
||||
|
||||
demo: [link](https://ai-muzic.github.io/museformer)
|
||||
|
||||
<p align="center"><img src="../img/museformer_visualization.png" width="800"><br/>Information flow of FC-Attention. </p>
|
||||
|
||||
|
||||
|
||||
The following content describes the steps to run Museformer. All the commands are run at the root directory of Museformer (named as `root_dir`) unless specified.
|
||||
|
||||
## 1. Dataset
|
||||
|
||||
We use [the Lakh MIDI dataset](https://colinraffel.com/projects/lmd/) (LMD-full). Specifically, we first preprocess it as described in the Appendix of our paper. The final dataset (see the file lists [here](data/meta)) contains 29,940 MIDI files. Their time signatures are all 4/4, and the instruments are normalized to 6 basic ones: square synthesizer (80), piano (0), guitar (25), string (48), bass (43), drum, where in the parentheses are MIDI program IDs if applicable. We put all the MIDI files in `data/midi`.
|
||||
|
||||
Install [MidiProcessor](https://github.com/btyu/MidiProcessor). Then, encode the MIDI files into tokens:
|
||||
|
||||
```bash
|
||||
midi_dir=data/midi
|
||||
token_dir=data/token
|
||||
mp-batch-encoding $midi_dir $token_dir --encoding-method REMIGEN2 --normalize-pitch-value --remove-empty-bars --ignore-ts --sort-insts 6tracks_cst1
|
||||
```
|
||||
|
||||
where the arguments are explained as follows:
|
||||
|
||||
- `encoding-method`: the representation method. We use the one called REMIGEN2, a REMI-like representation method that represents musical information in separate tokens.
|
||||
- `normalize-pitch-value`: normalize pitches so as to make the song *C major* or *A minor*.
|
||||
- `remove-empty-bars`: remove empty bars at the beginning or the end of a song.
|
||||
- `ignore-ts`: do not add the tokens of time signature. Since the used data are all 4/4, we do not encode it.
|
||||
- `sort-insts`: designate a method that sorts the instruments. `6tracks_cst1` sorts the instruments in order: square synthesizer, drum, bass, guitar, piano, string.
|
||||
|
||||
After encoding, you should see the token representation of each MIDI file in `output_dir`.
|
||||
|
||||
Then, run the following command to gather the tokens for each split.
|
||||
|
||||
```bash
|
||||
token_dir=data/token
|
||||
split_dir=data/split
|
||||
for split in train valid test :
|
||||
do python tools/generate_token_data_by_file_list.py data/meta/${split}.txt $token_dir $split_dir ;
|
||||
done
|
||||
```
|
||||
|
||||
Next, use `fairseq-preprocess` to make binary data:
|
||||
|
||||
```bash
|
||||
split_dir=data/split
|
||||
data_bin_dir=data-bin/lmd6remi
|
||||
|
||||
mkdir -p data-bin
|
||||
|
||||
fairseq-preprocess \
|
||||
--only-source \
|
||||
--trainpref $split_dir/train.data \
|
||||
--validpref $split_dir/valid.data \
|
||||
--testpref $split_dir/test.data \
|
||||
--destdir $data_bin_dir \
|
||||
--srcdict data/meta/dict.txt
|
||||
```
|
||||
|
||||
Now, you should see the binary data in `data-bin/lmd6remi`.
|
||||
|
||||
## 2. Environment
|
||||
|
||||
The implementation of Museformer relies on specific hardware and software environment.
|
||||
|
||||
- Hardware environment: we recommend Nvidia V100 32GB. Other GPUs are not tested.
|
||||
|
||||
- Software environment:
|
||||
|
||||
please ensure these packages are installed:
|
||||
|
||||
```
|
||||
CUDA: 11.3.1
|
||||
Python: 3.8
|
||||
fairseq: 0.10.2
|
||||
tensorboardX: 2.2
|
||||
```
|
||||
|
||||
And install [triton](https://github.com/openai/triton) at an arbitrary directory except `root_dir`:
|
||||
|
||||
```bash
|
||||
cd /path/to/a/directory/for/triton
|
||||
git clone https://github.com/openai/triton.git
|
||||
cd triton/python
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## 3. Train
|
||||
|
||||
Run the following command to train Museformer:
|
||||
|
||||
```bash
|
||||
bash ttrain/mf-lmd6remi-1.sh
|
||||
```
|
||||
|
||||
In our experiment, we run it on 4 GPUs, and the batch size is set to 1, so the real batch size is 4. Current implementation only supports batch size = 1. You may change `UPDATE_FREQ` to modify the real batch size.
|
||||
|
||||
In your first run, it may take some time to build up auxiliary information and compile CUDA kernels, so you may take a cup of coffee at this moment.
|
||||
|
||||
## 4. Evaluation
|
||||
|
||||
You can obtain perplexity on the test set with the following command:
|
||||
|
||||
```bash
|
||||
bash tval/val__mf-lmd6remi-x.sh 1 checkpoint_best.pt 10240
|
||||
```
|
||||
|
||||
The number `10240` indicates the maximum sequence length for calculation.
|
||||
|
||||
## 5. Inference
|
||||
|
||||
Use the following command to generate 5 music pieces, with the random seed set to 1:
|
||||
|
||||
```bash
|
||||
mkdir -p output_log
|
||||
seed=1
|
||||
printf '\n\n\n\n\n' | bash tgen/generation__mf-lmd6remi-x.sh 1 checkpoint_best.pt ${seed} | tee output_log/generation.log
|
||||
```
|
||||
|
||||
The number of `\n` controls the number of generated music pieces. The generation would take a while. Once done, the generation log will be saved at `output_log/generation.log`.
|
||||
|
||||
Then, use the following command to extract the generated token sequences from the generation log:
|
||||
|
||||
```bash
|
||||
python tools/batch_extract_log.py output_log/generation.log output/generation --start_idx 1
|
||||
```
|
||||
|
||||
You should see token representation of each generated music piece in `output/generation`.
|
||||
|
||||
Finally, run the following command to convert the token sequences into MIDI files:
|
||||
|
||||
```bash
|
||||
python tools/batch_generate_midis.py --encoding-method REMIGEN2 --input-dir output/generation --output-dir output/generation
|
||||
```
|
||||
|
||||
You should see the MIDI files in `output/generation`.
|
||||
|
|
@ -0,0 +1,556 @@
|
|||
b-0 1
|
||||
b-1 1
|
||||
o-0 1
|
||||
o-1 1
|
||||
o-2 1
|
||||
o-3 1
|
||||
o-4 1
|
||||
o-5 1
|
||||
o-6 1
|
||||
o-7 1
|
||||
o-8 1
|
||||
o-9 1
|
||||
o-10 1
|
||||
o-11 1
|
||||
o-12 1
|
||||
o-13 1
|
||||
o-14 1
|
||||
o-15 1
|
||||
o-16 1
|
||||
o-17 1
|
||||
o-18 1
|
||||
o-19 1
|
||||
o-20 1
|
||||
o-21 1
|
||||
o-22 1
|
||||
o-23 1
|
||||
o-24 1
|
||||
o-25 1
|
||||
o-26 1
|
||||
o-27 1
|
||||
o-28 1
|
||||
o-29 1
|
||||
o-30 1
|
||||
o-31 1
|
||||
o-32 1
|
||||
o-33 1
|
||||
o-34 1
|
||||
o-35 1
|
||||
o-36 1
|
||||
o-37 1
|
||||
o-38 1
|
||||
o-39 1
|
||||
o-40 1
|
||||
o-41 1
|
||||
o-42 1
|
||||
o-43 1
|
||||
o-44 1
|
||||
o-45 1
|
||||
o-46 1
|
||||
o-47 1
|
||||
o-48 1
|
||||
o-49 1
|
||||
o-50 1
|
||||
o-51 1
|
||||
o-52 1
|
||||
o-53 1
|
||||
o-54 1
|
||||
o-55 1
|
||||
o-56 1
|
||||
o-57 1
|
||||
o-58 1
|
||||
o-59 1
|
||||
o-60 1
|
||||
o-61 1
|
||||
o-62 1
|
||||
o-63 1
|
||||
i-0 1
|
||||
i-25 1
|
||||
i-43 1
|
||||
i-48 1
|
||||
i-80 1
|
||||
i-128 1
|
||||
p-0 1
|
||||
p-1 1
|
||||
p-2 1
|
||||
p-3 1
|
||||
p-4 1
|
||||
p-5 1
|
||||
p-6 1
|
||||
p-7 1
|
||||
p-8 1
|
||||
p-9 1
|
||||
p-10 1
|
||||
p-11 1
|
||||
p-12 1
|
||||
p-13 1
|
||||
p-14 1
|
||||
p-15 1
|
||||
p-16 1
|
||||
p-17 1
|
||||
p-18 1
|
||||
p-19 1
|
||||
p-20 1
|
||||
p-21 1
|
||||
p-22 1
|
||||
p-23 1
|
||||
p-24 1
|
||||
p-25 1
|
||||
p-26 1
|
||||
p-27 1
|
||||
p-28 1
|
||||
p-29 1
|
||||
p-30 1
|
||||
p-31 1
|
||||
p-32 1
|
||||
p-33 1
|
||||
p-34 1
|
||||
p-35 1
|
||||
p-36 1
|
||||
p-37 1
|
||||
p-38 1
|
||||
p-39 1
|
||||
p-40 1
|
||||
p-41 1
|
||||
p-42 1
|
||||
p-43 1
|
||||
p-44 1
|
||||
p-45 1
|
||||
p-46 1
|
||||
p-47 1
|
||||
p-48 1
|
||||
p-49 1
|
||||
p-50 1
|
||||
p-51 1
|
||||
p-52 1
|
||||
p-53 1
|
||||
p-54 1
|
||||
p-55 1
|
||||
p-56 1
|
||||
p-57 1
|
||||
p-58 1
|
||||
p-59 1
|
||||
p-60 1
|
||||
p-61 1
|
||||
p-62 1
|
||||
p-63 1
|
||||
p-64 1
|
||||
p-65 1
|
||||
p-66 1
|
||||
p-67 1
|
||||
p-68 1
|
||||
p-69 1
|
||||
p-70 1
|
||||
p-71 1
|
||||
p-72 1
|
||||
p-73 1
|
||||
p-74 1
|
||||
p-75 1
|
||||
p-76 1
|
||||
p-77 1
|
||||
p-78 1
|
||||
p-79 1
|
||||
p-80 1
|
||||
p-81 1
|
||||
p-82 1
|
||||
p-83 1
|
||||
p-84 1
|
||||
p-85 1
|
||||
p-86 1
|
||||
p-87 1
|
||||
p-88 1
|
||||
p-89 1
|
||||
p-90 1
|
||||
p-91 1
|
||||
p-92 1
|
||||
p-93 1
|
||||
p-94 1
|
||||
p-95 1
|
||||
p-96 1
|
||||
p-97 1
|
||||
p-98 1
|
||||
p-99 1
|
||||
p-100 1
|
||||
p-101 1
|
||||
p-102 1
|
||||
p-103 1
|
||||
p-104 1
|
||||
p-105 1
|
||||
p-106 1
|
||||
p-107 1
|
||||
p-108 1
|
||||
p-109 1
|
||||
p-110 1
|
||||
p-111 1
|
||||
p-112 1
|
||||
p-113 1
|
||||
p-114 1
|
||||
p-115 1
|
||||
p-116 1
|
||||
p-117 1
|
||||
p-118 1
|
||||
p-119 1
|
||||
p-120 1
|
||||
p-121 1
|
||||
p-122 1
|
||||
p-123 1
|
||||
p-124 1
|
||||
p-125 1
|
||||
p-126 1
|
||||
p-127 1
|
||||
p-128 1
|
||||
p-129 1
|
||||
p-130 1
|
||||
p-131 1
|
||||
p-132 1
|
||||
p-133 1
|
||||
p-134 1
|
||||
p-135 1
|
||||
p-136 1
|
||||
p-137 1
|
||||
p-138 1
|
||||
p-139 1
|
||||
p-140 1
|
||||
p-141 1
|
||||
p-142 1
|
||||
p-143 1
|
||||
p-144 1
|
||||
p-145 1
|
||||
p-146 1
|
||||
p-147 1
|
||||
p-148 1
|
||||
p-149 1
|
||||
p-150 1
|
||||
p-151 1
|
||||
p-152 1
|
||||
p-153 1
|
||||
p-154 1
|
||||
p-155 1
|
||||
p-156 1
|
||||
p-157 1
|
||||
p-158 1
|
||||
p-159 1
|
||||
p-160 1
|
||||
p-161 1
|
||||
p-162 1
|
||||
p-163 1
|
||||
p-164 1
|
||||
p-165 1
|
||||
p-166 1
|
||||
p-167 1
|
||||
p-168 1
|
||||
p-169 1
|
||||
p-170 1
|
||||
p-171 1
|
||||
p-172 1
|
||||
p-173 1
|
||||
p-174 1
|
||||
p-175 1
|
||||
p-176 1
|
||||
p-177 1
|
||||
p-178 1
|
||||
p-179 1
|
||||
p-180 1
|
||||
p-181 1
|
||||
p-182 1
|
||||
p-183 1
|
||||
p-184 1
|
||||
p-185 1
|
||||
p-186 1
|
||||
p-187 1
|
||||
p-188 1
|
||||
p-189 1
|
||||
p-190 1
|
||||
p-191 1
|
||||
p-192 1
|
||||
p-193 1
|
||||
p-194 1
|
||||
p-195 1
|
||||
p-196 1
|
||||
p-197 1
|
||||
p-198 1
|
||||
p-199 1
|
||||
p-200 1
|
||||
p-201 1
|
||||
p-202 1
|
||||
p-203 1
|
||||
p-204 1
|
||||
p-205 1
|
||||
p-206 1
|
||||
p-207 1
|
||||
p-208 1
|
||||
p-209 1
|
||||
p-210 1
|
||||
p-211 1
|
||||
p-212 1
|
||||
p-213 1
|
||||
p-214 1
|
||||
p-215 1
|
||||
p-216 1
|
||||
p-217 1
|
||||
p-218 1
|
||||
p-219 1
|
||||
p-220 1
|
||||
p-221 1
|
||||
p-222 1
|
||||
p-223 1
|
||||
p-224 1
|
||||
p-225 1
|
||||
p-226 1
|
||||
p-227 1
|
||||
p-228 1
|
||||
p-229 1
|
||||
p-230 1
|
||||
p-231 1
|
||||
p-232 1
|
||||
p-233 1
|
||||
p-234 1
|
||||
p-235 1
|
||||
p-236 1
|
||||
p-237 1
|
||||
p-238 1
|
||||
p-239 1
|
||||
p-240 1
|
||||
p-241 1
|
||||
p-242 1
|
||||
p-243 1
|
||||
p-244 1
|
||||
p-245 1
|
||||
p-246 1
|
||||
p-247 1
|
||||
p-248 1
|
||||
p-249 1
|
||||
p-250 1
|
||||
p-251 1
|
||||
p-252 1
|
||||
p-253 1
|
||||
p-254 1
|
||||
p-255 1
|
||||
d-0 1
|
||||
d-1 1
|
||||
d-2 1
|
||||
d-3 1
|
||||
d-4 1
|
||||
d-5 1
|
||||
d-6 1
|
||||
d-7 1
|
||||
d-8 1
|
||||
d-9 1
|
||||
d-10 1
|
||||
d-11 1
|
||||
d-12 1
|
||||
d-13 1
|
||||
d-14 1
|
||||
d-15 1
|
||||
d-16 1
|
||||
d-17 1
|
||||
d-18 1
|
||||
d-19 1
|
||||
d-20 1
|
||||
d-21 1
|
||||
d-22 1
|
||||
d-23 1
|
||||
d-24 1
|
||||
d-25 1
|
||||
d-26 1
|
||||
d-27 1
|
||||
d-28 1
|
||||
d-29 1
|
||||
d-30 1
|
||||
d-31 1
|
||||
d-32 1
|
||||
d-33 1
|
||||
d-34 1
|
||||
d-35 1
|
||||
d-36 1
|
||||
d-37 1
|
||||
d-38 1
|
||||
d-39 1
|
||||
d-40 1
|
||||
d-41 1
|
||||
d-42 1
|
||||
d-43 1
|
||||
d-44 1
|
||||
d-45 1
|
||||
d-46 1
|
||||
d-47 1
|
||||
d-48 1
|
||||
d-49 1
|
||||
d-50 1
|
||||
d-51 1
|
||||
d-52 1
|
||||
d-53 1
|
||||
d-54 1
|
||||
d-55 1
|
||||
d-56 1
|
||||
d-57 1
|
||||
d-58 1
|
||||
d-59 1
|
||||
d-60 1
|
||||
d-61 1
|
||||
d-62 1
|
||||
d-63 1
|
||||
d-64 1
|
||||
d-65 1
|
||||
d-66 1
|
||||
d-67 1
|
||||
d-68 1
|
||||
d-69 1
|
||||
d-70 1
|
||||
d-71 1
|
||||
d-72 1
|
||||
d-73 1
|
||||
d-74 1
|
||||
d-75 1
|
||||
d-76 1
|
||||
d-77 1
|
||||
d-78 1
|
||||
d-79 1
|
||||
d-80 1
|
||||
d-81 1
|
||||
d-82 1
|
||||
d-83 1
|
||||
d-84 1
|
||||
d-85 1
|
||||
d-86 1
|
||||
d-87 1
|
||||
d-88 1
|
||||
d-89 1
|
||||
d-90 1
|
||||
d-91 1
|
||||
d-92 1
|
||||
d-93 1
|
||||
d-94 1
|
||||
d-95 1
|
||||
d-96 1
|
||||
d-97 1
|
||||
d-98 1
|
||||
d-99 1
|
||||
d-100 1
|
||||
d-101 1
|
||||
d-102 1
|
||||
d-103 1
|
||||
d-104 1
|
||||
d-105 1
|
||||
d-106 1
|
||||
d-107 1
|
||||
d-108 1
|
||||
d-109 1
|
||||
d-110 1
|
||||
d-111 1
|
||||
d-112 1
|
||||
d-113 1
|
||||
d-114 1
|
||||
d-115 1
|
||||
d-116 1
|
||||
d-117 1
|
||||
d-118 1
|
||||
d-119 1
|
||||
d-120 1
|
||||
d-121 1
|
||||
d-122 1
|
||||
d-123 1
|
||||
d-124 1
|
||||
d-125 1
|
||||
d-126 1
|
||||
d-127 1
|
||||
v-0 1
|
||||
v-1 1
|
||||
v-2 1
|
||||
v-3 1
|
||||
v-4 1
|
||||
v-5 1
|
||||
v-6 1
|
||||
v-7 1
|
||||
v-8 1
|
||||
v-9 1
|
||||
v-10 1
|
||||
v-11 1
|
||||
v-12 1
|
||||
v-13 1
|
||||
v-14 1
|
||||
v-15 1
|
||||
v-16 1
|
||||
v-17 1
|
||||
v-18 1
|
||||
v-19 1
|
||||
v-20 1
|
||||
v-21 1
|
||||
v-22 1
|
||||
v-23 1
|
||||
v-24 1
|
||||
v-25 1
|
||||
v-26 1
|
||||
v-27 1
|
||||
v-28 1
|
||||
v-29 1
|
||||
v-30 1
|
||||
v-31 1
|
||||
t-0 1
|
||||
t-1 1
|
||||
t-2 1
|
||||
t-3 1
|
||||
t-4 1
|
||||
t-5 1
|
||||
t-6 1
|
||||
t-7 1
|
||||
t-8 1
|
||||
t-9 1
|
||||
t-10 1
|
||||
t-11 1
|
||||
t-12 1
|
||||
t-13 1
|
||||
t-14 1
|
||||
t-15 1
|
||||
t-16 1
|
||||
t-17 1
|
||||
t-18 1
|
||||
t-19 1
|
||||
t-20 1
|
||||
t-21 1
|
||||
t-22 1
|
||||
t-23 1
|
||||
t-24 1
|
||||
t-25 1
|
||||
t-26 1
|
||||
t-27 1
|
||||
t-28 1
|
||||
t-29 1
|
||||
t-30 1
|
||||
t-31 1
|
||||
t-32 1
|
||||
t-33 1
|
||||
t-34 1
|
||||
t-35 1
|
||||
t-36 1
|
||||
t-37 1
|
||||
t-38 1
|
||||
t-39 1
|
||||
t-40 1
|
||||
t-41 1
|
||||
t-42 1
|
||||
t-43 1
|
||||
t-44 1
|
||||
t-45 1
|
||||
t-46 1
|
||||
t-47 1
|
||||
t-48 1
|
||||
<c> 1
|
||||
</c> 1
|
||||
u-0 1
|
||||
u-1 1
|
||||
u-2 1
|
||||
u-3 1
|
||||
u-4 1
|
||||
u-5 1
|
||||
u-6 1
|
||||
u-7 1
|
||||
u-8 1
|
||||
u-9 1
|
||||
u-10 1
|
||||
u-11 1
|
||||
u-12 1
|
||||
u-13 1
|
||||
u-14 1
|
||||
u-15 1
|
||||
u-16 1
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Двоичный файл не отображается.
|
@ -0,0 +1,2 @@
|
|||
from .museformer_lm_task import MuseformerLanguageModelingTask
|
||||
from .museformer_lm import MuseformerLanguageModel
|
|
@ -0,0 +1,39 @@
|
|||
from ..tools import arg_tools
|
||||
|
||||
|
||||
def add_args(parser):
|
||||
parser.add_argument('--attn-query-proj-bias', type=arg_tools.str_bool_with_default_error)
|
||||
parser.add_argument('--attn-key-proj-bias', type=arg_tools.str_bool_with_default_error)
|
||||
parser.add_argument('--attn-value-proj-bias', type=arg_tools.str_bool_with_default_error)
|
||||
parser.add_argument('--attn-out-proj-bias', type=arg_tools.str_bool_with_default_error)
|
||||
|
||||
# valid: sum_then_reg, v2.1
|
||||
parser.add_argument('--attn-sum-key2-proj-bias', type=arg_tools.str_bool_with_default_error)
|
||||
parser.add_argument('--attn-sum-value2-proj-bias', type=arg_tools.str_bool_with_default_error)
|
||||
parser.add_argument('--attn-share-key2-value2-proj-weight', type=arg_tools.str_bool_with_default_error)
|
||||
|
||||
parser.add_argument('--add-different-kqv-bias-for-sum-and-reg', type=arg_tools.str_bool_with_default_error)
|
||||
parser.add_argument('--add-different-out-bias-for-sum-and-reg', type=arg_tools.str_bool_with_default_error)
|
||||
parser.add_argument('--attn-share-query-proj', type=arg_tools.str_bool_with_default_error)
|
||||
parser.add_argument('--attn-share-key-proj', type=arg_tools.str_bool_with_default_error)
|
||||
parser.add_argument('--attn-share-value-proj', type=arg_tools.str_bool_with_default_error)
|
||||
parser.add_argument('--attn-share-out-proj', type=arg_tools.str_bool_with_default_error)
|
||||
|
||||
|
||||
def create_attention_v2_s1(
|
||||
*args, implementation='mask', block_size=64, **kwargs
|
||||
):
|
||||
if implementation == 'triton':
|
||||
from .self_attention_v2s1.blocksparse_rpe_self_attention_v2s1 import BlocksparseRpeSelfAttentionV2S1
|
||||
return BlocksparseRpeSelfAttentionV2S1(
|
||||
*args, block_size=block_size, **kwargs
|
||||
)
|
||||
return NotImplementedError(implementation)
|
||||
|
||||
|
||||
def create_attention(
|
||||
*args, attention_mode='v2s1', **kwargs
|
||||
):
|
||||
if attention_mode == 'v2s1': # v2.1
|
||||
return create_attention_v2_s1(*args, **kwargs)
|
||||
raise NotImplementedError(attention_mode)
|
|
@ -0,0 +1,44 @@
|
|||
import torch
|
||||
from .....blocksparse import BlocksparseMatMul
|
||||
|
||||
|
||||
def do_sample_av_mul_base(self, sample_attn_weights, sample_v, sample_layout, real_part, sample_idx, tgt_len):
|
||||
if sample_layout is None:
|
||||
_, head, head_dim = sample_v.shape
|
||||
return sample_v.new_zeros(tgt_len, 1, head, head_dim)
|
||||
sample_v = sample_v.transpose(0, 1)[:, None] # (head, 1, reg_len, head_dim)
|
||||
dsd_matmul_key = (real_part, 'dsd_matmul', self.layer_sv, sample_idx)
|
||||
if dsd_matmul_key in self.instant_pocket:
|
||||
dsd_matmul = self.instant_pocket[dsd_matmul_key]
|
||||
else:
|
||||
dsd_matmul = BlocksparseMatMul(sample_layout, self.block_size, 'dsd',
|
||||
device=sample_v.device)
|
||||
self.instant_pocket[dsd_matmul_key] = dsd_matmul
|
||||
|
||||
sample_out = dsd_matmul(sample_attn_weights, sample_v) # (head, 1, tgt_len, head_dim)
|
||||
sample_out = sample_out.permute(2, 1, 0, 3) # (tgt_len, 1, head, head_dim)
|
||||
|
||||
return sample_out
|
||||
|
||||
|
||||
def do_av_mul_for_part(self, attn_weights_inc_part, v, attn_mask, real_part, tgt_len):
|
||||
attn_weights_for_part = attn_weights_inc_part[real_part]
|
||||
# samples list of (head, head_selected_blocks, block, block)
|
||||
bsz = len(attn_weights_for_part)
|
||||
attn_mask = attn_mask[real_part]
|
||||
result = []
|
||||
for sample_idx in range(bsz):
|
||||
sample_v = v[:, sample_idx]
|
||||
sample_attn_weights = attn_weights_for_part[sample_idx] # (head, head_selected_blocks, block, block)
|
||||
sample_layout = attn_mask[sample_idx][0]
|
||||
|
||||
sample_out = do_sample_av_mul_base(self, sample_attn_weights, sample_v, sample_layout, real_part,
|
||||
sample_idx, tgt_len)
|
||||
result.append(sample_out)
|
||||
|
||||
if bsz > 1:
|
||||
result = torch.cat(result, dim=1) # (tgt_len, bsz, num_heads, head_dim)
|
||||
else:
|
||||
result = result[0].contiguous()
|
||||
|
||||
return result
|
|
@ -0,0 +1,56 @@
|
|||
from .....blocksparse import BlocksparseMatMul
|
||||
|
||||
|
||||
def do_sample_qk_scores_base(
|
||||
self, sample_layout, sample_tgt, sample_src,
|
||||
tgt_len, src_len, tgt_label, sample_idx
|
||||
):
|
||||
if sample_layout is None:
|
||||
return None
|
||||
|
||||
# sample_layout: (1, tgt_block, src_block)
|
||||
assert sample_tgt.shape == (tgt_len, self.num_heads, self.head_dim)
|
||||
assert sample_src.shape == (src_len, self.num_heads, self.head_dim)
|
||||
assert sample_layout.shape == (1, tgt_len // self.block_size, src_len // self.block_size), \
|
||||
str(sample_layout.shape) + ' %d %d' % (tgt_len // self.block_size, src_len // self.block_size)
|
||||
|
||||
if tgt_len == 0 or src_len == 0:
|
||||
return sample_tgt.new_empty(self.num_heads, 0, self.block_size, self.block_size)
|
||||
|
||||
sdd_matmul_key = (tgt_label, 'sdd_matmul', self.layer_sv, sample_idx)
|
||||
if sdd_matmul_key in self.instant_pocket:
|
||||
sdd_matmul = self.instant_pocket[sdd_matmul_key]
|
||||
else:
|
||||
sdd_matmul = BlocksparseMatMul(sample_layout, self.block_size, 'sdd',
|
||||
device=sample_tgt.device, trans_b=True)
|
||||
self.instant_pocket[sdd_matmul_key] = sdd_matmul
|
||||
|
||||
sample_attn_scores = sdd_matmul(
|
||||
sample_tgt.transpose(0, 1)[:, None], # (heads, 1, sum_len, head_dim)
|
||||
sample_src.transpose(0, 1)[:, None], # (heads, 1, reg_len, head_dim)
|
||||
) # (heads, head_selected_blocks, block, block)
|
||||
assert sample_attn_scores.shape[1] == int(sample_layout[0].sum())
|
||||
|
||||
return sample_attn_scores
|
||||
|
||||
|
||||
def do_qk_scores_for_part(
|
||||
self,
|
||||
tgt, src,
|
||||
bsz, tgt_len, src_len,
|
||||
attn_mask, part_label,
|
||||
):
|
||||
# tgt: (tgt_len, bsz, num_heads, head_dim)
|
||||
# src: (src_len, bsz, num_heads, head_dim)
|
||||
part_attn_mask = attn_mask[part_label]
|
||||
attn_scores = []
|
||||
for idx in range(bsz):
|
||||
sample_layout = part_attn_mask[idx][0] # (1, tgt_block, src_block)
|
||||
sample_attn_scores = do_sample_qk_scores_base(
|
||||
self,
|
||||
sample_layout, tgt[:, idx], src[:, idx],
|
||||
tgt_len, src_len, part_label, idx
|
||||
)
|
||||
attn_scores.append(sample_attn_scores)
|
||||
attn_scores = {part_label: attn_scores}
|
||||
return attn_scores
|
|
@ -0,0 +1,55 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from .....kernels.range_fill import range_fill
|
||||
|
||||
|
||||
def select_and_do_r_proj(self, rel_indices):
|
||||
"""
|
||||
|
||||
:param rel_indices: relation list of real_part dict
|
||||
:return:
|
||||
"""
|
||||
|
||||
r_embed_list = []
|
||||
r_modified_indices = []
|
||||
for rel_idx, one_rel_indices in enumerate(rel_indices):
|
||||
# Collect those indices for one kind of relative positions that are used for all the samples and real_parts
|
||||
one_rel_used_indices = []
|
||||
for real_part in one_rel_indices: # ('sr', 'rx')
|
||||
one_rel_part_indices = one_rel_indices[real_part]
|
||||
if one_rel_part_indices is None:
|
||||
continue
|
||||
for sample_rel_indices in one_rel_part_indices:
|
||||
assert sample_rel_indices.shape[1:] == (self.block_size, self.block_size)
|
||||
sample_used_rel_indices = torch.unique(sample_rel_indices) # (num_unique,)
|
||||
one_rel_used_indices.append(sample_used_rel_indices)
|
||||
one_rel_used_indices = torch.cat(one_rel_used_indices).unique()
|
||||
|
||||
rel_selected_embed = F.embedding(one_rel_used_indices, self.rel_embeddings[rel_idx], padding_idx=0)
|
||||
|
||||
rel_proj = getattr(self, 'rel%d_proj' % rel_idx, None)
|
||||
if rel_proj is not None:
|
||||
rel_selected_embed = rel_proj(rel_selected_embed)
|
||||
|
||||
label_transform = range_fill(
|
||||
torch.stack((one_rel_used_indices, one_rel_used_indices + 1), dim=-1),
|
||||
torch.arange(len(one_rel_used_indices), device=one_rel_used_indices.device),
|
||||
self.num_rel_embeddings[rel_idx], 0
|
||||
)
|
||||
|
||||
one_r_indices = {}
|
||||
for real_part in one_rel_indices:
|
||||
one_rel_part_indices = one_rel_indices[real_part]
|
||||
if one_rel_part_indices is None:
|
||||
one_r_indices[real_part] = None
|
||||
continue
|
||||
samples_r_indices = []
|
||||
for sample_rel_indices in one_rel_part_indices:
|
||||
sample_r_indices = label_transform[sample_rel_indices]
|
||||
samples_r_indices.append(sample_r_indices)
|
||||
one_r_indices[real_part] = samples_r_indices
|
||||
|
||||
r_embed_list.append(rel_selected_embed)
|
||||
r_modified_indices.append(one_r_indices)
|
||||
|
||||
return r_embed_list, r_modified_indices
|
|
@ -0,0 +1,52 @@
|
|||
import torch
|
||||
|
||||
|
||||
def indexing_sample_rpe_base(self, sample_r, sample_r_indices, sample_layout, tgt_label, tgt_len, sample_idx):
|
||||
# sample_r: (heads, tgt_len, num_selected_distance)
|
||||
# sample_r_indices: (head_selected_blocks, block, block)
|
||||
# sample_layout: (1, tgt_block, src_block)
|
||||
sample_r = sample_r.view(self.num_heads, tgt_len // self.block_size, self.block_size, -1)
|
||||
sample_tgt_block_ids = sample_layout[0].nonzero()[:, 0] # (head_selected_blocks,)
|
||||
|
||||
temp_rpe = sample_r[
|
||||
self.num_heads_arange, # (8, 1, 1, 1)
|
||||
sample_tgt_block_ids[None, :, None, None], # (1, 4, 1, 1)
|
||||
self.block_size_arange, # (1, 1, block_size, 1)
|
||||
sample_r_indices[None], # (1, head_selected_blocks, block, block)
|
||||
] # (head, head_selected_blocks, block, block)
|
||||
|
||||
return temp_rpe
|
||||
|
||||
|
||||
def add_rpe_for_part(self, attn_scores, qs, r_list, rel_indices, bsz, query_len, attn_mask, part_label):
|
||||
attn_mask = attn_mask[part_label] # samples list of tuple (layout, block_mask)
|
||||
attn_scores_for_part = attn_scores[part_label] # samples list of (heads, head_selected_blocks, block, block)
|
||||
attn_scores_for_part_with_rpe = [item for item in attn_scores_for_part]
|
||||
for rel_idx in range(self.num_relation_types):
|
||||
r_indices = rel_indices[rel_idx]
|
||||
r_indices = r_indices[part_label]
|
||||
|
||||
if r_indices is None:
|
||||
continue
|
||||
|
||||
r_embed = r_list[rel_idx].view(-1, self.num_heads, self.head_dim) \
|
||||
# (num_selected_pos, heads, head_dim)
|
||||
r_qs = qs[rel_idx] # (sum_len, bsz, heads, head_dim)
|
||||
temp_r = torch.einsum("ibhd,jhd->bhij", r_qs, r_embed) # (bsz, heads, sum_len, num_selected_distance)
|
||||
for sample_idx in range(bsz):
|
||||
sample_r = temp_r[sample_idx] # (heads, sum_len, num_selected_distance)
|
||||
sample_r_indices = r_indices[sample_idx] # (head_selected_blocks, block, block)
|
||||
|
||||
sample_layout = attn_mask[sample_idx][0]
|
||||
|
||||
temp_rpe = indexing_sample_rpe_base(
|
||||
self, sample_r, sample_r_indices, sample_layout, part_label, query_len, sample_idx
|
||||
)
|
||||
|
||||
attn_scores_for_part_with_rpe[sample_idx] = attn_scores_for_part_with_rpe[sample_idx] + temp_rpe
|
||||
|
||||
attn_scores_for_part_with_rpe = {
|
||||
part_label: attn_scores_for_part_with_rpe
|
||||
}
|
||||
|
||||
return attn_scores_for_part_with_rpe
|
|
@ -0,0 +1,39 @@
|
|||
import torch
|
||||
from .....blocksparse import BlocksparseSoftmax
|
||||
|
||||
|
||||
def do_attn_softmax_for_part(self, attn_scores, attn_mask, real_part, mask_again=False):
|
||||
attn_scores_for_real_part = attn_scores[real_part] \
|
||||
# samples list of (heads, head_selected_blocks, block, block)
|
||||
attn_mask = attn_mask[real_part] # samples list of layout, block_mask
|
||||
bsz = len(attn_mask)
|
||||
result = [None] * bsz
|
||||
for sample_idx in range(bsz):
|
||||
sample_attn_mask = attn_mask[sample_idx]
|
||||
if sample_attn_mask is None:
|
||||
continue
|
||||
sample_layout, sample_block_mask = sample_attn_mask
|
||||
if sample_layout is None:
|
||||
continue
|
||||
|
||||
if sample_block_mask.dtype == torch.uint8:
|
||||
sample_block_mask = sample_block_mask.eq(1)
|
||||
assert sample_block_mask.dtype == torch.bool
|
||||
result[sample_idx] = attn_scores_for_real_part[sample_idx].masked_fill(sample_block_mask[None], -10000)
|
||||
|
||||
softmax_label = (real_part, 'softmax', self.layer_sv, sample_idx)
|
||||
if softmax_label in self.instant_pocket:
|
||||
softmax = self.instant_pocket[softmax_label]
|
||||
else:
|
||||
softmax = BlocksparseSoftmax(sample_layout, self.block_size)
|
||||
self.instant_pocket[softmax_label] = softmax
|
||||
|
||||
temp = softmax(result[sample_idx])
|
||||
if mask_again:
|
||||
temp = temp.masked_fill(sample_block_mask[None], 0.0)
|
||||
if self.dropout_module is not None:
|
||||
temp = self.dropout_module(temp)
|
||||
result[sample_idx] = temp
|
||||
|
||||
result = {real_part: result}
|
||||
return result
|
|
@ -0,0 +1,5 @@
|
|||
def print_redundant_params(kwargs, class_name=None):
|
||||
print('=====Redundant Params%s=====' % ('' if class_name is None else ' for %s' % class_name))
|
||||
for key in kwargs:
|
||||
print('{key}:\t{value}'.format(key=key, value=kwargs[key]))
|
||||
print('=============')
|
|
@ -0,0 +1,62 @@
|
|||
import torch
|
||||
|
||||
from .rpe_self_attention_v2s1 import RpeSelfAttentionV2S1
|
||||
from ..common.blocksparse_common_operations.qk_mul.qk_mul_1 import do_qk_scores_for_part
|
||||
from ..common.blocksparse_common_operations.softmax.softmax_1 import do_attn_softmax_for_part
|
||||
from ..common.blocksparse_common_operations.av_mul.av_mul_1 import do_av_mul_for_part
|
||||
|
||||
|
||||
class BlocksparseRpeSelfAttentionV2S1(RpeSelfAttentionV2S1):
|
||||
def __init__(self, *args, block_size=64, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.block_size = block_size
|
||||
|
||||
# --- for indexing relative position embeddings ---
|
||||
# num_heads_arange = torch.arange(self.num_heads)[:, None, None, None] # (num_heads, 1, 1, 1)
|
||||
# self.register_buffer('num_heads_arange', num_heads_arange, persistent=False)
|
||||
# block_size_arange = torch.arange(self.block_size)[None, None, :, None] # (1, 1, block_size, 1)
|
||||
# self.register_buffer('block_size_arange', block_size_arange, persistent=False)
|
||||
|
||||
# ============= Interfaces =============
|
||||
def do_qk_scores_for_sx(self, base_sum_q, base_sum_k, base_reg_k, bsz, sum_len, reg_len, attn_mask=None):
|
||||
# base_sum_q: (sum_len, bsz, heads, head_dim)
|
||||
# base_sum_k: (sum_len, bsz, heads, head_dim)
|
||||
# base_reg_k: (reg_len, bsz, heads, head_dim)
|
||||
tgt = base_sum_q
|
||||
src = torch.cat((base_sum_k, base_reg_k), dim=0)
|
||||
return do_qk_scores_for_part(self, tgt, src, bsz, sum_len, sum_len + reg_len, attn_mask, 'sx')
|
||||
|
||||
def do_masking_for_sx(self, attn_scores_inc, attn_mask):
|
||||
return attn_scores_inc
|
||||
|
||||
def do_attn_softmax_for_sx(self, attn_scores_inc_sr, attn_mask=None):
|
||||
return do_attn_softmax_for_part(self, attn_scores_inc_sr, attn_mask, 'sx')
|
||||
|
||||
def do_av_mul_for_sx(self, attn_weights_inc_sr, base_sum_v, base_reg_v, attn_mask=None, tgt_len=None):
|
||||
v = torch.cat((base_sum_v, base_reg_v), dim=0)
|
||||
return do_av_mul_for_part(self, attn_weights_inc_sr, v, attn_mask, 'sx', tgt_len)
|
||||
|
||||
def do_qk_scores_for_rx(
|
||||
self,
|
||||
reg_q, sum_k, reg_k,
|
||||
bsz, sum_len, reg_len,
|
||||
attn_mask=None
|
||||
):
|
||||
if sum_k is None:
|
||||
k = reg_k
|
||||
else:
|
||||
k = torch.cat((sum_k, reg_k), dim=0)
|
||||
return do_qk_scores_for_part(self, reg_q, k, bsz, reg_len, sum_len + reg_len, attn_mask, 'rx')
|
||||
|
||||
def do_masking_for_rx(self, attn_scores_inc, attn_mask):
|
||||
return attn_scores_inc
|
||||
|
||||
def do_attn_softmax_for_rx(self, attn_scores_inc, attn_mask=None):
|
||||
return do_attn_softmax_for_part(self, attn_scores_inc, attn_mask, 'rx')
|
||||
|
||||
def do_av_mul_for_rx(self, attn_weights_inc, base_sum_v2, base_reg_v, attn_mask=None, tgt_len=None):
|
||||
if base_sum_v2 is None:
|
||||
v = base_reg_v
|
||||
else:
|
||||
v = torch.cat((base_sum_v2, base_reg_v), dim=0)
|
||||
return do_av_mul_for_part(self, attn_weights_inc, v, attn_mask, 'rx', tgt_len)
|
|
@ -0,0 +1,337 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from fairseq.modules.fairseq_dropout import FairseqDropout
|
||||
|
||||
from ...data_structures.four_dim_pocket import FourDimPocket
|
||||
from ..common import common_funcs
|
||||
|
||||
|
||||
class RpeSelfAttentionV2S1(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim,
|
||||
num_heads,
|
||||
|
||||
num_summary,
|
||||
|
||||
layer_idx,
|
||||
|
||||
dropout=0.0, # attention dropout
|
||||
|
||||
query_proj_bias=True,
|
||||
key_proj_bias=True,
|
||||
value_proj_bias=True,
|
||||
sum_key2_proj_bias=True,
|
||||
sum_value2_proj_bias=True,
|
||||
out_proj_bias=True,
|
||||
add_different_kqv_bias_for_sum_and_reg=False,
|
||||
add_different_out_bias_for_sum_and_reg=False,
|
||||
|
||||
share_query_proj=False,
|
||||
share_key_proj=False,
|
||||
share_value_proj=False,
|
||||
share_out_proj=False,
|
||||
share_key2_value2_proj_weight=False,
|
||||
|
||||
max_summary=None,
|
||||
|
||||
no_sum_out=False,
|
||||
|
||||
single_head_masks=False,
|
||||
|
||||
**kwargs
|
||||
):
|
||||
assert single_head_masks, "Currently, we only support single head masks."
|
||||
common_funcs.print_redundant_params(kwargs, self.__class__.__name__)
|
||||
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.single_head_masks = single_head_masks
|
||||
self.num_summary = num_summary
|
||||
self.max_summary = self.num_summary if max_summary is None else max_summary
|
||||
|
||||
self.no_sum_out = no_sum_out
|
||||
|
||||
self.pocket = FourDimPocket()
|
||||
self.instant_pocket = self.pocket['instant']
|
||||
constant_pocket = self.pocket['constant']
|
||||
layer_to_sv = constant_pocket['layer_to_sv']
|
||||
self.layer_sv = layer_to_sv[self.layer_idx]
|
||||
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim
|
||||
), "attention_embed_dim must be divisible by num_heads"
|
||||
self.scaling = self.head_dim ** -0.5
|
||||
|
||||
if add_different_kqv_bias_for_sum_and_reg:
|
||||
query_proj_bias = False
|
||||
key_proj_bias = False
|
||||
value_proj_bias = False
|
||||
for proj_name in ('query', 'key', 'value'):
|
||||
for target in (('sum', 'reg') if self.num_summary > 0 else ('reg',)):
|
||||
bias_tensor = torch.zeros(self.embed_dim)
|
||||
self.register_parameter(
|
||||
'%s_%s_bias' % (target, proj_name),
|
||||
nn.Parameter(bias_tensor, requires_grad=True)
|
||||
)
|
||||
|
||||
if add_different_out_bias_for_sum_and_reg:
|
||||
out_proj_bias = False
|
||||
for target in (('sum', 'reg') if not self.no_sum_out and self.num_summary > 0 else ('reg',)):
|
||||
bias_tensor = torch.zeros(self.embed_dim)
|
||||
self.register_parameter(
|
||||
'%s_out_bias' % target,
|
||||
nn.Parameter(bias_tensor, requires_grad=True)
|
||||
)
|
||||
|
||||
self.reg_query_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=query_proj_bias)
|
||||
self.reg_key_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=key_proj_bias)
|
||||
self.reg_value_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=value_proj_bias)
|
||||
self.reg_out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=out_proj_bias)
|
||||
|
||||
if self.num_summary > 0:
|
||||
if share_query_proj:
|
||||
self.sum_query_proj = self.reg_query_proj
|
||||
else:
|
||||
self.sum_query_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=query_proj_bias)
|
||||
|
||||
if share_key_proj:
|
||||
self.sum_key_proj = self.reg_key_proj
|
||||
else:
|
||||
self.sum_key_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=key_proj_bias)
|
||||
|
||||
if share_value_proj:
|
||||
self.sum_value_proj = self.reg_value_proj
|
||||
else:
|
||||
self.sum_value_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=value_proj_bias)
|
||||
|
||||
if not self.no_sum_out:
|
||||
if share_out_proj:
|
||||
self.sum_out_proj = self.reg_out_proj
|
||||
else:
|
||||
self.sum_out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=out_proj_bias)
|
||||
|
||||
self.share_key2_value2_proj_weight = share_key2_value2_proj_weight
|
||||
if share_key2_value2_proj_weight:
|
||||
self.sum_key2_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||
self.sum_value2_proj = self.sum_key2_proj
|
||||
if sum_key2_proj_bias:
|
||||
self.sum_key2_bias = nn.Parameter(torch.zeros(self.embed_dim), requires_grad=True)
|
||||
if sum_value2_proj_bias:
|
||||
self.sum_value2_bias = nn.Parameter(torch.zeros(self.embed_dim), requires_grad=True)
|
||||
else:
|
||||
self.sum_key2_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=sum_key2_proj_bias)
|
||||
self.sum_value2_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=sum_value2_proj_bias)
|
||||
|
||||
self.dropout_module = FairseqDropout(
|
||||
dropout, module_name=self.__class__.__name__
|
||||
) if dropout > 0.0 else None
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.xavier_uniform_(self.reg_query_proj.weight, gain=1 / math.sqrt(2))
|
||||
nn.init.xavier_uniform_(self.reg_key_proj.weight, gain=1 / math.sqrt(2))
|
||||
nn.init.xavier_uniform_(self.reg_value_proj.weight, gain=1 / math.sqrt(2))
|
||||
nn.init.xavier_uniform_(self.reg_out_proj.weight)
|
||||
if self.num_summary > 0:
|
||||
if id(self.sum_query_proj) != id(self.reg_query_proj):
|
||||
nn.init.xavier_uniform_(self.sum_query_proj.weight, gain=1 / math.sqrt(2))
|
||||
if id(self.sum_key_proj) != id(self.reg_key_proj):
|
||||
nn.init.xavier_uniform_(self.sum_key_proj.weight, gain=1 / math.sqrt(2))
|
||||
if id(self.sum_value_proj) != id(self.reg_value_proj):
|
||||
nn.init.xavier_uniform_(self.sum_value_proj.weight, gain=1 / math.sqrt(2))
|
||||
nn.init.xavier_uniform_(self.sum_key2_proj.weight, gain=1 / math.sqrt(2))
|
||||
nn.init.xavier_uniform_(self.sum_value2_proj.weight, gain=1 / math.sqrt(2))
|
||||
if not self.no_sum_out and id(self.sum_out_proj) != id(self.reg_out_proj):
|
||||
nn.init.xavier_uniform_(self.sum_out_proj.weight)
|
||||
|
||||
nn.init.xavier_uniform_(self.reg_out_proj.weight)
|
||||
if self.reg_out_proj.bias is not None:
|
||||
nn.init.constant_(self.reg_out_proj.bias, 0.0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: tuple, # (sum_len, bsz, embed_dim), (reg_len, bsz, embed_dim)
|
||||
sum_token_ids, # (bsz, sum_len)
|
||||
sum_len,
|
||||
reg_len,
|
||||
key_padding_mask=None, # (bsz, all_seq_len)
|
||||
attn_mask=None,
|
||||
need_weights: bool = False,
|
||||
need_head_weights: bool = False,
|
||||
*args, **kwargs,
|
||||
):
|
||||
if key_padding_mask is not None:
|
||||
raise NotImplementedError("Please combine key_padding_mask into attn_mask ahead.")
|
||||
del key_padding_mask
|
||||
|
||||
if need_head_weights:
|
||||
need_weights = True
|
||||
|
||||
# ===== Input Checking =====
|
||||
sum_x, reg_x = x
|
||||
bsz = reg_x.shape[1]
|
||||
del sum_token_ids
|
||||
|
||||
# ===== Summarize =====
|
||||
base_reg_k = self.reg_key_proj(reg_x)
|
||||
bias = getattr(self, 'reg_key_bias', None)
|
||||
if bias is not None:
|
||||
base_reg_k = base_reg_k + bias
|
||||
base_reg_k = base_reg_k.view(reg_len, bsz, self.num_heads, self.head_dim)
|
||||
# base_reg_k: (reg_len, bsz, num_heads, head_dim)
|
||||
|
||||
base_reg_v = self.reg_value_proj(reg_x)
|
||||
bias = getattr(self, 'reg_value_bias', None)
|
||||
if bias is not None:
|
||||
base_reg_v = base_reg_v + bias
|
||||
base_reg_v = base_reg_v.view(reg_len, bsz, self.num_heads, self.head_dim)
|
||||
# base_reg_v: (reg_len, bsz, num_heads, head_dim)
|
||||
|
||||
if sum_len > 0:
|
||||
base_sum_q = self.sum_query_proj(sum_x)
|
||||
bias = getattr(self, 'sum_query_bias', None)
|
||||
if bias is not None:
|
||||
base_sum_q = base_sum_q + bias
|
||||
base_sum_q = base_sum_q.view(sum_len, bsz, self.num_heads, self.head_dim)
|
||||
# base_sum_q: (sum_len, bsz, num_heads, head_dim)
|
||||
|
||||
base_sum_k = self.sum_key_proj(sum_x)
|
||||
bias = getattr(self, 'sum_key_bias', None)
|
||||
if bias is not None:
|
||||
base_sum_k = base_sum_k + bias
|
||||
base_sum_k = base_sum_k.view(sum_len, bsz, self.num_heads, self.head_dim)
|
||||
|
||||
base_sum_v = self.sum_value_proj(sum_x)
|
||||
bias = getattr(self, 'sum_value_bias', None)
|
||||
if bias is not None:
|
||||
base_sum_v = base_sum_v + bias
|
||||
base_sum_v = base_sum_v.view(sum_len, bsz, self.num_heads, self.head_dim)
|
||||
|
||||
attn_scores_inc_sx = self.do_qk_scores_for_sx(
|
||||
base_sum_q, base_sum_k, base_reg_k,
|
||||
bsz, sum_len, reg_len,
|
||||
attn_mask=attn_mask
|
||||
) # real_parts dict of sample list of (heads, head_selected_blocks, block, block)
|
||||
|
||||
for real_part in attn_scores_inc_sx:
|
||||
for sample_item in attn_scores_inc_sx[real_part]:
|
||||
sample_item.mul_(self.scaling)
|
||||
|
||||
attn_scores_inc_sx = self.do_masking_for_sx(attn_scores_inc_sx, attn_mask)
|
||||
|
||||
attn_weights_inc_sx = self.do_attn_softmax_for_sx(attn_scores_inc_sx, attn_mask=attn_mask)
|
||||
del attn_scores_inc_sx
|
||||
|
||||
sum_x2 = self.do_av_mul_for_sx(
|
||||
attn_weights_inc_sx, base_sum_v, base_reg_v, attn_mask=attn_mask, tgt_len=sum_len
|
||||
) # samples list of (sum_len, 1, num_heads, head_dim)
|
||||
assert sum_x2.shape == (sum_len, bsz, self.num_heads, self.head_dim)
|
||||
sum_x2 = sum_x2.view(sum_len, bsz, self.embed_dim)
|
||||
|
||||
if self.share_key2_value2_proj_weight:
|
||||
base_sum_k2 = self.sum_key2_proj(sum_x2)
|
||||
base_sum_v2 = base_sum_k2
|
||||
sum_key2_bias = getattr(self, 'sum_key2_bias', None)
|
||||
if sum_key2_bias is not None:
|
||||
base_sum_k2 = base_sum_k2 + sum_key2_bias
|
||||
sum_value2_bias = getattr(self, 'sum_value2_bias', None)
|
||||
if sum_value2_bias is not None:
|
||||
base_sum_v2 = base_sum_v2 + sum_value2_bias
|
||||
base_sum_k2 = base_sum_k2.view(sum_len, bsz, self.num_heads, self.head_dim)
|
||||
base_sum_v2 = base_sum_v2.view(sum_len, bsz, self.num_heads, self.head_dim)
|
||||
else:
|
||||
base_sum_k2 = self.sum_key2_proj(sum_x2).view(sum_len, bsz, self.num_heads, self.head_dim)
|
||||
base_sum_v2 = self.sum_value2_proj(sum_x2).view(sum_len, bsz, self.num_heads, self.head_dim)
|
||||
|
||||
else:
|
||||
sum_x2 = reg_x.new_empty(0, bsz, self.embed_dim)
|
||||
base_sum_k2 = None
|
||||
base_sum_v2 = None
|
||||
|
||||
# ===== Updating =====
|
||||
base_reg_q = self.reg_query_proj(reg_x)
|
||||
reg_query_bias = getattr(self, 'reg_query_bias', None)
|
||||
if reg_query_bias is not None:
|
||||
base_reg_q = base_reg_q + reg_query_bias
|
||||
base_reg_q = base_reg_q.view(reg_len, bsz, self.num_heads, self.head_dim)
|
||||
|
||||
attn_scores_inc_rx = self.do_qk_scores_for_rx(
|
||||
base_reg_q, base_sum_k2, base_reg_k,
|
||||
bsz, sum_len, reg_len, attn_mask=attn_mask,
|
||||
)
|
||||
|
||||
for real_part in attn_scores_inc_rx:
|
||||
for sample_item in attn_scores_inc_rx[real_part]:
|
||||
sample_item.mul_(self.scaling)
|
||||
|
||||
attn_scores_inc_rx = self.do_masking_for_rx(attn_scores_inc_rx, attn_mask)
|
||||
|
||||
attn_weights_inc_rx = self.do_attn_softmax_for_rx(attn_scores_inc_rx, attn_mask=attn_mask)
|
||||
|
||||
# if self.layer_idx == 3:
|
||||
# with open('attn_weights_inc_rx.bin', 'wb') as f:
|
||||
# torch.save(attn_weights_inc_rx, f)
|
||||
# print('saved attn_weights_inc_rx')
|
||||
|
||||
reg_output = self.do_av_mul_for_rx(
|
||||
attn_weights_inc_rx, base_sum_v2, base_reg_v, attn_mask=attn_mask, tgt_len=reg_len
|
||||
) # (reg_len, bsz, num_heads, head_dim)
|
||||
|
||||
# ----- gate to combine sum_output and reg_output -----
|
||||
reg_output = reg_output.view(reg_len, bsz, self.embed_dim)
|
||||
reg_output = self.reg_out_proj(reg_output)
|
||||
reg_out_bias = getattr(self, 'reg_out_bias', None)
|
||||
if reg_out_bias is not None:
|
||||
reg_output = reg_output + reg_out_bias
|
||||
if not self.no_sum_out and self.num_summary > 0:
|
||||
sum_output = self.sum_out_proj(sum_x2)
|
||||
sum_out_bias = getattr(self, 'sum_out_bias', None)
|
||||
if sum_out_bias is not None:
|
||||
sum_output = sum_output + sum_out_bias
|
||||
else:
|
||||
sum_output = None
|
||||
|
||||
if need_weights:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
attn_weights = None
|
||||
|
||||
return (sum_output, reg_output), attn_weights
|
||||
# (sum_len, bsz, embed_dim) (reg_len, bsz, embed_dim)
|
||||
# None, (bsz, num_heads, all_seq_len, all_seq_len) or (bsz, all_seq_len, all_seq_len)
|
||||
|
||||
def do_qk_scores_for_sx(self, base_sum_q, base_sum_k, base_reg_k, bsz, sum_len, reg_len, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def do_masking_for_sx(self, attn_scores_inc, attn_mask):
|
||||
raise NotImplementedError
|
||||
|
||||
def do_attn_softmax_for_sx(self, attn_scores_for_sum, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def do_av_mul_for_sx(self, attn_weights_inc_sr, base_sum_v, base_reg_v, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def do_qk_scores_for_rx(
|
||||
self,
|
||||
reg_q, sum_k, reg_k,
|
||||
bsz, sum_len, reg_len,
|
||||
**kwargs
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
def do_masking_for_rx(self, attn_scores_for_reg, attn_mask):
|
||||
raise NotImplementedError
|
||||
|
||||
def do_attn_softmax_for_rx(self, attn_scores_for_reg, attn_mask=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def do_av_mul_for_rx(self, attn_weights_inc, base_sum_v2, base_reg_v, **kwargs):
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,404 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .multihead_block_selection_generation import MultiheadBlockSelectionGeneration, get_block_ranges
|
||||
from ..kernels.block_fill import block_fill_
|
||||
from ..tools import computation_tools
|
||||
from ..data_structures.attention_scheme import LayerAttentionScheme
|
||||
|
||||
PREF2PREF_CHOICES = ('full', 'lt', 'none')
|
||||
SUM2PREF_CHOICES = ('none',)
|
||||
PREF2SUM_CHOICES = ('none',)
|
||||
CON2PREF_CHOICES = ('full',)
|
||||
PREF2CON_CHOICES = ('none',)
|
||||
|
||||
|
||||
def construct_sum_chunk_ranges(max_complete_chunk, num_summary, device):
|
||||
ranges = torch.arange(0, max_complete_chunk * num_summary + num_summary, num_summary, device=device)
|
||||
ranges = torch.stack((ranges[:-1], ranges[1:]), dim=-1) # (max_complete_chunk, 2)
|
||||
return ranges
|
||||
|
||||
|
||||
def select_self_(selection, max_comp_chunk, mode):
|
||||
if mode == 'default':
|
||||
return selection
|
||||
ndim = selection.ndim
|
||||
eye = torch.eye(max_comp_chunk, dtype=torch.bool, device=selection.device)
|
||||
if ndim >= 3:
|
||||
num_heads = selection.shape[-1]
|
||||
eye = eye[:, :, None].expand(-1, -1, num_heads)
|
||||
if mode == 'none':
|
||||
selection[eye] = False
|
||||
elif mode == 'full':
|
||||
selection[eye] = True
|
||||
else:
|
||||
raise ValueError(mode)
|
||||
return selection
|
||||
|
||||
|
||||
def fill_mask_with_selection_(selection_gen: MultiheadBlockSelectionGeneration,
|
||||
row_ranges, col_ranges, selection_mask, fill_value, out):
|
||||
"""
|
||||
|
||||
:param selection_gen:
|
||||
:param row_ranges: (num_row_chunks, 2)
|
||||
:param col_ranges: (num_col_ranges, 2)
|
||||
:param selection_mask: (num_row_chunks, num_col_ranges, num_heads)
|
||||
:param fill_value: bool
|
||||
:param out: (num_heads, row_len, col_len)
|
||||
:return:
|
||||
"""
|
||||
assert row_ranges.ndim == 2
|
||||
assert col_ranges.ndim == 2
|
||||
assert out.is_contiguous()
|
||||
block_indices, command_masks = selection_gen.get_block_indices_and_masks_from_selection_masks(
|
||||
selection_mask, overall_mask=None
|
||||
) # (num_selected_blocks, 2) (num_selected_blocks, num_heads)
|
||||
block_ranges = get_block_ranges(row_ranges, col_ranges, block_indices) # (num_selected_blocks, 4)
|
||||
block_fill_(out, block_ranges, command_masks, fill_value)
|
||||
return out
|
||||
|
||||
|
||||
def partially_combine_part_masks(part_masks_dict):
|
||||
row_sum = computation_tools.may_bi_cat(part_masks_dict['ss'], part_masks_dict['sr'], dim=3)
|
||||
row_reg = computation_tools.may_bi_cat(part_masks_dict['rs'], part_masks_dict['rr'], dim=3)
|
||||
return row_sum, row_reg
|
||||
|
||||
|
||||
def combine_part_masks(part_masks_dict):
|
||||
row_sum, row_reg = partially_combine_part_masks(part_masks_dict)
|
||||
return computation_tools.may_bi_cat(row_sum, row_reg, dim=2)
|
||||
|
||||
|
||||
def combine_attn_mask_and_key_padding_mask_(attn_mask, key_padding_mask):
|
||||
"""
|
||||
|
||||
:param attn_mask: (bsz, num_heads, seq_len, seq_len)
|
||||
:param key_padding_mask: (bsz, seq_len)
|
||||
:return:
|
||||
"""
|
||||
if attn_mask is None:
|
||||
return attn_mask
|
||||
attn_mask = attn_mask.contiguous()
|
||||
attn_mask.masked_fill_(key_padding_mask[:, None, None], True)
|
||||
return attn_mask
|
||||
|
||||
|
||||
def layout_full_zero_check(input_layout):
|
||||
row_check = input_layout.sum(dim=2).eq(0) # (H, L // block)
|
||||
col_check = input_layout.sum(dim=1).eq(0)
|
||||
row_answer = bool(row_check.any())
|
||||
col_answer = bool(col_check.any())
|
||||
return row_answer or col_answer, row_answer, col_answer, row_check, col_check
|
||||
|
||||
|
||||
def transfer_attn_mask_to_block_layout(attn_mask, block_size, avoid_empty_row_or_column=True):
|
||||
"""
|
||||
|
||||
:param attn_mask: (bsz_heads, tgt_len, src_len) bool False for selected
|
||||
:param block_size: int
|
||||
:param avoid_empty_row_or_column
|
||||
:return:
|
||||
"""
|
||||
bsz_heads, tgt_len, src_len = attn_mask.shape
|
||||
num_tgt_blocks = tgt_len // block_size
|
||||
num_src_blocks = src_len // block_size
|
||||
|
||||
assert num_tgt_blocks * block_size == tgt_len
|
||||
assert num_src_blocks * block_size == src_len
|
||||
|
||||
assert attn_mask.dtype in (torch.bool, torch.uint8)
|
||||
attn_mask = attn_mask.view(
|
||||
bsz_heads, num_tgt_blocks, block_size, num_src_blocks, block_size
|
||||
).permute(0, 1, 3, 2, 4) # (bsz * num_heads, num_tgt_blocks, num_src_blocks, block_size, block_size)
|
||||
layout = (~(attn_mask.bool())).reshape(
|
||||
bsz_heads, num_tgt_blocks, num_src_blocks, block_size * block_size
|
||||
).sum(dim=-1).gt(0) # (bsz * num_heads, num_tgt_blocks, num_src_blocks)
|
||||
if not layout.any(): # for this part of attn_mask, there's nothing to compute
|
||||
return None, None
|
||||
if avoid_empty_row_or_column:
|
||||
from ..blocksparse import layout_full_zero_check
|
||||
answer, row_answer, col_answer, row_check, col_check = layout_full_zero_check(layout)
|
||||
if answer:
|
||||
if row_answer:
|
||||
row_zero_indices = row_check.nonzero()
|
||||
# print('row:')
|
||||
# print(row_zero_indices)
|
||||
layout[row_zero_indices[:, 0], row_zero_indices[:, 1], 0] = True
|
||||
if col_answer:
|
||||
col_zero_indices = col_check.nonzero()
|
||||
# print('col:')
|
||||
# print(col_zero_indices)
|
||||
layout[col_zero_indices[:, 0], 0, col_zero_indices[:, 1]] = True
|
||||
assert not layout_full_zero_check(layout)[0]
|
||||
block_mask = attn_mask.masked_select(layout[:, :, :, None, None]).view(-1, block_size, block_size)
|
||||
|
||||
return layout, block_mask
|
||||
|
||||
|
||||
class LayerAttentionMaskGeneration(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
layer_attention_scheme: LayerAttentionScheme,
|
||||
gen_parts=None,
|
||||
single_head=True,
|
||||
init_max_chunk=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.layer_attention_scheme = layer_attention_scheme
|
||||
self.single_head = single_head
|
||||
if not self.layer_attention_scheme.same_for_all_heads:
|
||||
raise NotImplementedError("Not support for different attention schemes over different heads.")
|
||||
|
||||
self.num_heads = self.layer_attention_scheme.num_layer_heads
|
||||
self.num_fake_heads = 1
|
||||
self.single_head_scheme = self.layer_attention_scheme[0]
|
||||
|
||||
assert self.single_head_scheme.head_pref2pref_mode in PREF2PREF_CHOICES
|
||||
self.pref2pref_mode = self.single_head_scheme.head_pref2pref_mode
|
||||
assert self.single_head_scheme.head_sum2pref_mode in SUM2PREF_CHOICES
|
||||
self.sum2pref_mode = self.single_head_scheme.head_sum2pref_mode
|
||||
assert self.single_head_scheme.head_pref2sum_mode in PREF2SUM_CHOICES
|
||||
self.pref2sum_mode = self.single_head_scheme.head_pref2sum_mode
|
||||
assert self.single_head_scheme.head_con2pref_mode in CON2PREF_CHOICES
|
||||
self.con2pref_mode = self.single_head_scheme.head_con2pref_mode
|
||||
assert self.single_head_scheme.head_pref2con_mode in PREF2CON_CHOICES
|
||||
self.pref2con_mode = self.single_head_scheme.head_pref2con_mode
|
||||
|
||||
self.gen_parts = gen_parts
|
||||
|
||||
self.con2con_gen, self.con2sum_gen, self.sum2con_gen, self.sum2sum_gen = None, None, None, None
|
||||
|
||||
if gen_parts is None or 'rr' in gen_parts:
|
||||
self.con2con_gen = MultiheadBlockSelectionGeneration((self.single_head_scheme.head_con2con,),
|
||||
init_max_chunk=init_max_chunk)
|
||||
if gen_parts is None or 'rs' in gen_parts:
|
||||
self.con2sum_gen = MultiheadBlockSelectionGeneration((self.single_head_scheme.head_con2sum,),
|
||||
init_max_chunk=init_max_chunk)
|
||||
if gen_parts is None or 'sr' in gen_parts:
|
||||
self.sum2con_gen = MultiheadBlockSelectionGeneration((self.single_head_scheme.head_sum2con,),
|
||||
init_max_chunk=init_max_chunk)
|
||||
if gen_parts is None or 'ss' in gen_parts:
|
||||
self.sum2sum_gen = MultiheadBlockSelectionGeneration((self.single_head_scheme.head_sum2sum,),
|
||||
init_max_chunk=init_max_chunk)
|
||||
|
||||
def gen_ss_mask(
|
||||
self,
|
||||
num_complete_chunks: torch.Tensor, # (bsz,)
|
||||
num_summary: int,
|
||||
bsz, sum_len,
|
||||
device,
|
||||
):
|
||||
if sum_len <= 0 or num_summary <= 0:
|
||||
return None
|
||||
max_complete_chunk = int(num_complete_chunks.max())
|
||||
temp_sum_len = max_complete_chunk * num_summary
|
||||
mask = self.sum2sum_gen.get_block_selection_masks(max_complete_chunk).to(device) \
|
||||
# (max_complete_chunk, max_complete_chunk, 1) # True for selected
|
||||
mask = select_self_(mask, max_complete_chunk, self.single_head_scheme.head_sum2sum_self) # Add Self
|
||||
mask = ~mask # False for selected
|
||||
mask = mask.permute(2, 0, 1) # (num_heads, max_comp_chunk, max_comp_chunk)
|
||||
mask = mask[None].repeat(bsz, 1, 1, 1).contiguous() # (bsz, num_heads, max_comp_chunk, max_comp_chunk)
|
||||
for idx, sample_comp_chunk in enumerate(num_complete_chunks):
|
||||
mask[idx, :, sample_comp_chunk:] = True
|
||||
mask[idx, :, :, sample_comp_chunk:] = True
|
||||
mask = mask[:, :, :, :, None, None].expand(-1, -1, -1, -1, num_summary, num_summary)
|
||||
mask = mask.permute(0, 1, 2, 4, 3, 5).reshape(bsz, self.num_fake_heads, temp_sum_len, temp_sum_len) \
|
||||
# (bsz, num_heads, temp_sum_len, temp_sum_len)
|
||||
if sum_len != temp_sum_len:
|
||||
temp_mask = torch.ones(bsz, self.num_fake_heads, sum_len, sum_len, dtype=torch.bool, device=device)
|
||||
temp_mask[:, :, :temp_sum_len, :temp_sum_len] = mask
|
||||
mask = temp_mask
|
||||
return mask
|
||||
|
||||
def gen_sr_mask(
|
||||
self,
|
||||
sum_chunk_ranges, reg_chunk_ranges, num_complete_chunks,
|
||||
sum_len, reg_len,
|
||||
num_pref,
|
||||
sum2pref_mode,
|
||||
bsz,
|
||||
device
|
||||
):
|
||||
if sum_len <= 0:
|
||||
return None
|
||||
mask = torch.ones(bsz, self.num_fake_heads, sum_len, reg_len, dtype=torch.bool, device=device)
|
||||
for sample_idx, sample_comp_chunk in enumerate(num_complete_chunks):
|
||||
if sample_comp_chunk <= 0:
|
||||
continue
|
||||
sample_selections = self.sum2con_gen.get_block_selection_masks(sample_comp_chunk) \
|
||||
# (sample_comp_chunk, sample_comp_chunk, num_heads)
|
||||
sample_selections = select_self_(sample_selections, sample_comp_chunk,
|
||||
self.single_head_scheme.head_sum2con_self)
|
||||
fill_mask_with_selection_(self.sum2con_gen,
|
||||
sum_chunk_ranges[sample_idx], reg_chunk_ranges[sample_idx],
|
||||
sample_selections,
|
||||
False, mask[sample_idx])
|
||||
if sum2pref_mode == 'none':
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError(sum2pref_mode)
|
||||
return mask
|
||||
|
||||
def gen_rs_mask(
|
||||
self,
|
||||
reg_chunk_ranges, sum_chunk_ranges, num_complete_chunks, num_chunks,
|
||||
sum_len, reg_len,
|
||||
num_pref,
|
||||
pref2sum_mode,
|
||||
bsz, device
|
||||
):
|
||||
if sum_len <= 0:
|
||||
return None
|
||||
mask = torch.ones(bsz, self.num_fake_heads, reg_len, sum_len, dtype=torch.bool, device=device)
|
||||
for sample_idx, (sample_comp_chunk, sample_chunk) in enumerate(zip(num_complete_chunks, num_chunks)):
|
||||
if sample_chunk <= 0:
|
||||
continue
|
||||
sample_selections = self.con2sum_gen.get_block_selection_masks(sample_chunk)
|
||||
sample_selections = select_self_(sample_selections, sample_chunk, self.single_head_scheme.head_con2sum_self)
|
||||
sample_selections = sample_selections[:, :sample_comp_chunk]
|
||||
fill_mask_with_selection_(self.con2sum_gen,
|
||||
reg_chunk_ranges[sample_idx], sum_chunk_ranges[sample_idx],
|
||||
sample_selections,
|
||||
False, mask[sample_idx])
|
||||
if pref2sum_mode == 'none':
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError(pref2sum_mode)
|
||||
return mask
|
||||
|
||||
def gen_rr_mask(
|
||||
self,
|
||||
reg_chunk_ranges, num_chunks,
|
||||
reg_len,
|
||||
num_pref,
|
||||
pref2pref_mode, pref2con_mode, con2pref_mode,
|
||||
bsz, device
|
||||
):
|
||||
mask = torch.ones(bsz, self.num_fake_heads, reg_len, reg_len, dtype=torch.bool, device=device)
|
||||
for sample_idx, sample_chunk in enumerate(num_chunks):
|
||||
if sample_chunk <= 0:
|
||||
continue
|
||||
sample_selections = self.con2con_gen.get_block_selection_masks(sample_chunk)
|
||||
sample_selections = select_self_(
|
||||
sample_selections, sample_chunk,
|
||||
'full' if self.single_head_scheme.head_con2con_self == 'lt'
|
||||
else self.single_head_scheme.head_con2con_self
|
||||
)
|
||||
fill_mask_with_selection_(self.con2con_gen,
|
||||
reg_chunk_ranges[sample_idx], reg_chunk_ranges[sample_idx],
|
||||
sample_selections,
|
||||
False, mask[sample_idx])
|
||||
sample_num_pref = num_pref[sample_idx]
|
||||
if self.single_head_scheme.head_con2con_causal:
|
||||
up_triangle = torch.ones(reg_len - sample_num_pref, reg_len - sample_num_pref,
|
||||
dtype=torch.bool, device=device)
|
||||
up_triangle.triu_(1) # (con_len, con_len)
|
||||
mask[sample_idx, :, sample_num_pref:, sample_num_pref:].masked_fill_(up_triangle[None], True)
|
||||
if pref2pref_mode == 'none':
|
||||
pass
|
||||
elif pref2pref_mode == 'full':
|
||||
for sample_idx, sample_num_pref in enumerate(num_pref):
|
||||
mask[sample_idx, :, :sample_num_pref, :sample_num_pref] = False
|
||||
elif pref2pref_mode == 'lt':
|
||||
for sample_idx, sample_num_pref in enumerate(num_pref):
|
||||
if sample_num_pref == 1:
|
||||
mask[sample_idx, :, :sample_num_pref, :sample_num_pref] = False
|
||||
continue
|
||||
mask[sample_idx, :, :sample_num_pref, :sample_num_pref].triu_(1)
|
||||
else:
|
||||
raise NotImplementedError(pref2pref_mode)
|
||||
|
||||
if pref2con_mode == 'none':
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError(pref2con_mode)
|
||||
|
||||
if con2pref_mode == 'none':
|
||||
pass
|
||||
elif con2pref_mode == 'full':
|
||||
for sample_idx, sample_num_pref in enumerate(num_pref):
|
||||
sample_num_chunks = num_chunks[sample_idx]
|
||||
if sample_num_chunks == 0:
|
||||
continue
|
||||
sample_reg_len = reg_chunk_ranges[sample_idx, sample_num_chunks - 1, -1]
|
||||
mask[sample_idx, :, sample_num_pref: sample_reg_len, :sample_num_pref] = False
|
||||
else:
|
||||
raise NotImplementedError(con2pref_mode)
|
||||
|
||||
return mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
reg_chunk_ranges: torch.Tensor, # (bsz, max_chunk, 2)
|
||||
num_chunks: torch.Tensor, # (bsz,)
|
||||
num_complete_chunks: torch.Tensor, # (bsz,)
|
||||
num_summary: int,
|
||||
num_pref: torch.Tensor, # (bsz,)
|
||||
sum_len: int,
|
||||
reg_len: int,
|
||||
):
|
||||
# print('===attn_mask===')
|
||||
# print(reg_chunk_ranges)
|
||||
# print(num_chunks)
|
||||
# print(num_complete_chunks)
|
||||
# print(num_summary)
|
||||
# print(num_pref)
|
||||
# print(sum_len, reg_len)
|
||||
# print('=====')
|
||||
# === Preliminaries ===
|
||||
device = reg_chunk_ranges.device
|
||||
bsz = num_chunks.shape[0]
|
||||
# all_seq_len = temp_sum_len + reg_len
|
||||
max_complete_chunk = int(num_complete_chunks.max())
|
||||
|
||||
if num_summary <= 0 or sum_len <= 0:
|
||||
sum_chunk_ranges = None
|
||||
else:
|
||||
sum_chunk_ranges = construct_sum_chunk_ranges(max_complete_chunk, num_summary, device=device) \
|
||||
# (max_comp_chunk, 2)
|
||||
sum_chunk_ranges = sum_chunk_ranges[None].repeat(bsz, 1, 1).contiguous() # (bsz, max_comp_chunk, 2)
|
||||
|
||||
# === Generate Masks for Parts (True is for the masked out)
|
||||
ss_mask, sr_mask, rs_mask, rr_mask = None, None, None, None
|
||||
if self.gen_parts is None or 'ss' in self.gen_parts:
|
||||
ss_mask = self.gen_ss_mask(num_complete_chunks, num_summary, bsz, sum_len, device) \
|
||||
# (bsz, num_heads, temp_sum_len, temp_sum_len)
|
||||
assert ss_mask is None or ss_mask.dtype == torch.bool
|
||||
if self.gen_parts is None or 'sr' in self.gen_parts:
|
||||
sr_mask = self.gen_sr_mask(sum_chunk_ranges, reg_chunk_ranges, num_complete_chunks, sum_len, reg_len,
|
||||
num_pref, self.sum2pref_mode, bsz, device)
|
||||
assert sr_mask is None or sr_mask.dtype == torch.bool
|
||||
if self.gen_parts is None or 'rs' in self.gen_parts:
|
||||
rs_mask = self.gen_rs_mask(reg_chunk_ranges, sum_chunk_ranges, num_complete_chunks, num_chunks,
|
||||
sum_len, reg_len, num_pref, self.pref2sum_mode, bsz, device)
|
||||
assert rs_mask is None or rs_mask.dtype == torch.bool
|
||||
if self.gen_parts is None or 'rr' in self.gen_parts:
|
||||
rr_mask = self.gen_rr_mask(reg_chunk_ranges, num_chunks, reg_len,
|
||||
num_pref, self.pref2pref_mode, self.pref2con_mode, self.con2pref_mode,
|
||||
bsz, device)
|
||||
assert rr_mask.dtype == torch.bool
|
||||
|
||||
if not self.single_head and self.num_heads > 1:
|
||||
if ss_mask is not None:
|
||||
ss_mask = ss_mask.expand(-1, self.real_num_heads, -1, -1)
|
||||
if sr_mask is not None:
|
||||
sr_mask = sr_mask.expand(-1, self.real_num_heads, -1, -1)
|
||||
if rs_mask is not None:
|
||||
rs_mask = rs_mask.expand(-1, self.real_num_heads, -1, -1)
|
||||
if rr_mask is not None:
|
||||
rr_mask = rr_mask.expand(-1, self.real_num_heads, -1, -1)
|
||||
|
||||
attn_mask = {
|
||||
'ss': ss_mask, # (bsz, num_heads, temp_sum_len, temp_sum_len)
|
||||
'sr': sr_mask, # (bsz, num_heads, temp_sum_len, reg_len)
|
||||
'rs': rs_mask, # (bsz, num_heads, reg_len, temp_sum_len)
|
||||
'rr': rr_mask, # (bsz, num_heads, reg_len, reg_len)
|
||||
}
|
||||
|
||||
for key in ('ss', 'sr', 'rs', 'rr'):
|
||||
if self.gen_parts is not None and key not in self.gen_parts:
|
||||
attn_mask.pop(key)
|
||||
|
||||
return attn_mask
|
|
@ -0,0 +1,165 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..tools.singleton_decorator import singleton
|
||||
|
||||
|
||||
def generate_index_matrix(num_chunks, device='cpu'):
|
||||
chunk_arange = torch.arange(num_chunks, device=device)
|
||||
rows = chunk_arange.view(num_chunks, 1).expand(num_chunks, num_chunks)
|
||||
cols = chunk_arange.view(1, num_chunks).expand(num_chunks, num_chunks)
|
||||
index_matrix = torch.stack((rows, cols), dim=-1) # (num_chunks, num_chunks, 2)
|
||||
return index_matrix
|
||||
|
||||
|
||||
def generate_mask_template(num_chunks, device='cpu'):
|
||||
return torch.ones(num_chunks, num_chunks, dtype=torch.bool, device=device)
|
||||
|
||||
|
||||
def generate_diagonal_indices(num_chunks, device='cpu'):
|
||||
diagonal = torch.arange(num_chunks, device=device)
|
||||
diagonal = torch.stack((diagonal, diagonal), dim=-1)
|
||||
return diagonal
|
||||
|
||||
|
||||
def generate_mask(head_range_command, num_chunks, mask_template=None, device='cpu'):
|
||||
"""
|
||||
Generate a selection mask of size (num_chunks, num_chunks) according to range_command. Selected ones are True.
|
||||
:param head_range_command:
|
||||
:param num_chunks:
|
||||
:param mask_template:
|
||||
:param device:
|
||||
:return:
|
||||
"""
|
||||
if head_range_command is None:
|
||||
return None
|
||||
|
||||
if mask_template is None:
|
||||
mask_template = generate_mask_template(num_chunks, device=device)
|
||||
else:
|
||||
w, h = mask_template.shape
|
||||
assert w == h == num_chunks
|
||||
mask_template = mask_template.to(device)
|
||||
|
||||
mask = torch.zeros_like(mask_template)
|
||||
for command_item in head_range_command:
|
||||
if isinstance(command_item, int):
|
||||
command_item = (command_item, command_item + 1)
|
||||
|
||||
begin, end = command_item
|
||||
if begin is None:
|
||||
begin = -num_chunks + 1
|
||||
if end is None:
|
||||
end = num_chunks
|
||||
|
||||
if begin >= end:
|
||||
continue
|
||||
|
||||
left = torch.triu(mask_template, begin)
|
||||
right = torch.tril(mask_template, end - 1)
|
||||
mask = mask | (left & right)
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
@singleton
|
||||
class BlockSelectionTemplateManager(nn.Module):
|
||||
"""
|
||||
Generate mask according to some specific range c, and provide related management. Singleton design mode.
|
||||
"""
|
||||
|
||||
INIT_MAX_CHUNK = 1024
|
||||
_RESERVED_ATTR_NAMES = ('max_chunks', 'index_matrix', 'mask_template', 'diagonal_indices', 'mask')
|
||||
|
||||
def __init__(self, init_max_chunk=None):
|
||||
super().__init__()
|
||||
if init_max_chunk is None:
|
||||
init_max_chunk = self.__class__.INIT_MAX_CHUNK
|
||||
self._max_chunk = init_max_chunk
|
||||
|
||||
index_matrix = generate_index_matrix(self._max_chunk)
|
||||
self.register_buffer('_index_matrix', index_matrix, persistent=False)
|
||||
|
||||
mask_template = generate_mask_template(self._max_chunk)
|
||||
self.register_buffer('_mask_template', mask_template, persistent=False)
|
||||
|
||||
diagonal_indices = generate_diagonal_indices(self._max_chunk)
|
||||
self.register_buffer('_diagonal_indices', diagonal_indices, persistent=False)
|
||||
|
||||
self.__range_commands_and_names = []
|
||||
|
||||
def __update_index_matrix(self, num_chunks):
|
||||
new_index_matrix = generate_index_matrix(num_chunks, device=self._index_matrix.device)
|
||||
self._index_matrix = new_index_matrix
|
||||
|
||||
def __update_mask_template(self, num_chunks):
|
||||
new_mask_template = generate_mask_template(num_chunks, device=self._mask_template.device)
|
||||
self._mask_template = new_mask_template
|
||||
|
||||
def __update_diagonal_indices(self, num_chunks):
|
||||
new_diagonal_indices = generate_diagonal_indices(
|
||||
num_chunks, device=self._diagonal_indices.device
|
||||
)
|
||||
self._diagonal_indices = new_diagonal_indices
|
||||
|
||||
def __update_masks(self, num_chunks):
|
||||
w, h = self._mask_template.shape
|
||||
assert w == h == num_chunks, "Please update mask_template first."
|
||||
for range_command, mask_name in self.__range_commands_and_names:
|
||||
setattr(self, mask_name,
|
||||
generate_mask(range_command, num_chunks,
|
||||
mask_template=self._mask_template,
|
||||
device=self._mask_template.device))
|
||||
|
||||
def update(self, num_chunks):
|
||||
if num_chunks <= self._max_chunk:
|
||||
return
|
||||
self.__update_index_matrix(num_chunks)
|
||||
self.__update_mask_template(num_chunks)
|
||||
self.__update_diagonal_indices(num_chunks)
|
||||
self.__update_masks(num_chunks)
|
||||
self._max_chunk = num_chunks
|
||||
|
||||
@property
|
||||
def max_chunk(self):
|
||||
return self._max_chunk
|
||||
|
||||
@max_chunk.setter
|
||||
def max_chunk(self, num_chunks):
|
||||
self.update(num_chunks)
|
||||
|
||||
@property
|
||||
def index_matrix(self):
|
||||
return self._index_matrix
|
||||
|
||||
@property
|
||||
def mask_template(self):
|
||||
return self._mask_template
|
||||
|
||||
@property
|
||||
def diagonal_indices(self):
|
||||
return self._diagonal_indices
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return getattr(self, '_mask_template').device
|
||||
|
||||
def register_range_command(self, range_command):
|
||||
name = str(range_command)
|
||||
if name in self.__class__._RESERVED_ATTR_NAMES:
|
||||
raise ValueError
|
||||
if hasattr(self, name):
|
||||
return
|
||||
mask = generate_mask(range_command, self.max_chunk, mask_template=self.mask_template)
|
||||
self.register_buffer(name, mask, persistent=False)
|
||||
self.__range_commands_and_names.append((range_command, name))
|
||||
|
||||
def mask(self, range_command):
|
||||
name = str(range_command)
|
||||
assert (range_command, name) in self.__range_commands_and_names
|
||||
return getattr(self, name)
|
||||
|
||||
def get_diagonal_indices(self, num_chunks):
|
||||
assert num_chunks > 0
|
||||
self.update(num_chunks)
|
||||
return self._diagonal_indices[:num_chunks]
|
|
@ -0,0 +1,101 @@
|
|||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .block_selection_template_manager import BlockSelectionTemplateManager
|
||||
|
||||
|
||||
class MultiheadBlockSelectionGeneration(nn.Module):
|
||||
"""
|
||||
Generate a set of block masks according to a set of range commands, typically for heads in a layer.
|
||||
"""
|
||||
def __init__(self, multihead_range_commands, init_max_chunk=None):
|
||||
"""
|
||||
|
||||
:param multihead_range_commands: tuple consisting of head_range_command
|
||||
:param init_max_chunk: int
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self._manager = BlockSelectionTemplateManager(init_max_chunk)
|
||||
self.range_commands = multihead_range_commands
|
||||
|
||||
for range_command in self.range_commands:
|
||||
self._manager.register_range_command(range_command)
|
||||
|
||||
@lru_cache(maxsize=128, typed=False)
|
||||
def get_block_selection_masks(self, num_chunks: int):
|
||||
"""
|
||||
Get selection masks of size (num_chunks, num_chunks, num_commands) corresponding to all the range_commands.
|
||||
:param num_chunks:
|
||||
:return:
|
||||
"""
|
||||
assert num_chunks > 0
|
||||
self._manager.update(num_chunks)
|
||||
|
||||
commands_masks = []
|
||||
false_chunk = torch.zeros(num_chunks, num_chunks, dtype=torch.bool, device=self._manager.device)
|
||||
considered_range_commands = set()
|
||||
for idx, range_command in enumerate(self.range_commands):
|
||||
mask = self._manager.mask(range_command)
|
||||
if mask is None:
|
||||
assert self.range_commands[idx] is None
|
||||
commands_masks.append(false_chunk)
|
||||
else:
|
||||
mask = mask[:num_chunks, :num_chunks]
|
||||
commands_masks.append(mask)
|
||||
considered_range_commands.add(range_command)
|
||||
commands_masks = torch.stack(commands_masks, dim=-1) # (num_chunks, num_chunks, num_commands)
|
||||
|
||||
return commands_masks
|
||||
|
||||
@staticmethod
|
||||
def get_overall_mask_from_commands_masks(commands_masks):
|
||||
return commands_masks.sum(dim=-1).gt(0)
|
||||
|
||||
def get_block_indices_and_masks_from_selection_masks(self, commands_masks, overall_mask=None):
|
||||
"""
|
||||
|
||||
:param commands_masks:
|
||||
:param overall_mask:
|
||||
:return:
|
||||
"""
|
||||
if overall_mask is None:
|
||||
overall_mask = self.get_overall_mask_from_commands_masks(commands_masks)
|
||||
else:
|
||||
assert overall_mask.ndim == 2
|
||||
num_chunks_1, num_chunks_2 = overall_mask.shape[:2]
|
||||
index_matrix = self._manager.index_matrix[:num_chunks_1, :num_chunks_2]
|
||||
indices = index_matrix.masked_select(overall_mask.unsqueeze(-1)).view(-1, 2)
|
||||
commands_masks = commands_masks.masked_select(overall_mask.unsqueeze(-1)).view(-1, len(self.range_commands))
|
||||
return indices, commands_masks
|
||||
|
||||
@lru_cache(maxsize=128, typed=False)
|
||||
def get_block_indices_and_masks(self, num_chunks: int):
|
||||
"""
|
||||
|
||||
:param num_chunks: int, number of chunks in one sample
|
||||
:return: block indices (query, key) to compute (num, 2);
|
||||
mask indicating whether to compute in each c (head)
|
||||
"""
|
||||
commands_masks = self.get_block_selection_masks(num_chunks)
|
||||
overall_mask = self.get_overall_mask_from_commands_masks(commands_masks)
|
||||
return self.get_block_indices_and_masks_from_selection_masks(commands_masks, overall_mask=overall_mask)
|
||||
|
||||
def get_diagonal_indices(self, num_chunks):
|
||||
return self._manager.get_diagonal_indices(num_chunks)
|
||||
|
||||
|
||||
def get_block_ranges(row_ranges, col_ranges, block_indices):
|
||||
"""
|
||||
|
||||
:param row_ranges: begins and endings for chunks on row dimension. (num_tgt_chunks, 2)
|
||||
:param col_ranges: begins and endings for chunks on col dimension. (num_src_chunks, 2)
|
||||
:param block_indices: row and col indices of selected blocks. (num_blocks, 2)
|
||||
:return:
|
||||
"""
|
||||
tgt_ranges = row_ranges[block_indices[:, 0]] # (num_blocks, 2)
|
||||
src_ranges = col_ranges[block_indices[:, 1]] # (num_blocks, 2)
|
||||
block_ranges = torch.cat((tgt_ranges, src_ranges), dim=1) # (num_blocks, 4)
|
||||
return block_ranges
|
|
@ -0,0 +1,2 @@
|
|||
from .operator_manager import BlocksparseMatMulManager, BlocksparseSoftmaxManager, BlocksparseMatMul, BlocksparseSoftmax
|
||||
from .utils import layout_full_zero_check
|
|
@ -0,0 +1,87 @@
|
|||
import torch
|
||||
|
||||
from .optimized_matmul import matmul as BlocksparseMatMul
|
||||
from .optimized_softmax import softmax as BlocksparseSoftmax
|
||||
from ..tools.singleton_decorator import singleton
|
||||
|
||||
|
||||
def layout_full_zero_check(input_layout):
|
||||
row_check = input_layout.sum(dim=2).eq(0) # (head, row_len // block)
|
||||
col_check = input_layout.sum(dim=1).eq(0) # (head, col_len // block)
|
||||
row_answer = bool(row_check.any())
|
||||
col_answer = bool(col_check.any())
|
||||
return row_answer or col_answer, row_answer, col_answer, row_check, col_check
|
||||
|
||||
|
||||
class CacheList(object):
|
||||
def __init__(self, cache_size):
|
||||
self._cache = list()
|
||||
self._num = 0
|
||||
self.cache_size = cache_size
|
||||
|
||||
def __len__(self):
|
||||
return self._num
|
||||
|
||||
def __iter__(self):
|
||||
for idx in range(self._num - 1, -1, -1):
|
||||
yield self._cache[idx]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self._cache[idx]
|
||||
|
||||
def put(self, x):
|
||||
if self._num >= self.cache_size:
|
||||
self._cache = self._cache[1:]
|
||||
self._num -= 1
|
||||
self._cache.append(x)
|
||||
self._num += 1
|
||||
|
||||
def clear(self):
|
||||
self._cache.clear()
|
||||
self._num = 0
|
||||
|
||||
|
||||
class OperatorManager(object):
|
||||
def __init__(self, operator_cls, cache_size=0):
|
||||
self.operator_cls = operator_cls
|
||||
self.cache_size = cache_size if cache_size > 0 else 0
|
||||
self.cache_dict = None if self.cache_size == 0 else {}
|
||||
|
||||
def clear_cache(self):
|
||||
self.cache_dict.clear()
|
||||
|
||||
def get_operator(self, layout, block, **kwargs):
|
||||
assert layout.dtype == torch.bool
|
||||
has_empty, row_answer, col_answer, _, _ = layout_full_zero_check(layout)
|
||||
assert not has_empty, "layout has empty %s, which may lead to error computation. Please check and fix." % (
|
||||
'row' if row_answer else 'column'
|
||||
)
|
||||
layout_cpu = None
|
||||
all_args = None
|
||||
if self.cache_dict is not None:
|
||||
layout_cpu = layout.cpu()
|
||||
all_args = (('block', block),) + tuple(kwargs.items())
|
||||
all_args = tuple(sorted(all_args, key=lambda x: x[0]))
|
||||
if all_args in self.cache_dict:
|
||||
cache_list = self.cache_dict[all_args]
|
||||
for cached_layout, operator in cache_list:
|
||||
if torch.equal(layout_cpu, cached_layout):
|
||||
return operator
|
||||
operator = self.operator_cls(layout=layout.long(), block=block, **kwargs)
|
||||
if self.cache_dict is not None:
|
||||
if all_args not in self.cache_dict:
|
||||
self.cache_dict[all_args] = CacheList(self.cache_size)
|
||||
self.cache_dict[all_args].put((layout_cpu, operator))
|
||||
return operator
|
||||
|
||||
|
||||
@singleton
|
||||
class BlocksparseMatMulManager(OperatorManager):
|
||||
def __init__(self, cache_size=0):
|
||||
super().__init__(BlocksparseMatMul, cache_size=cache_size)
|
||||
|
||||
|
||||
@singleton
|
||||
class BlocksparseSoftmaxManager(OperatorManager):
|
||||
def __init__(self, cache_size=0):
|
||||
super().__init__(BlocksparseSoftmax, cache_size=cache_size)
|
|
@ -0,0 +1,535 @@
|
|||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
# ********************************************************
|
||||
# --------------------------------------------------------
|
||||
# Sparse = Dense x Dense (SDD)
|
||||
# This operation uses super-blocking to make sure that
|
||||
# it's done efficiently when small blocks can be grouped
|
||||
# together
|
||||
# --------------------------------------------------------
|
||||
# ********************************************************
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0,
|
||||
})
|
||||
@triton.jit
|
||||
def _sdd_kernel(
|
||||
A, B, C,
|
||||
stride_za, stride_ha, stride_ma, stride_ak,
|
||||
stride_zb, stride_hb, stride_bk, stride_nb,
|
||||
stride_zc, stride_hc, stride_mc, stride_nc,
|
||||
K, grid_offset, lut,
|
||||
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
|
||||
BLOCK: tl.constexpr, EVEN_K: tl.constexpr
|
||||
):
|
||||
# ------------ #
|
||||
# - Prologue - #
|
||||
# ------------ #
|
||||
block_id = tl.program_id(1) + grid_offset
|
||||
lut += block_id * 3
|
||||
# offsets
|
||||
off_z = tl.program_id(2) # batch
|
||||
off_h = tl.load(lut + 0) # head
|
||||
|
||||
# initialize pointers to A
|
||||
start_am = tl.load(lut + 1)
|
||||
offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK)
|
||||
offs_ak = tl.arange(0, TILE_K)
|
||||
a_ptrs = A \
|
||||
+ off_z * stride_za \
|
||||
+ off_h * stride_ha \
|
||||
+ offs_am[:, None] * stride_ma \
|
||||
+ offs_ak[None, :] * stride_ak
|
||||
# initialize pointers to B
|
||||
start_bn = tl.load(lut + 2)
|
||||
offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK)
|
||||
offs_bk = tl.arange(0, TILE_K)
|
||||
b_ptrs = B \
|
||||
+ off_z * stride_zb \
|
||||
+ off_h * stride_hb \
|
||||
+ offs_bn[None, :] * stride_nb \
|
||||
+ offs_bk[:, None] * stride_bk
|
||||
# ---------------- #
|
||||
# Inner Loop #
|
||||
# ---------------- #
|
||||
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
|
||||
for k in range(K, 0, -TILE_K):
|
||||
if EVEN_K:
|
||||
a = tl.load(a_ptrs)
|
||||
b = tl.load(b_ptrs)
|
||||
else:
|
||||
a = tl.load(a_ptrs, mask=offs_ak[None, :] < k, other=0.)
|
||||
b = tl.load(b_ptrs, mask=offs_bk[:, None] < k, other=0.)
|
||||
acc += tl.dot(a, b)
|
||||
a_ptrs += TILE_K * stride_ak
|
||||
b_ptrs += TILE_K * stride_bk
|
||||
c = acc.to(C.dtype.element_ty)
|
||||
# ---------------- #
|
||||
# Epilogue #
|
||||
# ---------------- #
|
||||
offs_cm = tl.arange(0, TILE_M) % BLOCK
|
||||
offs_cn = tl.arange(0, TILE_N) % BLOCK
|
||||
pc = C \
|
||||
+ off_z * stride_zc \
|
||||
+ block_id * stride_hc \
|
||||
+ offs_cm[:, None] * stride_mc \
|
||||
+ offs_cn[None, :] * stride_nc
|
||||
tl.store(pc, c, mask=True)
|
||||
|
||||
|
||||
def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=None):
|
||||
if a.stride(2) != 1 and a.stride(3) != 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(2) != 1 and b.stride(3) != 1:
|
||||
b = b.contiguous()
|
||||
# (A * B)^T = B^T * A^T
|
||||
if trans_c:
|
||||
a, b = b, a
|
||||
trans_a, trans_b = not trans_b, not trans_a
|
||||
# shape constraints
|
||||
a_dim = -2 if trans_a else -1
|
||||
b_dim = -1 if trans_b else -2
|
||||
Ka, Kb = a.shape[a_dim], b.shape[b_dim]
|
||||
if Ka != Kb:
|
||||
raise ValueError(f"Inner dimension mismatch (A: {Ka} vs B: {Kb})")
|
||||
# allocate output
|
||||
if out is None:
|
||||
c = torch.empty((a.shape[0], lut.shape[0], block, block), dtype=a.dtype, device=a.device)
|
||||
else:
|
||||
assert out.shape == (a.shape[0], lut.shape[0], block, block)
|
||||
c = out
|
||||
grid = [1, c.shape[1], c.shape[0]]
|
||||
_sdd_kernel[grid](
|
||||
a, b, c,
|
||||
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
|
||||
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
|
||||
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
|
||||
Ka, 0, lut,
|
||||
TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4,
|
||||
num_warps=4,
|
||||
)
|
||||
return c
|
||||
|
||||
|
||||
def sdd_lut(layout, block, device):
|
||||
lut = layout.nonzero(as_tuple=False).to(device).int()
|
||||
lut = lut.contiguous()
|
||||
return lut, None
|
||||
|
||||
# -----------------------------
|
||||
# Dense = Sparse x Dense (DSD)
|
||||
# This operation uses a look-up table that contains pre-computed pointer increments
|
||||
# in order to minimize computations in the inner loop of the matmul kernel.
|
||||
# -----------------------------
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _dsd_kernel(
|
||||
A, B, C,
|
||||
stride_az, stride_ha, stride_am, stride_ak,
|
||||
stride_zb, stride_hb, stride_bk, stride_bn,
|
||||
stride_zc, stride_hc, stride_cm, stride_cn,
|
||||
DS0, DS1, lut,
|
||||
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr
|
||||
):
|
||||
# ------------ #
|
||||
# - Prologue - #
|
||||
# ------------ #
|
||||
pid_m = tl.program_id(0)
|
||||
pid_n = tl.program_id(1)
|
||||
num_pid_m = tl.num_programs(0)
|
||||
num_pid_n = tl.num_programs(1)
|
||||
pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M)
|
||||
pidz = tl.program_id(2)
|
||||
header = lut + pid_n * 4
|
||||
offset = tl.load(header + 0)
|
||||
K = tl.load(header + 1)
|
||||
column = tl.load(header + 2)
|
||||
off_h = tl.load(header + 3)
|
||||
pinc = lut + offset
|
||||
# initialize pointers to A (sparse)
|
||||
block_id = tl.load(pinc + 1)
|
||||
block_id = tl.multiple_of(block_id, 8) # compiler hint
|
||||
offs_am = tl.arange(0, TILE_M)
|
||||
offs_ak = tl.arange(0, TILE_K)
|
||||
pa = A + pidz * stride_az \
|
||||
+ block_id * stride_ha \
|
||||
+ offs_am[:, None] * stride_am \
|
||||
+ offs_ak[None, :] * stride_ak
|
||||
# initialize pointers to B (dense)
|
||||
offs_bn = pid_m * TILE_N + tl.arange(0, TILE_N)
|
||||
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N)
|
||||
start_bk = tl.load(pinc)
|
||||
start_bk = tl.multiple_of(start_bk, 8) # compiler hint
|
||||
offs_bk = start_bk + tl.arange(0, TILE_K)
|
||||
pb = B + pidz * stride_zb \
|
||||
+ off_h * stride_hb \
|
||||
+ offs_bn[None, :] * stride_bn \
|
||||
+ offs_bk[:, None] * stride_bk
|
||||
# ---------------- #
|
||||
# Inner Loop #
|
||||
# ---------------- #
|
||||
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
|
||||
pinc += 2
|
||||
inc_a = tl.load(pinc + 1)
|
||||
inc_a = tl.multiple_of(inc_a, 8)
|
||||
inc_b = tl.load(pinc)
|
||||
inc_b = tl.multiple_of(inc_b, 8)
|
||||
for k in range(K, 0, -TILE_K):
|
||||
a = tl.load(pa, mask=True)
|
||||
b = tl.load(pb, mask=offs_bn[None, :] < DS0)
|
||||
acc += tl.dot(a, b)
|
||||
pa += inc_a
|
||||
pb += inc_b * stride_bk
|
||||
pinc += 2
|
||||
inc_a = tl.load(pinc + 1)
|
||||
inc_a = tl.multiple_of(inc_a, 8)
|
||||
inc_b = tl.load(pinc)
|
||||
inc_b = tl.multiple_of(inc_b, 8)
|
||||
c = acc.to(C.dtype.element_ty)
|
||||
# initialize pointers to C
|
||||
offs_cm = column * TILE_M + tl.arange(0, TILE_M)
|
||||
offs_cn = pid_m * TILE_N + tl.arange(0, TILE_N)
|
||||
pc = C \
|
||||
+ off_h * stride_hc \
|
||||
+ pidz * stride_zc \
|
||||
+ offs_cm[:, None] * stride_cm \
|
||||
+ offs_cn[None, :] * stride_cn
|
||||
tl.store(pc, c, mask=offs_cn[None, :] < DS0)
|
||||
|
||||
|
||||
def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
|
||||
if a.stride(2) != 1 and a.stride(3) != 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(2) != 1 and b.stride(3) != 1:
|
||||
b = b.contiguous()
|
||||
# shapes / dtypes
|
||||
AS1 = block * spdims[2 if trans_a else 1]
|
||||
BS0 = b.size(0)
|
||||
BS1 = b.size(1)
|
||||
BS3 = b.size(2 if trans_b else 3)
|
||||
dtype = a.dtype
|
||||
# allocate output
|
||||
CS0 = BS0
|
||||
CS1 = BS1
|
||||
CS2 = BS3 if trans_c else AS1
|
||||
CS3 = AS1 if trans_c else BS3
|
||||
if out is None:
|
||||
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
|
||||
else:
|
||||
assert out.shape == (CS0, CS1, CS2, CS3)
|
||||
c = out
|
||||
# meta-parameter heuristics
|
||||
TILE_N = 128
|
||||
# compute output
|
||||
grid = lambda meta: [triton.cdiv(BS3, meta['TILE_N']), width, BS0]
|
||||
_dsd_kernel[grid](
|
||||
a, b, c,
|
||||
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
|
||||
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
|
||||
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
|
||||
BS3, AS1, lut,
|
||||
TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4,
|
||||
num_warps=4, GROUP_SIZE_M=4,
|
||||
)
|
||||
# exit()
|
||||
return c
|
||||
|
||||
|
||||
def dsd_lut(layout, block, step, trans, device):
|
||||
sizes = torch.sum(layout, 2 if trans else 1)
|
||||
head_id, col_id = sizes.nonzero(as_tuple=True)
|
||||
sizes = sizes.flatten()
|
||||
segments = sizes * step
|
||||
# pointer increments
|
||||
if trans:
|
||||
nnz = layout.nonzero(as_tuple=False)
|
||||
else:
|
||||
nnz = layout.transpose(1, 2).nonzero(as_tuple=False)
|
||||
num_blocks = nnz.size(0)
|
||||
offsets = torch.zeros_like(sizes)
|
||||
offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
|
||||
offsets = torch.min(offsets, (num_blocks - 1) * torch.ones_like(offsets))
|
||||
# -------------------------------
|
||||
# dense input pointer increments
|
||||
# -------------------------------
|
||||
# given a list of the indices for the first element of each non-zero block.
|
||||
# For example, for the indices
|
||||
# [32, 80, 128, 256, 288]
|
||||
# we would generate the increments
|
||||
# [32, 48, 48, 128, 32]
|
||||
# ^
|
||||
# index of first element
|
||||
# Note that the inner loop matmul kernel may have a fixed step size (e.g., TILE_K)
|
||||
# that is smaller than the block size, so we need to do a bit of extra work
|
||||
# to handle this case
|
||||
B_idx = nnz[:, 2] * block
|
||||
B_incs = B_idx.clone()
|
||||
B_incs[1:] -= B_idx[:-1]
|
||||
div = block // step
|
||||
B_incs = B_incs.view(-1, 1).repeat(1, div)
|
||||
B_incs[:, 1:] = step
|
||||
B_incs[:, 0] -= (div - 1) * step
|
||||
# first increment for each reduction is actually the offset
|
||||
B_incs[offsets[segments > 0], 0] = B_idx[offsets[segments > 0]]
|
||||
B_incs = B_incs.view(-1)
|
||||
# -------------------------------
|
||||
# sparse input pointer increments
|
||||
# -------------------------------
|
||||
# same as above, except that the increments are in the sparse memory layout
|
||||
if trans:
|
||||
A_idx = torch.arange(num_blocks, device=layout.device)
|
||||
else:
|
||||
A_idx = torch.tensor([], dtype=torch.int64, device=layout.device)
|
||||
current_offset = 0
|
||||
for z in range(layout.size(0)):
|
||||
layoutw = layout[z, :, :].clone().long()
|
||||
msum = layoutw.sum()
|
||||
layoutw[layoutw > 0] = 1 + torch.arange(msum, device=layout.device)
|
||||
A_idx = torch.cat((A_idx, current_offset + layoutw.T[layoutw.T > 0] - 1))
|
||||
current_offset += msum
|
||||
A_incs = A_idx * block * block
|
||||
A_incs[1:] -= A_idx[:-1] * block * block
|
||||
A_incs = A_incs.view(-1, 1).repeat(1, div)
|
||||
if trans:
|
||||
A_incs[:, 1:] = step
|
||||
A_incs[:, 0] -= (div - 1) * step
|
||||
else:
|
||||
A_incs[:, 1:] = step * block
|
||||
A_incs[:, 0] -= (div - 1) * step * block
|
||||
A_incs[offsets[segments > 0], 0] = A_idx[offsets[segments > 0]]
|
||||
A_incs = A_incs.view(-1)
|
||||
# create header
|
||||
width = col_id.size(0)
|
||||
offsets = offsets * 2 * div + 4 * width
|
||||
segments = segments * div
|
||||
try:
|
||||
header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous()
|
||||
except RuntimeError:
|
||||
print(offsets.shape, segments.shape, col_id.shape, head_id.shape)
|
||||
from .utils import layout_full_zero_check
|
||||
answer, row_answer, col_answer, row_check, col_check = layout_full_zero_check(layout)
|
||||
if answer:
|
||||
if row_answer:
|
||||
print('layout contains empty rows:', row_check.nonzero(as_tuple=False).squeeze())
|
||||
if col_answer:
|
||||
print('layout contains empty columns:', col_check.nonzero(as_tuple=False).squeeze())
|
||||
print(layout[0].long())
|
||||
raise
|
||||
# create increments
|
||||
incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous()
|
||||
incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype)))
|
||||
# create lut
|
||||
lut = torch.cat((header, incs))
|
||||
lut = lut.type(torch.int32).to(device)
|
||||
# create locks
|
||||
return lut, width
|
||||
|
||||
# -----------------------------
|
||||
# Dense = Dense x Sparse (DDS)
|
||||
# -----------------------------
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _dds_kernel(
|
||||
A, B, C,
|
||||
stride_za, stride_ha, stride_ma, stride_ka,
|
||||
stride_zb, stride_hb, stride_bk, stride_bn,
|
||||
stride_zc, stride_hc, stride_mc, stride_nc,
|
||||
DS0, DS1, lut,
|
||||
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr,
|
||||
):
|
||||
# ------------ #
|
||||
# - Prologue - #
|
||||
# ------------ #
|
||||
pid_m = tl.program_id(0)
|
||||
pid_n = tl.program_id(1)
|
||||
num_pid_m = tl.num_programs(0)
|
||||
num_pid_n = tl.num_programs(1)
|
||||
pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M)
|
||||
pid_z = tl.program_id(2)
|
||||
header = lut + pid_n * 4
|
||||
offset = tl.load(header + 0)
|
||||
AS1 = tl.load(header + 1)
|
||||
column = tl.load(header + 2)
|
||||
off_h = tl.load(header + 3)
|
||||
pinc = lut + offset
|
||||
# initialize pointers to A (dense)
|
||||
offs_am = pid_m * TILE_M + tl.arange(0, TILE_M)
|
||||
offs_am = tl.max_contiguous(tl.multiple_of(offs_am % DS0, TILE_M), TILE_M)
|
||||
start_ak = tl.load(pinc)
|
||||
start_ak = tl.multiple_of(start_ak, 8)
|
||||
offs_ak = start_ak + tl.arange(0, TILE_K)
|
||||
ptrs_a = A + pid_z * stride_za \
|
||||
+ off_h * stride_ha \
|
||||
+ offs_am[:, None] * stride_ma \
|
||||
+ offs_ak[None, :] * stride_ka
|
||||
# initialize pointers to B (sparse)
|
||||
block_id = tl.load(pinc + 1)
|
||||
block_id = tl.multiple_of(block_id, 8)
|
||||
offs_bn = tl.arange(0, TILE_N)
|
||||
offs_bk = tl.arange(0, TILE_K)
|
||||
ptrs_b = B + pid_z * stride_zb \
|
||||
+ block_id * stride_hb \
|
||||
+ offs_bn[None, :] * stride_bn \
|
||||
+ offs_bk[:, None] * stride_bk
|
||||
# ---------------- #
|
||||
# Inner Loop #
|
||||
# ---------------- #
|
||||
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
|
||||
for k in range(AS1, 0, -TILE_K):
|
||||
a = tl.load(ptrs_a, mask=offs_am[:, None] < DS0)
|
||||
b = tl.load(ptrs_b, mask=True)
|
||||
acc += tl.dot(a, b)
|
||||
pinc += 2
|
||||
inc_a = tl.load(pinc)
|
||||
inc_b = tl.load(pinc + 1)
|
||||
inc_a = tl.multiple_of(inc_a, 8)
|
||||
inc_b = tl.multiple_of(inc_b, 8)
|
||||
inc_a = inc_a * stride_ka
|
||||
ptrs_a += inc_a
|
||||
ptrs_b += inc_b
|
||||
# ---------------- #
|
||||
# Epilogue #
|
||||
# ---------------- #
|
||||
c = acc.to(C.dtype.element_ty)
|
||||
# initialize pointers to C (dense)
|
||||
offs_cm = pid_m * TILE_M + tl.arange(0, TILE_M)
|
||||
offs_cn = column * TILE_N + tl.arange(0, TILE_N)
|
||||
ptrs_c = C + off_h * stride_hc \
|
||||
+ pid_z * stride_zc \
|
||||
+ offs_cm[:, None] * stride_mc \
|
||||
+ offs_cn[None, :] * stride_nc
|
||||
# write back
|
||||
tl.store(ptrs_c, c, mask=offs_cm[:, None] < DS0)
|
||||
|
||||
|
||||
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
|
||||
if a.stride(2) != 1 and a.stride(3) != 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(2) != 1 and b.stride(3) != 1:
|
||||
b = b.contiguous()
|
||||
# shapes / dtypes
|
||||
AS0 = a.size(0)
|
||||
AS1 = a.size(1)
|
||||
AS2 = a.size(3 if trans_a else 2)
|
||||
BS2 = block * spdims[1 if trans_b else 2]
|
||||
dtype = a.dtype
|
||||
# output
|
||||
CS0 = AS0
|
||||
CS1 = AS1
|
||||
CS2 = BS2 if trans_c else AS2
|
||||
CS3 = AS2 if trans_c else BS2
|
||||
if out is None:
|
||||
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
|
||||
else:
|
||||
assert out.shape == (CS0, CS1, CS2, CS3)
|
||||
c = out
|
||||
TILE_M = {16: 256, 32: 256, 64: 128, 128: 128}[block]
|
||||
grid = lambda meta: [triton.cdiv(AS2, meta['TILE_M']), width, AS0]
|
||||
_dds_kernel[grid](
|
||||
a, b, c,
|
||||
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
|
||||
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
|
||||
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
|
||||
AS2, BS2, lut,
|
||||
TILE_M=TILE_M, TILE_N=block, TILE_K=min(block, 32), BLOCK=block, num_stages=4,
|
||||
num_warps=4, GROUP_SIZE_M=4,
|
||||
)
|
||||
return c
|
||||
|
||||
##############
|
||||
# MAIN API #
|
||||
##############
|
||||
|
||||
|
||||
class _matmul(torch.autograd.Function):
|
||||
|
||||
fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul}
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block,
|
||||
c_lut, c_width, da_lut, da_width, db_lut, db_width, out
|
||||
):
|
||||
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out)
|
||||
# save for backward
|
||||
ctx.save_for_backward(a, b)
|
||||
ctx.da_lut = da_lut
|
||||
ctx.da_width = da_width
|
||||
ctx.db_lut = db_lut
|
||||
ctx.db_width = db_width
|
||||
ctx.mode = mode
|
||||
ctx.spdims = spdims
|
||||
ctx.block = block
|
||||
ctx.trans_a = trans_a
|
||||
ctx.trans_b = trans_b
|
||||
ctx.trans_c = trans_c
|
||||
ctx.has_out = out is not None
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dc):
|
||||
# saved for backward
|
||||
a, b = ctx.saved_tensors
|
||||
da, db = None, None
|
||||
mode = ctx.mode
|
||||
# gradients w.r.t. a
|
||||
if ctx.needs_input_grad[0]:
|
||||
mode_da = mode[1] + mode[0] + mode[2]
|
||||
da = _matmul.fn[mode_da](
|
||||
dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut, ctx.da_width,
|
||||
)
|
||||
# gradients w.r.t. b
|
||||
if ctx.needs_input_grad[1]:
|
||||
mode_db = mode[2] + mode[1] + mode[0]
|
||||
db = _matmul.fn[mode_db](
|
||||
a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut, ctx.db_width,
|
||||
)
|
||||
dout = dc if ctx.has_out else None
|
||||
return da, db, None, None, None,\
|
||||
None, None, None, None,\
|
||||
None, None, None, None, None, dout
|
||||
|
||||
|
||||
class matmul:
|
||||
|
||||
def __init__(self, layout, block, mode, device, trans_a=False, trans_b=False, trans_c=False):
|
||||
if mode not in ['sdd', 'dsd', 'dds']:
|
||||
raise NotImplementedError('Supported modes are: sdd, dsd, dds')
|
||||
self.block = block
|
||||
self.mode = mode
|
||||
self.trans_a = trans_a
|
||||
self.trans_b = trans_b
|
||||
self.trans_c = trans_c
|
||||
self.layout = layout
|
||||
self.spdims = layout.shape
|
||||
step = min(block, 32)
|
||||
if self.mode == 'sdd':
|
||||
self.c_lut, self.c_width = sdd_lut(layout, block, device)
|
||||
self.da_lut, self.da_width = dsd_lut(layout, block, step, True, device)
|
||||
self.db_lut, self.db_width = dsd_lut(layout, block, step, False, device)
|
||||
if self.mode == 'dsd':
|
||||
self.c_lut, self.c_width = dsd_lut(layout, block, step, not self.trans_a, device)
|
||||
self.da_lut, self.da_width = sdd_lut(layout, block, device)
|
||||
self.db_lut, self.db_width = dsd_lut(layout, block, step, self.trans_a, device)
|
||||
if self.mode == 'dds':
|
||||
self.c_lut, self.c_width = dsd_lut(layout, block, step, self.trans_b, device)
|
||||
self.da_lut, self.da_width = dsd_lut(layout, block, step, not self.trans_b, device)
|
||||
self.db_lut, self.db_width = sdd_lut(layout, block, device)
|
||||
|
||||
def __call__(self, a, b, out=None):
|
||||
c = _matmul.apply(
|
||||
a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block,
|
||||
self.c_lut, self.c_width,
|
||||
self.da_lut, self.da_width,
|
||||
self.db_lut, self.db_width,
|
||||
out
|
||||
)
|
||||
return c
|
|
@ -0,0 +1,234 @@
|
|||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
def num_warps(n):
|
||||
if n < 512:
|
||||
return 4
|
||||
if n < 2048:
|
||||
return 8
|
||||
return 16
|
||||
|
||||
|
||||
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['sizemax'] * nargs['BLOCK'])})
|
||||
@triton.heuristics({'TN': lambda nargs: triton.next_power_of_2(nargs['sizemax'] * nargs['BLOCK'])})
|
||||
@triton.jit
|
||||
def _forward(
|
||||
X, scale, LUT, RPE, KP_M, ATTN_M, is_causal, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
|
||||
TN: tl.constexpr, BLOCK: tl.constexpr, APPLY_SCALE: tl.constexpr, APPLY_RPE: tl.constexpr, APPLY_KP_MASK: tl.constexpr,
|
||||
KP_MASK_MUL: tl.constexpr, APPLY_ATTN_MASK: tl.constexpr, ATTN_MASK_MUL: tl.constexpr,
|
||||
):
|
||||
pidhm = tl.program_id(0)
|
||||
pidz = tl.program_id(1)
|
||||
# create index ranges
|
||||
rxm = pidhm % BLOCK
|
||||
rbm = pidhm // BLOCK
|
||||
rxn = tl.arange(0, TN) % BLOCK
|
||||
rbn = tl.arange(0, TN) // BLOCK
|
||||
# extract information from LUT
|
||||
header = LUT + rbm * 2
|
||||
size = tl.load(header + 0)
|
||||
offset = tl.load(header + 1)
|
||||
check = rbn < size
|
||||
rbmn = tl.where(check, rbn, size - 1)
|
||||
# block id and column id
|
||||
blockid = tl.load(LUT + offset + rbmn * 4 + 0)
|
||||
columnid = tl.load(LUT + offset + rbmn * 4 + 1)
|
||||
rowid = tl.load(LUT + offset + rbmn * 4 + 2)
|
||||
headid = tl.load(LUT + offset + rbmn * 4 + 3)
|
||||
# pointers to X
|
||||
px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
|
||||
x = tl.load(px, mask=check, other=-float('inf'))
|
||||
x = x.to(tl.float32)
|
||||
# apply scale
|
||||
if APPLY_SCALE:
|
||||
x = x * scale
|
||||
# apply RPE
|
||||
if APPLY_RPE:
|
||||
prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn
|
||||
rpe = tl.load(prpe, mask=check, other=0)
|
||||
x = x + rpe
|
||||
# apply key-padding mask
|
||||
if APPLY_KP_MASK:
|
||||
pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn
|
||||
kp_m = tl.load(pkp_m, mask=check, other=-float('inf'))
|
||||
if KP_MASK_MUL:
|
||||
kp_m = tl.where(kp_m == 0, -float('inf'), 0.)
|
||||
x = x + kp_m
|
||||
# apply attention mask
|
||||
if APPLY_ATTN_MASK:
|
||||
pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn
|
||||
attn_m = tl.load(pattn_m, mask=check, other=-float('inf'))
|
||||
if ATTN_MASK_MUL:
|
||||
attn_m = tl.where(attn_m == 0, -float('inf'), 0.)
|
||||
x = x + attn_m
|
||||
# apply causal mask
|
||||
is_in_upper_triangle = columnid * BLOCK + rxn > rowid * BLOCK + rxm
|
||||
x = x + tl.where(is_in_upper_triangle & is_causal, -float('inf'), 0.)
|
||||
# computation
|
||||
x = tl.softmax(x)
|
||||
tl.store(px, x, mask=check)
|
||||
|
||||
|
||||
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['sizemax'] * nargs['BLOCK'])})
|
||||
@triton.heuristics({'TN': lambda nargs: triton.next_power_of_2(nargs['sizemax']) * nargs['BLOCK']})
|
||||
@triton.jit
|
||||
def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, TN: tl.constexpr, BLOCK: tl.constexpr):
|
||||
pidhm = tl.program_id(0)
|
||||
pidz = tl.program_id(1)
|
||||
# create index ranges
|
||||
rxm = pidhm % BLOCK
|
||||
rbm = pidhm // BLOCK
|
||||
rxn = tl.arange(0, TN) % BLOCK
|
||||
rbn = tl.arange(0, TN) // BLOCK
|
||||
# extract information from look-up table
|
||||
header = LUT + rbm * 2
|
||||
size = tl.load(header + 0)
|
||||
offset = tl.load(header + 1)
|
||||
# bounds checking on lut
|
||||
check = rbn < size
|
||||
rbmn = tl.where(check, rbn, size - 1)
|
||||
# initialize pointers to block-sparse input
|
||||
blockid = tl.load(LUT + offset + rbmn * 4)
|
||||
X = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
|
||||
DX = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
|
||||
# compute fused softmax backward
|
||||
x = tl.load(X, mask=check, other=0)
|
||||
dx = tl.load(DX, mask=check, other=0)
|
||||
x = x.to(tl.float32)
|
||||
dx = dx.to(tl.float32)
|
||||
y = x * (dx - tl.sum(x * dx, 0)) * scale
|
||||
tl.store(DX, y, mask=check)
|
||||
|
||||
|
||||
class _softmax(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def make_lut(layout, block, device):
|
||||
# sizes along rows
|
||||
sizes = layout.sum(-1).view(-1)
|
||||
# offsets in block format
|
||||
offsets = torch.zeros_like(sizes)
|
||||
offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
|
||||
# block indices
|
||||
layout_sum = sizes.sum()
|
||||
idx = torch.arange(layout_sum, device=layout.device)
|
||||
layout_nonzero = layout.nonzero(as_tuple=False)
|
||||
head = layout_nonzero[:, 0]
|
||||
rows = layout_nonzero[:, 1]
|
||||
columns = layout_nonzero[:, 2]
|
||||
core = torch.stack((idx, columns, rows, head), dim=1).view(-1)
|
||||
# construct look-up table
|
||||
offsets = offsets * 4 + 2 * sizes.numel()
|
||||
header = torch.stack((sizes, offsets), dim=1).view(-1)
|
||||
lut = torch.cat((header, core)).type(torch.int32).to(device)
|
||||
return lut, int(sizes.max())
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx, x, scale, rpe,
|
||||
key_padding_mask, attn_mask,
|
||||
kp_mask_mode, attn_mask_mode,
|
||||
is_causal,
|
||||
spdims, block, lut, maxlut
|
||||
):
|
||||
apply_scale = False if scale == 1.0 else True
|
||||
# handle None rpe
|
||||
if rpe is None:
|
||||
apply_rpe = False
|
||||
stride_zrpe, stride_hrpe, stride_srpe = 0, 0, 0
|
||||
rpe = torch.empty(0, dtype=x.dtype, device=x.device)
|
||||
else:
|
||||
apply_rpe = True
|
||||
stride_zrpe, stride_hrpe, stride_srpe = rpe.stride(0), rpe.stride(1), rpe.stride(2)
|
||||
# handle None key_padding_mask
|
||||
if key_padding_mask is None:
|
||||
apply_kp_mask = False
|
||||
stride_zkpm = 0
|
||||
key_padding_mask = torch.empty(0, dtype=x.dtype, device=x.device)
|
||||
else:
|
||||
apply_kp_mask = True
|
||||
stride_zkpm = key_padding_mask.stride(0)
|
||||
# handle None attention_mask
|
||||
if attn_mask is None:
|
||||
apply_attn_mask = False
|
||||
stride_zattnm = 0
|
||||
attn_mask = torch.empty(0, dtype=x.dtype, device=x.device)
|
||||
else:
|
||||
apply_attn_mask = True
|
||||
stride_zattnm = attn_mask.stride(0)
|
||||
# run kernel
|
||||
M = x.shape[0]
|
||||
grid = [spdims[0] * spdims[1] * block, M]
|
||||
_forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, is_causal, maxlut, x.stride(0),
|
||||
stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
|
||||
BLOCK=block,
|
||||
APPLY_SCALE=apply_scale,
|
||||
APPLY_RPE=apply_rpe,
|
||||
APPLY_KP_MASK=apply_kp_mask,
|
||||
APPLY_ATTN_MASK=apply_attn_mask,
|
||||
KP_MASK_MUL=(kp_mask_mode == 'mul'),
|
||||
ATTN_MASK_MUL=(attn_mask_mode == 'mul'))
|
||||
# save to context
|
||||
ctx.mark_dirty(x)
|
||||
ctx.save_for_backward(x, lut)
|
||||
ctx.spdims = spdims
|
||||
ctx.block = block
|
||||
ctx.maxlut = maxlut
|
||||
ctx.scale = scale
|
||||
ctx.apply_scale = apply_scale
|
||||
ctx.apply_rpe = apply_rpe
|
||||
ctx.apply_kp_mask = apply_kp_mask
|
||||
ctx.apply_attn_mask = apply_attn_mask
|
||||
ctx.kp_mask_mode = kp_mask_mode
|
||||
ctx.attn_mask_mode = attn_mask_mode
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dx):
|
||||
# retrieve from context
|
||||
x, lut = ctx.saved_tensors
|
||||
# run kernel
|
||||
M = x.shape[0]
|
||||
grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M]
|
||||
_backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), BLOCK=ctx.block)
|
||||
return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class softmax:
|
||||
|
||||
def make_lut(self, device):
|
||||
key = (device, )
|
||||
if key not in self.lut_cache:
|
||||
self.lut_cache[key] = _softmax.make_lut(self.layout, self.block, device)
|
||||
return self.lut_cache[key]
|
||||
|
||||
def __init__(self, layout, block):
|
||||
self.spdims = layout.shape
|
||||
self.layout = layout
|
||||
self.block = block
|
||||
self.lut_cache = dict()
|
||||
|
||||
def __call__(
|
||||
self, x, scale=1., rpe=None,
|
||||
key_padding_mask=None, attn_mask=None,
|
||||
key_padding_mask_mode='add', attn_mask_mode='add',
|
||||
is_causal=False
|
||||
):
|
||||
if rpe is not None and rpe.dtype != x.dtype:
|
||||
raise ValueError('relative position embedding must be %s' % x.dtype)
|
||||
if attn_mask is not None and attn_mask.dtype != x.dtype:
|
||||
raise ValueError('Attention mask must be %s' % x.dtype)
|
||||
if key_padding_mask is not None and key_padding_mask.dtype != x.dtype:
|
||||
raise ValueError('Key padding mask must be %s' % x.dtype)
|
||||
lut, maxlut = self.make_lut(x.device)
|
||||
x = _softmax.apply(
|
||||
x, scale, rpe,
|
||||
key_padding_mask, attn_mask,
|
||||
key_padding_mask_mode, attn_mask_mode,
|
||||
is_causal,
|
||||
self.spdims, self.block,
|
||||
lut, maxlut
|
||||
)
|
||||
return x
|
|
@ -0,0 +1,6 @@
|
|||
def layout_full_zero_check(input_layout):
|
||||
row_check = input_layout.sum(dim=2).eq(0) # (H, L // block)
|
||||
col_check = input_layout.sum(dim=1).eq(0)
|
||||
row_answer = bool(row_check.any())
|
||||
col_answer = bool(col_check.any())
|
||||
return row_answer or col_answer, row_answer, col_answer, row_check, col_check
|
|
@ -0,0 +1,129 @@
|
|||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
def get_item(schemes, idx):
|
||||
if schemes is None:
|
||||
return None
|
||||
len_schemes = len(schemes)
|
||||
if len_schemes > idx:
|
||||
return schemes[idx]
|
||||
if len_schemes == 1:
|
||||
return schemes[0]
|
||||
raise IndexError('idx (%d) is out of the length (%d) of the schemes' % (idx, len_schemes), schemes)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class HeadAttentionScheme:
|
||||
head_sum2sum: Optional[tuple]
|
||||
head_sum2sum_self: str
|
||||
head_sum2con: Optional[tuple]
|
||||
head_sum2con_self: str
|
||||
head_con2sum: Optional[tuple]
|
||||
head_con2sum_self: str
|
||||
head_con2con: Optional[tuple]
|
||||
head_con2con_self: str
|
||||
head_con2con_causal: bool
|
||||
head_pref2pref_mode: str
|
||||
head_sum2pref_mode: str
|
||||
head_pref2sum_mode: str
|
||||
head_con2pref_mode: str
|
||||
head_pref2con_mode: str
|
||||
|
||||
|
||||
class LayerAttentionScheme(object):
|
||||
def __init__(
|
||||
self,
|
||||
layer_sum2sum, layer_sum2sum_self,
|
||||
layer_sum2con, layer_sum2con_self,
|
||||
layer_con2sum, layer_con2sum_self,
|
||||
layer_con2con, layer_con2con_self,
|
||||
layer_con2con_causal,
|
||||
layer_pref2pref_mode,
|
||||
layer_sum2pref_mode,
|
||||
layer_pref2sum_mode,
|
||||
layer_con2pref_mode,
|
||||
layer_pref2con_mode,
|
||||
num_layer_heads,
|
||||
):
|
||||
self.num_layer_heads = num_layer_heads
|
||||
|
||||
self.heads_schemes = []
|
||||
for idx in range(self.num_layer_heads):
|
||||
self.heads_schemes.append(
|
||||
HeadAttentionScheme(
|
||||
get_item(layer_sum2sum, idx), layer_sum2sum_self,
|
||||
get_item(layer_sum2con, idx), layer_sum2con_self,
|
||||
get_item(layer_con2sum, idx), layer_con2sum_self,
|
||||
get_item(layer_con2con, idx), layer_con2con_self,
|
||||
layer_con2con_causal,
|
||||
layer_pref2pref_mode,
|
||||
layer_sum2pref_mode,
|
||||
layer_pref2sum_mode,
|
||||
layer_con2pref_mode,
|
||||
layer_pref2con_mode
|
||||
)
|
||||
)
|
||||
self.heads_schemes = tuple(self.heads_schemes)
|
||||
self.same_for_all_heads = len(set(self.heads_schemes)) == 1
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.heads_schemes)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.heads_schemes == other.heads_schemes
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.heads_schemes[idx]
|
||||
|
||||
def __len__(self):
|
||||
return self.num_layer_heads
|
||||
|
||||
|
||||
class AttentionScheme(object):
|
||||
def __init__(
|
||||
self,
|
||||
sum2sum, sum2sum_self,
|
||||
sum2con, sum2con_self,
|
||||
con2sum, con2sum_self,
|
||||
con2con, con2con_self,
|
||||
con2con_causal,
|
||||
pref2pref_mode,
|
||||
sum2pref_mode,
|
||||
pref2sum_mode,
|
||||
con2pref_mode,
|
||||
pref2con_mode,
|
||||
num_layers, num_layers_heads,
|
||||
):
|
||||
self.num_layers = num_layers
|
||||
|
||||
self.layers_schemes = []
|
||||
for idx in range(self.num_layers):
|
||||
self.layers_schemes.append(
|
||||
LayerAttentionScheme(
|
||||
get_item(sum2sum, idx), sum2sum_self,
|
||||
get_item(sum2con, idx), sum2con_self,
|
||||
get_item(con2sum, idx), con2sum_self,
|
||||
get_item(con2con, idx), con2con_self,
|
||||
con2con_causal,
|
||||
pref2pref_mode,
|
||||
sum2pref_mode,
|
||||
pref2sum_mode,
|
||||
con2pref_mode,
|
||||
pref2con_mode,
|
||||
get_item(num_layers_heads, idx)
|
||||
)
|
||||
)
|
||||
self.layers_schemes = tuple(self.layers_schemes)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.layers_schemes)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.layers_schemes == other.layers_schemes
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.layers_schemes[idx]
|
||||
|
||||
def __len__(self):
|
||||
return self.num_layers
|
|
@ -0,0 +1,6 @@
|
|||
from museformer.tools.singleton_decorator import singleton
|
||||
|
||||
|
||||
@singleton
|
||||
class FourDimPocket(dict):
|
||||
pass
|
|
@ -0,0 +1,87 @@
|
|||
import logging
|
||||
import os
|
||||
import numpy as np
|
||||
from fairseq.data import data_utils
|
||||
from fairseq.data.base_wrapper_dataset import BaseWrapperDataset
|
||||
from fairseq.data.indexed_dataset import MMapIndexedDatasetBuilder
|
||||
from ..dictionary.compound_dictionary import CompoundDictionary
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_beat_ids(src_tokens, dictionary, ts_instead_of_tempo=False):
|
||||
change = 's' if ts_instead_of_tempo else 't'
|
||||
|
||||
beat_mask = src_tokens.ge(dictionary.type_dict['o'][0]) & src_tokens.lt(dictionary.type_dict['o'][1])
|
||||
change_mask = src_tokens.ge(dictionary.type_dict[change][0]) & src_tokens.lt(dictionary.type_dict[change][1])
|
||||
bar_mask = src_tokens.ge(dictionary.type_dict['b'][0]) & src_tokens.lt(dictionary.type_dict['b'][1])
|
||||
special_mask = src_tokens.lt(4)
|
||||
no_beat_mask = change_mask | bar_mask | special_mask
|
||||
del change_mask, bar_mask
|
||||
src_tokens = src_tokens.clone() - (dictionary.type_dict['o'][0] - 1)
|
||||
cur_beat = 0
|
||||
for idx, (token, beat_token) in enumerate(zip(src_tokens, beat_mask)):
|
||||
if beat_token:
|
||||
cur_beat = token
|
||||
else:
|
||||
src_tokens[idx] = cur_beat
|
||||
src_tokens.masked_fill_(no_beat_mask, 0)
|
||||
return src_tokens
|
||||
|
||||
|
||||
class AddBeatDataset(BaseWrapperDataset):
|
||||
def __init__(
|
||||
self,
|
||||
dataset, dictionary: CompoundDictionary,
|
||||
cache_data_label='default', dataset_name=None,
|
||||
mask_ts_instead_of_tempo=False,
|
||||
):
|
||||
super().__init__(dataset)
|
||||
self.dictionary = dictionary
|
||||
self.cache_data_label = cache_data_label
|
||||
self.dataset_name = dataset_name
|
||||
self.mask_ts_instead_of_tempo = mask_ts_instead_of_tempo
|
||||
|
||||
self.cache_list = None
|
||||
|
||||
cache_name = 'add-beat-dataset_%s_%s' % (self.cache_data_label, self.dataset_name)
|
||||
cache_dataset_dir = 'cached_datasets/'
|
||||
if cache_data_label is not None:
|
||||
cache_dataset_dir = os.path.join(cache_dataset_dir, cache_data_label)
|
||||
cache_path = os.path.join(cache_dataset_dir, cache_name)
|
||||
|
||||
if all([os.path.isfile(cache_path + suffix) for suffix in ('.beat_ids.bin', '.beat_ids.idx')]):
|
||||
pass
|
||||
else:
|
||||
logger.info('Building up beat_ids dataset for %s ...' % dataset_name)
|
||||
os.makedirs(cache_dataset_dir, exist_ok=True)
|
||||
self.beat_ids_builder = MMapIndexedDatasetBuilder(
|
||||
cache_path + '.beat_ids.bin', dtype=np.int32
|
||||
)
|
||||
|
||||
self.__prepare_dataset()
|
||||
|
||||
self.beat_ids_builder.finalize(cache_path + '.beat_ids.idx')
|
||||
del self.beat_ids_builder
|
||||
|
||||
self.beat_ids_dataset = data_utils.load_indexed_dataset(cache_path + '.beat_ids',
|
||||
dictionary=None, dataset_impl='mmap')
|
||||
assert len(self.beat_ids_dataset) == len(self.dataset)
|
||||
for idx, beat_ids in enumerate(self.beat_ids_dataset):
|
||||
assert len(beat_ids) == self.dataset.size(idx), (idx, self.dataset.size(idx), len(beat_ids))
|
||||
logger.info('Checked the cached beat_ids dataset.')
|
||||
|
||||
def __prepare_dataset(self):
|
||||
for sample in self.dataset:
|
||||
src_tokens = sample[0]
|
||||
beat_ids = get_beat_ids(src_tokens, self.dictionary, ts_instead_of_tempo=self.mask_ts_instead_of_tempo)
|
||||
self.beat_ids_builder.add_item(beat_ids)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
beat_ids = self.beat_ids_dataset[idx]
|
||||
sample = self.dataset[idx]
|
||||
return (*sample, beat_ids)
|
||||
|
||||
def collater(self, samples):
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,129 @@
|
|||
# import os
|
||||
# from copy import deepcopy
|
||||
# import pickle
|
||||
import logging
|
||||
|
||||
# import numpy as np
|
||||
import torch
|
||||
# from fairseq.data import data_utils, FairseqDataset
|
||||
from fairseq.data.base_wrapper_dataset import BaseWrapperDataset
|
||||
# from fairseq.data.indexed_dataset import MMapIndexedDatasetBuilder
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_bar_chunk_points(seq: torch.Tensor, eob_index, begin_idx=0, take_bos_as_bar=False, bos_index=None):
|
||||
# seq: (seq_len,)
|
||||
# eob_index: int
|
||||
is_complete_bar = seq[-1] == eob_index
|
||||
indices = seq.eq(eob_index).nonzero(as_tuple=False).squeeze(1) # (num_bars,)
|
||||
indices = indices + 1
|
||||
indices = torch.cat(
|
||||
(indices.new_tensor([begin_idx]), indices), dim=0
|
||||
)
|
||||
len_seq = len(seq)
|
||||
if not is_complete_bar and len_seq > begin_idx:
|
||||
indices = torch.cat(
|
||||
(indices, indices.new_tensor([len_seq])), dim=0
|
||||
)
|
||||
if take_bos_as_bar:
|
||||
assert seq[0] == bos_index
|
||||
assert begin_idx == 1
|
||||
indices = torch.cat(
|
||||
(torch.tensor([0]), indices), dim=0
|
||||
)
|
||||
|
||||
return indices, is_complete_bar
|
||||
|
||||
|
||||
class BarChunkSequenceDataset(BaseWrapperDataset):
|
||||
def __init__(
|
||||
self,
|
||||
dataset, src_dict, eob,
|
||||
eos_appended=True,
|
||||
offset=1,
|
||||
take_bos_as_bar=True,
|
||||
bos_index=None,
|
||||
):
|
||||
super().__init__(dataset)
|
||||
self.src_dict = src_dict
|
||||
self.eos_appended = eos_appended
|
||||
self.eob = eob
|
||||
self.offset = offset
|
||||
self.take_bos_as_bar = take_bos_as_bar
|
||||
self.bos_index = bos_index
|
||||
self.cache = [None] * len(self.dataset)
|
||||
|
||||
def __iter__(self):
|
||||
len_dataset = len(self)
|
||||
for idx in range(len_dataset):
|
||||
yield self[idx]
|
||||
|
||||
def __getitem__(self, index):
|
||||
src, tgt = self.dataset[index] # all include eoc
|
||||
chunk_points = self.cache[index]
|
||||
if chunk_points is None:
|
||||
chunk_points, complete = get_bar_chunk_points(
|
||||
src, self.eob, begin_idx=self.offset,
|
||||
take_bos_as_bar=self.take_bos_as_bar, bos_index=self.bos_index
|
||||
)
|
||||
assert complete
|
||||
self.cache[index] = chunk_points
|
||||
return src, tgt, chunk_points
|
||||
|
||||
def collater(self, samples):
|
||||
raise NotImplementedError("Dataset class %s is not designed for collating samples." % self.__class__.__name__)
|
||||
|
||||
|
||||
class FixedChunkingLengthDataset(BaseWrapperDataset):
|
||||
def __init__(
|
||||
self,
|
||||
dataset, fixed_chunking_length
|
||||
):
|
||||
assert fixed_chunking_length is not None
|
||||
super().__init__(dataset)
|
||||
self.fixed_chunking_length = fixed_chunking_length
|
||||
|
||||
def __iter__(self):
|
||||
len_dataset = len(self)
|
||||
for idx in range(len_dataset):
|
||||
yield self[idx]
|
||||
|
||||
def __getitem__(self, index):
|
||||
src, tgt = self.dataset[index] # all include eoc
|
||||
|
||||
sample_len = len(src)
|
||||
chunk_points = torch.arange(0, sample_len, self.fixed_chunking_length)
|
||||
chunk_points = torch.cat((chunk_points, chunk_points.new_tensor([sample_len])), dim=0)
|
||||
|
||||
return src, tgt, chunk_points
|
||||
|
||||
def collater(self, samples):
|
||||
raise NotImplementedError("Dataset class %s is not designed for collating samples." % self.__class__.__name__)
|
||||
|
||||
|
||||
def ChunkSequenceDataset(
|
||||
dataset, src_dict,
|
||||
eob, eoc,
|
||||
chunking_scheme='bar_aware',
|
||||
chunking_length=None,
|
||||
dataset_name=None,
|
||||
cache_data_label=None,
|
||||
cache_sequence=None,
|
||||
offset=0,
|
||||
take_bos_as_bar=False, bos_index=None
|
||||
):
|
||||
if chunking_scheme == 'bar_aware':
|
||||
return BarChunkSequenceDataset(
|
||||
dataset, src_dict, eob,
|
||||
offset=offset,
|
||||
take_bos_as_bar=take_bos_as_bar,
|
||||
bos_index=bos_index
|
||||
)
|
||||
elif chunking_scheme == 'fixed':
|
||||
return FixedChunkingLengthDataset(
|
||||
dataset, chunking_length
|
||||
)
|
||||
|
||||
raise NotImplementedError(chunking_scheme)
|
|
@ -0,0 +1,43 @@
|
|||
from fairseq.data import FairseqDataset
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
|
||||
|
||||
class ExtendedWrapperDataset(FairseqDataset):
|
||||
def __init__(self, dataset):
|
||||
super().__init__()
|
||||
self.dataset = dataset
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.dataset[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def collater(self, samples):
|
||||
if hasattr(self.dataset, "collater"):
|
||||
return self.dataset.collater(samples)
|
||||
else:
|
||||
return default_collate(samples)
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self.dataset.sizes
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.dataset.sizes[index]
|
||||
|
||||
def size(self, index):
|
||||
return self.dataset.sizes[index]
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return getattr(self.dataset, "supports_prefetch", False)
|
||||
|
||||
def attr(self, attr: str, index: int):
|
||||
return self.dataset.attr(attr, index)
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
super().set_epoch(epoch)
|
||||
if hasattr(self.dataset, "set_epoch"):
|
||||
self.dataset.set_epoch(epoch)
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
import torch
|
||||
from fairseq.data.base_wrapper_dataset import BaseWrapperDataset
|
||||
|
||||
|
||||
class MusicMonolingualDataset(BaseWrapperDataset):
|
||||
def __int__(self, dataset):
|
||||
super().__init__(dataset)
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.dataset[index] # (len,)
|
||||
return torch.cat((sample[-1:], sample[:-1]), dim=0), sample
|
||||
|
||||
def collater(self, samples):
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,185 @@
|
|||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from fairseq.data import data_utils
|
||||
from fairseq.data.base_wrapper_dataset import BaseWrapperDataset
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PostProcessDataset(BaseWrapperDataset):
|
||||
def __init__(self, dataset):
|
||||
super().__init__(dataset)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __iter__(self):
|
||||
for idx in range(len(self)):
|
||||
yield self[idx]
|
||||
|
||||
def __getitem__(self, index):
|
||||
src, tgt, chunk_points, beat_ids, num_chunks, num_complete_chunks, num_prefix = self.dataset[index]
|
||||
seq_len = src.shape[0]
|
||||
|
||||
new_sample = {
|
||||
'id': index,
|
||||
'src_tokens': src,
|
||||
'src_length': seq_len,
|
||||
'target': tgt,
|
||||
'chunk_points': chunk_points,
|
||||
'num_chunks': num_chunks,
|
||||
'num_complete_chunks': num_complete_chunks,
|
||||
'num_pref': num_prefix,
|
||||
'beat_ids': beat_ids,
|
||||
}
|
||||
|
||||
return new_sample
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self.dataset.sizes
|
||||
|
||||
def size(self, index):
|
||||
return self.dataset.size(index)
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.dataset.size(index)
|
||||
|
||||
def ordered_indices(self):
|
||||
"""Return an ordered list of indices. Batches will be constructed based
|
||||
on this order."""
|
||||
return np.arange(len(self), dtype=np.int64)
|
||||
|
||||
def collater(self, samples):
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
bsz = len(samples)
|
||||
sample_id = torch.tensor([s['id'] for s in samples])
|
||||
src_tokens = [s['src_tokens'] for s in samples]
|
||||
src_lengths = [s['src_length'] for s in samples]
|
||||
target = [s['target'] for s in samples]
|
||||
chunk_points = [s['chunk_points'] for s in samples]
|
||||
num_chunks = [s['num_chunks'] for s in samples]
|
||||
num_complete_chunks = [s['num_complete_chunks'] for s in samples]
|
||||
num_prefix = [s['num_pref'] for s in samples]
|
||||
beat_ids = [s['beat_ids'] for s in samples]
|
||||
ntokens = sum(src_lengths)
|
||||
|
||||
src_tokens = pad_sequence(src_tokens, batch_first=True, padding_value=0)
|
||||
src_lengths = torch.tensor(src_lengths, dtype=torch.long)
|
||||
target = pad_sequence(target, batch_first=True, padding_value=0)
|
||||
chunk_points = data_utils.collate_tokens(
|
||||
chunk_points, 0
|
||||
)
|
||||
num_chunks = torch.tensor(num_chunks, dtype=torch.long)
|
||||
num_complete_chunks = torch.tensor(num_complete_chunks, dtype=torch.long)
|
||||
num_prefix = torch.tensor(num_prefix, dtype=torch.long)
|
||||
beat_ids = pad_sequence(beat_ids, batch_first=True, padding_value=0)
|
||||
|
||||
batch = {
|
||||
'id': sample_id,
|
||||
'nsentences': bsz,
|
||||
'ntokens': ntokens,
|
||||
'net_input': {
|
||||
'src_tokens': src_tokens,
|
||||
'src_lengths': src_lengths,
|
||||
'chunk_points': chunk_points,
|
||||
'num_chunks': num_chunks,
|
||||
'num_complete_chunks': num_complete_chunks,
|
||||
'num_pref': num_prefix,
|
||||
'beat_ids': beat_ids,
|
||||
},
|
||||
'target': target,
|
||||
}
|
||||
|
||||
return batch
|
||||
|
||||
def filter_indices_by_size(self, indices, max_sizes):
|
||||
"""
|
||||
Filter a list of sample indices. Remove those that are longer than
|
||||
specified in *max_sizes*.
|
||||
|
||||
WARNING: don't update, override method in child classes
|
||||
|
||||
Args:
|
||||
indices (np.array): original array of sample indices
|
||||
max_sizes (int or list[int] or tuple[int]): max sample size,
|
||||
can be defined separately for src and tgt (then list or tuple)
|
||||
|
||||
Returns:
|
||||
np.array: filtered sample array
|
||||
list: list of removed indices
|
||||
"""
|
||||
# print(indices)
|
||||
if isinstance(max_sizes, float) or isinstance(max_sizes, int):
|
||||
if hasattr(self, "sizes") and isinstance(self.sizes, np.ndarray):
|
||||
ignored = indices[self.sizes[indices] > max_sizes].tolist()
|
||||
indices = indices[self.sizes[indices] <= max_sizes]
|
||||
elif (
|
||||
hasattr(self, "sizes")
|
||||
and isinstance(self.sizes, list)
|
||||
and len(self.sizes) == 1
|
||||
):
|
||||
ignored = indices[self.sizes[0][indices] > max_sizes].tolist()
|
||||
indices = indices[self.sizes[0][indices] <= max_sizes]
|
||||
else:
|
||||
indices, ignored = data_utils._filter_by_size_dynamic(
|
||||
indices, self.size, max_sizes
|
||||
)
|
||||
else:
|
||||
indices, ignored = data_utils._filter_by_size_dynamic(
|
||||
indices, self.size, max_sizes
|
||||
)
|
||||
if len(ignored) > 0:
|
||||
print(self.sizes)
|
||||
print(ignored)
|
||||
print(max_sizes)
|
||||
return indices, ignored
|
||||
|
||||
def batch_by_size(
|
||||
self,
|
||||
indices,
|
||||
max_tokens=None,
|
||||
max_sentences=None,
|
||||
required_batch_size_multiple=1,
|
||||
):
|
||||
"""
|
||||
Given an ordered set of indices, return batches according to
|
||||
*max_tokens*, *max_sentences* and *required_batch_size_multiple*.
|
||||
"""
|
||||
from fairseq.data import data_utils
|
||||
|
||||
fixed_shapes = self.get_batch_shapes()
|
||||
if fixed_shapes is not None:
|
||||
|
||||
def adjust_bsz(bsz, num_tokens):
|
||||
if bsz is None:
|
||||
assert max_tokens is not None, "Must specify --max-tokens"
|
||||
bsz = max_tokens // num_tokens
|
||||
if max_sentences is not None:
|
||||
bsz = min(bsz, max_sentences)
|
||||
elif (
|
||||
bsz >= required_batch_size_multiple
|
||||
and bsz % required_batch_size_multiple != 0
|
||||
):
|
||||
bsz -= bsz % required_batch_size_multiple
|
||||
return bsz
|
||||
|
||||
fixed_shapes = np.array(
|
||||
[
|
||||
[adjust_bsz(bsz, num_tokens), num_tokens]
|
||||
for (bsz, num_tokens) in fixed_shapes
|
||||
]
|
||||
)
|
||||
|
||||
return data_utils.batch_by_size(
|
||||
indices,
|
||||
num_tokens_fn=self.num_tokens,
|
||||
max_tokens=max_tokens,
|
||||
max_sentences=max_sentences,
|
||||
required_batch_size_multiple=required_batch_size_multiple,
|
||||
fixed_shapes=fixed_shapes,
|
||||
)
|
|
@ -0,0 +1,15 @@
|
|||
from fairseq.data.base_wrapper_dataset import BaseWrapperDataset
|
||||
|
||||
|
||||
class PrefixTagDataset(BaseWrapperDataset):
|
||||
def __init__(self, dataset, num_prefix):
|
||||
super().__init__(dataset)
|
||||
self.dataset = dataset
|
||||
self.num_prefix = num_prefix
|
||||
|
||||
def __getitem__(self, idx):
|
||||
sample = self.dataset[idx]
|
||||
return (*sample, self.num_prefix)
|
||||
|
||||
def collater(self, samples):
|
||||
raise NotImplementedError("Dataset class %s is not designed for collating samples." % self.__class__.__name__)
|
|
@ -0,0 +1,53 @@
|
|||
# Author: Botao Yu
|
||||
|
||||
import numpy as np
|
||||
from fairseq.data.base_wrapper_dataset import BaseWrapperDataset
|
||||
|
||||
|
||||
class RemoveLongSizeSamplesDataset(BaseWrapperDataset):
|
||||
def __init__(self, dataset, max_size):
|
||||
super().__init__(dataset)
|
||||
self.max_size = max_size
|
||||
max_token_select = self.dataset.sizes <= self.max_size
|
||||
final_select = max_token_select
|
||||
self.selected_index = np.nonzero(final_select)[0]
|
||||
|
||||
def __getitem__(self, index):
|
||||
origin_index = self.selected_index[index]
|
||||
return self.dataset[origin_index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.selected_index)
|
||||
|
||||
def __iter__(self):
|
||||
len_dataset = len(self)
|
||||
for idx in range(len_dataset):
|
||||
yield self[idx]
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self.dataset.sizes[self.selected_index]
|
||||
|
||||
def size(self, index):
|
||||
origin_index = self.selected_index[index]
|
||||
return self.dataset.size(origin_index)
|
||||
|
||||
def num_tokens(self, index):
|
||||
origin_index = self.selected_index[index]
|
||||
return self.dataset.num_tokens(origin_index)
|
||||
|
||||
def ordered_indices(self):
|
||||
"""Return an ordered list of indices. Batches will be constructed based
|
||||
on this order."""
|
||||
return np.arange(len(self), dtype=np.int64)
|
||||
|
||||
|
||||
def MayRemoveLongSizeSamplesDataset(dataset, max_size):
|
||||
if max_size is None:
|
||||
return dataset
|
||||
|
||||
new_dataset = RemoveLongSizeSamplesDataset(dataset, max_size)
|
||||
if len(new_dataset) == len(dataset):
|
||||
del new_dataset
|
||||
return dataset
|
||||
return new_dataset
|
|
@ -0,0 +1,61 @@
|
|||
# Author: Botao Yu
|
||||
|
||||
import numpy as np
|
||||
from fairseq.data import FairseqDataset
|
||||
|
||||
|
||||
class RemoveShortSizeSamplesDataset(FairseqDataset):
|
||||
def __init__(self, dataset, min_size):
|
||||
super().__init__()
|
||||
# st = time.time()
|
||||
self.dataset = dataset
|
||||
|
||||
self.min_size = min_size
|
||||
|
||||
min_token_select = self.dataset.sizes >= self.min_size
|
||||
|
||||
final_select = min_token_select
|
||||
self.selected_index = np.nonzero(final_select)[0]
|
||||
|
||||
# et = time.time()
|
||||
# print('%s dataset: %.2fs' % (self.__class__.__name__, et - st))
|
||||
|
||||
def __getitem__(self, index):
|
||||
origin_index = self.selected_index[index]
|
||||
return self.dataset[origin_index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.selected_index)
|
||||
|
||||
def __iter__(self):
|
||||
len_dataset = len(self)
|
||||
for idx in range(len_dataset):
|
||||
yield self[idx]
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self.dataset.sizes[self.selected_index]
|
||||
|
||||
def size(self, index):
|
||||
origin_index = self.selected_index[index]
|
||||
return self.dataset.size(origin_index)
|
||||
|
||||
def num_tokens(self, index):
|
||||
origin_index = self.selected_index[index]
|
||||
return self.dataset.num_tokens(origin_index)
|
||||
|
||||
def ordered_indices(self):
|
||||
"""Return an ordered list of indices. Batches will be constructed based
|
||||
on this order."""
|
||||
return np.arange(len(self), dtype=np.int64)
|
||||
|
||||
|
||||
def MayRemoveShortSizeSamplesDataset(dataset, min_size):
|
||||
if min_size is None or min_size == 0:
|
||||
return dataset
|
||||
|
||||
new_dataset = RemoveShortSizeSamplesDataset(dataset, min_size)
|
||||
if len(new_dataset) == len(dataset):
|
||||
del new_dataset
|
||||
return dataset
|
||||
return new_dataset
|
|
@ -0,0 +1,44 @@
|
|||
import torch
|
||||
from fairseq.data.base_wrapper_dataset import BaseWrapperDataset
|
||||
|
||||
|
||||
class TruncateMusicDataset(BaseWrapperDataset):
|
||||
def __init__(self, dataset, truncated_length):
|
||||
super().__init__(dataset)
|
||||
assert truncated_length is None or truncated_length > 1
|
||||
self.truncated_length = truncated_length
|
||||
|
||||
self._sizes = self.dataset.sizes.copy()
|
||||
self._sizes[self._sizes > self.truncated_length] = self.truncated_length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
src, tgt, chunk_points, beat_ids = self.dataset[idx]
|
||||
num_chunks = len(chunk_points) - 1
|
||||
if self.truncated_length is None or self.dataset.size(idx) <= self.truncated_length:
|
||||
return src, tgt, chunk_points, beat_ids, num_chunks, num_chunks # num_complete_chunks
|
||||
|
||||
src = src[:self.truncated_length]
|
||||
tgt = tgt[:self.truncated_length]
|
||||
beat_ids = beat_ids[:self.truncated_length]
|
||||
chunk_points = chunk_points[chunk_points.le(self.truncated_length)]
|
||||
if chunk_points[-1] == self.truncated_length:
|
||||
num_chunks = len(chunk_points) - 1
|
||||
num_complete_chunks = num_chunks
|
||||
else:
|
||||
num_chunks = len(chunk_points)
|
||||
num_complete_chunks = num_chunks - 1
|
||||
chunk_points = torch.cat((chunk_points, chunk_points.new_tensor([self.truncated_length])))
|
||||
return src, tgt, chunk_points, beat_ids, num_chunks, num_complete_chunks
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self._sizes
|
||||
|
||||
def size(self, index):
|
||||
return self._sizes[index]
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self._sizes[index]
|
||||
|
||||
def collater(self, samples):
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,51 @@
|
|||
import re
|
||||
|
||||
from fairseq.data import Dictionary
|
||||
|
||||
|
||||
class CompoundDictionary(Dictionary):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.type_dict = None
|
||||
self.special_pad_word = None
|
||||
self.special_pad_index = None
|
||||
|
||||
@classmethod
|
||||
def load(cls, f):
|
||||
d = cls()
|
||||
d.add_from_file(f)
|
||||
d.construct_types()
|
||||
return d
|
||||
|
||||
def construct_types(self):
|
||||
type_dict = {}
|
||||
type_token = None
|
||||
current_begin = None
|
||||
considered_type_token = set()
|
||||
for idx, symbol in enumerate(self.symbols):
|
||||
match = re.fullmatch('([a-z]+)\-\d+', symbol)
|
||||
if match is None:
|
||||
if current_begin is not None:
|
||||
type_dict[type_token] = (current_begin, idx)
|
||||
type_token = None
|
||||
current_begin = None
|
||||
else:
|
||||
now_type_token = match.group(1)
|
||||
if current_begin is not None:
|
||||
if type_token == now_type_token:
|
||||
continue
|
||||
else:
|
||||
type_dict[type_token] = (current_begin, idx)
|
||||
assert now_type_token not in considered_type_token
|
||||
type_token = now_type_token
|
||||
considered_type_token.add(type_token)
|
||||
current_begin = idx
|
||||
else:
|
||||
assert now_type_token not in considered_type_token
|
||||
type_token = now_type_token
|
||||
considered_type_token.add(type_token)
|
||||
current_begin = idx
|
||||
if current_begin is not None:
|
||||
type_dict[type_token] = (current_begin, len(self.symbols))
|
||||
self.type_dict = type_dict
|
||||
return type_dict
|
|
@ -0,0 +1,6 @@
|
|||
import torch.nn as nn
|
||||
|
||||
|
||||
class EmbeddingLayer(nn.Module):
|
||||
def __init__(self):
|
||||
pass
|
|
@ -0,0 +1,25 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def generate_sinusoid_embeddings(max_len, dim):
|
||||
inv_freq = 1 / (10000 ** (torch.arange(0.0, dim, 2.0) / dim))
|
||||
position = torch.arange(max_len) # (max_len, 1)
|
||||
sinusoid_inp = torch.ger(position, inv_freq)
|
||||
pe = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
|
||||
return pe
|
||||
|
||||
|
||||
def generate_sinusoid_position_embedding_with_padding(num_positions, embed_dim):
|
||||
embedding_weight = generate_sinusoid_embeddings(num_positions, embed_dim)
|
||||
embedding_weight = torch.cat(
|
||||
(torch.zeros(1, embed_dim), embedding_weight), dim=0
|
||||
)
|
||||
return embedding_weight
|
||||
|
||||
|
||||
def generate_randomly_initialized_position_embedding_with_padding(num_positions, embed_dim):
|
||||
embedding_weight = torch.empty(num_positions + 1, embed_dim)
|
||||
nn.init.normal_(embedding_weight)
|
||||
nn.init.zeros_(embedding_weight[0])
|
||||
return embedding_weight
|
|
@ -0,0 +1,71 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class SinusoidalEmbedding(nn.Module):
|
||||
def __init__(self, embedding_dim, padding_idx, init_size=1024):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.padding_idx = padding_idx
|
||||
self.num_embedding = init_size
|
||||
weights = self.get_embedding(
|
||||
init_size, embedding_dim, padding_idx
|
||||
)
|
||||
self.register_buffer('weights', weights, persistent=False)
|
||||
|
||||
@staticmethod
|
||||
def get_embedding(
|
||||
num_embeddings: int, embedding_dim: int, padding_idx=None
|
||||
):
|
||||
"""Build sinusoidal embeddings.
|
||||
|
||||
This matches the implementation in tensor2tensor, but differs slightly
|
||||
from the description in Section 3.5 of "Attention Is All You Need".
|
||||
"""
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
|
||||
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(
|
||||
1
|
||||
) * emb.unsqueeze(0)
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(
|
||||
num_embeddings, -1
|
||||
)
|
||||
if padding_idx is not None:
|
||||
emb[padding_idx, :] = 0
|
||||
return emb
|
||||
|
||||
def forward(self, x):
|
||||
if x.numel() > 0:
|
||||
x_max = int(x.max())
|
||||
if x_max >= self.num_embedding:
|
||||
self.num_embedding = x_max + 32
|
||||
weights = self.get_embedding(
|
||||
self.num_embedding, self.embedding_dim, self.padding_idx
|
||||
)
|
||||
self.weights = weights.to(self.weights)
|
||||
r = F.embedding(x, self.weights, padding_idx=self.padding_idx)
|
||||
return r
|
||||
|
||||
|
||||
def TaoEmbedding(
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int,
|
||||
learned: bool = False,
|
||||
):
|
||||
if learned:
|
||||
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
||||
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
|
||||
if padding_idx is not None:
|
||||
nn.init.constant_(m.weight[padding_idx], 0)
|
||||
else:
|
||||
m = SinusoidalEmbedding(
|
||||
embedding_dim,
|
||||
padding_idx,
|
||||
init_size=num_embeddings,
|
||||
)
|
||||
return m
|
|
@ -0,0 +1 @@
|
|||
from .main import block_fill_
|
|
@ -0,0 +1,32 @@
|
|||
// #pragma once
|
||||
#include <torch/extension.h>
|
||||
|
||||
|
||||
// CUDA函数声明
|
||||
int cudaForwardLauncher(
|
||||
const at::Tensor&, const at::Tensor&,
|
||||
const long, const long, const long, const long, const long, const long, const bool, at::Tensor&
|
||||
);
|
||||
|
||||
|
||||
// C++函数包装
|
||||
int cuda_forward(const at::Tensor& block_ranges,
|
||||
const at::Tensor& head_masks,
|
||||
const long num_lines,
|
||||
const long max_query,
|
||||
const long max_key,
|
||||
const long num_heads,
|
||||
const long seq_len_1,
|
||||
const long seq_len_2,
|
||||
const bool fill_value,
|
||||
at::Tensor& out) {
|
||||
at::DeviceGuard guard(block_ranges.device());
|
||||
cudaForwardLauncher(block_ranges, head_masks, num_lines, max_query, max_key, num_heads,
|
||||
seq_len_1, seq_len_2, fill_value, out);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// 绑定
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
|
||||
m.def("cuda_forward", &cuda_forward, "cuda_forward");
|
||||
}
|
|
@ -0,0 +1,107 @@
|
|||
//#include <cstdio>
|
||||
#include <torch/torch.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <THC/THCAtomics.cuh>
|
||||
|
||||
|
||||
const long THREADS_PER_BLOCK = 1024;
|
||||
const long MAX_GRID_NUM = 2147483647;
|
||||
|
||||
|
||||
inline long GET_BLOCKS(const long N) {
|
||||
long optimal_block_num = (N + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
|
||||
return long(min(optimal_block_num, MAX_GRID_NUM));
|
||||
}
|
||||
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void cudaForward(
|
||||
const scalar_t * block_ranges,
|
||||
const bool * head_masks,
|
||||
const long input_size,
|
||||
const long num_lines, const long max_query, const long max_key, const long num_heads,
|
||||
const long seq_len_2,
|
||||
const long line_stride, const long query_stride, const long key_stride, const long seq_len_square,
|
||||
const bool fill_value,
|
||||
bool * out) {
|
||||
|
||||
// block_ranges: (num_lines, 4)
|
||||
// head_masks: (num_lines, num_heads)
|
||||
// input_size == num_lines * max_query * max_key * num_heads
|
||||
// line_stride == max_query * max_key * num_heads
|
||||
// query_stride == max_key * num_heads
|
||||
// key_stride == num_heads
|
||||
// out: (num_heads, seq_len_1, seq_len_2)
|
||||
|
||||
const long index = long(blockIdx.x) * long(blockDim.x) + long(threadIdx.x);
|
||||
|
||||
if (index >= input_size) return;
|
||||
|
||||
const long line_index = index / line_stride;
|
||||
assert (line_index < num_lines);
|
||||
const long query_index = (index % line_stride) / query_stride;
|
||||
assert (query_index < max_query);
|
||||
const long key_index = (index % query_stride) / key_stride;
|
||||
assert (key_index < max_key);
|
||||
const long head_index = (index % key_stride);
|
||||
|
||||
const long line_start = line_index * 4;
|
||||
|
||||
const long query_begin = block_ranges[line_start];
|
||||
const long query_end = block_ranges[line_start + 1];
|
||||
const long query_index_end = query_index + query_begin;
|
||||
if (!(query_index_end < query_end)) return;
|
||||
|
||||
const long key_begin = block_ranges[line_start + 2];
|
||||
const long key_end = block_ranges[line_start + 3];
|
||||
const long key_index_end = key_index + key_begin;
|
||||
if (!(key_index_end < key_end)) return;
|
||||
|
||||
if (head_masks[line_index * num_heads + head_index])
|
||||
out[head_index * seq_len_square + query_index_end * seq_len_2 + key_index_end] = fill_value;
|
||||
}
|
||||
|
||||
|
||||
int cudaForwardLauncher(
|
||||
const at::Tensor& block_ranges,
|
||||
const at::Tensor& head_masks,
|
||||
const long num_lines,
|
||||
const long max_query,
|
||||
const long max_key,
|
||||
const long num_heads,
|
||||
const long seq_len_1,
|
||||
const long seq_len_2,
|
||||
const bool fill_value,
|
||||
at::Tensor& out
|
||||
) {
|
||||
const long input_size = num_lines * max_query * max_key * num_heads;
|
||||
assert (input_size <= THREADS_PER_BLOCK * MAX_GRID_NUM);
|
||||
|
||||
const long key_stride = num_heads;
|
||||
const long query_stride = max_key * key_stride;
|
||||
const long line_stride = max_query * query_stride;
|
||||
const long seq_len_square = seq_len_1 * seq_len_2;
|
||||
|
||||
AT_DISPATCH_INTEGRAL_TYPES(
|
||||
block_ranges.type(), "cudaForward",
|
||||
([&] {
|
||||
const scalar_t *block_ranges_ = block_ranges.data_ptr<scalar_t>();
|
||||
const bool *head_masks_ = head_masks.data_ptr<bool>();
|
||||
bool *out_ = out.data_ptr<bool>();
|
||||
|
||||
cudaForward<<<GET_BLOCKS(input_size), THREADS_PER_BLOCK>>>(
|
||||
block_ranges_, head_masks_, input_size,
|
||||
num_lines, max_query, max_key, num_heads, seq_len_2,
|
||||
line_stride, query_stride, key_stride, seq_len_square,
|
||||
fill_value,
|
||||
out_
|
||||
);
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
THCudaCheck(cudaGetLastError());
|
||||
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,101 @@
|
|||
import os
|
||||
import glob
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
block_fill_cuda_module = None
|
||||
|
||||
|
||||
def load_cuda_module():
|
||||
global block_fill_cuda_module
|
||||
|
||||
if block_fill_cuda_module is not None:
|
||||
return block_fill_cuda_module
|
||||
|
||||
path = os.path.dirname(os.path.abspath(__file__))
|
||||
cpps = glob.glob(os.path.join(path, "cuda_src/*.cpp"))
|
||||
cudas = glob.glob(os.path.join(path, "cuda_src/*.cu"))
|
||||
sources = list(cpps) + list(cudas)
|
||||
|
||||
from torch.utils.cpp_extension import load
|
||||
module = load(name='block_fill_cuda',
|
||||
sources=sources,
|
||||
extra_cflags=['-O2'],
|
||||
with_cuda=True,
|
||||
verbose=False)
|
||||
|
||||
block_fill_cuda_module = module
|
||||
|
||||
return block_fill_cuda_module
|
||||
|
||||
|
||||
def block_fill_(out, block_ranges, head_masks, fill_value, no_cuda_kernel=False):
|
||||
"""
|
||||
|
||||
:param block_ranges: (num_blocks, 4)
|
||||
:param head_masks: (num_blocks, num_heads)
|
||||
:param out:
|
||||
:param fill_value:
|
||||
:param no_cuda_kernel:
|
||||
:return:
|
||||
"""
|
||||
if block_ranges.shape[0] == 0:
|
||||
return out
|
||||
|
||||
assert isinstance(fill_value, bool)
|
||||
if block_ranges.is_cuda and not no_cuda_kernel:
|
||||
return block_fill_cuda(out, block_ranges, head_masks, fill_value)
|
||||
else:
|
||||
return block_fill_pytorch(out, block_ranges, head_masks, fill_value)
|
||||
|
||||
|
||||
def _check_input(block_ranges, head_masks, out):
|
||||
assert head_masks.dtype == torch.bool
|
||||
assert block_ranges.device == head_masks.device
|
||||
assert block_ranges.device == out.device
|
||||
num_blocks, col = block_ranges.shape
|
||||
num_blocks_2, num_heads = head_masks.shape
|
||||
assert num_blocks == num_blocks_2
|
||||
assert col == 4
|
||||
assert num_heads > 0
|
||||
return num_blocks, num_heads
|
||||
|
||||
|
||||
def block_fill_cuda(out, block_ranges, head_masks, fill_value):
|
||||
num_blocks, num_heads = _check_input(block_ranges, head_masks, out)
|
||||
|
||||
num_heads_2, seq_len_1, seq_len_2 = out.shape
|
||||
assert num_heads_2 == num_heads
|
||||
assert block_ranges.is_contiguous()
|
||||
assert head_masks.is_contiguous()
|
||||
assert out.is_contiguous()
|
||||
|
||||
max_query = (block_ranges[:, 1] - block_ranges[:, 0]).max().item()
|
||||
max_key = (block_ranges[:, 3] - block_ranges[:, 2]).max().item()
|
||||
|
||||
if max_query <= 0 or max_key <= 0:
|
||||
return out
|
||||
|
||||
module = load_cuda_module()
|
||||
|
||||
module.cuda_forward(block_ranges, head_masks, num_blocks, max_query, max_key, num_heads,
|
||||
seq_len_1, seq_len_2, fill_value, out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def block_fill_pytorch(out, block_ranges, head_masks, fill_value):
|
||||
num_blocks, num_heads = _check_input(block_ranges, head_masks, out)
|
||||
|
||||
assert out.shape[0] == num_heads
|
||||
num_heads_2, _, _ = out.shape
|
||||
assert num_heads_2 == num_heads
|
||||
|
||||
for (query_begin, query_end, key_begin, key_end), line_mask in zip(block_ranges, head_masks):
|
||||
line_head_idx = torch.nonzero(line_mask, as_tuple=False).squeeze(-1) # (num,)
|
||||
out[line_head_idx, query_begin: query_end, key_begin: key_end] = fill_value
|
||||
|
||||
return out
|
|
@ -0,0 +1 @@
|
|||
from .main import range_fill
|
|
@ -0,0 +1,27 @@
|
|||
// #pragma once
|
||||
#include <torch/extension.h>
|
||||
|
||||
|
||||
// CUDA函数声明
|
||||
int cudaForwardLauncher(
|
||||
const at::Tensor&,
|
||||
const at::Tensor&,
|
||||
const long,
|
||||
at::Tensor&
|
||||
);
|
||||
|
||||
|
||||
// C++函数包装
|
||||
int cuda_forward(const at::Tensor& ranges,
|
||||
const at::Tensor& values,
|
||||
const long num_chunks,
|
||||
at::Tensor& output) {
|
||||
at::DeviceGuard guard(ranges.device());
|
||||
cudaForwardLauncher(ranges, values, num_chunks, output);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// 绑定
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
|
||||
m.def("cuda_forward", &cuda_forward, "");
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
#include <torch/torch.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <THC/THCAtomics.cuh>
|
||||
|
||||
|
||||
const long THREADS_PER_BLOCK = 1024;
|
||||
const long MAX_GRID_NUM = 2147483647;
|
||||
|
||||
|
||||
inline long GET_BLOCKS(const long N) {
|
||||
long optimal_block_num = (N + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
|
||||
return long(min(optimal_block_num, MAX_GRID_NUM));
|
||||
}
|
||||
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void cudaForward(
|
||||
const long * ranges,
|
||||
const scalar_t * values,
|
||||
const long input_size,
|
||||
scalar_t * output) {
|
||||
|
||||
const long index = long(blockIdx.x) * long(blockDim.x) + long(threadIdx.x);
|
||||
|
||||
if (index >= input_size) return;
|
||||
|
||||
const long line_start = index * 2;
|
||||
const long begin_idx = ranges[line_start];
|
||||
const long end_idx = ranges[line_start + 1];
|
||||
|
||||
long value = values[index];
|
||||
for (long idx = begin_idx; idx < end_idx; idx++) {
|
||||
output[idx] = value;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
int cudaForwardLauncher(
|
||||
const at::Tensor& ranges,
|
||||
const at::Tensor& values,
|
||||
const long num_chunks,
|
||||
at::Tensor& output
|
||||
) {
|
||||
const long input_size = num_chunks;
|
||||
assert (input_size <= THREADS_PER_BLOCK * MAX_GRID_NUM);
|
||||
|
||||
AT_DISPATCH_INTEGRAL_TYPES(
|
||||
values.type(), "cudaForward",
|
||||
([&] {
|
||||
const long *ranges_ = ranges.data_ptr<long>();
|
||||
const scalar_t *values_ = values.data_ptr<scalar_t>();
|
||||
scalar_t *output_ = output.data_ptr<scalar_t>();
|
||||
|
||||
cudaForward<<<GET_BLOCKS(input_size), THREADS_PER_BLOCK>>>(
|
||||
ranges_, values_, input_size, output_
|
||||
);
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
THCudaCheck(cudaGetLastError());
|
||||
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,81 @@
|
|||
import os
|
||||
import glob
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
range_fill_cuda_module = None
|
||||
|
||||
|
||||
def load_cuda_module():
|
||||
global range_fill_cuda_module
|
||||
|
||||
if range_fill_cuda_module is not None:
|
||||
return range_fill_cuda_module
|
||||
|
||||
path = os.path.dirname(os.path.abspath(__file__))
|
||||
cpps = glob.glob(os.path.join(path, "cuda_src/*.cpp"))
|
||||
cudas = glob.glob(os.path.join(path, "cuda_src/*.cu"))
|
||||
sources = list(cpps) + list(cudas)
|
||||
|
||||
from torch.utils.cpp_extension import load
|
||||
module = load(name='range_fill_cuda',
|
||||
sources=sources,
|
||||
extra_cflags=['-O2'],
|
||||
with_cuda=True,
|
||||
verbose=False)
|
||||
|
||||
range_fill_cuda_module = module
|
||||
|
||||
return range_fill_cuda_module
|
||||
|
||||
|
||||
def range_fill(
|
||||
ranges: torch.Tensor,
|
||||
values: torch.Tensor,
|
||||
seq_len: int,
|
||||
pad_value,
|
||||
dtype=torch.long,
|
||||
no_cuda_kernel: bool = False
|
||||
):
|
||||
"""
|
||||
Fill a set of values into some ranges of 1D vector.
|
||||
Note: The ranges cannot be overlapped, or some unexpected behavior may appear. Current code does not check it.
|
||||
:param ranges: (num_ranges, 2) begin, end
|
||||
:param values: (num_ranges,)
|
||||
:param seq_len:
|
||||
:param pad_value:
|
||||
:param dtype:
|
||||
:param no_cuda_kernel:
|
||||
:return:
|
||||
"""
|
||||
num_ranges, dim = ranges.shape
|
||||
values_shape, = values.shape
|
||||
assert num_ranges == values_shape
|
||||
assert dim == 2
|
||||
assert ranges[:, 0].lt(ranges[:, 1]).all()
|
||||
if ranges.is_cuda and not no_cuda_kernel:
|
||||
return range_fill_cuda(ranges, values, seq_len, pad_value, dtype=dtype)
|
||||
else:
|
||||
return range_fill_pytorch(ranges, values, seq_len, pad_value, dtype=dtype)
|
||||
|
||||
|
||||
def range_fill_cuda(ranges, values, seq_len, pad_value, dtype=torch.long):
|
||||
if dtype not in (torch.long,):
|
||||
raise NotImplementedError
|
||||
num_ranges = ranges.shape[0]
|
||||
out = torch.full((seq_len,), pad_value, dtype=dtype, device=values.device)
|
||||
if num_ranges == 0:
|
||||
return out
|
||||
module = load_cuda_module()
|
||||
module.cuda_forward(ranges, values, num_ranges, out)
|
||||
return out
|
||||
|
||||
|
||||
def range_fill_pytorch(ranges, values, seq_len, pad_value, dtype=torch.long):
|
||||
out = torch.full((seq_len,), pad_value, dtype=dtype, device=values.device)
|
||||
for idx, (begin, end) in enumerate(ranges):
|
||||
out[begin: end] = values[idx]
|
||||
return out
|
|
@ -0,0 +1 @@
|
|||
from .main import segment_arange
|
|
@ -0,0 +1,27 @@
|
|||
// #pragma once
|
||||
#include <torch/extension.h>
|
||||
|
||||
|
||||
// CUDA函数声明
|
||||
int cudaForwardLauncher(
|
||||
const at::Tensor&,
|
||||
const long,
|
||||
const long,
|
||||
at::Tensor&
|
||||
);
|
||||
|
||||
|
||||
// C++函数包装
|
||||
int cuda_forward(const at::Tensor& ranges,
|
||||
const long start,
|
||||
const long num_chunks,
|
||||
at::Tensor& output) {
|
||||
at::DeviceGuard guard(ranges.device());
|
||||
cudaForwardLauncher(ranges, start, num_chunks, output);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// 绑定
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
|
||||
m.def("cuda_forward", &cuda_forward, "");
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
#include <torch/torch.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <THC/THCAtomics.cuh>
|
||||
|
||||
|
||||
const long THREADS_PER_BLOCK = 1024;
|
||||
const long MAX_GRID_NUM = 2147483647;
|
||||
|
||||
|
||||
inline long GET_BLOCKS(const long N) {
|
||||
long optimal_block_num = (N + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
|
||||
return long(min(optimal_block_num, MAX_GRID_NUM));
|
||||
}
|
||||
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void cudaForward(
|
||||
const long * ranges,
|
||||
const long start,
|
||||
const long input_size,
|
||||
scalar_t * output) {
|
||||
|
||||
const long index = long(blockIdx.x) * long(blockDim.x) + long(threadIdx.x);
|
||||
|
||||
if (index >= input_size) return;
|
||||
|
||||
const long line_start = index * 2;
|
||||
const long begin_idx = ranges[line_start];
|
||||
const long end_idx = ranges[line_start + 1];
|
||||
|
||||
long cur = 0;
|
||||
for (long idx = begin_idx; idx < end_idx; idx++) {
|
||||
output[idx] = cur + start;
|
||||
cur++;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
int cudaForwardLauncher(
|
||||
const at::Tensor& ranges,
|
||||
const long start,
|
||||
const long num_chunks,
|
||||
at::Tensor& output
|
||||
) {
|
||||
const long input_size = num_chunks;
|
||||
assert (input_size <= THREADS_PER_BLOCK * MAX_GRID_NUM);
|
||||
|
||||
AT_DISPATCH_INTEGRAL_TYPES(
|
||||
ranges.type(), "cudaForward",
|
||||
([&] {
|
||||
const long *ranges_ = ranges.data_ptr<long>();
|
||||
scalar_t *output_ = output.data_ptr<scalar_t>();
|
||||
|
||||
cudaForward<<<GET_BLOCKS(input_size), THREADS_PER_BLOCK>>>(
|
||||
ranges_, start, input_size, output_
|
||||
);
|
||||
}
|
||||
)
|
||||
);
|
||||
|
||||
THCudaCheck(cudaGetLastError());
|
||||
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1,80 @@
|
|||
import os
|
||||
import glob
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
segment_arange_cuda_module = None
|
||||
|
||||
|
||||
def load_cuda_module():
|
||||
global segment_arange_cuda_module
|
||||
|
||||
if segment_arange_cuda_module is not None:
|
||||
return segment_arange_cuda_module
|
||||
|
||||
path = os.path.dirname(os.path.abspath(__file__))
|
||||
cpps = glob.glob(os.path.join(path, "cuda_src/*.cpp"))
|
||||
cudas = glob.glob(os.path.join(path, "cuda_src/*.cu"))
|
||||
sources = list(cpps) + list(cudas)
|
||||
|
||||
from torch.utils.cpp_extension import load
|
||||
module = load(name='segment_arange_cuda',
|
||||
sources=sources,
|
||||
extra_cflags=['-O2'],
|
||||
with_cuda=True,
|
||||
verbose=False)
|
||||
|
||||
segment_arange_cuda_module = module
|
||||
|
||||
return segment_arange_cuda_module
|
||||
|
||||
|
||||
def segment_arange(
|
||||
ranges: torch.Tensor,
|
||||
start: int,
|
||||
seq_len: int,
|
||||
pad_value,
|
||||
dtype=torch.long,
|
||||
no_cuda_kernel: bool = False
|
||||
):
|
||||
"""
|
||||
Fill a set of values into some ranges of 1D vector.
|
||||
Note: The ranges cannot be overlapped, or some unexpected behavior may appear. Current code does not check it.
|
||||
:param ranges: (num_ranges, 2) begin, end
|
||||
:param start: int
|
||||
:param seq_len:
|
||||
:param pad_value:
|
||||
:param dtype:
|
||||
:param no_cuda_kernel:
|
||||
:return:
|
||||
"""
|
||||
# Todo: Verify the effect of the segment_arange kernel.
|
||||
num_ranges, dim = ranges.shape
|
||||
assert dim == 2
|
||||
assert ranges[:, 0].le(ranges[:, 1]).all()
|
||||
if ranges.is_cuda and not no_cuda_kernel:
|
||||
return segment_arange_cuda(ranges, start, seq_len, pad_value, dtype=dtype)
|
||||
else:
|
||||
return segment_arange_pytorch(ranges, start, seq_len, pad_value, dtype=dtype)
|
||||
|
||||
|
||||
def segment_arange_cuda(ranges, start, seq_len, pad_value, dtype=torch.long):
|
||||
if dtype not in (torch.long,):
|
||||
raise NotImplementedError
|
||||
num_ranges = ranges.shape[0]
|
||||
out = torch.full((seq_len,), pad_value, dtype=dtype, device=ranges.device)
|
||||
if num_ranges == 0:
|
||||
return out
|
||||
module = load_cuda_module()
|
||||
module.cuda_forward(ranges, start, num_ranges, out)
|
||||
return out
|
||||
|
||||
|
||||
def segment_arange_pytorch(ranges, start, seq_len, pad_value, dtype=torch.long):
|
||||
out = torch.full((seq_len,), pad_value, dtype=dtype, device=ranges.device)
|
||||
for idx, (begin, end) in enumerate(ranges):
|
||||
out[begin: end] = torch.arange(start, start + (end - begin), dtype=dtype, device=ranges.device)
|
||||
return out
|
|
@ -0,0 +1,794 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from fairseq import utils
|
||||
from fairseq.models import FairseqDecoder
|
||||
from fairseq.modules.fairseq_dropout import FairseqDropout
|
||||
from fairseq.modules.layer_norm import LayerNorm
|
||||
|
||||
from .tools import instant_info_construction as iic
|
||||
from .tools import arg_tools, computation_tools
|
||||
from .data_structures.four_dim_pocket import FourDimPocket
|
||||
from .kernels.range_fill import range_fill
|
||||
from .kernels.segment_arange import segment_arange
|
||||
from .museformer_decoder_layer import MuseformerDecoderLayer
|
||||
from .data_structures.attention_scheme import AttentionScheme
|
||||
from .embedding.tao_embedding import TaoEmbedding
|
||||
from .datasets.add_beat_dataset import get_beat_ids
|
||||
|
||||
|
||||
def construct_reg_bar_ids(chunk_ranges, num_chunks, reg_len):
|
||||
device = chunk_ranges.device
|
||||
bar_ids = []
|
||||
for sample_ranges, sample_num_chunk in zip(chunk_ranges, num_chunks):
|
||||
if sample_num_chunk == 0:
|
||||
sample_bar_ids = torch.zeros(reg_len, dtype=torch.long, device=device)
|
||||
else:
|
||||
sample_bar_ids = range_fill(sample_ranges, torch.arange(1, sample_num_chunk + 1, device=device),
|
||||
reg_len, pad_value=0)
|
||||
bar_ids.append(sample_bar_ids)
|
||||
bar_ids = torch.stack(bar_ids, dim=0)
|
||||
return bar_ids # (bsz, reg_len)
|
||||
|
||||
|
||||
def construct_reg_token_in_chunk_ids(chunk_ranges, num_chunks, reg_len):
|
||||
device = chunk_ranges.device
|
||||
ids = []
|
||||
for sample_ranges, sample_num_chunk in zip(chunk_ranges, num_chunks):
|
||||
if sample_num_chunk == 0:
|
||||
sample_ids = torch.zeros(reg_len, dtype=torch.long, device=device)
|
||||
else:
|
||||
sample_ranges = sample_ranges[:sample_num_chunk]
|
||||
sample_ids = segment_arange(sample_ranges, 1, reg_len, 0, dtype=torch.long, no_cuda_kernel=False)
|
||||
ids.append(sample_ids)
|
||||
ids = torch.stack(ids, dim=0)
|
||||
return ids
|
||||
|
||||
|
||||
class MuseformerDecoder(FairseqDecoder):
|
||||
_submodules = (MuseformerDecoderLayer,)
|
||||
|
||||
@classmethod
|
||||
def add_args(cls, parser):
|
||||
# === Implementation ===
|
||||
parser.add_argument('--attention-impl', choices=('mask', 'triton', 'sparta'))
|
||||
parser.add_argument('--block-size', type=int, choices=(64, 32))
|
||||
parser.add_argument('--attention-mode', choices=('v2s1',))
|
||||
|
||||
# === Transformer ===
|
||||
parser.add_argument('--attention-embed-dim', type=int)
|
||||
parser.add_argument('--num-layers', type=int)
|
||||
parser.add_argument('--num-attention-heads', type=eval)
|
||||
parser.add_argument('--normalize-before', type=arg_tools.str_bool_with_default_error)
|
||||
parser.add_argument('--ffn-embed-dim', type=int)
|
||||
parser.add_argument('--dropout', type=float)
|
||||
parser.add_argument('--attention-dropout', type=float)
|
||||
parser.add_argument('--activation-fn', choices=utils.get_available_activation_fns())
|
||||
parser.add_argument('--no-final-norm', type=arg_tools.str_bool_with_default_error)
|
||||
|
||||
# === Attention Scheme ===
|
||||
parser.add_argument('--con2con', type=eval)
|
||||
parser.add_argument('--con2con-self', choices=('default', 'none', 'full'))
|
||||
parser.add_argument('--con2sum', type=eval)
|
||||
parser.add_argument('--con2sum-self', choices=('default', 'none', 'full'))
|
||||
parser.add_argument('--sum2con', type=eval)
|
||||
parser.add_argument('--sum2con-self', choices=('default', 'none', 'full'))
|
||||
parser.add_argument('--sum2sum', type=eval)
|
||||
parser.add_argument('--sum2sum-self', choices=('default', 'none', 'full'))
|
||||
parser.add_argument('--con2con-causal', type=arg_tools.str_bool_with_default_error)
|
||||
parser.add_argument('--pref2pref-mode', choices=('full', 'lt', 'none'))
|
||||
parser.add_argument('--sum2pref-mode', choices=('none',))
|
||||
parser.add_argument('--pref2sum-mode', choices=('none',))
|
||||
parser.add_argument('--con2pref-mode', choices=('full',))
|
||||
parser.add_argument('--pref2con-mode', choices=('none',))
|
||||
|
||||
# === Summary Tokens ===
|
||||
parser.add_argument('--num-summary-tokens-each-chunk', type=int)
|
||||
parser.add_argument('--max-summary-tokens', type=int)
|
||||
|
||||
# === Embedding ===
|
||||
parser.add_argument('--concat-sum-embeddings', type=arg_tools.str_bool_with_default_error)
|
||||
parser.add_argument('--concat-reg-embeddings', type=arg_tools.str_bool_with_default_error)
|
||||
parser.add_argument('--proj-sum-embeddings', type=arg_tools.str_bool_with_default_error)
|
||||
parser.add_argument('--proj-reg-embeddings', type=arg_tools.str_bool_with_default_error)
|
||||
parser.add_argument('--share-embedding-projection', type=arg_tools.str_bool_with_default_error)
|
||||
|
||||
parser.add_argument('--share-input-output-embed', type=arg_tools.str_bool_with_default_error)
|
||||
parser.add_argument('--no-scale-embedding', type=arg_tools.str_bool_with_default_error)
|
||||
parser.add_argument('--layernorm-embedding', type=arg_tools.str_bool_with_default_error)
|
||||
parser.add_argument('--share-layernorm-embedding', type=arg_tools.str_bool_with_default_error)
|
||||
|
||||
# token embedding
|
||||
parser.add_argument('--token-embed-dim', type=int)
|
||||
|
||||
# absolute token in-chunk-position embedding
|
||||
parser.add_argument('--use-token-in-chunk-abs-pos', type=arg_tools.str_bool_with_default_error)
|
||||
parser.add_argument('--learned-token-in-chunk-abs-pos', type=arg_tools.str_bool_with_default_error)
|
||||
parser.add_argument('--max-token-in-chunk-abs-pos', type=int)
|
||||
parser.add_argument('--token-in-chunk-abs-pos-embed-dim', type=int)
|
||||
|
||||
# absolute token position embedding
|
||||
parser.add_argument('--use-token-abs-pos', type=arg_tools.str_bool_with_default_error)
|
||||
parser.add_argument('--learned-token-abs-pos', type=arg_tools.str_bool_with_default_error)
|
||||
parser.add_argument('--max-token-abs-pos', type=int)
|
||||
parser.add_argument('--token-abs-pos-embed-dim', type=int)
|
||||
|
||||
# absolute bar position embedding
|
||||
parser.add_argument('--use-bar-abs-pos', type=arg_tools.str_bool_with_default_error)
|
||||
parser.add_argument('--learned-bar-abs-pos', type=arg_tools.str_bool_with_default_error)
|
||||
parser.add_argument('--max-bar-abs-pos', type=int)
|
||||
parser.add_argument('--bar-abs-pos-embed-dim', type=int)
|
||||
parser.add_argument('--valid-parts-for-bar-abs-pos', type=arg_tools.comma_split_tuple_of_specific_type(str))
|
||||
|
||||
# absolute beat position embedding
|
||||
parser.add_argument('--use-beat-abs-pos', type=arg_tools.str_bool_with_default_error)
|
||||
parser.add_argument('--learned-beat-abs-pos', type=arg_tools.str_bool_with_default_error)
|
||||
parser.add_argument('--max-beat-abs-pos', type=int)
|
||||
parser.add_argument('--beat-abs-pos-embed-dim', type=int)
|
||||
parser.add_argument('--beat-abs-pos-padding-on-sum', type=arg_tools.str_bool_with_default_error)
|
||||
|
||||
# === Gradient Checkpointing ===
|
||||
parser.add_argument('--gradient-checkpointing', type=arg_tools.str_bool_with_default_error)
|
||||
parser.add_argument('--gradient-checkpointing-every-n-layer', type=int)
|
||||
parser.add_argument('--gradient-checkpointing-layers',
|
||||
type=lambda x: tuple([int(item) for item in x.split(',')]))
|
||||
|
||||
arg_tools.add_submodule_args(cls, parser)
|
||||
|
||||
def __init__(self, args, dictionary):
|
||||
super().__init__(dictionary)
|
||||
self.args = args
|
||||
|
||||
# === Meta Preparations ===
|
||||
# pocket to store things that can be used everywhere
|
||||
self.pocket = FourDimPocket()
|
||||
self.pocket['constant'] = self.pocket_constant = {}
|
||||
self.pocket['instant'] = self.pocket_instant = {}
|
||||
|
||||
# --- control parameters ---
|
||||
self.attention_impl = self.args.attention_impl
|
||||
self.block_size = getattr(self.args, "block_size", None)
|
||||
if self.attention_impl == 'triton':
|
||||
if self.block_size is None:
|
||||
self.block_size = 64
|
||||
self.attention_mode = getattr(self.args, "attention_mode", "simu")
|
||||
self.attn_mask_gen_parts = None # which parts are involved for attn_mask
|
||||
if self.attention_mode == 'v2s1':
|
||||
self.attn_mask_combination = {'sx': ('ss', 'sr'), 'rx': ('rs', 'rr')}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.need_embedding_summary_tokens = True
|
||||
|
||||
self.pocket_constant['attention_impl'] = self.attention_impl
|
||||
self.pocket_constant['block_size'] = self.block_size
|
||||
self.pocket_constant['attention_mode'] = self.attention_mode
|
||||
self.pocket_constant['attn_mask_combination'] = self.attn_mask_combination
|
||||
self.pocket_constant['attn_mask_gen_parts'] = self.attn_mask_gen_parts
|
||||
self.pocket_constant['need_embedding_summary_tokens'] = self.need_embedding_summary_tokens
|
||||
|
||||
# === Basic Parameters ===
|
||||
self.embed_dim = self.args.attention_embed_dim
|
||||
self.num_layers = self.args.num_layers
|
||||
self.num_attention_heads = arg_tools.possibly_extend_tuple(
|
||||
self.args.num_attention_heads, self.num_layers
|
||||
) # tuple of size num_layers
|
||||
self.valid_dictionary_len = getattr(self.args, 'valid_dict_len', len(dictionary))
|
||||
|
||||
# === Embedding ===
|
||||
self.embed_scale = None if args.no_scale_embedding else math.sqrt(self.args.token_embed_dim)
|
||||
|
||||
# --- Regular Token Embedding ---
|
||||
self.embed_regular_tokens = nn.Embedding(
|
||||
self.valid_dictionary_len, self.args.token_embed_dim, self.dictionary.pad()
|
||||
)
|
||||
# --- Summary Token Embedding ---
|
||||
self.num_summary = self.args.num_summary_tokens_each_chunk
|
||||
assert self.num_summary >= 0
|
||||
if self.num_summary == 0 or not self.need_embedding_summary_tokens:
|
||||
self.embed_summary_tokens = None
|
||||
else:
|
||||
self.embed_summary_tokens = nn.Embedding(
|
||||
getattr(self.args, 'max_summary_tokens', self.num_summary) + 1,
|
||||
self.args.token_embed_dim,
|
||||
padding_idx=0
|
||||
)
|
||||
|
||||
# --- Absolute Token In-Chunk Absolute Position Embedding ---
|
||||
self.use_token_in_chunk_abs_pos = getattr(self.args, 'use_token_in_chunk_abs_pos', False)
|
||||
if self.use_token_in_chunk_abs_pos:
|
||||
token_in_chunk_abs_pos_embed_dim = getattr(self.args, "token_in_chunk_abs_pos_embed_dim",
|
||||
self.args.token_embed_dim)
|
||||
self.embed_token_in_chunk_abs_pos = TaoEmbedding(
|
||||
self.args.max_token_in_chunk_abs_pos + 1, token_in_chunk_abs_pos_embed_dim, padding_idx=0,
|
||||
learned=getattr(self.args, 'learned_token_in_chunk_abs_pos', False)
|
||||
)
|
||||
else:
|
||||
self.embed_token_in_chunk_abs_pos = None
|
||||
token_in_chunk_abs_pos_embed_dim = 0
|
||||
|
||||
# --- Absolute Token Position Embedding ---
|
||||
self.use_token_abs_pos = getattr(self.args, 'use_token_abs_pos', False)
|
||||
if self.use_token_abs_pos:
|
||||
token_abs_pos_embed_dim = getattr(self.args, "token_abs_pos_embed_dim", self.args.token_embed_dim)
|
||||
if getattr(self.args, 'learned_token_abs_pos', False):
|
||||
raise NotImplementedError("The support for overly long token absolute position embedding is not done.")
|
||||
self.embed_token_abs_pos = TaoEmbedding(
|
||||
self.args.max_token_abs_pos, token_abs_pos_embed_dim, padding_idx=0,
|
||||
learned=getattr(self.args, 'learned_token_abs_pos', False)
|
||||
)
|
||||
else:
|
||||
self.embed_token_abs_pos = None
|
||||
token_abs_pos_embed_dim = 0
|
||||
|
||||
# --- Absolute Bar Position Embedding ---
|
||||
self.use_bar_abs_pos = getattr(self.args, 'use_bar_abs_pos', False)
|
||||
if self.use_bar_abs_pos:
|
||||
bar_abs_pos_embed_dim = getattr(self.args, "bar_abs_pos_embed_dim", self.args.token_embed_dim)
|
||||
# if getattr(self.args, 'learned_bar_abs_pos', False):
|
||||
# raise NotImplementedError("The support for overly long bar absolute position embedding is not done.")
|
||||
self.embed_bar_abs_pos = TaoEmbedding(
|
||||
self.args.max_bar_abs_pos + 1, bar_abs_pos_embed_dim, padding_idx=0,
|
||||
learned=getattr(self.args, 'learned_bar_abs_pos', False)
|
||||
)
|
||||
self.valid_parts_for_bar_abs_pos = getattr(self.args, 'valid_parts_for_bar_abs_pos', None)
|
||||
else:
|
||||
self.embed_bar_abs_pos = None
|
||||
bar_abs_pos_embed_dim = 0
|
||||
self.valid_parts_for_bar_abs_pos = None
|
||||
|
||||
# --- Absolute Beat Position Embedding ---
|
||||
self.use_beat_abs_pos = getattr(self.args, 'use_beat_abs_pos', False)
|
||||
if self.use_beat_abs_pos:
|
||||
beat_abs_pos_embed_dim = getattr(self.args, "beat_abs_pos_embed_dim", self.args.token_embed_dim)
|
||||
# if getattr(self.args, 'learned_beat_abs_pos', False):
|
||||
# raise NotImplementedError("The support for overly long beat absolute position embedding is not done.")
|
||||
self.embed_beat_abs_pos = TaoEmbedding(
|
||||
self.args.max_beat_abs_pos + 1, beat_abs_pos_embed_dim, padding_idx=0,
|
||||
learned=getattr(self.args, 'learned_beat_abs_pos', False)
|
||||
)
|
||||
self.valid_parts_for_beat_abs_pos = getattr(self.args, 'valid_parts_for_beat_abs_pos', None)
|
||||
else:
|
||||
self.embed_beat_abs_pos = None
|
||||
beat_abs_pos_embed_dim = 0
|
||||
self.valid_parts_for_beat_abs_pos = None
|
||||
|
||||
# --- Conclude Embedding ---
|
||||
self.sum_proj_embeddings = None
|
||||
self.reg_proj_embeddings = None
|
||||
self.concat_reg_embeddings = getattr(self.args, "concat_reg_embeddings", False)
|
||||
if self.concat_reg_embeddings:
|
||||
if getattr(self.args, 'proj_reg_embeddings', False):
|
||||
self.reg_proj_embeddings = nn.Linear(
|
||||
self.args.token_embed_dim + token_in_chunk_abs_pos_embed_dim +
|
||||
token_abs_pos_embed_dim +
|
||||
(
|
||||
bar_abs_pos_embed_dim
|
||||
if (self.valid_parts_for_bar_abs_pos is None
|
||||
or 'r' in self.valid_parts_for_bar_abs_pos)
|
||||
else 0
|
||||
) + beat_abs_pos_embed_dim,
|
||||
self.embed_dim
|
||||
)
|
||||
self.concat_sum_embeddings = False
|
||||
self.beat_abs_pos_padding_on_sum = getattr(self.args, 'beat_abs_pos_padding_on_sum', False)
|
||||
if self.num_summary > 0 and self.need_embedding_summary_tokens:
|
||||
self.concat_sum_embeddings = getattr(self.args, "concat_sum_embeddings", False)
|
||||
if getattr(self.args, 'proj_sum_embeddings', False):
|
||||
if getattr(self.args, 'share_embedding_projection', False):
|
||||
self.sum_proj_embeddings = self.reg_proj_embeddings
|
||||
else:
|
||||
self.sum_proj_embeddings = nn.Linear(
|
||||
self.args.token_embed_dim + (
|
||||
bar_abs_pos_embed_dim
|
||||
if (self.valid_parts_for_bar_abs_pos is None
|
||||
or 's' in self.valid_parts_for_bar_abs_pos)
|
||||
else 0
|
||||
) +
|
||||
(
|
||||
beat_abs_pos_embed_dim if self.beat_abs_pos_padding_on_sum else 0
|
||||
),
|
||||
self.embed_dim
|
||||
)
|
||||
|
||||
self.reg_layernorm_embedding = None
|
||||
self.sum_layernorm_embedding = None
|
||||
if getattr(args, "layernorm_embedding", False):
|
||||
self.reg_layernorm_embedding = LayerNorm(self.embed_dim)
|
||||
if self.num_summary > 0 and self.need_embedding_summary_tokens:
|
||||
if getattr(self.args, "share_layernorm_embedding", False):
|
||||
self.sum_layernorm_embedding = self.reg_layernorm_embedding
|
||||
else:
|
||||
self.sum_layernorm_embedding = LayerNorm(self.embed_dim)
|
||||
|
||||
# === Attention Scheme ===
|
||||
self.attn_scheme = AttentionScheme(
|
||||
self.args.sum2sum, self.args.sum2sum_self,
|
||||
self.args.sum2con, self.args.sum2con_self,
|
||||
self.args.con2sum, self.args.con2sum_self,
|
||||
self.args.con2con, self.args.con2con_self,
|
||||
self.args.con2con_causal,
|
||||
self.args.pref2pref_mode,
|
||||
self.args.sum2pref_mode,
|
||||
self.args.pref2sum_mode,
|
||||
self.args.con2pref_mode,
|
||||
self.args.pref2con_mode,
|
||||
self.num_layers, self.num_attention_heads
|
||||
)
|
||||
self.layer_to_sv = {}
|
||||
self.sv_to_layers = {}
|
||||
attn_scheme_set = set(self.attn_scheme)
|
||||
for layer_idx, layer_scheme in enumerate(self.attn_scheme):
|
||||
for sv, unique_scheme in enumerate(attn_scheme_set):
|
||||
if layer_scheme == unique_scheme:
|
||||
self.layer_to_sv[layer_idx] = sv
|
||||
if sv not in self.sv_to_layers:
|
||||
self.sv_to_layers[sv] = []
|
||||
self.sv_to_layers[sv].append(layer_idx)
|
||||
for sv in self.sv_to_layers:
|
||||
self.sv_to_layers[sv] = set(self.sv_to_layers[sv])
|
||||
self.pocket_constant['layer_to_sv'] = self.layer_to_sv
|
||||
self.pocket_constant['sv_to_layers'] = self.sv_to_layers
|
||||
|
||||
# === Transformer Blocks ===
|
||||
self.dropout_module = FairseqDropout(
|
||||
self.args.dropout, module_name=self.__class__.__name__
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
self.layers.extend(
|
||||
[
|
||||
self.build_decoder_layer(
|
||||
args, layer_idx, self.num_attention_heads[layer_idx],
|
||||
self.attn_scheme[layer_idx],
|
||||
args.attention_dropout
|
||||
) for layer_idx in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
if args.normalize_before and not getattr(args, "no_final_norm", False):
|
||||
self.layer_norm = LayerNorm(self.embed_dim)
|
||||
else:
|
||||
self.layer_norm = None
|
||||
|
||||
# === Output ===
|
||||
if self.args.share_input_output_embed:
|
||||
self.output_projection = nn.Linear(
|
||||
self.embed_regular_tokens.weight.shape[1],
|
||||
self.embed_regular_tokens.weight.shape[0],
|
||||
bias=False,
|
||||
)
|
||||
self.output_projection.weight = self.embed_regular_tokens.weight
|
||||
else:
|
||||
self.output_projection = nn.Linear(
|
||||
self.embed_dim, self.valid_dictionary_len, bias=False
|
||||
)
|
||||
nn.init.normal_(
|
||||
self.output_projection.weight, mean=0, std=self.embed_dim ** -0.5
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = getattr(self.args, 'gradient_checkpointing', False)
|
||||
if self.gradient_checkpointing:
|
||||
checkpointing_layers = getattr(self.args, 'gradient_checkpointing_layers', None)
|
||||
if checkpointing_layers is None:
|
||||
gradient_checkpointing_every_n_layer = getattr(self.args, 'gradient_checkpointing_every_n_layer', 1)
|
||||
checkpointing_layers = tuple(range(0, self.num_layers, gradient_checkpointing_every_n_layer))
|
||||
self.checkpointing_layers = checkpointing_layers
|
||||
|
||||
def build_decoder_layer(
|
||||
self,
|
||||
args,
|
||||
layer_idx,
|
||||
num_attention_heads,
|
||||
layer_attention_scheme,
|
||||
attention_dropout,
|
||||
**kwargs
|
||||
):
|
||||
return MuseformerDecoderLayer(
|
||||
args,
|
||||
layer_idx,
|
||||
num_attention_heads,
|
||||
layer_attention_scheme,
|
||||
attention_dropout,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src_tokens, # (batch, reg_len)
|
||||
src_lengths=None,
|
||||
chunk_points=None, # (batch, max_chunk + 1)
|
||||
num_chunks=None, # (batch,)
|
||||
num_complete_chunks=None, # (batch,)
|
||||
num_pref=None, # (batch,)
|
||||
beat_ids=None, # (batch, reg_len)
|
||||
|
||||
last_state_only=False,
|
||||
features_only=False,
|
||||
|
||||
**kwargs
|
||||
):
|
||||
# print('Beg', src_tokens.shape)
|
||||
x, extra = self.extract_features(
|
||||
src_tokens,
|
||||
src_lengths=src_lengths,
|
||||
reg_chunk_points=chunk_points,
|
||||
num_chunks=num_chunks,
|
||||
num_complete_chunks=num_complete_chunks,
|
||||
num_pref=num_pref,
|
||||
beat_ids=beat_ids,
|
||||
last_state_only=last_state_only,
|
||||
**kwargs
|
||||
)
|
||||
if not features_only:
|
||||
x = self.output_layer(x)
|
||||
|
||||
return x, extra
|
||||
|
||||
def extract_features(
|
||||
self,
|
||||
reg_tokens, # (batch, reg_len)
|
||||
src_lengths=None,
|
||||
reg_chunk_points=None, # (batch, max_chunk + 1)
|
||||
num_chunks=None, # (batch,)
|
||||
num_complete_chunks=None, # (batch,)
|
||||
num_pref=None, # (batch,)
|
||||
beat_ids=None,
|
||||
|
||||
last_state_only=False,
|
||||
|
||||
**kwargs
|
||||
):
|
||||
self.pocket_instant.clear()
|
||||
|
||||
bsz, reg_len = reg_tokens.shape
|
||||
device = reg_tokens.device
|
||||
|
||||
# ===== Instant Info Construction ===
|
||||
# Construct the missing input. Only for inference.
|
||||
if not self.training:
|
||||
if all([item is None for item in (reg_chunk_points, num_chunks, num_complete_chunks, num_pref)]):
|
||||
if src_lengths is None:
|
||||
src_lengths = reg_tokens.ne(self.dictionary.pad()).sum(dim=-1)
|
||||
reg_chunk_points, num_chunks, num_complete_chunks, num_pref = iic.construct_bar_chunk_info(
|
||||
reg_tokens, src_lengths,
|
||||
self.dictionary.index(getattr(self.args, 'eob_token', 'b-1')),
|
||||
begin_idx=1,
|
||||
only_bos_prefix=True,
|
||||
device=device
|
||||
)
|
||||
if getattr(self.args, 'take_bos_as_bar', False):
|
||||
assert reg_chunk_points.shape[0] == 1
|
||||
reg_chunk_points = torch.cat(
|
||||
(reg_chunk_points.new_tensor([[0]]), reg_chunk_points), dim=-1
|
||||
)
|
||||
num_pref = num_pref.new_zeros(1,)
|
||||
num_chunks = num_chunks + 1
|
||||
num_complete_chunks = num_complete_chunks + 1
|
||||
|
||||
if getattr(self.args, 'use_beat_abs_pos', False) and beat_ids is None:
|
||||
beat_ids = get_beat_ids(
|
||||
reg_tokens[0], self.dictionary,
|
||||
ts_instead_of_tempo=getattr(self.args, 'beat_mask_ts', False)
|
||||
)[None]
|
||||
|
||||
# ===== Input Checking ===
|
||||
if num_complete_chunks is None:
|
||||
num_complete_chunks = num_chunks # (bsz,)
|
||||
else:
|
||||
temp_check = num_chunks - num_complete_chunks
|
||||
assert (temp_check.eq(0) | temp_check.eq(1)).all()
|
||||
del temp_check
|
||||
max_chunk = int(num_chunks.max())
|
||||
assert reg_chunk_points.shape == (bsz, max_chunk + 1)
|
||||
|
||||
reg_chunk_ranges = computation_tools.transfer_chunk_points_to_ranges(reg_chunk_points) # (bsz, max_chunk, 2)
|
||||
assert reg_chunk_ranges.shape == (bsz, max_chunk, 2)
|
||||
del reg_chunk_points
|
||||
|
||||
# ===== Summary Sequence and Embeddings =====
|
||||
max_comp_chunk = int(num_complete_chunks.max())
|
||||
sum_len = max_comp_chunk * self.num_summary
|
||||
|
||||
# ===== Padding for Blocksparse Computation =====
|
||||
sum_pad_len = 0
|
||||
reg_pad_len = 0
|
||||
real_sum_len = sum_len
|
||||
real_reg_len = reg_len
|
||||
if self.block_size is not None:
|
||||
if sum_len > 0:
|
||||
sum_pad_len = self.block_size - sum_len % self.block_size
|
||||
if sum_pad_len == self.block_size:
|
||||
sum_pad_len = 0
|
||||
if sum_pad_len > 0:
|
||||
sum_len = real_sum_len + sum_pad_len
|
||||
reg_pad_len = self.block_size - reg_len % self.block_size
|
||||
if reg_pad_len == self.block_size:
|
||||
reg_pad_len = 0
|
||||
if reg_pad_len > 0:
|
||||
reg_len = real_reg_len + reg_pad_len
|
||||
|
||||
# ===== Embedding Layer =====
|
||||
# --- Summary and Regular Token Embeddings ---
|
||||
sum_x = None
|
||||
sum_key_padding_mask = None
|
||||
sum_token_ids = None
|
||||
if self.num_summary > 0:
|
||||
sum_key_padding_mask = torch.arange(max_comp_chunk, device=device)[None].expand(
|
||||
bsz, -1).ge(num_complete_chunks[:, None])[:, :, None].expand(-1, -1, self.num_summary).reshape(
|
||||
bsz, real_sum_len
|
||||
) # (bsz, real_sum_len)
|
||||
if sum_pad_len > 0:
|
||||
sum_key_padding_mask = torch.cat(
|
||||
(sum_key_padding_mask, sum_key_padding_mask.new_ones(bsz, sum_pad_len)), dim=1
|
||||
) # (bsz, sum_len)
|
||||
sum_token_ids = torch.arange(1, self.num_summary + 1, device=device)[None, None].repeat(
|
||||
bsz, max_comp_chunk, 1
|
||||
).reshape(bsz, real_sum_len)
|
||||
if sum_pad_len > 0:
|
||||
sum_token_ids = torch.cat(
|
||||
(sum_token_ids, sum_token_ids.new_zeros(bsz, sum_pad_len)), dim=1
|
||||
)
|
||||
sum_token_ids.masked_fill_(sum_key_padding_mask, 0)
|
||||
if self.need_embedding_summary_tokens:
|
||||
sum_x = self.embed_summary_tokens(sum_token_ids.transpose(0, 1))
|
||||
if self.embed_scale is not None:
|
||||
sum_x.mul_(self.embed_scale)
|
||||
self.pocket_instant['sum_token_ids'] = sum_token_ids
|
||||
if reg_pad_len > 0:
|
||||
reg_tokens = torch.cat((reg_tokens, reg_tokens.new_full((bsz, reg_pad_len), self.dictionary.pad())), dim=1)
|
||||
beat_ids = torch.cat((beat_ids, beat_ids.new_full((bsz, reg_pad_len), 0)), dim=1)
|
||||
reg_key_padding_mask = reg_tokens.eq(self.embed_regular_tokens.padding_idx) # (bsz, reg_len)
|
||||
reg_x = self.embed_regular_tokens(reg_tokens.transpose(0, 1)) # (reg_len, bsz, token_embed_dim)
|
||||
del reg_tokens
|
||||
if self.embed_scale is not None:
|
||||
reg_x.mul_(self.embed_scale)
|
||||
|
||||
# --- Absolute Token In-Chunk Position Embedding ---
|
||||
if self.embed_token_in_chunk_abs_pos is not None:
|
||||
token_in_chunk_abs_pos = construct_reg_token_in_chunk_ids(reg_chunk_ranges, num_chunks, reg_len)
|
||||
token_in_chunk_abs_pos_embed = self.embed_token_in_chunk_abs_pos(token_in_chunk_abs_pos).transpose(0, 1) \
|
||||
# (reg_len, bsz, dim)
|
||||
del token_in_chunk_abs_pos
|
||||
else:
|
||||
token_in_chunk_abs_pos_embed = None
|
||||
|
||||
# --- Absolute Beat Position Embedding ---
|
||||
reg_beat_abs_pos_embed = None
|
||||
sum_beat_abs_pos_embed = None
|
||||
if self.embed_beat_abs_pos is not None:
|
||||
reg_beat_abs_pos_embed = self.embed_beat_abs_pos(beat_ids.transpose(0, 1)) # (l, bsz, dim)
|
||||
sum_beat_abs_pos_embed = None
|
||||
if sum_x is not None and self.beat_abs_pos_padding_on_sum:
|
||||
sum_beat_abs_pos_embed = reg_x.new_zeros(sum_x.shape[0], bsz, self.embed_beat_abs_pos.embedding_dim)
|
||||
del beat_ids
|
||||
|
||||
# --- Absolute Token Position Embedding ---
|
||||
token_abs_pos_ids = None
|
||||
if self.use_token_abs_pos:
|
||||
token_abs_pos_ids = torch.arange(1, reg_len + 1, device=device)[None].repeat(bsz, 1)
|
||||
token_abs_pos_ids.masked_fill_(reg_key_padding_mask, 0)
|
||||
if self.embed_token_abs_pos is not None:
|
||||
token_abs_pos_embed = self.embed_token_abs_pos(token_abs_pos_ids.transpose(0, 1))
|
||||
else:
|
||||
token_abs_pos_embed = None
|
||||
del token_abs_pos_ids
|
||||
|
||||
# --- Absolute Bar Position Embedding ---
|
||||
sum_bar_pos_ids = None
|
||||
reg_bar_pos_ids = None
|
||||
if self.use_bar_abs_pos:
|
||||
reg_bar_pos_ids = construct_reg_bar_ids(reg_chunk_ranges, num_chunks, reg_len) # (bsz, reg_len)
|
||||
if self.num_summary > 0:
|
||||
sum_bar_pos_ids = torch.arange(1, max_comp_chunk + 1, device=device) # (max_comp_chunk,)
|
||||
sum_bar_pos_ids = sum_bar_pos_ids[None, :, None].expand(bsz, -1, self.num_summary).reshape(
|
||||
bsz, real_sum_len
|
||||
) # (bsz, sum_len)
|
||||
sum_bar_pos_ids = torch.cat((sum_bar_pos_ids, sum_bar_pos_ids.new_zeros(bsz, sum_pad_len)), dim=1)
|
||||
sum_bar_pos_ids.masked_fill_(sum_key_padding_mask, 0)
|
||||
sum_bar_abs_pos_embed = None
|
||||
reg_bar_abs_pos_embed = None
|
||||
if self.embed_bar_abs_pos is not None:
|
||||
if (
|
||||
self.num_summary > 0 and self.need_embedding_summary_tokens and
|
||||
(self.valid_parts_for_bar_abs_pos is None or 's' in self.valid_parts_for_bar_abs_pos)
|
||||
):
|
||||
sum_bar_abs_pos_embed = self.embed_bar_abs_pos(sum_bar_pos_ids.transpose(0, 1)) # (l, bsz, dim)
|
||||
if self.valid_parts_for_bar_abs_pos is None or 'r' in self.valid_parts_for_bar_abs_pos:
|
||||
reg_bar_abs_pos_embed = self.embed_bar_abs_pos(reg_bar_pos_ids).transpose(0, 1) # (l, bsz, dim)
|
||||
del sum_bar_pos_ids, reg_bar_pos_ids
|
||||
|
||||
# --- Conclude Embeddings ---
|
||||
sum_x = [item for item in (sum_x, sum_bar_abs_pos_embed, sum_beat_abs_pos_embed) if item is not None]
|
||||
if len(sum_x) == 0:
|
||||
sum_x = None
|
||||
reg_x = [item for item in (
|
||||
reg_x, token_in_chunk_abs_pos_embed, token_abs_pos_embed, reg_bar_abs_pos_embed, reg_beat_abs_pos_embed
|
||||
) if item is not None]
|
||||
del sum_bar_abs_pos_embed
|
||||
del token_in_chunk_abs_pos_embed, token_abs_pos_embed, reg_bar_abs_pos_embed
|
||||
|
||||
if self.concat_reg_embeddings:
|
||||
reg_x = torch.cat(reg_x, dim=-1)
|
||||
if self.reg_proj_embeddings is not None:
|
||||
reg_x = self.reg_proj_embeddings(reg_x)
|
||||
else:
|
||||
reg_x = sum(reg_x)
|
||||
if self.num_summary > 0 and self.need_embedding_summary_tokens:
|
||||
if self.concat_sum_embeddings:
|
||||
sum_x = torch.cat(sum_x, dim=-1)
|
||||
if self.sum_proj_embeddings is not None:
|
||||
sum_x = self.sum_proj_embeddings(sum_x)
|
||||
else:
|
||||
sum_x = sum(sum_x)
|
||||
|
||||
if self.sum_layernorm_embedding is not None and sum_x is not None:
|
||||
sum_x = self.sum_layernorm_embedding(sum_x)
|
||||
if self.reg_layernorm_embedding is not None:
|
||||
reg_x = self.reg_layernorm_embedding(reg_x)
|
||||
|
||||
if sum_x is not None:
|
||||
sum_x = self.dropout_module(sum_x)
|
||||
reg_x = self.dropout_module(reg_x)
|
||||
|
||||
key_padding_mask = computation_tools.may_bi_cat(sum_key_padding_mask, reg_key_padding_mask, dim=1)
|
||||
if key_padding_mask is not None and not key_padding_mask.any():
|
||||
key_padding_mask = None
|
||||
del sum_key_padding_mask, reg_key_padding_mask
|
||||
|
||||
# with open('meta.bin', 'wb') as f:
|
||||
# torch.save((sum_len, reg_len), f)
|
||||
# print('saved meta')
|
||||
|
||||
# === Transformer Layers ===
|
||||
(sum_x, reg_x), inner_states = self.run_layers(
|
||||
(sum_x, reg_x),
|
||||
reg_chunk_ranges=reg_chunk_ranges,
|
||||
num_chunks=num_chunks,
|
||||
num_complete_chunks=num_complete_chunks,
|
||||
num_pref=num_pref,
|
||||
sum_len=sum_len,
|
||||
reg_len=reg_len,
|
||||
|
||||
key_padding_mask=key_padding_mask,
|
||||
attn_mask=None,
|
||||
|
||||
need_weights=False,
|
||||
need_head_weights=False,
|
||||
|
||||
last_state_only=last_state_only,
|
||||
)
|
||||
|
||||
if sum_x is not None:
|
||||
sum_x = sum_x[:real_sum_len]
|
||||
sum_len = real_sum_len
|
||||
sum_x = sum_x.transpose(0, 1)
|
||||
assert sum_x.shape == (bsz, sum_len, self.embed_dim), (sum_x.shape, (bsz, sum_len, self.embed_dim))
|
||||
|
||||
reg_x = reg_x[:real_reg_len]
|
||||
reg_len = real_reg_len
|
||||
if self.layer_norm is not None:
|
||||
reg_x = self.layer_norm(reg_x)
|
||||
reg_x = reg_x.transpose(0, 1)
|
||||
assert reg_x.shape == (bsz, reg_len, self.embed_dim), (reg_x.shape, (bsz, reg_len, self.embed_dim))
|
||||
|
||||
others = {
|
||||
# Uncomment if needed
|
||||
# 'summary': sum_x,
|
||||
'attn': None,
|
||||
# 'inner_states': inner_states,
|
||||
}
|
||||
|
||||
return reg_x, others
|
||||
|
||||
def run_layers(
|
||||
self,
|
||||
x,
|
||||
reg_chunk_ranges,
|
||||
num_chunks,
|
||||
num_complete_chunks,
|
||||
num_pref,
|
||||
sum_len,
|
||||
reg_len,
|
||||
|
||||
key_padding_mask,
|
||||
attn_mask,
|
||||
|
||||
need_weights,
|
||||
need_head_weights,
|
||||
|
||||
last_state_only
|
||||
):
|
||||
inner_states = []
|
||||
if not last_state_only:
|
||||
inner_states.append(x)
|
||||
|
||||
for layer in self.layers:
|
||||
layer_idx = layer.layer_idx
|
||||
|
||||
if (
|
||||
getattr(self.args, "gradient_checkpointing", False) and self.training and
|
||||
layer_idx in self.checkpointing_layers
|
||||
):
|
||||
x, _ = checkpoint(
|
||||
layer,
|
||||
x,
|
||||
reg_chunk_ranges,
|
||||
num_chunks,
|
||||
num_complete_chunks,
|
||||
num_pref,
|
||||
sum_len,
|
||||
reg_len,
|
||||
|
||||
key_padding_mask,
|
||||
None if attn_mask is None else attn_mask[layer_idx],
|
||||
|
||||
need_weights,
|
||||
need_head_weights,
|
||||
)
|
||||
else:
|
||||
x, _ = layer(
|
||||
x,
|
||||
reg_chunk_ranges,
|
||||
num_chunks,
|
||||
num_complete_chunks,
|
||||
num_pref,
|
||||
sum_len,
|
||||
reg_len,
|
||||
|
||||
key_padding_mask=key_padding_mask,
|
||||
attn_mask=None if attn_mask is None else attn_mask[layer_idx],
|
||||
|
||||
need_weights=need_weights,
|
||||
need_head_weights=need_head_weights,
|
||||
)
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask[layer_idx] = None
|
||||
|
||||
if not last_state_only:
|
||||
inner_states.append(x)
|
||||
|
||||
if last_state_only:
|
||||
inner_states = [x]
|
||||
|
||||
return x, inner_states
|
||||
|
||||
def output_layer(self, features, **kwargs):
|
||||
"""Project features to the vocabulary size."""
|
||||
return self.output_projection(features)
|
||||
|
||||
def max_positions(self):
|
||||
"""Maximum output length supported by the decoder."""
|
||||
if getattr(self.args, 'use_token_abs_pos', False) and getattr(self.args, 'learned_token_abs_pos', False):
|
||||
return min(self.args.max_target_positions, self.args.max_token_abs_pos)
|
||||
return self.args.max_target_positions
|
||||
|
||||
def get_normalized_probs(self, net_output, log_probs, sample=None):
|
||||
"""Get normalized probabilities (or log probs) from a net's output."""
|
||||
# assert sample is not None
|
||||
|
||||
logits, others = net_output # (batch, seq_len, vocab_len)
|
||||
embed_dim = logits.shape[-1]
|
||||
if sample is not None and 'target_mask' in sample:
|
||||
target_mask = sample['target_mask']
|
||||
else:
|
||||
target_mask = None
|
||||
if target_mask is not None:
|
||||
logits = logits.masked_select(target_mask.unsqueeze(-1)).view(-1, embed_dim)
|
||||
if log_probs:
|
||||
return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
|
||||
else:
|
||||
return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
|
||||
|
||||
def get_targets(self, sample, net_output):
|
||||
target = sample['target']
|
||||
_, others = net_output
|
||||
if 'target_mask' in sample:
|
||||
target_mask = sample['target_mask']
|
||||
else:
|
||||
target_mask = None
|
||||
if target_mask is not None:
|
||||
target = target.masked_select(target_mask)
|
||||
return target
|
|
@ -0,0 +1,549 @@
|
|||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from fairseq import utils
|
||||
from fairseq.modules import FairseqDropout, LayerNorm
|
||||
|
||||
from .tools import arg_tools, computation_tools
|
||||
from . import attention
|
||||
from .attention_mask_generation.attention_mask_generation import LayerAttentionMaskGeneration, \
|
||||
combine_part_masks, transfer_attn_mask_to_block_layout, combine_attn_mask_and_key_padding_mask_
|
||||
from .data_structures.four_dim_pocket import FourDimPocket
|
||||
from .embedding.embedding_weight_generation import \
|
||||
generate_sinusoid_position_embedding_with_padding, generate_randomly_initialized_position_embedding_with_padding
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_layer_attn_mask(
|
||||
layer_gen: LayerAttentionMaskGeneration,
|
||||
reg_chunk_ranges, num_chunks, num_complete_chunks,
|
||||
num_summary, num_pref, sum_len, reg_len,
|
||||
key_padding_mask=None,
|
||||
highlight_prefix=False,
|
||||
combine_parts=None,
|
||||
block_size=None,
|
||||
avoid_empty_row_or_column=True
|
||||
):
|
||||
# all_seq_len = sum_len + reg_len
|
||||
bsz = reg_chunk_ranges.shape[0]
|
||||
attn_mask = layer_gen(
|
||||
reg_chunk_ranges, num_chunks, num_complete_chunks,
|
||||
num_summary, num_pref, sum_len, reg_len
|
||||
) # each part: (bsz, 1, part_tgt_len, part_src_len)
|
||||
|
||||
# assert (~attn_mask['ss']).any()
|
||||
|
||||
# combine key_padding_mask into attn_mask
|
||||
if key_padding_mask is not None:
|
||||
sum_key_padding_mask = key_padding_mask[:, :sum_len] # (bsz, sum_len)
|
||||
if sum_key_padding_mask.any():
|
||||
for key in ('ss', 'rs'):
|
||||
if key not in attn_mask:
|
||||
continue
|
||||
combine_attn_mask_and_key_padding_mask_(attn_mask[key], sum_key_padding_mask)
|
||||
reg_key_padding_mask = key_padding_mask[:, sum_len:] # (bsz, reg_len)
|
||||
if reg_key_padding_mask.any():
|
||||
for key in ('sr', 'rr'):
|
||||
if key not in attn_mask:
|
||||
continue
|
||||
combine_attn_mask_and_key_padding_mask_(attn_mask[key], reg_key_padding_mask)
|
||||
del sum_key_padding_mask, reg_key_padding_mask
|
||||
del key_padding_mask
|
||||
|
||||
if highlight_prefix:
|
||||
raise NotImplementedError
|
||||
rr_attn_mask = attn_mask['rr']
|
||||
assert rr_attn_mask.shape[1:] == (1, reg_len, reg_len)
|
||||
rr_attn_mask = rr_attn_mask.type(torch.uint8) # (bsz, 1, reg_len, reg_len)
|
||||
num_unique_pref = len(torch.unique(num_pref))
|
||||
if num_unique_pref == 1:
|
||||
unique_prefix_number = num_pref[0].item()
|
||||
change_2_mask = rr_attn_mask[:, :, :, :unique_prefix_number].eq(0)
|
||||
rr_attn_mask[:, :, :, :unique_prefix_number][change_2_mask] = 2
|
||||
else:
|
||||
raise NotImplementedError
|
||||
attn_mask['rr'] = rr_attn_mask
|
||||
|
||||
# for key in attn_mask:
|
||||
# print(key)
|
||||
# temp_mask = attn_mask[key]
|
||||
# if temp_mask is not None:
|
||||
# print(temp_mask.dtype)
|
||||
# temp_mask = temp_mask.long()
|
||||
# temp_mask = temp_mask[0, 0]
|
||||
#
|
||||
# for line in temp_mask:
|
||||
# for col in line:
|
||||
# print(int(col), end=' ')
|
||||
# print()
|
||||
# else:
|
||||
# print('None')
|
||||
|
||||
if combine_parts == 'all':
|
||||
attn_mask = combine_part_masks(attn_mask) # (bsz, 1, all_seq_len, all_seq_len)
|
||||
attn_mask = {'all': attn_mask}
|
||||
elif isinstance(combine_parts, dict):
|
||||
new_attn_mask = {}
|
||||
processed_parts = set()
|
||||
for key in combine_parts:
|
||||
part1, part2 = combine_parts[key]
|
||||
assert part1 not in processed_parts
|
||||
assert part2 not in processed_parts
|
||||
assert part1[0] == part2[0] # only ss + sr, or rs + rr can be combined
|
||||
new_attn_mask[key] = computation_tools.may_bi_cat(attn_mask[part1], attn_mask[part2], dim=3)
|
||||
processed_parts.add(part1)
|
||||
processed_parts.add(part2)
|
||||
for key in attn_mask:
|
||||
if key in processed_parts:
|
||||
continue
|
||||
new_attn_mask[key] = attn_mask[key]
|
||||
attn_mask = new_attn_mask
|
||||
del processed_parts, new_attn_mask
|
||||
elif combine_parts is None:
|
||||
pass
|
||||
else:
|
||||
raise ValueError('combine_parts has wrong value:', combine_parts)
|
||||
|
||||
if block_size is None:
|
||||
return attn_mask
|
||||
|
||||
for key in attn_mask:
|
||||
batch_key_attn_mask = []
|
||||
for sample_idx in range(bsz):
|
||||
if attn_mask[key] is None:
|
||||
r = None
|
||||
else:
|
||||
r = transfer_attn_mask_to_block_layout(attn_mask[key][sample_idx], block_size,
|
||||
avoid_empty_row_or_column=avoid_empty_row_or_column)
|
||||
batch_key_attn_mask.append(r)
|
||||
attn_mask[key] = batch_key_attn_mask
|
||||
# real_part_sample_mask: layout, block_mask
|
||||
# layout: (head, num_tgt_blocks, num_src_blocks)
|
||||
# block_mask: (-1, block, block)
|
||||
|
||||
return attn_mask
|
||||
|
||||
|
||||
class MuseformerDecoderLayer(nn.Module):
|
||||
_submodules = (attention,)
|
||||
|
||||
@classmethod
|
||||
def add_args(cls, parser):
|
||||
parser.add_argument('--share-self-attention-layer-norm', type=arg_tools.str_bool_with_default_error)
|
||||
parser.add_argument('--share-ffn', type=arg_tools.str_bool_with_default_error)
|
||||
parser.add_argument('--share-final-layer-norm', type=arg_tools.str_bool_with_default_error)
|
||||
parser.add_argument('--chunk-ffn', type=int)
|
||||
|
||||
arg_tools.add_submodule_args(cls, parser)
|
||||
|
||||
def __init__(self,
|
||||
args,
|
||||
layer_idx,
|
||||
num_attention_heads,
|
||||
layer_attention_scheme,
|
||||
attention_dropout,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.pocket = FourDimPocket()
|
||||
self.pocket_constant = self.pocket['constant']
|
||||
self.pocket_instant = self.pocket['instant']
|
||||
|
||||
# === Basic Settings ===
|
||||
self.args = args
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.attention_impl = self.pocket_constant['attention_impl']
|
||||
self.block_size = self.pocket_constant['block_size']
|
||||
self.attention_mode = self.pocket_constant['attention_mode']
|
||||
self.attn_mask_combination = self.pocket_constant['attn_mask_combination']
|
||||
self.attn_mask_gen_parts = self.pocket_constant['attn_mask_gen_parts']
|
||||
self.need_embedding_summary_tokens = self.pocket_constant['need_embedding_summary_tokens']
|
||||
|
||||
self.layer_to_sv = self.pocket_constant['layer_to_sv']
|
||||
self.sv_to_layers = self.pocket_constant['sv_to_layers']
|
||||
self.layer_sv = self.layer_to_sv[self.layer_idx]
|
||||
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.embed_dim = args.attention_embed_dim
|
||||
self.num_summary = args.num_summary_tokens_each_chunk
|
||||
self.normalize_before = args.normalize_before
|
||||
|
||||
self.chunk_ffn = getattr(args, 'chunk_ffn', None)
|
||||
if self.chunk_ffn is not None:
|
||||
if self.chunk_ffn <= 0:
|
||||
self.chunk_ffn = None
|
||||
|
||||
# === Layer Attention Mask Generator ===
|
||||
self.layer_attention_scheme = layer_attention_scheme
|
||||
layer_attention_mask_generator_label = ('layer_attention_mask_generator', self.layer_sv)
|
||||
if layer_attention_mask_generator_label in self.pocket_constant:
|
||||
self.layer_attention_mask_generator = self.pocket_constant[layer_attention_mask_generator_label]
|
||||
else:
|
||||
self.layer_attention_mask_generator = LayerAttentionMaskGeneration(self.layer_attention_scheme,
|
||||
gen_parts=self.attn_mask_gen_parts)
|
||||
self.pocket_constant[layer_attention_mask_generator_label] = self.layer_attention_mask_generator
|
||||
|
||||
# === Construct Relative Embeddings ===
|
||||
self.use_token_rel_pos = getattr(self.args, 'use_token_rel_pos', False)
|
||||
if self.use_token_rel_pos:
|
||||
self.max_token_rel_pos = args.max_token_rel_pos
|
||||
token_rel_pos_embed_dim = getattr(self.args, 'token_rel_pos_embed_dim', self.embed_dim)
|
||||
if getattr(self.args, 'learned_token_rel_pos', False):
|
||||
token_rel_embed = generate_randomly_initialized_position_embedding_with_padding(
|
||||
self.max_token_rel_pos + 1, token_rel_pos_embed_dim
|
||||
)
|
||||
self.register_parameter('token_rel_embed', nn.Parameter(token_rel_embed, requires_grad=True))
|
||||
else:
|
||||
token_rel_embed = generate_sinusoid_position_embedding_with_padding(
|
||||
self.max_token_rel_pos + 1, token_rel_pos_embed_dim
|
||||
)
|
||||
self.register_buffer('token_rel_embed', token_rel_embed, persistent=False)
|
||||
else:
|
||||
self.token_rel_embed = None
|
||||
|
||||
self.no_token_rel_pos_for_prefix = getattr(self.args, 'no_token_rel_pos_for_prefix', False)
|
||||
|
||||
self.use_bar_rel_pos = getattr(self.args, 'use_bar_rel_pos', False)
|
||||
if self.use_bar_rel_pos:
|
||||
self.max_bar_rel_pos = args.max_bar_rel_pos
|
||||
bar_rel_pos_embed_dim = getattr(self.args, 'bar_rel_pos_embed_dim', self.embed_dim)
|
||||
if getattr(self.args, 'learned_bar_rel_pos', False):
|
||||
bar_rel_embed = generate_randomly_initialized_position_embedding_with_padding(
|
||||
self.max_bar_rel_pos + 1, bar_rel_pos_embed_dim
|
||||
)
|
||||
self.register_parameter('bar_rel_embed', nn.Parameter(bar_rel_embed, requires_grad=True))
|
||||
else:
|
||||
bar_rel_embed = generate_sinusoid_position_embedding_with_padding(
|
||||
self.max_bar_rel_pos + 1, bar_rel_pos_embed_dim
|
||||
)
|
||||
self.register_buffer('bar_rel_embed', bar_rel_embed, persistent=False)
|
||||
else:
|
||||
self.bar_rel_embed = None
|
||||
|
||||
# === Self Attention ===
|
||||
self.self_attn = self.build_self_attention(
|
||||
self.args,
|
||||
self.embed_dim,
|
||||
self.num_attention_heads,
|
||||
attention_dropout,
|
||||
token_rel_pos_embeddings=self.token_rel_embed,
|
||||
bar_rel_pos_embeddings=self.bar_rel_embed
|
||||
)
|
||||
|
||||
# === Other Modules ===
|
||||
self.dropout_module = FairseqDropout(
|
||||
args.dropout, module_name=self.__class__.__name__
|
||||
)
|
||||
|
||||
self.activation_fn = utils.get_activation_fn(
|
||||
activation=str(args.activation_fn)
|
||||
if getattr(args, "activation_fn", None) is not None
|
||||
else "relu"
|
||||
)
|
||||
# activation_dropout_p = getattr(args, "activation_dropout", None)
|
||||
# if activation_dropout_p is not None:
|
||||
# assert activation_dropout_p > 0
|
||||
# self.activation_dropout_module = FairseqDropout(
|
||||
# float(activation_dropout_p), module_name=self.__class__.__name__
|
||||
# )
|
||||
# else:
|
||||
# self.activation_dropout_module = None
|
||||
|
||||
self.reg_self_attn_layer_norm = LayerNorm(self.embed_dim, export=False)
|
||||
if self.need_embedding_summary_tokens:
|
||||
if getattr(self.args, 'share_self_attention_layer_norm', False):
|
||||
self.sum_self_attn_layer_norm = self.reg_self_attn_layer_norm
|
||||
else:
|
||||
self.sum_self_attn_layer_norm = LayerNorm(self.embed_dim, export=False)
|
||||
else:
|
||||
self.sum_self_attn_layer_norm = None
|
||||
|
||||
self.reg_fc1 = nn.Linear(self.embed_dim, args.ffn_embed_dim)
|
||||
self.reg_fc2 = nn.Linear(args.ffn_embed_dim, self.embed_dim)
|
||||
share_ffn = getattr(self.args, 'share_ffn', False)
|
||||
if self.need_embedding_summary_tokens:
|
||||
if share_ffn:
|
||||
self.sum_fc1 = self.reg_fc1
|
||||
self.sum_fc2 = self.reg_fc2
|
||||
else:
|
||||
self.sum_fc1 = nn.Linear(self.embed_dim, args.ffn_embed_dim)
|
||||
self.sum_fc2 = nn.Linear(args.ffn_embed_dim, self.embed_dim)
|
||||
else:
|
||||
self.sum_fc1 = None
|
||||
self.sum_fc2 = None
|
||||
|
||||
self.reg_final_layer_norm = LayerNorm(self.embed_dim, export=False)
|
||||
if self.need_embedding_summary_tokens:
|
||||
if getattr(self.args, 'share_final_layer_norm', False):
|
||||
self.sum_final_layer_norm = self.reg_final_layer_norm
|
||||
else:
|
||||
self.sum_final_layer_norm = LayerNorm(self.embed_dim, export=False)
|
||||
else:
|
||||
self.sum_final_layer_norm = None
|
||||
|
||||
def build_self_attention(
|
||||
self,
|
||||
args,
|
||||
embed_dim,
|
||||
num_attention_heads,
|
||||
dropout,
|
||||
token_rel_pos_embeddings,
|
||||
bar_rel_pos_embeddings,
|
||||
**kwargs,
|
||||
):
|
||||
rel_embeddings = []
|
||||
no_rel_projections = []
|
||||
if self.use_token_rel_pos:
|
||||
rel_embeddings.append(token_rel_pos_embeddings)
|
||||
no_rel_projections.append(getattr(args, 'no_token_rel_pos_proj', False))
|
||||
if self.use_bar_rel_pos:
|
||||
rel_embeddings.append(bar_rel_pos_embeddings)
|
||||
no_rel_projections.append(getattr(args, 'no_bar_rel_pos_proj', False))
|
||||
|
||||
return attention.create_attention(
|
||||
implementation=self.attention_impl,
|
||||
attention_mode=self.attention_mode,
|
||||
block_size=self.block_size,
|
||||
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_attention_heads,
|
||||
|
||||
num_summary=self.args.num_summary_tokens_each_chunk,
|
||||
|
||||
rel_embeddings=rel_embeddings,
|
||||
|
||||
layer_idx=self.layer_idx,
|
||||
|
||||
dropout=dropout,
|
||||
|
||||
query_proj_bias=getattr(args, 'attn_query_proj_bias', False),
|
||||
key_proj_bias=getattr(args, 'attn_key_proj_bias', False),
|
||||
value_proj_bias=getattr(args, 'attn_value_proj_bias', False),
|
||||
out_proj_bias=getattr(args, 'attn_out_proj_bias', False),
|
||||
|
||||
no_rel_proj=no_rel_projections,
|
||||
rel_proj_bias=getattr(args, 'rel_proj_bias', False),
|
||||
|
||||
single_head_masks=self.layer_attention_mask_generator.single_head,
|
||||
|
||||
# For v1, v2, v2.1
|
||||
add_different_kqv_bias_for_sum_and_reg=getattr(args, 'add_different_kqv_bias_for_sum_and_reg', False),
|
||||
add_different_out_bias_for_sum_and_reg=getattr(args, 'add_different_out_bias_for_sum_and_reg', False),
|
||||
|
||||
# For v2, v2.1
|
||||
sum_key2_proj_bias=getattr(args, 'attn_sum_key2_proj_bias', False),
|
||||
sum_value2_proj_bias=getattr(args, 'attn_sum_value2_proj_bias', False),
|
||||
share_query_proj=getattr(args, 'attn_share_query_proj', False),
|
||||
share_key_proj=getattr(args, 'attn_share_key_proj', False),
|
||||
share_value_proj=getattr(args, 'attn_share_value_proj', False),
|
||||
share_out_proj=getattr(args, 'attn_share_out_proj', False),
|
||||
share_key2_value2_proj_weight=getattr(args, 'attn_share_key2_value2_proj_weight', False),
|
||||
|
||||
no_sum_out=((self.layer_idx == self.args.num_layers - 1)
|
||||
if getattr(args, 'add_different_out_bias_for_sum_and_reg', False) else False
|
||||
), # to make compatible with previous checkpoints
|
||||
|
||||
# For v5 (sum_then_reg_3)
|
||||
share_reg_kv_proj=getattr(args, 'share_reg_kv_proj', False),
|
||||
#
|
||||
# key_rel_proj_bias=getattr(args, 'attn_key_rel_proj_bias', True), # renamed to rel_proj_bias
|
||||
# add_global_rel_bias=getattr(args, 'attn_add_global_rel_bias', True),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x, # tuple of (Optional[Tensor], Tensor) # (sum_len, bsz, embed_dim) (reg_len, bsz, embed_dim)
|
||||
reg_chunk_ranges, # (bsz, max_chunk, 2)
|
||||
num_chunks, # (bsz,)
|
||||
num_complete_chunks, # (bsz,)
|
||||
num_pref, # (batch,)
|
||||
sum_len,
|
||||
reg_len,
|
||||
|
||||
key_padding_mask=None, # (bsz, all_seq_len)
|
||||
attn_mask=None, # (bsz, num_heads, all_seq_len, all_seq_len)
|
||||
|
||||
need_weights=False,
|
||||
need_head_weights=False,
|
||||
):
|
||||
if need_head_weights:
|
||||
need_weights = True
|
||||
|
||||
if attn_mask is None:
|
||||
attn_mask_label = 'attn_mask_sv%d' % self.layer_sv
|
||||
if attn_mask_label in self.pocket_instant:
|
||||
attn_mask = self.pocket_instant[attn_mask_label]
|
||||
else:
|
||||
attn_mask = generate_layer_attn_mask(
|
||||
self.layer_attention_mask_generator,
|
||||
reg_chunk_ranges, num_chunks, num_complete_chunks,
|
||||
self.args.num_summary_tokens_each_chunk,
|
||||
num_pref, sum_len, reg_len,
|
||||
key_padding_mask=key_padding_mask,
|
||||
highlight_prefix=False,
|
||||
combine_parts=self.attn_mask_combination,
|
||||
block_size=self.block_size if self.attention_impl == 'triton' else None,
|
||||
avoid_empty_row_or_column=True
|
||||
)
|
||||
self.pocket_instant[attn_mask_label] = attn_mask
|
||||
else:
|
||||
raise NotImplementedError('Passing attn_mask into a Museformer layer is not supported yet.')
|
||||
key_padding_mask = None # key_padding_mask is useless, after combined into attn_mask
|
||||
|
||||
token_rel_indices = None
|
||||
bar_rel_indices = None
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = computation_tools.may_bi_op(
|
||||
self.sum_self_attn_layer_norm, self.reg_self_attn_layer_norm,
|
||||
x, sum_len, reg_len, self.embed_dim, as_tuple=True
|
||||
)
|
||||
|
||||
# print(attn_mask)
|
||||
# print(token_rel_indices)
|
||||
# print(bar_rel_indices)
|
||||
|
||||
# st = time.time()
|
||||
# try:
|
||||
x, attn = self.run_self_attn(
|
||||
x, x, x,
|
||||
sum_len, reg_len,
|
||||
key_padding_mask=key_padding_mask, attn_mask=attn_mask,
|
||||
incremental_state=None,
|
||||
need_weights=need_weights, need_head_weights=need_head_weights,
|
||||
token_rel_indices=token_rel_indices, bar_rel_indices=bar_rel_indices,
|
||||
)
|
||||
# except RuntimeError:
|
||||
# print('x:', x.shape)
|
||||
# print('num_heads:', self.num_attention_heads)
|
||||
# print('attn_mask:', attn_mask.shape)
|
||||
# print('key_padding_mask:', key_padding_mask.shape)
|
||||
# raise
|
||||
# et = time.time()
|
||||
# if debug:
|
||||
# logger.info('Run Self-Attn: %f' % (et - st))
|
||||
|
||||
x = computation_tools.may_bi_op(self.dropout_module, self.dropout_module, x,
|
||||
sum_len, reg_len, self.embed_dim, as_tuple=True)
|
||||
x = self.residual_connection(x, residual)
|
||||
if not self.normalize_before:
|
||||
x = computation_tools.may_bi_op(
|
||||
self.sum_self_attn_layer_norm, self.reg_self_attn_layer_norm,
|
||||
x, sum_len, reg_len, self.embed_dim, as_tuple=True
|
||||
)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = computation_tools.may_bi_op(
|
||||
self.sum_final_layer_norm, self.reg_final_layer_norm,
|
||||
x, sum_len, reg_len, self.embed_dim, as_tuple=True
|
||||
)
|
||||
|
||||
x = self.run_ffn(x, sum_len, reg_len)
|
||||
|
||||
x = computation_tools.may_bi_op(self.dropout_module, self.dropout_module, x,
|
||||
sum_len, reg_len, self.embed_dim, as_tuple=True)
|
||||
x = self.residual_connection(x, residual)
|
||||
if not self.normalize_before:
|
||||
x = computation_tools.may_bi_op(
|
||||
self.sum_final_layer_norm, self.reg_final_layer_norm,
|
||||
x, sum_len, reg_len, self.embed_dim, as_tuple=True
|
||||
)
|
||||
|
||||
return x, attn
|
||||
|
||||
@staticmethod
|
||||
def residual_connection(x, residual):
|
||||
return residual[0] + x[0] if (residual[0] is not None and x[0] is not None) else None, residual[1] + x[1]
|
||||
|
||||
def run_self_attn(
|
||||
self,
|
||||
query, # (all_seq_len, bsz, embed_dim)
|
||||
key, # (all_seq_len, bsz, embed_dim)
|
||||
value, # (all_seq_len, bsz, embed_dim)
|
||||
sum_len,
|
||||
reg_len,
|
||||
key_padding_mask=None,
|
||||
attn_mask=None,
|
||||
incremental_state=None,
|
||||
need_weights=True,
|
||||
need_head_weights=False,
|
||||
token_rel_indices=None,
|
||||
bar_rel_indices=None,
|
||||
**kwargs,
|
||||
):
|
||||
assert incremental_state is None
|
||||
rel_indices = []
|
||||
if self.use_token_rel_pos:
|
||||
rel_indices.append(token_rel_indices)
|
||||
if self.use_bar_rel_pos:
|
||||
rel_indices.append(bar_rel_indices)
|
||||
r, weight = self.self_attn(
|
||||
query,
|
||||
self.pocket_instant['sum_token_ids'],
|
||||
sum_len, reg_len,
|
||||
rel_indices,
|
||||
key_padding_mask=key_padding_mask,
|
||||
attn_mask=attn_mask,
|
||||
need_weights=need_weights,
|
||||
need_head_weights=need_head_weights,
|
||||
**kwargs,
|
||||
)
|
||||
return r, weight
|
||||
|
||||
def run_ffn(self, x, sum_len, reg_len):
|
||||
sum_x, reg_x = x
|
||||
del x
|
||||
|
||||
def reg_ffn(input_x):
|
||||
input_x = self.reg_fc1(input_x)
|
||||
input_x = self.activation_fn(input_x)
|
||||
input_x = self.reg_fc2(input_x)
|
||||
return input_x
|
||||
|
||||
if sum_x is None:
|
||||
sum_ffn = None
|
||||
else:
|
||||
if getattr(self.args, 'share_ffn', False):
|
||||
sum_ffn = reg_ffn
|
||||
else:
|
||||
def sum_ffn(input_x):
|
||||
input_x = self.sum_fc1(input_x)
|
||||
input_x = self.activation_fn(input_x)
|
||||
input_x = self.sum_fc2(input_x)
|
||||
return input_x
|
||||
|
||||
if self.chunk_ffn is None:
|
||||
if sum_x is not None:
|
||||
sum_x = sum_ffn(sum_x)
|
||||
reg_x = reg_ffn(reg_x)
|
||||
else:
|
||||
def do_chunk_ffn(input_x, ffn, split_size):
|
||||
input_x = torch.split(input_x, split_size, dim=0)
|
||||
result = []
|
||||
for chunk_x in input_x:
|
||||
chunk_x = ffn(chunk_x)
|
||||
result.append(chunk_x)
|
||||
del input_x
|
||||
result = torch.cat(result, dim=0)
|
||||
return result
|
||||
if reg_x.requires_grad:
|
||||
if id(reg_ffn) == id(sum_ffn):
|
||||
def do_multi_chunk_ffn(input_x, ffn, split_size):
|
||||
return [do_chunk_ffn(one_x, ffn, split_size) for one_x in input_x]
|
||||
sum_x, reg_x = do_multi_chunk_ffn((sum_x, reg_x), reg_ffn, self.chunk_ffn)
|
||||
else:
|
||||
reg_x = checkpoint(do_chunk_ffn, reg_x, reg_ffn, self.chunk_ffn)
|
||||
if sum_x is not None:
|
||||
sum_x = do_chunk_ffn(sum_x, sum_ffn, self.chunk_ffn)
|
||||
else:
|
||||
reg_x = do_chunk_ffn(reg_x, reg_ffn, self.chunk_ffn)
|
||||
if sum_x is not None:
|
||||
sum_x = do_chunk_ffn(sum_x, sum_ffn, self.chunk_ffn)
|
||||
|
||||
return sum_x, reg_x
|
|
@ -0,0 +1,158 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from omegaconf import II
|
||||
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
from fairseq.models import FairseqLanguageModel, register_model, register_model_architecture
|
||||
|
||||
from .tools import arg_tools
|
||||
from .museformer_decoder import MuseformerDecoder
|
||||
|
||||
DEFAULT_MAX_TARGET_POSITIONS = 100000
|
||||
|
||||
|
||||
@dataclass
|
||||
class MuseformerLanguageModelConfig(FairseqDataclass):
|
||||
tokens_per_sample: int = II("task.tokens_per_sample")
|
||||
max_target_positions: Optional[int] = II("task.max_target_positions")
|
||||
eob_token: str = II("task.eob_token")
|
||||
eoc_token: str = II("task.eoc_token")
|
||||
chunking_scheme: str = II("task.chunking_scheme")
|
||||
fixed_chunking_length = int = II("task.fixed_chunking_length")
|
||||
|
||||
|
||||
@register_model('museformer_lm', dataclass=MuseformerLanguageModelConfig)
|
||||
class MuseformerLanguageModel(FairseqLanguageModel):
|
||||
_submodules = (MuseformerDecoder,)
|
||||
|
||||
@classmethod
|
||||
def add_args(cls, parser):
|
||||
super(MuseformerLanguageModel, cls).add_args(parser)
|
||||
arg_tools.add_submodule_args(cls, parser)
|
||||
|
||||
def __init__(self, decoder):
|
||||
super().__init__(decoder)
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, task):
|
||||
decoder = MuseformerDecoder(args, task.target_dictionary)
|
||||
print(args)
|
||||
return cls(decoder)
|
||||
|
||||
def get_targets(self, sample, net_output):
|
||||
return self.decoder.get_targets(sample, net_output)
|
||||
|
||||
|
||||
def base_lm_scheme(args):
|
||||
args.pref2pref_mode = getattr(args, "pref2pref_mode", 'lt')
|
||||
args.sum2pref_mode = getattr(args, "sum2pref_mode", 'none')
|
||||
args.pref2sum_mode = getattr(args, "pref2sum_mode", "none")
|
||||
args.con2pref_mode = getattr(args, "con2pref_mode", 'full')
|
||||
args.pref2con_mode = getattr(args, "pref2con_mode", 'none')
|
||||
|
||||
|
||||
def b1248_lm_scheme(args):
|
||||
args.con2con = getattr(args, "con2con", ((((-2, 0), -4, -8, -12, -16, -24, -32),),))
|
||||
args.con2con_self = getattr(args, "con2con_self", "full")
|
||||
args.con2con_causal = getattr(args, "con2con_causal", True)
|
||||
args.con2sum = getattr(args, "con2sum",
|
||||
((((None, -32), (-31, -24), (-23, -16), (-15, -12), (-11, -8), (-7, -4), -3,),),))
|
||||
args.con2sum_self = getattr(args, "con2sum_self", "none")
|
||||
args.sum2con = getattr(args, "sum2con", ((None,),))
|
||||
args.sum2con_self = getattr(args, "sum2con_self", "full")
|
||||
args.sum2sum = getattr(args, "sum2sum", ((None,),))
|
||||
args.sum2sum_self = getattr(args, "sum2sum_self", "full")
|
||||
base_lm_scheme(args)
|
||||
|
||||
|
||||
def b1234_lm_scheme(args):
|
||||
args.con2con = getattr(args, "con2con", ((((-8, 0),),),))
|
||||
args.con2con_self = getattr(args, "con2con_self", "full")
|
||||
args.con2con_causal = getattr(args, "con2con_causal", True)
|
||||
args.con2sum = getattr(args, "con2sum", ((((None, -8),),),))
|
||||
args.con2sum_self = getattr(args, "con2sum_self", "none")
|
||||
args.sum2con = getattr(args, "sum2con", ((None,),))
|
||||
args.sum2con_self = getattr(args, "sum2con_self", "full")
|
||||
args.sum2sum = getattr(args, "sum2sum", ((None,),))
|
||||
args.sum2sum_self = getattr(args, "sum2sum_self", "full")
|
||||
base_lm_scheme(args)
|
||||
|
||||
|
||||
def share_sum_reg_params(args):
|
||||
args.share_layernorm_embedding = getattr(args, 'share_layernorm_embedding', True)
|
||||
args.attn_share_query_proj = getattr(args, 'attn_share_query_proj', True)
|
||||
args.attn_share_key_proj = getattr(args, 'attn_share_key_proj', True)
|
||||
args.attn_share_value_proj = getattr(args, 'attn_share_value_proj', True)
|
||||
args.attn_share_out_proj = getattr(args, 'attn_share_out_proj', True)
|
||||
args.share_self_attention_layer_norm = getattr(args, 'share_self_attention_layer_norm', True)
|
||||
args.share_ffn = getattr(args, 'share_ffn', True)
|
||||
args.share_final_layer_norm = getattr(args, 'share_final_layer_norm', True)
|
||||
|
||||
|
||||
def base_lm_architecture(args):
|
||||
args.attention_embed_dim = getattr(args, "attention_embed_dim", 512)
|
||||
args.num_layers = getattr(args, "num_layers", 4)
|
||||
args.num_attention_heads = getattr(args, "num_attention_heads", (8,))
|
||||
args.normalize_before = getattr(args, "normalize_before", True)
|
||||
args.ffn_embed_dim = getattr(args, "ffn_embed_dim", 2048)
|
||||
args.dropout = getattr(args, "dropout", 0.1)
|
||||
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
||||
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
||||
args.no_final_norm = getattr(args, "no_final_norm", False)
|
||||
|
||||
args.take_bos_as_bar = getattr(args, 'take_bos_as_bar', True)
|
||||
args.num_summary_tokens_each_chunk = getattr(args, "num_summary_tokens_each_chunk", 1)
|
||||
|
||||
args.share_input_output_embed = getattr(args, "share_input_output_embed", False)
|
||||
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
|
||||
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
|
||||
|
||||
args.token_embed_dim = getattr(args, "token_embed_dim", args.attention_embed_dim)
|
||||
|
||||
if getattr(args, "max_target_positions", None) is None:
|
||||
args.max_target_positions = getattr(
|
||||
args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS
|
||||
)
|
||||
if args.num_summary_tokens_each_chunk <= 0:
|
||||
args.num_summary_tokens_each_chunk = 0
|
||||
args.con2sum = ((None,),)
|
||||
args.sum2con = ((None,),)
|
||||
args.sum2sum = ((None,),)
|
||||
|
||||
args.attention_impl = getattr(args, "attention_impl", "triton")
|
||||
if args.attention_impl == 'triton':
|
||||
args.block_size = getattr(args, "block_size", 64)
|
||||
|
||||
|
||||
@register_model_architecture("museformer_lm", "museformer_lm")
|
||||
def museformer_lm_v2s1(args):
|
||||
args.attention_mode = getattr(args, 'attention_mode', 'v2s1')
|
||||
|
||||
b1248_lm_scheme(args)
|
||||
share_sum_reg_params(args)
|
||||
|
||||
args.tokens_per_sample = getattr(args, 'tokens_per_sample', 100000)
|
||||
|
||||
args.use_token_in_chunk_abs_pos = getattr(args, 'use_token_in_chunk_abs_pos', False)
|
||||
args.use_token_abs_pos = getattr(args, 'use_token_abs_pos', False)
|
||||
args.use_bar_abs_pos = getattr(args, 'use_bar_abs_pos', True)
|
||||
args.max_bar_abs_pos = getattr(args, 'max_bar_abs_pos', 512)
|
||||
args.bar_abs_pos_embed_dim = getattr(args, 'bar_abs_pos_embed_dim', 256)
|
||||
args.use_beat_abs_pos = getattr(args, 'use_beat_abs_pos', True)
|
||||
args.max_beat_abs_pos = getattr(args, 'max_beat_abs_pos', 64)
|
||||
args.beat_abs_pos_embed_dim = getattr(args, 'beat_abs_pos_embed_dim', 128)
|
||||
|
||||
args.concat_reg_embeddings = getattr(args, 'concat_reg_embeddings', True)
|
||||
args.proj_reg_embeddings = getattr(args, 'proj_reg_embeddings', True)
|
||||
args.concat_sum_embeddings = getattr(args, 'concat_sum_embeddings', True)
|
||||
args.proj_sum_embeddings = getattr(args, 'proj_sum_embeddings', True)
|
||||
|
||||
args.attn_query_proj_bias = getattr(args, 'attn_query_proj_bias', True)
|
||||
args.attn_key_proj_bias = getattr(args, 'attn_key_proj_bias', True)
|
||||
args.sum_key2_proj_bias = getattr(args, 'attn_sum_key2_proj_bias', True)
|
||||
args.attn_value_proj_bias = getattr(args, 'attn_value_proj_bias', True)
|
||||
args.sum_value2_proj_bias = getattr(args, 'attn_sum_value2_proj_bias', True)
|
||||
args.attn_out_proj_bias = getattr(args, 'attn_out_proj_bias', True)
|
||||
args.add_different_kqv_bias_for_sum_and_reg = getattr(args, 'add_different_kqv_bias_for_sum_and_reg', True)
|
||||
|
||||
base_lm_architecture(args)
|
|
@ -0,0 +1,159 @@
|
|||
import os
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from fairseq import utils
|
||||
from fairseq.tasks.language_modeling import LanguageModelingTask
|
||||
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
|
||||
from omegaconf import II
|
||||
from fairseq.data import data_utils
|
||||
from fairseq.data.indexed_dataset import get_available_dataset_impl
|
||||
from fairseq.tasks import register_task
|
||||
|
||||
from .dictionary.compound_dictionary import CompoundDictionary
|
||||
from .tools import arg_tools
|
||||
from .datasets.extended_wrapper_dataset import ExtendedWrapperDataset
|
||||
from .datasets.remove_short_size_samples_dataset import MayRemoveShortSizeSamplesDataset
|
||||
from .datasets.chunk_sequence_dataset_2 import ChunkSequenceDataset as ChunkSequenceDataset2
|
||||
from .datasets.prefix_tag_dataset import PrefixTagDataset
|
||||
from .datasets.music_monolingual_dataset_2 import MusicMonolingualDataset as MusicMonolingualDataset2
|
||||
from .datasets.remove_long_size_samples_dataset import MayRemoveLongSizeSamplesDataset
|
||||
from .datasets.truncate_music_dataset_2 import TruncateMusicDataset as TruncateMusicDataset2
|
||||
from .datasets.post_process_dataset_2 import PostProcessDataset as PostProcessDataset2
|
||||
from .datasets.add_beat_dataset import AddBeatDataset
|
||||
from .sequence_generator import MuseformerSequenceGenerator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MuseformerLanguageModelingConfig(FairseqDataclass):
|
||||
data: Optional[str] = field(
|
||||
default=None, metadata={"help": "path to data directory"}
|
||||
)
|
||||
tokens_per_sample: int = field(
|
||||
default=1024,
|
||||
metadata={"help": "max number of tokens per sample for LM dataset"},
|
||||
)
|
||||
max_target_positions: Optional[int] = field(
|
||||
default=None, metadata={"help": "max number of tokens in the target sequence"}
|
||||
)
|
||||
output_dictionary_size: int = field(
|
||||
default=-1, metadata={"help": "limit the size of output dictionary"}
|
||||
)
|
||||
seed: int = II("params.common.seed")
|
||||
dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II(
|
||||
"params.dataset.dataset_impl"
|
||||
)
|
||||
|
||||
|
||||
@register_task('museformer_language_modeling', dataclass=MuseformerLanguageModelingConfig)
|
||||
class MuseformerLanguageModelingTask(LanguageModelingTask):
|
||||
@classmethod
|
||||
def add_args(cls, parser):
|
||||
super(MuseformerLanguageModelingTask, cls).add_args(parser)
|
||||
|
||||
# Basic
|
||||
parser.add_argument('--eob-token', default='b-1')
|
||||
parser.add_argument('--eoc-token', type=arg_tools.str_to_type_with_specific_word_as_none(str, 'None'),
|
||||
default='e-1')
|
||||
|
||||
parser.add_argument('--chunking-scheme', choices=('bar_aware', 'fixed'), default='bar_aware')
|
||||
parser.add_argument('--fixed-chunking-length', type=int, default=None)
|
||||
|
||||
parser.add_argument('--max-size-train', type=int)
|
||||
parser.add_argument('--max-size-valid', type=int)
|
||||
parser.add_argument('--max-size-test', type=int)
|
||||
|
||||
parser.add_argument('--truncate-train', type=int)
|
||||
parser.add_argument('--truncate-valid', type=int)
|
||||
parser.add_argument('--truncate-test', type=int)
|
||||
|
||||
parser.add_argument('--take-bos-as-bar', type=arg_tools.str_bool_with_default_error, default=False)
|
||||
parser.add_argument('--beat-mask-ts', type=arg_tools.str_bool_with_default_error, default=False)
|
||||
|
||||
@classmethod
|
||||
def setup_dictionary(cls, args, **kwargs):
|
||||
dictionary = None
|
||||
output_dictionary = None
|
||||
if args.data:
|
||||
paths = utils.split_paths(args.data)
|
||||
assert len(paths) > 0
|
||||
dictionary = CompoundDictionary.load(os.path.join(paths[0], "dict.txt"))
|
||||
logger.info("dictionary: {} types".format(len(dictionary)))
|
||||
output_dictionary = dictionary
|
||||
if args.output_dictionary_size >= 0:
|
||||
raise NotImplementedError
|
||||
# output_dictionary = TruncatedDictionary(
|
||||
# dictionary, args.output_dictionary_size
|
||||
# )
|
||||
return (dictionary, output_dictionary)
|
||||
|
||||
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
||||
paths = utils.split_paths(self.args.data)
|
||||
assert len(paths) > 0
|
||||
|
||||
data_path = paths[(epoch - 1) % len(paths)]
|
||||
split_path = os.path.join(data_path, split)
|
||||
|
||||
dataset = data_utils.load_indexed_dataset(
|
||||
split_path, self.dictionary, self.args.dataset_impl, combine=combine
|
||||
)
|
||||
if dataset is None:
|
||||
raise FileNotFoundError(
|
||||
"Dataset not found: {} ({})".format(split, split_path)
|
||||
)
|
||||
|
||||
dataset = ExtendedWrapperDataset(dataset)
|
||||
|
||||
dataset = MayRemoveShortSizeSamplesDataset(dataset, 2) # Delete empty samples
|
||||
|
||||
assert self.args.eob_token in self.dictionary
|
||||
eob_index = self.dictionary.index(self.args.eob_token)
|
||||
eoc_index = None
|
||||
|
||||
data_name = self.args.data.replace('\\', '/')
|
||||
data_name = data_name[:-1] if self.args.data.endswith('/') else data_name
|
||||
data_name = data_name.split('/')[-1]
|
||||
|
||||
take_bos_as_bar = getattr(self.args, 'take_bos_as_bar', False)
|
||||
|
||||
dataset = MusicMonolingualDataset2(dataset)
|
||||
|
||||
dataset = ChunkSequenceDataset2(
|
||||
dataset, self.source_dictionary,
|
||||
eob_index, eoc_index,
|
||||
chunking_scheme=self.args.chunking_scheme,
|
||||
chunking_length=getattr(self.args, 'fixed_chunking_length', None),
|
||||
dataset_name=split,
|
||||
cache_data_label=data_name,
|
||||
cache_sequence=True,
|
||||
offset=1 if take_bos_as_bar else 0,
|
||||
take_bos_as_bar=take_bos_as_bar,
|
||||
bos_index=self.source_dictionary.eos_index
|
||||
)
|
||||
|
||||
dataset = AddBeatDataset(dataset, self.source_dictionary, cache_data_label=data_name, dataset_name=split,
|
||||
mask_ts_instead_of_tempo=self.args.beat_mask_ts)
|
||||
|
||||
max_size_split = getattr(self.args, 'max_size_%s' % split, None)
|
||||
dataset = MayRemoveLongSizeSamplesDataset(dataset, max_size_split)
|
||||
|
||||
truncate_length = getattr(self.args, 'truncate_%s' % split, None)
|
||||
dataset = TruncateMusicDataset2(dataset, truncate_length)
|
||||
|
||||
dataset = PrefixTagDataset(dataset, 0 if take_bos_as_bar or self.args.chunking_scheme != 'bar_aware' else 1)
|
||||
|
||||
dataset = PostProcessDataset2(dataset)
|
||||
|
||||
self.datasets[split] = dataset
|
||||
|
||||
logger.info('loaded %d samples for %s' % (len(self.datasets[split]), split))
|
||||
|
||||
def build_generator(
|
||||
self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None
|
||||
):
|
||||
if seq_gen_cls is None:
|
||||
seq_gen_cls = MuseformerSequenceGenerator
|
||||
return super().build_generator(models, args, seq_gen_cls=seq_gen_cls, extra_gen_cls_kwargs=extra_gen_cls_kwargs)
|
|
@ -0,0 +1,8 @@
|
|||
from fairseq.sequence_generator import SequenceGenerator
|
||||
|
||||
|
||||
class MuseformerSequenceGenerator(SequenceGenerator):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.vocab_size = self.model.single_model.decoder.valid_dictionary_len
|
||||
self.beam_size = min(self.beam_size, self.vocab_size - 1)
|
|
@ -0,0 +1,48 @@
|
|||
def add_submodule_args(cls, parser, submodules_attr_name='_submodules'):
|
||||
submodules = getattr(cls, submodules_attr_name, None)
|
||||
if submodules is None:
|
||||
return
|
||||
for submodule in submodules:
|
||||
if hasattr(submodule, 'add_args'):
|
||||
submodule.add_args(parser)
|
||||
|
||||
|
||||
def str_bool(c):
|
||||
c_lower = c.lower()
|
||||
if c_lower in ('true', 'yes', 'y'):
|
||||
return True
|
||||
elif c_lower in ('false', 'no', 'n'):
|
||||
return False
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def str_bool_with_default_error(c):
|
||||
r = str_bool(c)
|
||||
if r is None:
|
||||
raise ValueError('Value "%s" is not valid.' % c)
|
||||
else:
|
||||
return r
|
||||
|
||||
|
||||
def str_to_type_with_specific_word_as_none(type_func, none_word):
|
||||
def f(x):
|
||||
return None if x == none_word else type_func(x)
|
||||
return f
|
||||
|
||||
|
||||
def comma_split_tuple_of_specific_type(type_func):
|
||||
def inner(x):
|
||||
return tuple([type_func(item) for item in x.split(',')])
|
||||
return inner
|
||||
|
||||
|
||||
def possibly_extend_tuple(c, n):
|
||||
assert isinstance(c, tuple)
|
||||
if len(c) == 1 and n > 1:
|
||||
c = c * n
|
||||
else:
|
||||
assert len(c) == n, \
|
||||
"%s for %d layers, len(c) == %d, type(c) == %s" % \
|
||||
(str(c), n, len(c), str(type(c)))
|
||||
return c
|
|
@ -0,0 +1,158 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def may_bi_op(
|
||||
op1,
|
||||
op2,
|
||||
x: (torch.Tensor, tuple),
|
||||
seq_len_1: int,
|
||||
seq_len_2: int,
|
||||
out_dim: Optional[int],
|
||||
as_tuple=False,
|
||||
):
|
||||
"""
|
||||
Do two different projections over two different parts of x.
|
||||
:param op1:
|
||||
:param op2:
|
||||
:param x: input, either tensor or tuple of two tensors
|
||||
:param seq_len_1: int
|
||||
:param seq_len_2: int
|
||||
:param out_dim: int
|
||||
:param as_tuple: bool
|
||||
:return:
|
||||
"""
|
||||
if isinstance(x, torch.Tensor):
|
||||
seq_len, bsz, _ = x.shape
|
||||
assert seq_len == seq_len_1 + seq_len_2
|
||||
if id(op1) == id(op2):
|
||||
if as_tuple:
|
||||
r = op1(x)
|
||||
return r[:seq_len_1], r[seq_len_1:]
|
||||
return op1(x)
|
||||
x1 = x[:seq_len_1]
|
||||
x2 = x[seq_len_1:]
|
||||
assert x2.shape[0] == seq_len_2
|
||||
else:
|
||||
x1, x2 = x
|
||||
assert x1 is None or x1.shape[0] == seq_len_1
|
||||
assert x2 is None or x2.shape[0] == seq_len_2
|
||||
bsz = x2.shape[1]
|
||||
|
||||
if as_tuple:
|
||||
return op1(x1) if op1 is not None and (x1 is not None and seq_len_1 > 0) else x1, \
|
||||
op2(x2) if op2 is not None and x2 is not None else x2
|
||||
out = x1.new_empty(seq_len_1 + seq_len_2, bsz, out_dim)
|
||||
if seq_len_1 > 0:
|
||||
out[:seq_len_1] = op1(x1) if op1 is not None else x1
|
||||
if seq_len_2 > 0:
|
||||
out[seq_len_1:] = op2(x2) if op2 is not None else x2
|
||||
return out
|
||||
|
||||
|
||||
def bi_projection(
|
||||
proj1: torch.nn.Linear,
|
||||
proj2: torch.nn.Linear,
|
||||
x: (torch.Tensor, tuple),
|
||||
seq_len_1: int,
|
||||
seq_len_2: int,
|
||||
as_tuple=False,
|
||||
):
|
||||
"""
|
||||
Do two different projections over two different parts of x.
|
||||
:param proj1: Linear instance
|
||||
:param proj2: Linear instance
|
||||
:param x: input, either tensor or tuple of two tensors
|
||||
:param seq_len_1: int
|
||||
:param seq_len_2: int
|
||||
:param as_tuple:
|
||||
:return:
|
||||
"""
|
||||
out_dim = proj2.weight.shape[0]
|
||||
return may_bi_op(proj1, proj2, x, seq_len_1, seq_len_2, out_dim, as_tuple=as_tuple)
|
||||
|
||||
|
||||
def may_bi_add(x, add1, add2, seq_len_1, seq_len_2):
|
||||
if add1 is None and add2 is None:
|
||||
return x
|
||||
if isinstance(x, torch.Tensor):
|
||||
seq_len, bsz = x.shape[:2]
|
||||
assert seq_len == seq_len_1 + seq_len_2
|
||||
x1 = x[:seq_len_1]
|
||||
x2 = x[seq_len_1:]
|
||||
assert x2.shape[0] == seq_len_2
|
||||
else:
|
||||
x1, x2 = x
|
||||
assert x1.shape[0] == seq_len_1
|
||||
assert x2.shape[0] == seq_len_2
|
||||
bsz = x2.shape[1]
|
||||
out = x1.new_empty(seq_len_1 + seq_len_2, bsz, *x1.shape[2:])
|
||||
if add1 is None:
|
||||
out[:seq_len_1] = x1
|
||||
else:
|
||||
out[:seq_len_1] = x1 + add1
|
||||
if add2 is None:
|
||||
out[seq_len_1:] = x2
|
||||
else:
|
||||
out[seq_len_1:] = x2 + add2
|
||||
return out
|
||||
|
||||
|
||||
def may_bi_cat(x1, x2, dim=0):
|
||||
if x1 is None or x1.shape[dim] == 0:
|
||||
return x2
|
||||
if x2 is None or x1.shape[dim] == 0:
|
||||
return x1
|
||||
return torch.cat((x1, x2), dim=dim)
|
||||
|
||||
|
||||
def pad_embed_first_dim(x, pad_len):
|
||||
if pad_len is None or pad_len == 0:
|
||||
return x
|
||||
shape = x.shape
|
||||
first_dim = shape[0]
|
||||
new_shape = ((first_dim + pad_len,) + shape[1:])
|
||||
r = x.new_zeros(*new_shape)
|
||||
r[:first_dim] = x
|
||||
return r
|
||||
|
||||
|
||||
def pad_2d(x, pad_len_1, pad_len_2, pad_value):
|
||||
if pad_len_1 is None:
|
||||
pad_len_1 = 0
|
||||
if pad_len_2 is None:
|
||||
pad_len_2 = 0
|
||||
if pad_len_1 == 0 and pad_len_2 == 0:
|
||||
return x
|
||||
bsz, len_1, len_2 = x.shape
|
||||
r = x.new_full((bsz, len_1 + pad_len_1, len_2 + pad_len_2), pad_value)
|
||||
r[:, :len_1, :len_2] = x
|
||||
return r
|
||||
|
||||
|
||||
def projection_and_pad(proj: torch.nn.Linear, x: torch.Tensor, padded_len: int):
|
||||
"""
|
||||
Do projection and pad to a specific length. Combine together to save memory.
|
||||
:param proj:
|
||||
:param x: (seq_len, bsz, in_dim)
|
||||
:param padded_len: target seq_len
|
||||
:return:
|
||||
"""
|
||||
seq_len, bsz = x.shape[:2]
|
||||
if seq_len == padded_len:
|
||||
return proj(x)
|
||||
assert padded_len > seq_len
|
||||
out_dim = proj.weight.shape[0]
|
||||
out = x.new_zeros((padded_len, bsz, out_dim))
|
||||
out[:seq_len] = proj(x)
|
||||
return out
|
||||
|
||||
|
||||
def transfer_chunk_points_to_ranges(chunk_points):
|
||||
"""
|
||||
|
||||
:param chunk_points: (bsz, max_chunk + 1)
|
||||
:return:
|
||||
"""
|
||||
return torch.stack((chunk_points[:, :-1], chunk_points[:, 1:]), dim=-1) # (bsz, max_chunk, 2)
|
|
@ -0,0 +1,41 @@
|
|||
import torch
|
||||
from fairseq.data import data_utils
|
||||
|
||||
from ..datasets.chunk_sequence_dataset_2 import get_bar_chunk_points
|
||||
|
||||
|
||||
def construct_bar_chunk_info(reg_tokens, src_lengths, eob, begin_idx=0, only_bos_prefix=True, device='cpu'):
|
||||
"""
|
||||
|
||||
:param reg_tokens: (bsz, seq_len)
|
||||
:param src_lengths: (bsz,)
|
||||
:param eob:
|
||||
:param begin_idx:
|
||||
:param only_bos_prefix:
|
||||
:param device:
|
||||
:return:
|
||||
"""
|
||||
assert only_bos_prefix
|
||||
chunk_points = []
|
||||
num_chunks = []
|
||||
should_minus_one = []
|
||||
bsz = len(reg_tokens)
|
||||
for idx, (sample_reg_tokens, sample_length) in enumerate(zip(reg_tokens, src_lengths)):
|
||||
sample_reg_tokens = sample_reg_tokens[:sample_length]
|
||||
sample_chunk_points, sample_is_complete_bar = get_bar_chunk_points(
|
||||
sample_reg_tokens,
|
||||
eob, begin_idx=begin_idx
|
||||
) #
|
||||
chunk_points.append(sample_chunk_points)
|
||||
sample_num_chunks = len(sample_chunk_points) - 1
|
||||
num_chunks.append(sample_num_chunks)
|
||||
should_minus_one.append(not sample_is_complete_bar and sample_num_chunks > 0)
|
||||
chunk_points = data_utils.collate_tokens(
|
||||
chunk_points, 0
|
||||
).to(device) # (bsz, max_chunk + 1)
|
||||
num_chunks = torch.tensor(num_chunks, device=device)
|
||||
should_minus_one = torch.tensor(should_minus_one, device=device).long()
|
||||
num_complete_chunks = num_chunks - should_minus_one
|
||||
num_pref = torch.tensor((1,) * bsz, device=device)
|
||||
|
||||
return chunk_points, num_chunks, num_complete_chunks, num_pref
|
|
@ -0,0 +1,8 @@
|
|||
def singleton(cls):
|
||||
_instance = {}
|
||||
|
||||
def inner(*args, **kwargs):
|
||||
if cls not in _instance:
|
||||
_instance[cls] = cls(*args, **kwargs)
|
||||
return _instance[cls]
|
||||
return inner
|
|
@ -0,0 +1,15 @@
|
|||
DATA_DIR=data-bin/lmd6remi
|
||||
MODEL_NAME=mf-lmd6remi-$1
|
||||
|
||||
OMP_NUM_THREADS=$(cat /proc/cpuinfo| grep "processor"| wc -l)
|
||||
NUM_WORKERS=$OMP_NUM_THREADS
|
||||
|
||||
fairseq-interactive $DATA_DIR \
|
||||
--path checkpoints/$MODEL_NAME/$2 \
|
||||
--user-dir museformer \
|
||||
--task museformer_language_modeling \
|
||||
--sampling --sampling-topk 8 --beam 1 --nbest 1 \
|
||||
--min-len 8192 \
|
||||
--max-len-b 20480 \
|
||||
--num-workers $NUM_WORKERS \
|
||||
--seed $3
|
|
@ -0,0 +1,82 @@
|
|||
import os
|
||||
import argparse
|
||||
import re
|
||||
|
||||
|
||||
def load_midi_processor():
|
||||
if midi_processor_cls is not None:
|
||||
return midi_processor_cls
|
||||
|
||||
import midiprocessor
|
||||
midi_processor_cls = midiprocessor
|
||||
return midi_processor_cls
|
||||
|
||||
|
||||
class GenerationLogExtractor(object):
|
||||
@classmethod
|
||||
def add_args(cls, parser):
|
||||
parser.add_argument('--base_name', default='random')
|
||||
parser.add_argument('--start_idx', type=int, default=1)
|
||||
parser.add_argument('--process-token-str-method', default='default')
|
||||
|
||||
@classmethod
|
||||
def build(cls, args):
|
||||
return cls(base_name=args.base_name, start_idx=args.start_idx,
|
||||
process_token_str_method=args.process_token_str_method)
|
||||
|
||||
def __init__(self, base_name='random', start_idx=1, process_token_str_method='default'):
|
||||
self.base_name = base_name
|
||||
self.start_idx = start_idx
|
||||
self.process_token_str_method = process_token_str_method
|
||||
|
||||
def do(self, log_path, token_output_dir, base_name=None, start_idx=None, process_token_str_method=None):
|
||||
if base_name is None:
|
||||
base_name = self.base_name
|
||||
if start_idx is None:
|
||||
start_idx = self.start_idx
|
||||
if process_token_str_method is None:
|
||||
process_token_str_method = self.process_token_str_method
|
||||
process_token_str_func = self.get_process_token_str_func(process_token_str_method)
|
||||
return self.extract_midi_tokens_from_output_log(
|
||||
log_path, token_output_dir, base_name, start_idx, process_token_str_func
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_process_token_str_func(cls, method):
|
||||
if method == 'default':
|
||||
return cls.default_process_token_str
|
||||
else:
|
||||
raise ValueError(method)
|
||||
|
||||
@staticmethod
|
||||
def default_process_token_str(token_str):
|
||||
return token_str.strip()
|
||||
|
||||
@staticmethod
|
||||
def extract_midi_tokens_from_output_log(log_path, token_output_dir, base_name, start_idx, process_token_str):
|
||||
with open(log_path, 'r') as f:
|
||||
s = f.read()
|
||||
r = re.findall('D-\d+?\t.+?\t(.+?)\n', s)
|
||||
|
||||
os.makedirs(token_output_dir, exist_ok=True)
|
||||
for idx, token_str in enumerate(r, start=start_idx):
|
||||
token_str = process_token_str(token_str)
|
||||
with open(os.path.join(token_output_dir, '%s-%d.txt') % (base_name, idx), 'w') as f:
|
||||
f.write(token_str)
|
||||
num_songs = len(r)
|
||||
print('Extract %d songs from log. (%s-%d ~ %s-%d)' %
|
||||
(num_songs, base_name, start_idx, base_name, start_idx + num_songs - 1))
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
GenerationLogExtractor.add_args(parser)
|
||||
parser.add_argument('log_path')
|
||||
parser.add_argument('token_output_dir')
|
||||
args = parser.parse_args()
|
||||
generation_log_extractor = GenerationLogExtractor.build(args)
|
||||
generation_log_extractor.do(args.log_path, args.token_output_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,102 @@
|
|||
import os
|
||||
import argparse
|
||||
import midiprocessor as mp
|
||||
|
||||
|
||||
class MidiGenerator(object):
|
||||
@classmethod
|
||||
def add_args(cls, parser):
|
||||
parser.add_argument('--encoding-method', required=True)
|
||||
parser.add_argument('--process-token-str-list-method', default='default')
|
||||
parser.add_argument('--input-dir', required=True)
|
||||
parser.add_argument('--output-dir')
|
||||
parser.add_argument('--suffix', default='.txt')
|
||||
parser.add_argument('--skip-error', type=lambda x: x.lower() == 'true', default=True)
|
||||
|
||||
@classmethod
|
||||
def build(cls, args):
|
||||
return cls(args.encoding_method, process_token_str_list_method=args.process_token_str_list_method)
|
||||
|
||||
def __init__(self, encoding_method, process_token_str_list_method='default'):
|
||||
self.encoding_method = encoding_method
|
||||
self.process_token_str_list_method = process_token_str_list_method
|
||||
self.process_token_str_list_func = self.get_process_token_str_list_func(self.process_token_str_list_method)
|
||||
self.midi_decoder = mp.MidiDecoder(self.encoding_method)
|
||||
|
||||
def do(self, token_text_path, output_midi_path):
|
||||
with open(token_text_path, 'r') as f:
|
||||
x = f.read()
|
||||
line = self._process_token_str(x)
|
||||
midi_obj = self.midi_decoder.decode_from_token_str_list(line)
|
||||
dir_name = os.path.dirname(output_midi_path)
|
||||
if dir_name not in ('', '.'):
|
||||
os.makedirs(dir_name, exist_ok=True)
|
||||
midi_obj.dump(output_midi_path)
|
||||
|
||||
def do_batch(self, token_output_dir, midi_output_dir, suffix='.txt', skip_error=True):
|
||||
count = 0
|
||||
error_count = 0
|
||||
for root_dir, dirs, files in os.walk(token_output_dir):
|
||||
for file_name in files:
|
||||
if not file_name.endswith(suffix):
|
||||
continue
|
||||
file_path = os.path.join(root_dir, file_name)
|
||||
relative_file_path = os.path.relpath(file_path, token_output_dir)
|
||||
|
||||
base_name = file_name
|
||||
save_dir = os.path.dirname(os.path.join(midi_output_dir, relative_file_path))
|
||||
save_path = os.path.join(save_dir, base_name + '.mid')
|
||||
|
||||
print('parsing', file_path)
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
self.do(file_path, save_path)
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except:
|
||||
print('Error:', file_path)
|
||||
error_count += 1
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
if skip_error:
|
||||
continue
|
||||
else:
|
||||
count += 1
|
||||
|
||||
return count, error_count
|
||||
|
||||
def _process_token_str(self, x):
|
||||
x = x.split('\n')[0]
|
||||
x = x.strip().split(' ')
|
||||
x = self.process_token_str_list_func(x)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def get_process_token_str_list_func(cls, method):
|
||||
if method == 'default':
|
||||
return cls.default_process_token_str_list
|
||||
else:
|
||||
raise ValueError(method)
|
||||
|
||||
@staticmethod
|
||||
def default_process_token_str_list(x):
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS)
|
||||
|
||||
MidiGenerator.add_args(parser)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
midi_generator = MidiGenerator.build(args)
|
||||
count, error_count = midi_generator.do_batch(
|
||||
args.input_dir, getattr(args, 'output_dir', args.input_dir),
|
||||
suffix=args.suffix, skip_error=args.skip_error
|
||||
)
|
||||
print('Done. %d succeed! %d failed.' % (count, error_count))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,46 @@
|
|||
import os
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def read_file_list(file_list):
|
||||
path_list = []
|
||||
with open(file_list, 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line == '':
|
||||
continue
|
||||
path_list.append(line)
|
||||
return path_list
|
||||
|
||||
|
||||
def main(file_list, token_dir, save_dir, file_list_suffix='.txt'):
|
||||
file_list_base = os.path.basename(file_list)[:-len(file_list_suffix)]
|
||||
file_list = read_file_list(file_list)
|
||||
# print('processing %d files...' % len(file_list))
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
with open(os.path.join(save_dir, file_list_base + '.data'), 'w') as save_f:
|
||||
for file_path in tqdm(file_list):
|
||||
file_name = os.path.basename(file_path)
|
||||
token_path = os.path.join(token_dir, file_name + '.txt')
|
||||
with open(token_path, 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line == '':
|
||||
continue
|
||||
save_f.write(line + '\n')
|
||||
break
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('file_list')
|
||||
parser.add_argument('token_dir')
|
||||
parser.add_argument('save_dir')
|
||||
parser.add_argument('--file_list_suffix', default='.txt')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args.file_list, args.token_dir, args.save_dir, args.file_list_suffix)
|
||||
|
||||
print('Done')
|
|
@ -0,0 +1,44 @@
|
|||
MODEL_NAME='mf-lmd6remi-1'
|
||||
|
||||
DATA_DIR=data-bin/lmd6remi # Data dir
|
||||
|
||||
# 4 GPUs
|
||||
UPDATE_FREQ=1
|
||||
|
||||
PEAK_LR=5e-4 # Peak learning rate, adjust as needed
|
||||
WARMUP_UPDATES=16000 # Warmup the learning rate over this many updates
|
||||
|
||||
OMP_NUM_THREADS=$(cat /proc/cpuinfo| grep "processor"| wc -l)
|
||||
|
||||
ulimit -n 4096
|
||||
|
||||
mkdir -p log
|
||||
|
||||
fairseq-train \
|
||||
$DATA_DIR \
|
||||
--user-dir museformer \
|
||||
--task museformer_language_modeling \
|
||||
--arch museformer_lm \
|
||||
--num-layers 4 \
|
||||
--truncate-train 15360 \
|
||||
--truncate-valid 10240 \
|
||||
--batch-size 1 \
|
||||
--update-freq $UPDATE_FREQ \
|
||||
--optimizer adam \
|
||||
--adam-betas '(0.9, 0.98)' \
|
||||
--adam-eps 1e-9 \
|
||||
--weight-decay 0.01 \
|
||||
--lr $PEAK_LR \
|
||||
--lr-scheduler inverse_sqrt \
|
||||
--warmup-updates $WARMUP_UPDATES \
|
||||
--max-update 1000000 \
|
||||
--validate-interval 1000000000 \
|
||||
--save-interval 1000000000 \
|
||||
--save-interval-updates 5000 \
|
||||
--fp16 \
|
||||
--log-format simple \
|
||||
--log-interval 10 \
|
||||
--tensorboard-logdir tb_log/$MODEL_NAME \
|
||||
--num-workers "$OMP_NUM_THREADS" \
|
||||
--save-dir checkpoints/$MODEL_NAME \
|
||||
| tee log/${MODEL_NAME}.log
|
|
@ -0,0 +1,13 @@
|
|||
DATA_DIR=data-bin/lmd6remi # Data dir
|
||||
|
||||
OMP_NUM_THREADS=$(cat /proc/cpuinfo| grep "processor"| wc -l)
|
||||
|
||||
fairseq-validate \
|
||||
$DATA_DIR \
|
||||
--user-dir museformer \
|
||||
--task museformer_language_modeling \
|
||||
--path checkpoints/mf-lmd6remi-$1/$2 \
|
||||
--batch-size 1 \
|
||||
--truncate-test $3 \
|
||||
--valid-subset test \
|
||||
--num-workers $OMP_NUM_THREADS
|
Загрузка…
Ссылка в новой задаче