This commit is contained in:
btyu 2022-10-30 20:11:25 +08:00
Родитель 8ad59b7ecb
Коммит 8ef9e354ae
69 изменённых файлов: 36731 добавлений и 0 удалений

Двоичные данные
img/museformer_visualization.png Normal file

Двоичный файл не отображается.

После

Ширина:  |  Высота:  |  Размер: 191 KiB

139
museformer/.gitignore поставляемый Normal file
Просмотреть файл

@ -0,0 +1,139 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
.idea/
.vscode/
data-bin/
cached_datasets/
checkpoints*/
log*/
tb_log*/
output*/

139
museformer/README.md Normal file
Просмотреть файл

@ -0,0 +1,139 @@
# Museformer
[Museformer: Transformer with Fine- and Coarse-Grained Attention for Music Generation](https://arxiv.org/abs/2210.10349), by Botao Yu, Peiling Lu, Rui Wang, Wei Hu, Xu Tan, Wei Ye, Shikun Zhang, Tao Qin, Tie-Yan Liu, NeurIPS 2022, is a Transformer with a novel fine- and coarse-grained attention (FC-Attention) for music generation. Specifically, with the fine-grained attention, a token of a specific bar directly attends to all the tokens of the bars that are most relevant to music structures (e.g., the previous 1st, 2nd, 4th and 8th bars, selected via similarity statistics); with the coarse-grained attention, a token only attends to the summarization of the other bars rather than each token of them so as to reduce the computational cost. The advantages are two-fold. First, it can capture both music structure-related correlations via the fine-grained attention, and other contextual information via the coarse-grained attention. Second, it is efficient and can model over 3X longer music sequences compared to its full-attention counterpart. Both objective and subjective experimental results demonstrate its ability to generate long music sequences with high quality and better structures.
demo: [link](https://ai-muzic.github.io/museformer)
<p align="center"><img src="../img/museformer_visualization.png" width="800"><br/>Information flow of FC-Attention. </p>
The following content describes the steps to run Museformer. All the commands are run at the root directory of Museformer (named as `root_dir`) unless specified.
## 1. Dataset
We use [the Lakh MIDI dataset](https://colinraffel.com/projects/lmd/) (LMD-full). Specifically, we first preprocess it as described in the Appendix of our paper. The final dataset (see the file lists [here](data/meta)) contains 29,940 MIDI files. Their time signatures are all 4/4, and the instruments are normalized to 6 basic ones: square synthesizer (80), piano (0), guitar (25), string (48), bass (43), drum, where in the parentheses are MIDI program IDs if applicable. We put all the MIDI files in `data/midi`.
Install [MidiProcessor](https://github.com/btyu/MidiProcessor). Then, encode the MIDI files into tokens:
```bash
midi_dir=data/midi
token_dir=data/token
mp-batch-encoding $midi_dir $token_dir --encoding-method REMIGEN2 --normalize-pitch-value --remove-empty-bars --ignore-ts --sort-insts 6tracks_cst1
```
where the arguments are explained as follows:
- `encoding-method`: the representation method. We use the one called REMIGEN2, a REMI-like representation method that represents musical information in separate tokens.
- `normalize-pitch-value`: normalize pitches so as to make the song *C major* or *A minor*.
- `remove-empty-bars`: remove empty bars at the beginning or the end of a song.
- `ignore-ts`: do not add the tokens of time signature. Since the used data are all 4/4, we do not encode it.
- `sort-insts`: designate a method that sorts the instruments. `6tracks_cst1` sorts the instruments in order: square synthesizer, drum, bass, guitar, piano, string.
After encoding, you should see the token representation of each MIDI file in `output_dir`.
Then, run the following command to gather the tokens for each split.
```bash
token_dir=data/token
split_dir=data/split
for split in train valid test :
do python tools/generate_token_data_by_file_list.py data/meta/${split}.txt $token_dir $split_dir ;
done
```
Next, use `fairseq-preprocess` to make binary data:
```bash
split_dir=data/split
data_bin_dir=data-bin/lmd6remi
mkdir -p data-bin
fairseq-preprocess \
--only-source \
--trainpref $split_dir/train.data \
--validpref $split_dir/valid.data \
--testpref $split_dir/test.data \
--destdir $data_bin_dir \
--srcdict data/meta/dict.txt
```
Now, you should see the binary data in `data-bin/lmd6remi`.
## 2. Environment
The implementation of Museformer relies on specific hardware and software environment.
- Hardware environment: we recommend Nvidia V100 32GB. Other GPUs are not tested.
- Software environment:
please ensure these packages are installed:
```
CUDA: 11.3.1
Python: 3.8
fairseq: 0.10.2
tensorboardX: 2.2
```
And install [triton](https://github.com/openai/triton) at an arbitrary directory except `root_dir`:
```bash
cd /path/to/a/directory/for/triton
git clone https://github.com/openai/triton.git
cd triton/python
pip install -e .
```
## 3. Train
Run the following command to train Museformer:
```bash
bash ttrain/mf-lmd6remi-1.sh
```
In our experiment, we run it on 4 GPUs, and the batch size is set to 1, so the real batch size is 4. Current implementation only supports batch size = 1. You may change `UPDATE_FREQ` to modify the real batch size.
In your first run, it may take some time to build up auxiliary information and compile CUDA kernels, so you may take a cup of coffee at this moment.
## 4. Evaluation
You can obtain perplexity on the test set with the following command:
```bash
bash tval/val__mf-lmd6remi-x.sh 1 checkpoint_best.pt 10240
```
The number `10240` indicates the maximum sequence length for calculation.
## 5. Inference
Use the following command to generate 5 music pieces, with the random seed set to 1:
```bash
mkdir -p output_log
seed=1
printf '\n\n\n\n\n' | bash tgen/generation__mf-lmd6remi-x.sh 1 checkpoint_best.pt ${seed} | tee output_log/generation.log
```
The number of `\n` controls the number of generated music pieces. The generation would take a while. Once done, the generation log will be saved at `output_log/generation.log`.
Then, use the following command to extract the generated token sequences from the generation log:
```bash
python tools/batch_extract_log.py output_log/generation.log output/generation --start_idx 1
```
You should see token representation of each generated music piece in `output/generation`.
Finally, run the following command to convert the token sequences into MIDI files:
```bash
python tools/batch_generate_midis.py --encoding-method REMIGEN2 --input-dir output/generation --output-dir output/generation
```
You should see the MIDI files in `output/generation`.

Просмотреть файл

@ -0,0 +1,556 @@
b-0 1
b-1 1
o-0 1
o-1 1
o-2 1
o-3 1
o-4 1
o-5 1
o-6 1
o-7 1
o-8 1
o-9 1
o-10 1
o-11 1
o-12 1
o-13 1
o-14 1
o-15 1
o-16 1
o-17 1
o-18 1
o-19 1
o-20 1
o-21 1
o-22 1
o-23 1
o-24 1
o-25 1
o-26 1
o-27 1
o-28 1
o-29 1
o-30 1
o-31 1
o-32 1
o-33 1
o-34 1
o-35 1
o-36 1
o-37 1
o-38 1
o-39 1
o-40 1
o-41 1
o-42 1
o-43 1
o-44 1
o-45 1
o-46 1
o-47 1
o-48 1
o-49 1
o-50 1
o-51 1
o-52 1
o-53 1
o-54 1
o-55 1
o-56 1
o-57 1
o-58 1
o-59 1
o-60 1
o-61 1
o-62 1
o-63 1
i-0 1
i-25 1
i-43 1
i-48 1
i-80 1
i-128 1
p-0 1
p-1 1
p-2 1
p-3 1
p-4 1
p-5 1
p-6 1
p-7 1
p-8 1
p-9 1
p-10 1
p-11 1
p-12 1
p-13 1
p-14 1
p-15 1
p-16 1
p-17 1
p-18 1
p-19 1
p-20 1
p-21 1
p-22 1
p-23 1
p-24 1
p-25 1
p-26 1
p-27 1
p-28 1
p-29 1
p-30 1
p-31 1
p-32 1
p-33 1
p-34 1
p-35 1
p-36 1
p-37 1
p-38 1
p-39 1
p-40 1
p-41 1
p-42 1
p-43 1
p-44 1
p-45 1
p-46 1
p-47 1
p-48 1
p-49 1
p-50 1
p-51 1
p-52 1
p-53 1
p-54 1
p-55 1
p-56 1
p-57 1
p-58 1
p-59 1
p-60 1
p-61 1
p-62 1
p-63 1
p-64 1
p-65 1
p-66 1
p-67 1
p-68 1
p-69 1
p-70 1
p-71 1
p-72 1
p-73 1
p-74 1
p-75 1
p-76 1
p-77 1
p-78 1
p-79 1
p-80 1
p-81 1
p-82 1
p-83 1
p-84 1
p-85 1
p-86 1
p-87 1
p-88 1
p-89 1
p-90 1
p-91 1
p-92 1
p-93 1
p-94 1
p-95 1
p-96 1
p-97 1
p-98 1
p-99 1
p-100 1
p-101 1
p-102 1
p-103 1
p-104 1
p-105 1
p-106 1
p-107 1
p-108 1
p-109 1
p-110 1
p-111 1
p-112 1
p-113 1
p-114 1
p-115 1
p-116 1
p-117 1
p-118 1
p-119 1
p-120 1
p-121 1
p-122 1
p-123 1
p-124 1
p-125 1
p-126 1
p-127 1
p-128 1
p-129 1
p-130 1
p-131 1
p-132 1
p-133 1
p-134 1
p-135 1
p-136 1
p-137 1
p-138 1
p-139 1
p-140 1
p-141 1
p-142 1
p-143 1
p-144 1
p-145 1
p-146 1
p-147 1
p-148 1
p-149 1
p-150 1
p-151 1
p-152 1
p-153 1
p-154 1
p-155 1
p-156 1
p-157 1
p-158 1
p-159 1
p-160 1
p-161 1
p-162 1
p-163 1
p-164 1
p-165 1
p-166 1
p-167 1
p-168 1
p-169 1
p-170 1
p-171 1
p-172 1
p-173 1
p-174 1
p-175 1
p-176 1
p-177 1
p-178 1
p-179 1
p-180 1
p-181 1
p-182 1
p-183 1
p-184 1
p-185 1
p-186 1
p-187 1
p-188 1
p-189 1
p-190 1
p-191 1
p-192 1
p-193 1
p-194 1
p-195 1
p-196 1
p-197 1
p-198 1
p-199 1
p-200 1
p-201 1
p-202 1
p-203 1
p-204 1
p-205 1
p-206 1
p-207 1
p-208 1
p-209 1
p-210 1
p-211 1
p-212 1
p-213 1
p-214 1
p-215 1
p-216 1
p-217 1
p-218 1
p-219 1
p-220 1
p-221 1
p-222 1
p-223 1
p-224 1
p-225 1
p-226 1
p-227 1
p-228 1
p-229 1
p-230 1
p-231 1
p-232 1
p-233 1
p-234 1
p-235 1
p-236 1
p-237 1
p-238 1
p-239 1
p-240 1
p-241 1
p-242 1
p-243 1
p-244 1
p-245 1
p-246 1
p-247 1
p-248 1
p-249 1
p-250 1
p-251 1
p-252 1
p-253 1
p-254 1
p-255 1
d-0 1
d-1 1
d-2 1
d-3 1
d-4 1
d-5 1
d-6 1
d-7 1
d-8 1
d-9 1
d-10 1
d-11 1
d-12 1
d-13 1
d-14 1
d-15 1
d-16 1
d-17 1
d-18 1
d-19 1
d-20 1
d-21 1
d-22 1
d-23 1
d-24 1
d-25 1
d-26 1
d-27 1
d-28 1
d-29 1
d-30 1
d-31 1
d-32 1
d-33 1
d-34 1
d-35 1
d-36 1
d-37 1
d-38 1
d-39 1
d-40 1
d-41 1
d-42 1
d-43 1
d-44 1
d-45 1
d-46 1
d-47 1
d-48 1
d-49 1
d-50 1
d-51 1
d-52 1
d-53 1
d-54 1
d-55 1
d-56 1
d-57 1
d-58 1
d-59 1
d-60 1
d-61 1
d-62 1
d-63 1
d-64 1
d-65 1
d-66 1
d-67 1
d-68 1
d-69 1
d-70 1
d-71 1
d-72 1
d-73 1
d-74 1
d-75 1
d-76 1
d-77 1
d-78 1
d-79 1
d-80 1
d-81 1
d-82 1
d-83 1
d-84 1
d-85 1
d-86 1
d-87 1
d-88 1
d-89 1
d-90 1
d-91 1
d-92 1
d-93 1
d-94 1
d-95 1
d-96 1
d-97 1
d-98 1
d-99 1
d-100 1
d-101 1
d-102 1
d-103 1
d-104 1
d-105 1
d-106 1
d-107 1
d-108 1
d-109 1
d-110 1
d-111 1
d-112 1
d-113 1
d-114 1
d-115 1
d-116 1
d-117 1
d-118 1
d-119 1
d-120 1
d-121 1
d-122 1
d-123 1
d-124 1
d-125 1
d-126 1
d-127 1
v-0 1
v-1 1
v-2 1
v-3 1
v-4 1
v-5 1
v-6 1
v-7 1
v-8 1
v-9 1
v-10 1
v-11 1
v-12 1
v-13 1
v-14 1
v-15 1
v-16 1
v-17 1
v-18 1
v-19 1
v-20 1
v-21 1
v-22 1
v-23 1
v-24 1
v-25 1
v-26 1
v-27 1
v-28 1
v-29 1
v-30 1
v-31 1
t-0 1
t-1 1
t-2 1
t-3 1
t-4 1
t-5 1
t-6 1
t-7 1
t-8 1
t-9 1
t-10 1
t-11 1
t-12 1
t-13 1
t-14 1
t-15 1
t-16 1
t-17 1
t-18 1
t-19 1
t-20 1
t-21 1
t-22 1
t-23 1
t-24 1
t-25 1
t-26 1
t-27 1
t-28 1
t-29 1
t-30 1
t-31 1
t-32 1
t-33 1
t-34 1
t-35 1
t-36 1
t-37 1
t-38 1
t-39 1
t-40 1
t-41 1
t-42 1
t-43 1
t-44 1
t-45 1
t-46 1
t-47 1
t-48 1
<c> 1
</c> 1
u-0 1
u-1 1
u-2 1
u-3 1
u-4 1
u-5 1
u-6 1
u-7 1
u-8 1
u-9 1
u-10 1
u-11 1
u-12 1
u-13 1
u-14 1
u-15 1
u-16 1

Разница между файлами не показана из-за своего большого размера Загрузить разницу

23952
museformer/data/meta/train.txt Normal file

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Двоичные данные
museformer/data/midi/0ab8dd236dc77863a3ac4ad2b186616a.mid Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,2 @@
from .museformer_lm_task import MuseformerLanguageModelingTask
from .museformer_lm import MuseformerLanguageModel

Просмотреть файл

@ -0,0 +1,39 @@
from ..tools import arg_tools
def add_args(parser):
parser.add_argument('--attn-query-proj-bias', type=arg_tools.str_bool_with_default_error)
parser.add_argument('--attn-key-proj-bias', type=arg_tools.str_bool_with_default_error)
parser.add_argument('--attn-value-proj-bias', type=arg_tools.str_bool_with_default_error)
parser.add_argument('--attn-out-proj-bias', type=arg_tools.str_bool_with_default_error)
# valid: sum_then_reg, v2.1
parser.add_argument('--attn-sum-key2-proj-bias', type=arg_tools.str_bool_with_default_error)
parser.add_argument('--attn-sum-value2-proj-bias', type=arg_tools.str_bool_with_default_error)
parser.add_argument('--attn-share-key2-value2-proj-weight', type=arg_tools.str_bool_with_default_error)
parser.add_argument('--add-different-kqv-bias-for-sum-and-reg', type=arg_tools.str_bool_with_default_error)
parser.add_argument('--add-different-out-bias-for-sum-and-reg', type=arg_tools.str_bool_with_default_error)
parser.add_argument('--attn-share-query-proj', type=arg_tools.str_bool_with_default_error)
parser.add_argument('--attn-share-key-proj', type=arg_tools.str_bool_with_default_error)
parser.add_argument('--attn-share-value-proj', type=arg_tools.str_bool_with_default_error)
parser.add_argument('--attn-share-out-proj', type=arg_tools.str_bool_with_default_error)
def create_attention_v2_s1(
*args, implementation='mask', block_size=64, **kwargs
):
if implementation == 'triton':
from .self_attention_v2s1.blocksparse_rpe_self_attention_v2s1 import BlocksparseRpeSelfAttentionV2S1
return BlocksparseRpeSelfAttentionV2S1(
*args, block_size=block_size, **kwargs
)
return NotImplementedError(implementation)
def create_attention(
*args, attention_mode='v2s1', **kwargs
):
if attention_mode == 'v2s1': # v2.1
return create_attention_v2_s1(*args, **kwargs)
raise NotImplementedError(attention_mode)

Просмотреть файл

@ -0,0 +1,44 @@
import torch
from .....blocksparse import BlocksparseMatMul
def do_sample_av_mul_base(self, sample_attn_weights, sample_v, sample_layout, real_part, sample_idx, tgt_len):
if sample_layout is None:
_, head, head_dim = sample_v.shape
return sample_v.new_zeros(tgt_len, 1, head, head_dim)
sample_v = sample_v.transpose(0, 1)[:, None] # (head, 1, reg_len, head_dim)
dsd_matmul_key = (real_part, 'dsd_matmul', self.layer_sv, sample_idx)
if dsd_matmul_key in self.instant_pocket:
dsd_matmul = self.instant_pocket[dsd_matmul_key]
else:
dsd_matmul = BlocksparseMatMul(sample_layout, self.block_size, 'dsd',
device=sample_v.device)
self.instant_pocket[dsd_matmul_key] = dsd_matmul
sample_out = dsd_matmul(sample_attn_weights, sample_v) # (head, 1, tgt_len, head_dim)
sample_out = sample_out.permute(2, 1, 0, 3) # (tgt_len, 1, head, head_dim)
return sample_out
def do_av_mul_for_part(self, attn_weights_inc_part, v, attn_mask, real_part, tgt_len):
attn_weights_for_part = attn_weights_inc_part[real_part]
# samples list of (head, head_selected_blocks, block, block)
bsz = len(attn_weights_for_part)
attn_mask = attn_mask[real_part]
result = []
for sample_idx in range(bsz):
sample_v = v[:, sample_idx]
sample_attn_weights = attn_weights_for_part[sample_idx] # (head, head_selected_blocks, block, block)
sample_layout = attn_mask[sample_idx][0]
sample_out = do_sample_av_mul_base(self, sample_attn_weights, sample_v, sample_layout, real_part,
sample_idx, tgt_len)
result.append(sample_out)
if bsz > 1:
result = torch.cat(result, dim=1) # (tgt_len, bsz, num_heads, head_dim)
else:
result = result[0].contiguous()
return result

Просмотреть файл

@ -0,0 +1,56 @@
from .....blocksparse import BlocksparseMatMul
def do_sample_qk_scores_base(
self, sample_layout, sample_tgt, sample_src,
tgt_len, src_len, tgt_label, sample_idx
):
if sample_layout is None:
return None
# sample_layout: (1, tgt_block, src_block)
assert sample_tgt.shape == (tgt_len, self.num_heads, self.head_dim)
assert sample_src.shape == (src_len, self.num_heads, self.head_dim)
assert sample_layout.shape == (1, tgt_len // self.block_size, src_len // self.block_size), \
str(sample_layout.shape) + ' %d %d' % (tgt_len // self.block_size, src_len // self.block_size)
if tgt_len == 0 or src_len == 0:
return sample_tgt.new_empty(self.num_heads, 0, self.block_size, self.block_size)
sdd_matmul_key = (tgt_label, 'sdd_matmul', self.layer_sv, sample_idx)
if sdd_matmul_key in self.instant_pocket:
sdd_matmul = self.instant_pocket[sdd_matmul_key]
else:
sdd_matmul = BlocksparseMatMul(sample_layout, self.block_size, 'sdd',
device=sample_tgt.device, trans_b=True)
self.instant_pocket[sdd_matmul_key] = sdd_matmul
sample_attn_scores = sdd_matmul(
sample_tgt.transpose(0, 1)[:, None], # (heads, 1, sum_len, head_dim)
sample_src.transpose(0, 1)[:, None], # (heads, 1, reg_len, head_dim)
) # (heads, head_selected_blocks, block, block)
assert sample_attn_scores.shape[1] == int(sample_layout[0].sum())
return sample_attn_scores
def do_qk_scores_for_part(
self,
tgt, src,
bsz, tgt_len, src_len,
attn_mask, part_label,
):
# tgt: (tgt_len, bsz, num_heads, head_dim)
# src: (src_len, bsz, num_heads, head_dim)
part_attn_mask = attn_mask[part_label]
attn_scores = []
for idx in range(bsz):
sample_layout = part_attn_mask[idx][0] # (1, tgt_block, src_block)
sample_attn_scores = do_sample_qk_scores_base(
self,
sample_layout, tgt[:, idx], src[:, idx],
tgt_len, src_len, part_label, idx
)
attn_scores.append(sample_attn_scores)
attn_scores = {part_label: attn_scores}
return attn_scores

Просмотреть файл

@ -0,0 +1,55 @@
import torch
import torch.nn.functional as F
from .....kernels.range_fill import range_fill
def select_and_do_r_proj(self, rel_indices):
"""
:param rel_indices: relation list of real_part dict
:return:
"""
r_embed_list = []
r_modified_indices = []
for rel_idx, one_rel_indices in enumerate(rel_indices):
# Collect those indices for one kind of relative positions that are used for all the samples and real_parts
one_rel_used_indices = []
for real_part in one_rel_indices: # ('sr', 'rx')
one_rel_part_indices = one_rel_indices[real_part]
if one_rel_part_indices is None:
continue
for sample_rel_indices in one_rel_part_indices:
assert sample_rel_indices.shape[1:] == (self.block_size, self.block_size)
sample_used_rel_indices = torch.unique(sample_rel_indices) # (num_unique,)
one_rel_used_indices.append(sample_used_rel_indices)
one_rel_used_indices = torch.cat(one_rel_used_indices).unique()
rel_selected_embed = F.embedding(one_rel_used_indices, self.rel_embeddings[rel_idx], padding_idx=0)
rel_proj = getattr(self, 'rel%d_proj' % rel_idx, None)
if rel_proj is not None:
rel_selected_embed = rel_proj(rel_selected_embed)
label_transform = range_fill(
torch.stack((one_rel_used_indices, one_rel_used_indices + 1), dim=-1),
torch.arange(len(one_rel_used_indices), device=one_rel_used_indices.device),
self.num_rel_embeddings[rel_idx], 0
)
one_r_indices = {}
for real_part in one_rel_indices:
one_rel_part_indices = one_rel_indices[real_part]
if one_rel_part_indices is None:
one_r_indices[real_part] = None
continue
samples_r_indices = []
for sample_rel_indices in one_rel_part_indices:
sample_r_indices = label_transform[sample_rel_indices]
samples_r_indices.append(sample_r_indices)
one_r_indices[real_part] = samples_r_indices
r_embed_list.append(rel_selected_embed)
r_modified_indices.append(one_r_indices)
return r_embed_list, r_modified_indices

Просмотреть файл

@ -0,0 +1,52 @@
import torch
def indexing_sample_rpe_base(self, sample_r, sample_r_indices, sample_layout, tgt_label, tgt_len, sample_idx):
# sample_r: (heads, tgt_len, num_selected_distance)
# sample_r_indices: (head_selected_blocks, block, block)
# sample_layout: (1, tgt_block, src_block)
sample_r = sample_r.view(self.num_heads, tgt_len // self.block_size, self.block_size, -1)
sample_tgt_block_ids = sample_layout[0].nonzero()[:, 0] # (head_selected_blocks,)
temp_rpe = sample_r[
self.num_heads_arange, # (8, 1, 1, 1)
sample_tgt_block_ids[None, :, None, None], # (1, 4, 1, 1)
self.block_size_arange, # (1, 1, block_size, 1)
sample_r_indices[None], # (1, head_selected_blocks, block, block)
] # (head, head_selected_blocks, block, block)
return temp_rpe
def add_rpe_for_part(self, attn_scores, qs, r_list, rel_indices, bsz, query_len, attn_mask, part_label):
attn_mask = attn_mask[part_label] # samples list of tuple (layout, block_mask)
attn_scores_for_part = attn_scores[part_label] # samples list of (heads, head_selected_blocks, block, block)
attn_scores_for_part_with_rpe = [item for item in attn_scores_for_part]
for rel_idx in range(self.num_relation_types):
r_indices = rel_indices[rel_idx]
r_indices = r_indices[part_label]
if r_indices is None:
continue
r_embed = r_list[rel_idx].view(-1, self.num_heads, self.head_dim) \
# (num_selected_pos, heads, head_dim)
r_qs = qs[rel_idx] # (sum_len, bsz, heads, head_dim)
temp_r = torch.einsum("ibhd,jhd->bhij", r_qs, r_embed) # (bsz, heads, sum_len, num_selected_distance)
for sample_idx in range(bsz):
sample_r = temp_r[sample_idx] # (heads, sum_len, num_selected_distance)
sample_r_indices = r_indices[sample_idx] # (head_selected_blocks, block, block)
sample_layout = attn_mask[sample_idx][0]
temp_rpe = indexing_sample_rpe_base(
self, sample_r, sample_r_indices, sample_layout, part_label, query_len, sample_idx
)
attn_scores_for_part_with_rpe[sample_idx] = attn_scores_for_part_with_rpe[sample_idx] + temp_rpe
attn_scores_for_part_with_rpe = {
part_label: attn_scores_for_part_with_rpe
}
return attn_scores_for_part_with_rpe

Просмотреть файл

@ -0,0 +1,39 @@
import torch
from .....blocksparse import BlocksparseSoftmax
def do_attn_softmax_for_part(self, attn_scores, attn_mask, real_part, mask_again=False):
attn_scores_for_real_part = attn_scores[real_part] \
# samples list of (heads, head_selected_blocks, block, block)
attn_mask = attn_mask[real_part] # samples list of layout, block_mask
bsz = len(attn_mask)
result = [None] * bsz
for sample_idx in range(bsz):
sample_attn_mask = attn_mask[sample_idx]
if sample_attn_mask is None:
continue
sample_layout, sample_block_mask = sample_attn_mask
if sample_layout is None:
continue
if sample_block_mask.dtype == torch.uint8:
sample_block_mask = sample_block_mask.eq(1)
assert sample_block_mask.dtype == torch.bool
result[sample_idx] = attn_scores_for_real_part[sample_idx].masked_fill(sample_block_mask[None], -10000)
softmax_label = (real_part, 'softmax', self.layer_sv, sample_idx)
if softmax_label in self.instant_pocket:
softmax = self.instant_pocket[softmax_label]
else:
softmax = BlocksparseSoftmax(sample_layout, self.block_size)
self.instant_pocket[softmax_label] = softmax
temp = softmax(result[sample_idx])
if mask_again:
temp = temp.masked_fill(sample_block_mask[None], 0.0)
if self.dropout_module is not None:
temp = self.dropout_module(temp)
result[sample_idx] = temp
result = {real_part: result}
return result

Просмотреть файл

@ -0,0 +1,5 @@
def print_redundant_params(kwargs, class_name=None):
print('=====Redundant Params%s=====' % ('' if class_name is None else ' for %s' % class_name))
for key in kwargs:
print('{key}:\t{value}'.format(key=key, value=kwargs[key]))
print('=============')

Просмотреть файл

@ -0,0 +1,62 @@
import torch
from .rpe_self_attention_v2s1 import RpeSelfAttentionV2S1
from ..common.blocksparse_common_operations.qk_mul.qk_mul_1 import do_qk_scores_for_part
from ..common.blocksparse_common_operations.softmax.softmax_1 import do_attn_softmax_for_part
from ..common.blocksparse_common_operations.av_mul.av_mul_1 import do_av_mul_for_part
class BlocksparseRpeSelfAttentionV2S1(RpeSelfAttentionV2S1):
def __init__(self, *args, block_size=64, **kwargs):
super().__init__(*args, **kwargs)
self.block_size = block_size
# --- for indexing relative position embeddings ---
# num_heads_arange = torch.arange(self.num_heads)[:, None, None, None] # (num_heads, 1, 1, 1)
# self.register_buffer('num_heads_arange', num_heads_arange, persistent=False)
# block_size_arange = torch.arange(self.block_size)[None, None, :, None] # (1, 1, block_size, 1)
# self.register_buffer('block_size_arange', block_size_arange, persistent=False)
# ============= Interfaces =============
def do_qk_scores_for_sx(self, base_sum_q, base_sum_k, base_reg_k, bsz, sum_len, reg_len, attn_mask=None):
# base_sum_q: (sum_len, bsz, heads, head_dim)
# base_sum_k: (sum_len, bsz, heads, head_dim)
# base_reg_k: (reg_len, bsz, heads, head_dim)
tgt = base_sum_q
src = torch.cat((base_sum_k, base_reg_k), dim=0)
return do_qk_scores_for_part(self, tgt, src, bsz, sum_len, sum_len + reg_len, attn_mask, 'sx')
def do_masking_for_sx(self, attn_scores_inc, attn_mask):
return attn_scores_inc
def do_attn_softmax_for_sx(self, attn_scores_inc_sr, attn_mask=None):
return do_attn_softmax_for_part(self, attn_scores_inc_sr, attn_mask, 'sx')
def do_av_mul_for_sx(self, attn_weights_inc_sr, base_sum_v, base_reg_v, attn_mask=None, tgt_len=None):
v = torch.cat((base_sum_v, base_reg_v), dim=0)
return do_av_mul_for_part(self, attn_weights_inc_sr, v, attn_mask, 'sx', tgt_len)
def do_qk_scores_for_rx(
self,
reg_q, sum_k, reg_k,
bsz, sum_len, reg_len,
attn_mask=None
):
if sum_k is None:
k = reg_k
else:
k = torch.cat((sum_k, reg_k), dim=0)
return do_qk_scores_for_part(self, reg_q, k, bsz, reg_len, sum_len + reg_len, attn_mask, 'rx')
def do_masking_for_rx(self, attn_scores_inc, attn_mask):
return attn_scores_inc
def do_attn_softmax_for_rx(self, attn_scores_inc, attn_mask=None):
return do_attn_softmax_for_part(self, attn_scores_inc, attn_mask, 'rx')
def do_av_mul_for_rx(self, attn_weights_inc, base_sum_v2, base_reg_v, attn_mask=None, tgt_len=None):
if base_sum_v2 is None:
v = base_reg_v
else:
v = torch.cat((base_sum_v2, base_reg_v), dim=0)
return do_av_mul_for_part(self, attn_weights_inc, v, attn_mask, 'rx', tgt_len)

Просмотреть файл

@ -0,0 +1,337 @@
import math
import torch
import torch.nn as nn
from fairseq.modules.fairseq_dropout import FairseqDropout
from ...data_structures.four_dim_pocket import FourDimPocket
from ..common import common_funcs
class RpeSelfAttentionV2S1(nn.Module):
def __init__(
self,
embed_dim,
num_heads,
num_summary,
layer_idx,
dropout=0.0, # attention dropout
query_proj_bias=True,
key_proj_bias=True,
value_proj_bias=True,
sum_key2_proj_bias=True,
sum_value2_proj_bias=True,
out_proj_bias=True,
add_different_kqv_bias_for_sum_and_reg=False,
add_different_out_bias_for_sum_and_reg=False,
share_query_proj=False,
share_key_proj=False,
share_value_proj=False,
share_out_proj=False,
share_key2_value2_proj_weight=False,
max_summary=None,
no_sum_out=False,
single_head_masks=False,
**kwargs
):
assert single_head_masks, "Currently, we only support single head masks."
common_funcs.print_redundant_params(kwargs, self.__class__.__name__)
super().__init__()
self.layer_idx = layer_idx
self.embed_dim = embed_dim
self.num_heads = num_heads
self.single_head_masks = single_head_masks
self.num_summary = num_summary
self.max_summary = self.num_summary if max_summary is None else max_summary
self.no_sum_out = no_sum_out
self.pocket = FourDimPocket()
self.instant_pocket = self.pocket['instant']
constant_pocket = self.pocket['constant']
layer_to_sv = constant_pocket['layer_to_sv']
self.layer_sv = layer_to_sv[self.layer_idx]
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "attention_embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
if add_different_kqv_bias_for_sum_and_reg:
query_proj_bias = False
key_proj_bias = False
value_proj_bias = False
for proj_name in ('query', 'key', 'value'):
for target in (('sum', 'reg') if self.num_summary > 0 else ('reg',)):
bias_tensor = torch.zeros(self.embed_dim)
self.register_parameter(
'%s_%s_bias' % (target, proj_name),
nn.Parameter(bias_tensor, requires_grad=True)
)
if add_different_out_bias_for_sum_and_reg:
out_proj_bias = False
for target in (('sum', 'reg') if not self.no_sum_out and self.num_summary > 0 else ('reg',)):
bias_tensor = torch.zeros(self.embed_dim)
self.register_parameter(
'%s_out_bias' % target,
nn.Parameter(bias_tensor, requires_grad=True)
)
self.reg_query_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=query_proj_bias)
self.reg_key_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=key_proj_bias)
self.reg_value_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=value_proj_bias)
self.reg_out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=out_proj_bias)
if self.num_summary > 0:
if share_query_proj:
self.sum_query_proj = self.reg_query_proj
else:
self.sum_query_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=query_proj_bias)
if share_key_proj:
self.sum_key_proj = self.reg_key_proj
else:
self.sum_key_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=key_proj_bias)
if share_value_proj:
self.sum_value_proj = self.reg_value_proj
else:
self.sum_value_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=value_proj_bias)
if not self.no_sum_out:
if share_out_proj:
self.sum_out_proj = self.reg_out_proj
else:
self.sum_out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=out_proj_bias)
self.share_key2_value2_proj_weight = share_key2_value2_proj_weight
if share_key2_value2_proj_weight:
self.sum_key2_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.sum_value2_proj = self.sum_key2_proj
if sum_key2_proj_bias:
self.sum_key2_bias = nn.Parameter(torch.zeros(self.embed_dim), requires_grad=True)
if sum_value2_proj_bias:
self.sum_value2_bias = nn.Parameter(torch.zeros(self.embed_dim), requires_grad=True)
else:
self.sum_key2_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=sum_key2_proj_bias)
self.sum_value2_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=sum_value2_proj_bias)
self.dropout_module = FairseqDropout(
dropout, module_name=self.__class__.__name__
) if dropout > 0.0 else None
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.reg_query_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.reg_key_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.reg_value_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.reg_out_proj.weight)
if self.num_summary > 0:
if id(self.sum_query_proj) != id(self.reg_query_proj):
nn.init.xavier_uniform_(self.sum_query_proj.weight, gain=1 / math.sqrt(2))
if id(self.sum_key_proj) != id(self.reg_key_proj):
nn.init.xavier_uniform_(self.sum_key_proj.weight, gain=1 / math.sqrt(2))
if id(self.sum_value_proj) != id(self.reg_value_proj):
nn.init.xavier_uniform_(self.sum_value_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.sum_key2_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.sum_value2_proj.weight, gain=1 / math.sqrt(2))
if not self.no_sum_out and id(self.sum_out_proj) != id(self.reg_out_proj):
nn.init.xavier_uniform_(self.sum_out_proj.weight)
nn.init.xavier_uniform_(self.reg_out_proj.weight)
if self.reg_out_proj.bias is not None:
nn.init.constant_(self.reg_out_proj.bias, 0.0)
def forward(
self,
x: tuple, # (sum_len, bsz, embed_dim), (reg_len, bsz, embed_dim)
sum_token_ids, # (bsz, sum_len)
sum_len,
reg_len,
key_padding_mask=None, # (bsz, all_seq_len)
attn_mask=None,
need_weights: bool = False,
need_head_weights: bool = False,
*args, **kwargs,
):
if key_padding_mask is not None:
raise NotImplementedError("Please combine key_padding_mask into attn_mask ahead.")
del key_padding_mask
if need_head_weights:
need_weights = True
# ===== Input Checking =====
sum_x, reg_x = x
bsz = reg_x.shape[1]
del sum_token_ids
# ===== Summarize =====
base_reg_k = self.reg_key_proj(reg_x)
bias = getattr(self, 'reg_key_bias', None)
if bias is not None:
base_reg_k = base_reg_k + bias
base_reg_k = base_reg_k.view(reg_len, bsz, self.num_heads, self.head_dim)
# base_reg_k: (reg_len, bsz, num_heads, head_dim)
base_reg_v = self.reg_value_proj(reg_x)
bias = getattr(self, 'reg_value_bias', None)
if bias is not None:
base_reg_v = base_reg_v + bias
base_reg_v = base_reg_v.view(reg_len, bsz, self.num_heads, self.head_dim)
# base_reg_v: (reg_len, bsz, num_heads, head_dim)
if sum_len > 0:
base_sum_q = self.sum_query_proj(sum_x)
bias = getattr(self, 'sum_query_bias', None)
if bias is not None:
base_sum_q = base_sum_q + bias
base_sum_q = base_sum_q.view(sum_len, bsz, self.num_heads, self.head_dim)
# base_sum_q: (sum_len, bsz, num_heads, head_dim)
base_sum_k = self.sum_key_proj(sum_x)
bias = getattr(self, 'sum_key_bias', None)
if bias is not None:
base_sum_k = base_sum_k + bias
base_sum_k = base_sum_k.view(sum_len, bsz, self.num_heads, self.head_dim)
base_sum_v = self.sum_value_proj(sum_x)
bias = getattr(self, 'sum_value_bias', None)
if bias is not None:
base_sum_v = base_sum_v + bias
base_sum_v = base_sum_v.view(sum_len, bsz, self.num_heads, self.head_dim)
attn_scores_inc_sx = self.do_qk_scores_for_sx(
base_sum_q, base_sum_k, base_reg_k,
bsz, sum_len, reg_len,
attn_mask=attn_mask
) # real_parts dict of sample list of (heads, head_selected_blocks, block, block)
for real_part in attn_scores_inc_sx:
for sample_item in attn_scores_inc_sx[real_part]:
sample_item.mul_(self.scaling)
attn_scores_inc_sx = self.do_masking_for_sx(attn_scores_inc_sx, attn_mask)
attn_weights_inc_sx = self.do_attn_softmax_for_sx(attn_scores_inc_sx, attn_mask=attn_mask)
del attn_scores_inc_sx
sum_x2 = self.do_av_mul_for_sx(
attn_weights_inc_sx, base_sum_v, base_reg_v, attn_mask=attn_mask, tgt_len=sum_len
) # samples list of (sum_len, 1, num_heads, head_dim)
assert sum_x2.shape == (sum_len, bsz, self.num_heads, self.head_dim)
sum_x2 = sum_x2.view(sum_len, bsz, self.embed_dim)
if self.share_key2_value2_proj_weight:
base_sum_k2 = self.sum_key2_proj(sum_x2)
base_sum_v2 = base_sum_k2
sum_key2_bias = getattr(self, 'sum_key2_bias', None)
if sum_key2_bias is not None:
base_sum_k2 = base_sum_k2 + sum_key2_bias
sum_value2_bias = getattr(self, 'sum_value2_bias', None)
if sum_value2_bias is not None:
base_sum_v2 = base_sum_v2 + sum_value2_bias
base_sum_k2 = base_sum_k2.view(sum_len, bsz, self.num_heads, self.head_dim)
base_sum_v2 = base_sum_v2.view(sum_len, bsz, self.num_heads, self.head_dim)
else:
base_sum_k2 = self.sum_key2_proj(sum_x2).view(sum_len, bsz, self.num_heads, self.head_dim)
base_sum_v2 = self.sum_value2_proj(sum_x2).view(sum_len, bsz, self.num_heads, self.head_dim)
else:
sum_x2 = reg_x.new_empty(0, bsz, self.embed_dim)
base_sum_k2 = None
base_sum_v2 = None
# ===== Updating =====
base_reg_q = self.reg_query_proj(reg_x)
reg_query_bias = getattr(self, 'reg_query_bias', None)
if reg_query_bias is not None:
base_reg_q = base_reg_q + reg_query_bias
base_reg_q = base_reg_q.view(reg_len, bsz, self.num_heads, self.head_dim)
attn_scores_inc_rx = self.do_qk_scores_for_rx(
base_reg_q, base_sum_k2, base_reg_k,
bsz, sum_len, reg_len, attn_mask=attn_mask,
)
for real_part in attn_scores_inc_rx:
for sample_item in attn_scores_inc_rx[real_part]:
sample_item.mul_(self.scaling)
attn_scores_inc_rx = self.do_masking_for_rx(attn_scores_inc_rx, attn_mask)
attn_weights_inc_rx = self.do_attn_softmax_for_rx(attn_scores_inc_rx, attn_mask=attn_mask)
# if self.layer_idx == 3:
# with open('attn_weights_inc_rx.bin', 'wb') as f:
# torch.save(attn_weights_inc_rx, f)
# print('saved attn_weights_inc_rx')
reg_output = self.do_av_mul_for_rx(
attn_weights_inc_rx, base_sum_v2, base_reg_v, attn_mask=attn_mask, tgt_len=reg_len
) # (reg_len, bsz, num_heads, head_dim)
# ----- gate to combine sum_output and reg_output -----
reg_output = reg_output.view(reg_len, bsz, self.embed_dim)
reg_output = self.reg_out_proj(reg_output)
reg_out_bias = getattr(self, 'reg_out_bias', None)
if reg_out_bias is not None:
reg_output = reg_output + reg_out_bias
if not self.no_sum_out and self.num_summary > 0:
sum_output = self.sum_out_proj(sum_x2)
sum_out_bias = getattr(self, 'sum_out_bias', None)
if sum_out_bias is not None:
sum_output = sum_output + sum_out_bias
else:
sum_output = None
if need_weights:
raise NotImplementedError
else:
attn_weights = None
return (sum_output, reg_output), attn_weights
# (sum_len, bsz, embed_dim) (reg_len, bsz, embed_dim)
# None, (bsz, num_heads, all_seq_len, all_seq_len) or (bsz, all_seq_len, all_seq_len)
def do_qk_scores_for_sx(self, base_sum_q, base_sum_k, base_reg_k, bsz, sum_len, reg_len, **kwargs):
raise NotImplementedError
def do_masking_for_sx(self, attn_scores_inc, attn_mask):
raise NotImplementedError
def do_attn_softmax_for_sx(self, attn_scores_for_sum, **kwargs):
raise NotImplementedError
def do_av_mul_for_sx(self, attn_weights_inc_sr, base_sum_v, base_reg_v, **kwargs):
raise NotImplementedError
def do_qk_scores_for_rx(
self,
reg_q, sum_k, reg_k,
bsz, sum_len, reg_len,
**kwargs
):
raise NotImplementedError
def do_masking_for_rx(self, attn_scores_for_reg, attn_mask):
raise NotImplementedError
def do_attn_softmax_for_rx(self, attn_scores_for_reg, attn_mask=None):
raise NotImplementedError
def do_av_mul_for_rx(self, attn_weights_inc, base_sum_v2, base_reg_v, **kwargs):
raise NotImplementedError

Просмотреть файл

Просмотреть файл

@ -0,0 +1,404 @@
import torch
import torch.nn as nn
from .multihead_block_selection_generation import MultiheadBlockSelectionGeneration, get_block_ranges
from ..kernels.block_fill import block_fill_
from ..tools import computation_tools
from ..data_structures.attention_scheme import LayerAttentionScheme
PREF2PREF_CHOICES = ('full', 'lt', 'none')
SUM2PREF_CHOICES = ('none',)
PREF2SUM_CHOICES = ('none',)
CON2PREF_CHOICES = ('full',)
PREF2CON_CHOICES = ('none',)
def construct_sum_chunk_ranges(max_complete_chunk, num_summary, device):
ranges = torch.arange(0, max_complete_chunk * num_summary + num_summary, num_summary, device=device)
ranges = torch.stack((ranges[:-1], ranges[1:]), dim=-1) # (max_complete_chunk, 2)
return ranges
def select_self_(selection, max_comp_chunk, mode):
if mode == 'default':
return selection
ndim = selection.ndim
eye = torch.eye(max_comp_chunk, dtype=torch.bool, device=selection.device)
if ndim >= 3:
num_heads = selection.shape[-1]
eye = eye[:, :, None].expand(-1, -1, num_heads)
if mode == 'none':
selection[eye] = False
elif mode == 'full':
selection[eye] = True
else:
raise ValueError(mode)
return selection
def fill_mask_with_selection_(selection_gen: MultiheadBlockSelectionGeneration,
row_ranges, col_ranges, selection_mask, fill_value, out):
"""
:param selection_gen:
:param row_ranges: (num_row_chunks, 2)
:param col_ranges: (num_col_ranges, 2)
:param selection_mask: (num_row_chunks, num_col_ranges, num_heads)
:param fill_value: bool
:param out: (num_heads, row_len, col_len)
:return:
"""
assert row_ranges.ndim == 2
assert col_ranges.ndim == 2
assert out.is_contiguous()
block_indices, command_masks = selection_gen.get_block_indices_and_masks_from_selection_masks(
selection_mask, overall_mask=None
) # (num_selected_blocks, 2) (num_selected_blocks, num_heads)
block_ranges = get_block_ranges(row_ranges, col_ranges, block_indices) # (num_selected_blocks, 4)
block_fill_(out, block_ranges, command_masks, fill_value)
return out
def partially_combine_part_masks(part_masks_dict):
row_sum = computation_tools.may_bi_cat(part_masks_dict['ss'], part_masks_dict['sr'], dim=3)
row_reg = computation_tools.may_bi_cat(part_masks_dict['rs'], part_masks_dict['rr'], dim=3)
return row_sum, row_reg
def combine_part_masks(part_masks_dict):
row_sum, row_reg = partially_combine_part_masks(part_masks_dict)
return computation_tools.may_bi_cat(row_sum, row_reg, dim=2)
def combine_attn_mask_and_key_padding_mask_(attn_mask, key_padding_mask):
"""
:param attn_mask: (bsz, num_heads, seq_len, seq_len)
:param key_padding_mask: (bsz, seq_len)
:return:
"""
if attn_mask is None:
return attn_mask
attn_mask = attn_mask.contiguous()
attn_mask.masked_fill_(key_padding_mask[:, None, None], True)
return attn_mask
def layout_full_zero_check(input_layout):
row_check = input_layout.sum(dim=2).eq(0) # (H, L // block)
col_check = input_layout.sum(dim=1).eq(0)
row_answer = bool(row_check.any())
col_answer = bool(col_check.any())
return row_answer or col_answer, row_answer, col_answer, row_check, col_check
def transfer_attn_mask_to_block_layout(attn_mask, block_size, avoid_empty_row_or_column=True):
"""
:param attn_mask: (bsz_heads, tgt_len, src_len) bool False for selected
:param block_size: int
:param avoid_empty_row_or_column
:return:
"""
bsz_heads, tgt_len, src_len = attn_mask.shape
num_tgt_blocks = tgt_len // block_size
num_src_blocks = src_len // block_size
assert num_tgt_blocks * block_size == tgt_len
assert num_src_blocks * block_size == src_len
assert attn_mask.dtype in (torch.bool, torch.uint8)
attn_mask = attn_mask.view(
bsz_heads, num_tgt_blocks, block_size, num_src_blocks, block_size
).permute(0, 1, 3, 2, 4) # (bsz * num_heads, num_tgt_blocks, num_src_blocks, block_size, block_size)
layout = (~(attn_mask.bool())).reshape(
bsz_heads, num_tgt_blocks, num_src_blocks, block_size * block_size
).sum(dim=-1).gt(0) # (bsz * num_heads, num_tgt_blocks, num_src_blocks)
if not layout.any(): # for this part of attn_mask, there's nothing to compute
return None, None
if avoid_empty_row_or_column:
from ..blocksparse import layout_full_zero_check
answer, row_answer, col_answer, row_check, col_check = layout_full_zero_check(layout)
if answer:
if row_answer:
row_zero_indices = row_check.nonzero()
# print('row:')
# print(row_zero_indices)
layout[row_zero_indices[:, 0], row_zero_indices[:, 1], 0] = True
if col_answer:
col_zero_indices = col_check.nonzero()
# print('col:')
# print(col_zero_indices)
layout[col_zero_indices[:, 0], 0, col_zero_indices[:, 1]] = True
assert not layout_full_zero_check(layout)[0]
block_mask = attn_mask.masked_select(layout[:, :, :, None, None]).view(-1, block_size, block_size)
return layout, block_mask
class LayerAttentionMaskGeneration(nn.Module):
def __init__(
self,
layer_attention_scheme: LayerAttentionScheme,
gen_parts=None,
single_head=True,
init_max_chunk=None
):
super().__init__()
self.layer_attention_scheme = layer_attention_scheme
self.single_head = single_head
if not self.layer_attention_scheme.same_for_all_heads:
raise NotImplementedError("Not support for different attention schemes over different heads.")
self.num_heads = self.layer_attention_scheme.num_layer_heads
self.num_fake_heads = 1
self.single_head_scheme = self.layer_attention_scheme[0]
assert self.single_head_scheme.head_pref2pref_mode in PREF2PREF_CHOICES
self.pref2pref_mode = self.single_head_scheme.head_pref2pref_mode
assert self.single_head_scheme.head_sum2pref_mode in SUM2PREF_CHOICES
self.sum2pref_mode = self.single_head_scheme.head_sum2pref_mode
assert self.single_head_scheme.head_pref2sum_mode in PREF2SUM_CHOICES
self.pref2sum_mode = self.single_head_scheme.head_pref2sum_mode
assert self.single_head_scheme.head_con2pref_mode in CON2PREF_CHOICES
self.con2pref_mode = self.single_head_scheme.head_con2pref_mode
assert self.single_head_scheme.head_pref2con_mode in PREF2CON_CHOICES
self.pref2con_mode = self.single_head_scheme.head_pref2con_mode
self.gen_parts = gen_parts
self.con2con_gen, self.con2sum_gen, self.sum2con_gen, self.sum2sum_gen = None, None, None, None
if gen_parts is None or 'rr' in gen_parts:
self.con2con_gen = MultiheadBlockSelectionGeneration((self.single_head_scheme.head_con2con,),
init_max_chunk=init_max_chunk)
if gen_parts is None or 'rs' in gen_parts:
self.con2sum_gen = MultiheadBlockSelectionGeneration((self.single_head_scheme.head_con2sum,),
init_max_chunk=init_max_chunk)
if gen_parts is None or 'sr' in gen_parts:
self.sum2con_gen = MultiheadBlockSelectionGeneration((self.single_head_scheme.head_sum2con,),
init_max_chunk=init_max_chunk)
if gen_parts is None or 'ss' in gen_parts:
self.sum2sum_gen = MultiheadBlockSelectionGeneration((self.single_head_scheme.head_sum2sum,),
init_max_chunk=init_max_chunk)
def gen_ss_mask(
self,
num_complete_chunks: torch.Tensor, # (bsz,)
num_summary: int,
bsz, sum_len,
device,
):
if sum_len <= 0 or num_summary <= 0:
return None
max_complete_chunk = int(num_complete_chunks.max())
temp_sum_len = max_complete_chunk * num_summary
mask = self.sum2sum_gen.get_block_selection_masks(max_complete_chunk).to(device) \
# (max_complete_chunk, max_complete_chunk, 1) # True for selected
mask = select_self_(mask, max_complete_chunk, self.single_head_scheme.head_sum2sum_self) # Add Self
mask = ~mask # False for selected
mask = mask.permute(2, 0, 1) # (num_heads, max_comp_chunk, max_comp_chunk)
mask = mask[None].repeat(bsz, 1, 1, 1).contiguous() # (bsz, num_heads, max_comp_chunk, max_comp_chunk)
for idx, sample_comp_chunk in enumerate(num_complete_chunks):
mask[idx, :, sample_comp_chunk:] = True
mask[idx, :, :, sample_comp_chunk:] = True
mask = mask[:, :, :, :, None, None].expand(-1, -1, -1, -1, num_summary, num_summary)
mask = mask.permute(0, 1, 2, 4, 3, 5).reshape(bsz, self.num_fake_heads, temp_sum_len, temp_sum_len) \
# (bsz, num_heads, temp_sum_len, temp_sum_len)
if sum_len != temp_sum_len:
temp_mask = torch.ones(bsz, self.num_fake_heads, sum_len, sum_len, dtype=torch.bool, device=device)
temp_mask[:, :, :temp_sum_len, :temp_sum_len] = mask
mask = temp_mask
return mask
def gen_sr_mask(
self,
sum_chunk_ranges, reg_chunk_ranges, num_complete_chunks,
sum_len, reg_len,
num_pref,
sum2pref_mode,
bsz,
device
):
if sum_len <= 0:
return None
mask = torch.ones(bsz, self.num_fake_heads, sum_len, reg_len, dtype=torch.bool, device=device)
for sample_idx, sample_comp_chunk in enumerate(num_complete_chunks):
if sample_comp_chunk <= 0:
continue
sample_selections = self.sum2con_gen.get_block_selection_masks(sample_comp_chunk) \
# (sample_comp_chunk, sample_comp_chunk, num_heads)
sample_selections = select_self_(sample_selections, sample_comp_chunk,
self.single_head_scheme.head_sum2con_self)
fill_mask_with_selection_(self.sum2con_gen,
sum_chunk_ranges[sample_idx], reg_chunk_ranges[sample_idx],
sample_selections,
False, mask[sample_idx])
if sum2pref_mode == 'none':
pass
else:
raise NotImplementedError(sum2pref_mode)
return mask
def gen_rs_mask(
self,
reg_chunk_ranges, sum_chunk_ranges, num_complete_chunks, num_chunks,
sum_len, reg_len,
num_pref,
pref2sum_mode,
bsz, device
):
if sum_len <= 0:
return None
mask = torch.ones(bsz, self.num_fake_heads, reg_len, sum_len, dtype=torch.bool, device=device)
for sample_idx, (sample_comp_chunk, sample_chunk) in enumerate(zip(num_complete_chunks, num_chunks)):
if sample_chunk <= 0:
continue
sample_selections = self.con2sum_gen.get_block_selection_masks(sample_chunk)
sample_selections = select_self_(sample_selections, sample_chunk, self.single_head_scheme.head_con2sum_self)
sample_selections = sample_selections[:, :sample_comp_chunk]
fill_mask_with_selection_(self.con2sum_gen,
reg_chunk_ranges[sample_idx], sum_chunk_ranges[sample_idx],
sample_selections,
False, mask[sample_idx])
if pref2sum_mode == 'none':
pass
else:
raise NotImplementedError(pref2sum_mode)
return mask
def gen_rr_mask(
self,
reg_chunk_ranges, num_chunks,
reg_len,
num_pref,
pref2pref_mode, pref2con_mode, con2pref_mode,
bsz, device
):
mask = torch.ones(bsz, self.num_fake_heads, reg_len, reg_len, dtype=torch.bool, device=device)
for sample_idx, sample_chunk in enumerate(num_chunks):
if sample_chunk <= 0:
continue
sample_selections = self.con2con_gen.get_block_selection_masks(sample_chunk)
sample_selections = select_self_(
sample_selections, sample_chunk,
'full' if self.single_head_scheme.head_con2con_self == 'lt'
else self.single_head_scheme.head_con2con_self
)
fill_mask_with_selection_(self.con2con_gen,
reg_chunk_ranges[sample_idx], reg_chunk_ranges[sample_idx],
sample_selections,
False, mask[sample_idx])
sample_num_pref = num_pref[sample_idx]
if self.single_head_scheme.head_con2con_causal:
up_triangle = torch.ones(reg_len - sample_num_pref, reg_len - sample_num_pref,
dtype=torch.bool, device=device)
up_triangle.triu_(1) # (con_len, con_len)
mask[sample_idx, :, sample_num_pref:, sample_num_pref:].masked_fill_(up_triangle[None], True)
if pref2pref_mode == 'none':
pass
elif pref2pref_mode == 'full':
for sample_idx, sample_num_pref in enumerate(num_pref):
mask[sample_idx, :, :sample_num_pref, :sample_num_pref] = False
elif pref2pref_mode == 'lt':
for sample_idx, sample_num_pref in enumerate(num_pref):
if sample_num_pref == 1:
mask[sample_idx, :, :sample_num_pref, :sample_num_pref] = False
continue
mask[sample_idx, :, :sample_num_pref, :sample_num_pref].triu_(1)
else:
raise NotImplementedError(pref2pref_mode)
if pref2con_mode == 'none':
pass
else:
raise NotImplementedError(pref2con_mode)
if con2pref_mode == 'none':
pass
elif con2pref_mode == 'full':
for sample_idx, sample_num_pref in enumerate(num_pref):
sample_num_chunks = num_chunks[sample_idx]
if sample_num_chunks == 0:
continue
sample_reg_len = reg_chunk_ranges[sample_idx, sample_num_chunks - 1, -1]
mask[sample_idx, :, sample_num_pref: sample_reg_len, :sample_num_pref] = False
else:
raise NotImplementedError(con2pref_mode)
return mask
def forward(
self,
reg_chunk_ranges: torch.Tensor, # (bsz, max_chunk, 2)
num_chunks: torch.Tensor, # (bsz,)
num_complete_chunks: torch.Tensor, # (bsz,)
num_summary: int,
num_pref: torch.Tensor, # (bsz,)
sum_len: int,
reg_len: int,
):
# print('===attn_mask===')
# print(reg_chunk_ranges)
# print(num_chunks)
# print(num_complete_chunks)
# print(num_summary)
# print(num_pref)
# print(sum_len, reg_len)
# print('=====')
# === Preliminaries ===
device = reg_chunk_ranges.device
bsz = num_chunks.shape[0]
# all_seq_len = temp_sum_len + reg_len
max_complete_chunk = int(num_complete_chunks.max())
if num_summary <= 0 or sum_len <= 0:
sum_chunk_ranges = None
else:
sum_chunk_ranges = construct_sum_chunk_ranges(max_complete_chunk, num_summary, device=device) \
# (max_comp_chunk, 2)
sum_chunk_ranges = sum_chunk_ranges[None].repeat(bsz, 1, 1).contiguous() # (bsz, max_comp_chunk, 2)
# === Generate Masks for Parts (True is for the masked out)
ss_mask, sr_mask, rs_mask, rr_mask = None, None, None, None
if self.gen_parts is None or 'ss' in self.gen_parts:
ss_mask = self.gen_ss_mask(num_complete_chunks, num_summary, bsz, sum_len, device) \
# (bsz, num_heads, temp_sum_len, temp_sum_len)
assert ss_mask is None or ss_mask.dtype == torch.bool
if self.gen_parts is None or 'sr' in self.gen_parts:
sr_mask = self.gen_sr_mask(sum_chunk_ranges, reg_chunk_ranges, num_complete_chunks, sum_len, reg_len,
num_pref, self.sum2pref_mode, bsz, device)
assert sr_mask is None or sr_mask.dtype == torch.bool
if self.gen_parts is None or 'rs' in self.gen_parts:
rs_mask = self.gen_rs_mask(reg_chunk_ranges, sum_chunk_ranges, num_complete_chunks, num_chunks,
sum_len, reg_len, num_pref, self.pref2sum_mode, bsz, device)
assert rs_mask is None or rs_mask.dtype == torch.bool
if self.gen_parts is None or 'rr' in self.gen_parts:
rr_mask = self.gen_rr_mask(reg_chunk_ranges, num_chunks, reg_len,
num_pref, self.pref2pref_mode, self.pref2con_mode, self.con2pref_mode,
bsz, device)
assert rr_mask.dtype == torch.bool
if not self.single_head and self.num_heads > 1:
if ss_mask is not None:
ss_mask = ss_mask.expand(-1, self.real_num_heads, -1, -1)
if sr_mask is not None:
sr_mask = sr_mask.expand(-1, self.real_num_heads, -1, -1)
if rs_mask is not None:
rs_mask = rs_mask.expand(-1, self.real_num_heads, -1, -1)
if rr_mask is not None:
rr_mask = rr_mask.expand(-1, self.real_num_heads, -1, -1)
attn_mask = {
'ss': ss_mask, # (bsz, num_heads, temp_sum_len, temp_sum_len)
'sr': sr_mask, # (bsz, num_heads, temp_sum_len, reg_len)
'rs': rs_mask, # (bsz, num_heads, reg_len, temp_sum_len)
'rr': rr_mask, # (bsz, num_heads, reg_len, reg_len)
}
for key in ('ss', 'sr', 'rs', 'rr'):
if self.gen_parts is not None and key not in self.gen_parts:
attn_mask.pop(key)
return attn_mask

Просмотреть файл

@ -0,0 +1,165 @@
import torch
import torch.nn as nn
from ..tools.singleton_decorator import singleton
def generate_index_matrix(num_chunks, device='cpu'):
chunk_arange = torch.arange(num_chunks, device=device)
rows = chunk_arange.view(num_chunks, 1).expand(num_chunks, num_chunks)
cols = chunk_arange.view(1, num_chunks).expand(num_chunks, num_chunks)
index_matrix = torch.stack((rows, cols), dim=-1) # (num_chunks, num_chunks, 2)
return index_matrix
def generate_mask_template(num_chunks, device='cpu'):
return torch.ones(num_chunks, num_chunks, dtype=torch.bool, device=device)
def generate_diagonal_indices(num_chunks, device='cpu'):
diagonal = torch.arange(num_chunks, device=device)
diagonal = torch.stack((diagonal, diagonal), dim=-1)
return diagonal
def generate_mask(head_range_command, num_chunks, mask_template=None, device='cpu'):
"""
Generate a selection mask of size (num_chunks, num_chunks) according to range_command. Selected ones are True.
:param head_range_command:
:param num_chunks:
:param mask_template:
:param device:
:return:
"""
if head_range_command is None:
return None
if mask_template is None:
mask_template = generate_mask_template(num_chunks, device=device)
else:
w, h = mask_template.shape
assert w == h == num_chunks
mask_template = mask_template.to(device)
mask = torch.zeros_like(mask_template)
for command_item in head_range_command:
if isinstance(command_item, int):
command_item = (command_item, command_item + 1)
begin, end = command_item
if begin is None:
begin = -num_chunks + 1
if end is None:
end = num_chunks
if begin >= end:
continue
left = torch.triu(mask_template, begin)
right = torch.tril(mask_template, end - 1)
mask = mask | (left & right)
return mask
@singleton
class BlockSelectionTemplateManager(nn.Module):
"""
Generate mask according to some specific range c, and provide related management. Singleton design mode.
"""
INIT_MAX_CHUNK = 1024
_RESERVED_ATTR_NAMES = ('max_chunks', 'index_matrix', 'mask_template', 'diagonal_indices', 'mask')
def __init__(self, init_max_chunk=None):
super().__init__()
if init_max_chunk is None:
init_max_chunk = self.__class__.INIT_MAX_CHUNK
self._max_chunk = init_max_chunk
index_matrix = generate_index_matrix(self._max_chunk)
self.register_buffer('_index_matrix', index_matrix, persistent=False)
mask_template = generate_mask_template(self._max_chunk)
self.register_buffer('_mask_template', mask_template, persistent=False)
diagonal_indices = generate_diagonal_indices(self._max_chunk)
self.register_buffer('_diagonal_indices', diagonal_indices, persistent=False)
self.__range_commands_and_names = []
def __update_index_matrix(self, num_chunks):
new_index_matrix = generate_index_matrix(num_chunks, device=self._index_matrix.device)
self._index_matrix = new_index_matrix
def __update_mask_template(self, num_chunks):
new_mask_template = generate_mask_template(num_chunks, device=self._mask_template.device)
self._mask_template = new_mask_template
def __update_diagonal_indices(self, num_chunks):
new_diagonal_indices = generate_diagonal_indices(
num_chunks, device=self._diagonal_indices.device
)
self._diagonal_indices = new_diagonal_indices
def __update_masks(self, num_chunks):
w, h = self._mask_template.shape
assert w == h == num_chunks, "Please update mask_template first."
for range_command, mask_name in self.__range_commands_and_names:
setattr(self, mask_name,
generate_mask(range_command, num_chunks,
mask_template=self._mask_template,
device=self._mask_template.device))
def update(self, num_chunks):
if num_chunks <= self._max_chunk:
return
self.__update_index_matrix(num_chunks)
self.__update_mask_template(num_chunks)
self.__update_diagonal_indices(num_chunks)
self.__update_masks(num_chunks)
self._max_chunk = num_chunks
@property
def max_chunk(self):
return self._max_chunk
@max_chunk.setter
def max_chunk(self, num_chunks):
self.update(num_chunks)
@property
def index_matrix(self):
return self._index_matrix
@property
def mask_template(self):
return self._mask_template
@property
def diagonal_indices(self):
return self._diagonal_indices
@property
def device(self):
return getattr(self, '_mask_template').device
def register_range_command(self, range_command):
name = str(range_command)
if name in self.__class__._RESERVED_ATTR_NAMES:
raise ValueError
if hasattr(self, name):
return
mask = generate_mask(range_command, self.max_chunk, mask_template=self.mask_template)
self.register_buffer(name, mask, persistent=False)
self.__range_commands_and_names.append((range_command, name))
def mask(self, range_command):
name = str(range_command)
assert (range_command, name) in self.__range_commands_and_names
return getattr(self, name)
def get_diagonal_indices(self, num_chunks):
assert num_chunks > 0
self.update(num_chunks)
return self._diagonal_indices[:num_chunks]

Просмотреть файл

@ -0,0 +1,101 @@
from functools import lru_cache
import torch
import torch.nn as nn
from .block_selection_template_manager import BlockSelectionTemplateManager
class MultiheadBlockSelectionGeneration(nn.Module):
"""
Generate a set of block masks according to a set of range commands, typically for heads in a layer.
"""
def __init__(self, multihead_range_commands, init_max_chunk=None):
"""
:param multihead_range_commands: tuple consisting of head_range_command
:param init_max_chunk: int
"""
super().__init__()
self._manager = BlockSelectionTemplateManager(init_max_chunk)
self.range_commands = multihead_range_commands
for range_command in self.range_commands:
self._manager.register_range_command(range_command)
@lru_cache(maxsize=128, typed=False)
def get_block_selection_masks(self, num_chunks: int):
"""
Get selection masks of size (num_chunks, num_chunks, num_commands) corresponding to all the range_commands.
:param num_chunks:
:return:
"""
assert num_chunks > 0
self._manager.update(num_chunks)
commands_masks = []
false_chunk = torch.zeros(num_chunks, num_chunks, dtype=torch.bool, device=self._manager.device)
considered_range_commands = set()
for idx, range_command in enumerate(self.range_commands):
mask = self._manager.mask(range_command)
if mask is None:
assert self.range_commands[idx] is None
commands_masks.append(false_chunk)
else:
mask = mask[:num_chunks, :num_chunks]
commands_masks.append(mask)
considered_range_commands.add(range_command)
commands_masks = torch.stack(commands_masks, dim=-1) # (num_chunks, num_chunks, num_commands)
return commands_masks
@staticmethod
def get_overall_mask_from_commands_masks(commands_masks):
return commands_masks.sum(dim=-1).gt(0)
def get_block_indices_and_masks_from_selection_masks(self, commands_masks, overall_mask=None):
"""
:param commands_masks:
:param overall_mask:
:return:
"""
if overall_mask is None:
overall_mask = self.get_overall_mask_from_commands_masks(commands_masks)
else:
assert overall_mask.ndim == 2
num_chunks_1, num_chunks_2 = overall_mask.shape[:2]
index_matrix = self._manager.index_matrix[:num_chunks_1, :num_chunks_2]
indices = index_matrix.masked_select(overall_mask.unsqueeze(-1)).view(-1, 2)
commands_masks = commands_masks.masked_select(overall_mask.unsqueeze(-1)).view(-1, len(self.range_commands))
return indices, commands_masks
@lru_cache(maxsize=128, typed=False)
def get_block_indices_and_masks(self, num_chunks: int):
"""
:param num_chunks: int, number of chunks in one sample
:return: block indices (query, key) to compute (num, 2);
mask indicating whether to compute in each c (head)
"""
commands_masks = self.get_block_selection_masks(num_chunks)
overall_mask = self.get_overall_mask_from_commands_masks(commands_masks)
return self.get_block_indices_and_masks_from_selection_masks(commands_masks, overall_mask=overall_mask)
def get_diagonal_indices(self, num_chunks):
return self._manager.get_diagonal_indices(num_chunks)
def get_block_ranges(row_ranges, col_ranges, block_indices):
"""
:param row_ranges: begins and endings for chunks on row dimension. (num_tgt_chunks, 2)
:param col_ranges: begins and endings for chunks on col dimension. (num_src_chunks, 2)
:param block_indices: row and col indices of selected blocks. (num_blocks, 2)
:return:
"""
tgt_ranges = row_ranges[block_indices[:, 0]] # (num_blocks, 2)
src_ranges = col_ranges[block_indices[:, 1]] # (num_blocks, 2)
block_ranges = torch.cat((tgt_ranges, src_ranges), dim=1) # (num_blocks, 4)
return block_ranges

Просмотреть файл

@ -0,0 +1,2 @@
from .operator_manager import BlocksparseMatMulManager, BlocksparseSoftmaxManager, BlocksparseMatMul, BlocksparseSoftmax
from .utils import layout_full_zero_check

Просмотреть файл

@ -0,0 +1,87 @@
import torch
from .optimized_matmul import matmul as BlocksparseMatMul
from .optimized_softmax import softmax as BlocksparseSoftmax
from ..tools.singleton_decorator import singleton
def layout_full_zero_check(input_layout):
row_check = input_layout.sum(dim=2).eq(0) # (head, row_len // block)
col_check = input_layout.sum(dim=1).eq(0) # (head, col_len // block)
row_answer = bool(row_check.any())
col_answer = bool(col_check.any())
return row_answer or col_answer, row_answer, col_answer, row_check, col_check
class CacheList(object):
def __init__(self, cache_size):
self._cache = list()
self._num = 0
self.cache_size = cache_size
def __len__(self):
return self._num
def __iter__(self):
for idx in range(self._num - 1, -1, -1):
yield self._cache[idx]
def __getitem__(self, idx):
return self._cache[idx]
def put(self, x):
if self._num >= self.cache_size:
self._cache = self._cache[1:]
self._num -= 1
self._cache.append(x)
self._num += 1
def clear(self):
self._cache.clear()
self._num = 0
class OperatorManager(object):
def __init__(self, operator_cls, cache_size=0):
self.operator_cls = operator_cls
self.cache_size = cache_size if cache_size > 0 else 0
self.cache_dict = None if self.cache_size == 0 else {}
def clear_cache(self):
self.cache_dict.clear()
def get_operator(self, layout, block, **kwargs):
assert layout.dtype == torch.bool
has_empty, row_answer, col_answer, _, _ = layout_full_zero_check(layout)
assert not has_empty, "layout has empty %s, which may lead to error computation. Please check and fix." % (
'row' if row_answer else 'column'
)
layout_cpu = None
all_args = None
if self.cache_dict is not None:
layout_cpu = layout.cpu()
all_args = (('block', block),) + tuple(kwargs.items())
all_args = tuple(sorted(all_args, key=lambda x: x[0]))
if all_args in self.cache_dict:
cache_list = self.cache_dict[all_args]
for cached_layout, operator in cache_list:
if torch.equal(layout_cpu, cached_layout):
return operator
operator = self.operator_cls(layout=layout.long(), block=block, **kwargs)
if self.cache_dict is not None:
if all_args not in self.cache_dict:
self.cache_dict[all_args] = CacheList(self.cache_size)
self.cache_dict[all_args].put((layout_cpu, operator))
return operator
@singleton
class BlocksparseMatMulManager(OperatorManager):
def __init__(self, cache_size=0):
super().__init__(BlocksparseMatMul, cache_size=cache_size)
@singleton
class BlocksparseSoftmaxManager(OperatorManager):
def __init__(self, cache_size=0):
super().__init__(BlocksparseSoftmax, cache_size=cache_size)

Просмотреть файл

@ -0,0 +1,535 @@
import torch
import triton
import triton.language as tl
# ********************************************************
# --------------------------------------------------------
# Sparse = Dense x Dense (SDD)
# This operation uses super-blocking to make sure that
# it's done efficiently when small blocks can be grouped
# together
# --------------------------------------------------------
# ********************************************************
@triton.heuristics({
'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0,
})
@triton.jit
def _sdd_kernel(
A, B, C,
stride_za, stride_ha, stride_ma, stride_ak,
stride_zb, stride_hb, stride_bk, stride_nb,
stride_zc, stride_hc, stride_mc, stride_nc,
K, grid_offset, lut,
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
BLOCK: tl.constexpr, EVEN_K: tl.constexpr
):
# ------------ #
# - Prologue - #
# ------------ #
block_id = tl.program_id(1) + grid_offset
lut += block_id * 3
# offsets
off_z = tl.program_id(2) # batch
off_h = tl.load(lut + 0) # head
# initialize pointers to A
start_am = tl.load(lut + 1)
offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK)
offs_ak = tl.arange(0, TILE_K)
a_ptrs = A \
+ off_z * stride_za \
+ off_h * stride_ha \
+ offs_am[:, None] * stride_ma \
+ offs_ak[None, :] * stride_ak
# initialize pointers to B
start_bn = tl.load(lut + 2)
offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK)
offs_bk = tl.arange(0, TILE_K)
b_ptrs = B \
+ off_z * stride_zb \
+ off_h * stride_hb \
+ offs_bn[None, :] * stride_nb \
+ offs_bk[:, None] * stride_bk
# ---------------- #
# Inner Loop #
# ---------------- #
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
for k in range(K, 0, -TILE_K):
if EVEN_K:
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
else:
a = tl.load(a_ptrs, mask=offs_ak[None, :] < k, other=0.)
b = tl.load(b_ptrs, mask=offs_bk[:, None] < k, other=0.)
acc += tl.dot(a, b)
a_ptrs += TILE_K * stride_ak
b_ptrs += TILE_K * stride_bk
c = acc.to(C.dtype.element_ty)
# ---------------- #
# Epilogue #
# ---------------- #
offs_cm = tl.arange(0, TILE_M) % BLOCK
offs_cn = tl.arange(0, TILE_N) % BLOCK
pc = C \
+ off_z * stride_zc \
+ block_id * stride_hc \
+ offs_cm[:, None] * stride_mc \
+ offs_cn[None, :] * stride_nc
tl.store(pc, c, mask=True)
def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=None):
if a.stride(2) != 1 and a.stride(3) != 1:
a = a.contiguous()
if b.stride(2) != 1 and b.stride(3) != 1:
b = b.contiguous()
# (A * B)^T = B^T * A^T
if trans_c:
a, b = b, a
trans_a, trans_b = not trans_b, not trans_a
# shape constraints
a_dim = -2 if trans_a else -1
b_dim = -1 if trans_b else -2
Ka, Kb = a.shape[a_dim], b.shape[b_dim]
if Ka != Kb:
raise ValueError(f"Inner dimension mismatch (A: {Ka} vs B: {Kb})")
# allocate output
if out is None:
c = torch.empty((a.shape[0], lut.shape[0], block, block), dtype=a.dtype, device=a.device)
else:
assert out.shape == (a.shape[0], lut.shape[0], block, block)
c = out
grid = [1, c.shape[1], c.shape[0]]
_sdd_kernel[grid](
a, b, c,
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(2), c.stride(3),
Ka, 0, lut,
TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4,
num_warps=4,
)
return c
def sdd_lut(layout, block, device):
lut = layout.nonzero(as_tuple=False).to(device).int()
lut = lut.contiguous()
return lut, None
# -----------------------------
# Dense = Sparse x Dense (DSD)
# This operation uses a look-up table that contains pre-computed pointer increments
# in order to minimize computations in the inner loop of the matmul kernel.
# -----------------------------
@triton.jit
def _dsd_kernel(
A, B, C,
stride_az, stride_ha, stride_am, stride_ak,
stride_zb, stride_hb, stride_bk, stride_bn,
stride_zc, stride_hc, stride_cm, stride_cn,
DS0, DS1, lut,
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr
):
# ------------ #
# - Prologue - #
# ------------ #
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
num_pid_m = tl.num_programs(0)
num_pid_n = tl.num_programs(1)
pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M)
pidz = tl.program_id(2)
header = lut + pid_n * 4
offset = tl.load(header + 0)
K = tl.load(header + 1)
column = tl.load(header + 2)
off_h = tl.load(header + 3)
pinc = lut + offset
# initialize pointers to A (sparse)
block_id = tl.load(pinc + 1)
block_id = tl.multiple_of(block_id, 8) # compiler hint
offs_am = tl.arange(0, TILE_M)
offs_ak = tl.arange(0, TILE_K)
pa = A + pidz * stride_az \
+ block_id * stride_ha \
+ offs_am[:, None] * stride_am \
+ offs_ak[None, :] * stride_ak
# initialize pointers to B (dense)
offs_bn = pid_m * TILE_N + tl.arange(0, TILE_N)
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N)
start_bk = tl.load(pinc)
start_bk = tl.multiple_of(start_bk, 8) # compiler hint
offs_bk = start_bk + tl.arange(0, TILE_K)
pb = B + pidz * stride_zb \
+ off_h * stride_hb \
+ offs_bn[None, :] * stride_bn \
+ offs_bk[:, None] * stride_bk
# ---------------- #
# Inner Loop #
# ---------------- #
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
pinc += 2
inc_a = tl.load(pinc + 1)
inc_a = tl.multiple_of(inc_a, 8)
inc_b = tl.load(pinc)
inc_b = tl.multiple_of(inc_b, 8)
for k in range(K, 0, -TILE_K):
a = tl.load(pa, mask=True)
b = tl.load(pb, mask=offs_bn[None, :] < DS0)
acc += tl.dot(a, b)
pa += inc_a
pb += inc_b * stride_bk
pinc += 2
inc_a = tl.load(pinc + 1)
inc_a = tl.multiple_of(inc_a, 8)
inc_b = tl.load(pinc)
inc_b = tl.multiple_of(inc_b, 8)
c = acc.to(C.dtype.element_ty)
# initialize pointers to C
offs_cm = column * TILE_M + tl.arange(0, TILE_M)
offs_cn = pid_m * TILE_N + tl.arange(0, TILE_N)
pc = C \
+ off_h * stride_hc \
+ pidz * stride_zc \
+ offs_cm[:, None] * stride_cm \
+ offs_cn[None, :] * stride_cn
tl.store(pc, c, mask=offs_cn[None, :] < DS0)
def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
if a.stride(2) != 1 and a.stride(3) != 1:
a = a.contiguous()
if b.stride(2) != 1 and b.stride(3) != 1:
b = b.contiguous()
# shapes / dtypes
AS1 = block * spdims[2 if trans_a else 1]
BS0 = b.size(0)
BS1 = b.size(1)
BS3 = b.size(2 if trans_b else 3)
dtype = a.dtype
# allocate output
CS0 = BS0
CS1 = BS1
CS2 = BS3 if trans_c else AS1
CS3 = AS1 if trans_c else BS3
if out is None:
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
else:
assert out.shape == (CS0, CS1, CS2, CS3)
c = out
# meta-parameter heuristics
TILE_N = 128
# compute output
grid = lambda meta: [triton.cdiv(BS3, meta['TILE_N']), width, BS0]
_dsd_kernel[grid](
a, b, c,
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
BS3, AS1, lut,
TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4,
num_warps=4, GROUP_SIZE_M=4,
)
# exit()
return c
def dsd_lut(layout, block, step, trans, device):
sizes = torch.sum(layout, 2 if trans else 1)
head_id, col_id = sizes.nonzero(as_tuple=True)
sizes = sizes.flatten()
segments = sizes * step
# pointer increments
if trans:
nnz = layout.nonzero(as_tuple=False)
else:
nnz = layout.transpose(1, 2).nonzero(as_tuple=False)
num_blocks = nnz.size(0)
offsets = torch.zeros_like(sizes)
offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
offsets = torch.min(offsets, (num_blocks - 1) * torch.ones_like(offsets))
# -------------------------------
# dense input pointer increments
# -------------------------------
# given a list of the indices for the first element of each non-zero block.
# For example, for the indices
# [32, 80, 128, 256, 288]
# we would generate the increments
# [32, 48, 48, 128, 32]
# ^
# index of first element
# Note that the inner loop matmul kernel may have a fixed step size (e.g., TILE_K)
# that is smaller than the block size, so we need to do a bit of extra work
# to handle this case
B_idx = nnz[:, 2] * block
B_incs = B_idx.clone()
B_incs[1:] -= B_idx[:-1]
div = block // step
B_incs = B_incs.view(-1, 1).repeat(1, div)
B_incs[:, 1:] = step
B_incs[:, 0] -= (div - 1) * step
# first increment for each reduction is actually the offset
B_incs[offsets[segments > 0], 0] = B_idx[offsets[segments > 0]]
B_incs = B_incs.view(-1)
# -------------------------------
# sparse input pointer increments
# -------------------------------
# same as above, except that the increments are in the sparse memory layout
if trans:
A_idx = torch.arange(num_blocks, device=layout.device)
else:
A_idx = torch.tensor([], dtype=torch.int64, device=layout.device)
current_offset = 0
for z in range(layout.size(0)):
layoutw = layout[z, :, :].clone().long()
msum = layoutw.sum()
layoutw[layoutw > 0] = 1 + torch.arange(msum, device=layout.device)
A_idx = torch.cat((A_idx, current_offset + layoutw.T[layoutw.T > 0] - 1))
current_offset += msum
A_incs = A_idx * block * block
A_incs[1:] -= A_idx[:-1] * block * block
A_incs = A_incs.view(-1, 1).repeat(1, div)
if trans:
A_incs[:, 1:] = step
A_incs[:, 0] -= (div - 1) * step
else:
A_incs[:, 1:] = step * block
A_incs[:, 0] -= (div - 1) * step * block
A_incs[offsets[segments > 0], 0] = A_idx[offsets[segments > 0]]
A_incs = A_incs.view(-1)
# create header
width = col_id.size(0)
offsets = offsets * 2 * div + 4 * width
segments = segments * div
try:
header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous()
except RuntimeError:
print(offsets.shape, segments.shape, col_id.shape, head_id.shape)
from .utils import layout_full_zero_check
answer, row_answer, col_answer, row_check, col_check = layout_full_zero_check(layout)
if answer:
if row_answer:
print('layout contains empty rows:', row_check.nonzero(as_tuple=False).squeeze())
if col_answer:
print('layout contains empty columns:', col_check.nonzero(as_tuple=False).squeeze())
print(layout[0].long())
raise
# create increments
incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous()
incs = torch.cat((incs, torch.zeros(2, device=incs.device, dtype=incs.dtype)))
# create lut
lut = torch.cat((header, incs))
lut = lut.type(torch.int32).to(device)
# create locks
return lut, width
# -----------------------------
# Dense = Dense x Sparse (DDS)
# -----------------------------
@triton.jit
def _dds_kernel(
A, B, C,
stride_za, stride_ha, stride_ma, stride_ka,
stride_zb, stride_hb, stride_bk, stride_bn,
stride_zc, stride_hc, stride_mc, stride_nc,
DS0, DS1, lut,
TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr,
):
# ------------ #
# - Prologue - #
# ------------ #
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
num_pid_m = tl.num_programs(0)
num_pid_n = tl.num_programs(1)
pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M)
pid_z = tl.program_id(2)
header = lut + pid_n * 4
offset = tl.load(header + 0)
AS1 = tl.load(header + 1)
column = tl.load(header + 2)
off_h = tl.load(header + 3)
pinc = lut + offset
# initialize pointers to A (dense)
offs_am = pid_m * TILE_M + tl.arange(0, TILE_M)
offs_am = tl.max_contiguous(tl.multiple_of(offs_am % DS0, TILE_M), TILE_M)
start_ak = tl.load(pinc)
start_ak = tl.multiple_of(start_ak, 8)
offs_ak = start_ak + tl.arange(0, TILE_K)
ptrs_a = A + pid_z * stride_za \
+ off_h * stride_ha \
+ offs_am[:, None] * stride_ma \
+ offs_ak[None, :] * stride_ka
# initialize pointers to B (sparse)
block_id = tl.load(pinc + 1)
block_id = tl.multiple_of(block_id, 8)
offs_bn = tl.arange(0, TILE_N)
offs_bk = tl.arange(0, TILE_K)
ptrs_b = B + pid_z * stride_zb \
+ block_id * stride_hb \
+ offs_bn[None, :] * stride_bn \
+ offs_bk[:, None] * stride_bk
# ---------------- #
# Inner Loop #
# ---------------- #
acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)
for k in range(AS1, 0, -TILE_K):
a = tl.load(ptrs_a, mask=offs_am[:, None] < DS0)
b = tl.load(ptrs_b, mask=True)
acc += tl.dot(a, b)
pinc += 2
inc_a = tl.load(pinc)
inc_b = tl.load(pinc + 1)
inc_a = tl.multiple_of(inc_a, 8)
inc_b = tl.multiple_of(inc_b, 8)
inc_a = inc_a * stride_ka
ptrs_a += inc_a
ptrs_b += inc_b
# ---------------- #
# Epilogue #
# ---------------- #
c = acc.to(C.dtype.element_ty)
# initialize pointers to C (dense)
offs_cm = pid_m * TILE_M + tl.arange(0, TILE_M)
offs_cn = column * TILE_N + tl.arange(0, TILE_N)
ptrs_c = C + off_h * stride_hc \
+ pid_z * stride_zc \
+ offs_cm[:, None] * stride_mc \
+ offs_cn[None, :] * stride_nc
# write back
tl.store(ptrs_c, c, mask=offs_cm[:, None] < DS0)
def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None):
if a.stride(2) != 1 and a.stride(3) != 1:
a = a.contiguous()
if b.stride(2) != 1 and b.stride(3) != 1:
b = b.contiguous()
# shapes / dtypes
AS0 = a.size(0)
AS1 = a.size(1)
AS2 = a.size(3 if trans_a else 2)
BS2 = block * spdims[1 if trans_b else 2]
dtype = a.dtype
# output
CS0 = AS0
CS1 = AS1
CS2 = BS2 if trans_c else AS2
CS3 = AS2 if trans_c else BS2
if out is None:
c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device)
else:
assert out.shape == (CS0, CS1, CS2, CS3)
c = out
TILE_M = {16: 256, 32: 256, 64: 128, 128: 128}[block]
grid = lambda meta: [triton.cdiv(AS2, meta['TILE_M']), width, AS0]
_dds_kernel[grid](
a, b, c,
a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3),
b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3),
c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3),
AS2, BS2, lut,
TILE_M=TILE_M, TILE_N=block, TILE_K=min(block, 32), BLOCK=block, num_stages=4,
num_warps=4, GROUP_SIZE_M=4,
)
return c
##############
# MAIN API #
##############
class _matmul(torch.autograd.Function):
fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul}
@staticmethod
def forward(
ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block,
c_lut, c_width, da_lut, da_width, db_lut, db_width, out
):
c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out)
# save for backward
ctx.save_for_backward(a, b)
ctx.da_lut = da_lut
ctx.da_width = da_width
ctx.db_lut = db_lut
ctx.db_width = db_width
ctx.mode = mode
ctx.spdims = spdims
ctx.block = block
ctx.trans_a = trans_a
ctx.trans_b = trans_b
ctx.trans_c = trans_c
ctx.has_out = out is not None
return c
@staticmethod
def backward(ctx, dc):
# saved for backward
a, b = ctx.saved_tensors
da, db = None, None
mode = ctx.mode
# gradients w.r.t. a
if ctx.needs_input_grad[0]:
mode_da = mode[1] + mode[0] + mode[2]
da = _matmul.fn[mode_da](
dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut, ctx.da_width,
)
# gradients w.r.t. b
if ctx.needs_input_grad[1]:
mode_db = mode[2] + mode[1] + mode[0]
db = _matmul.fn[mode_db](
a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut, ctx.db_width,
)
dout = dc if ctx.has_out else None
return da, db, None, None, None,\
None, None, None, None,\
None, None, None, None, None, dout
class matmul:
def __init__(self, layout, block, mode, device, trans_a=False, trans_b=False, trans_c=False):
if mode not in ['sdd', 'dsd', 'dds']:
raise NotImplementedError('Supported modes are: sdd, dsd, dds')
self.block = block
self.mode = mode
self.trans_a = trans_a
self.trans_b = trans_b
self.trans_c = trans_c
self.layout = layout
self.spdims = layout.shape
step = min(block, 32)
if self.mode == 'sdd':
self.c_lut, self.c_width = sdd_lut(layout, block, device)
self.da_lut, self.da_width = dsd_lut(layout, block, step, True, device)
self.db_lut, self.db_width = dsd_lut(layout, block, step, False, device)
if self.mode == 'dsd':
self.c_lut, self.c_width = dsd_lut(layout, block, step, not self.trans_a, device)
self.da_lut, self.da_width = sdd_lut(layout, block, device)
self.db_lut, self.db_width = dsd_lut(layout, block, step, self.trans_a, device)
if self.mode == 'dds':
self.c_lut, self.c_width = dsd_lut(layout, block, step, self.trans_b, device)
self.da_lut, self.da_width = dsd_lut(layout, block, step, not self.trans_b, device)
self.db_lut, self.db_width = sdd_lut(layout, block, device)
def __call__(self, a, b, out=None):
c = _matmul.apply(
a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block,
self.c_lut, self.c_width,
self.da_lut, self.da_width,
self.db_lut, self.db_width,
out
)
return c

Просмотреть файл

@ -0,0 +1,234 @@
import torch
import triton
import triton.language as tl
def num_warps(n):
if n < 512:
return 4
if n < 2048:
return 8
return 16
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['sizemax'] * nargs['BLOCK'])})
@triton.heuristics({'TN': lambda nargs: triton.next_power_of_2(nargs['sizemax'] * nargs['BLOCK'])})
@triton.jit
def _forward(
X, scale, LUT, RPE, KP_M, ATTN_M, is_causal, sizemax, stride_zx, stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
TN: tl.constexpr, BLOCK: tl.constexpr, APPLY_SCALE: tl.constexpr, APPLY_RPE: tl.constexpr, APPLY_KP_MASK: tl.constexpr,
KP_MASK_MUL: tl.constexpr, APPLY_ATTN_MASK: tl.constexpr, ATTN_MASK_MUL: tl.constexpr,
):
pidhm = tl.program_id(0)
pidz = tl.program_id(1)
# create index ranges
rxm = pidhm % BLOCK
rbm = pidhm // BLOCK
rxn = tl.arange(0, TN) % BLOCK
rbn = tl.arange(0, TN) // BLOCK
# extract information from LUT
header = LUT + rbm * 2
size = tl.load(header + 0)
offset = tl.load(header + 1)
check = rbn < size
rbmn = tl.where(check, rbn, size - 1)
# block id and column id
blockid = tl.load(LUT + offset + rbmn * 4 + 0)
columnid = tl.load(LUT + offset + rbmn * 4 + 1)
rowid = tl.load(LUT + offset + rbmn * 4 + 2)
headid = tl.load(LUT + offset + rbmn * 4 + 3)
# pointers to X
px = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
x = tl.load(px, mask=check, other=-float('inf'))
x = x.to(tl.float32)
# apply scale
if APPLY_SCALE:
x = x * scale
# apply RPE
if APPLY_RPE:
prpe = RPE + pidz * stride_zrpe + headid * stride_hrpe + columnid * BLOCK + rowid * BLOCK * stride_srpe + rxm * stride_srpe + rxn
rpe = tl.load(prpe, mask=check, other=0)
x = x + rpe
# apply key-padding mask
if APPLY_KP_MASK:
pkp_m = KP_M + pidz * stride_zkpm + columnid * BLOCK + rxn
kp_m = tl.load(pkp_m, mask=check, other=-float('inf'))
if KP_MASK_MUL:
kp_m = tl.where(kp_m == 0, -float('inf'), 0.)
x = x + kp_m
# apply attention mask
if APPLY_ATTN_MASK:
pattn_m = ATTN_M + columnid * BLOCK + rowid * BLOCK * stride_zattnm + rxm * stride_zattnm + rxn
attn_m = tl.load(pattn_m, mask=check, other=-float('inf'))
if ATTN_MASK_MUL:
attn_m = tl.where(attn_m == 0, -float('inf'), 0.)
x = x + attn_m
# apply causal mask
is_in_upper_triangle = columnid * BLOCK + rxn > rowid * BLOCK + rxm
x = x + tl.where(is_in_upper_triangle & is_causal, -float('inf'), 0.)
# computation
x = tl.softmax(x)
tl.store(px, x, mask=check)
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['sizemax'] * nargs['BLOCK'])})
@triton.heuristics({'TN': lambda nargs: triton.next_power_of_2(nargs['sizemax']) * nargs['BLOCK']})
@triton.jit
def _backward(X, scale, DX, LUT, sizemax, stride_zx, stride_zdx, TN: tl.constexpr, BLOCK: tl.constexpr):
pidhm = tl.program_id(0)
pidz = tl.program_id(1)
# create index ranges
rxm = pidhm % BLOCK
rbm = pidhm // BLOCK
rxn = tl.arange(0, TN) % BLOCK
rbn = tl.arange(0, TN) // BLOCK
# extract information from look-up table
header = LUT + rbm * 2
size = tl.load(header + 0)
offset = tl.load(header + 1)
# bounds checking on lut
check = rbn < size
rbmn = tl.where(check, rbn, size - 1)
# initialize pointers to block-sparse input
blockid = tl.load(LUT + offset + rbmn * 4)
X = X + pidz * stride_zx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
DX = DX + pidz * stride_zdx + blockid * BLOCK * BLOCK + rxm * BLOCK + rxn
# compute fused softmax backward
x = tl.load(X, mask=check, other=0)
dx = tl.load(DX, mask=check, other=0)
x = x.to(tl.float32)
dx = dx.to(tl.float32)
y = x * (dx - tl.sum(x * dx, 0)) * scale
tl.store(DX, y, mask=check)
class _softmax(torch.autograd.Function):
@staticmethod
def make_lut(layout, block, device):
# sizes along rows
sizes = layout.sum(-1).view(-1)
# offsets in block format
offsets = torch.zeros_like(sizes)
offsets[1:] = torch.cumsum(sizes[:-1], dim=0)
# block indices
layout_sum = sizes.sum()
idx = torch.arange(layout_sum, device=layout.device)
layout_nonzero = layout.nonzero(as_tuple=False)
head = layout_nonzero[:, 0]
rows = layout_nonzero[:, 1]
columns = layout_nonzero[:, 2]
core = torch.stack((idx, columns, rows, head), dim=1).view(-1)
# construct look-up table
offsets = offsets * 4 + 2 * sizes.numel()
header = torch.stack((sizes, offsets), dim=1).view(-1)
lut = torch.cat((header, core)).type(torch.int32).to(device)
return lut, int(sizes.max())
@staticmethod
def forward(
ctx, x, scale, rpe,
key_padding_mask, attn_mask,
kp_mask_mode, attn_mask_mode,
is_causal,
spdims, block, lut, maxlut
):
apply_scale = False if scale == 1.0 else True
# handle None rpe
if rpe is None:
apply_rpe = False
stride_zrpe, stride_hrpe, stride_srpe = 0, 0, 0
rpe = torch.empty(0, dtype=x.dtype, device=x.device)
else:
apply_rpe = True
stride_zrpe, stride_hrpe, stride_srpe = rpe.stride(0), rpe.stride(1), rpe.stride(2)
# handle None key_padding_mask
if key_padding_mask is None:
apply_kp_mask = False
stride_zkpm = 0
key_padding_mask = torch.empty(0, dtype=x.dtype, device=x.device)
else:
apply_kp_mask = True
stride_zkpm = key_padding_mask.stride(0)
# handle None attention_mask
if attn_mask is None:
apply_attn_mask = False
stride_zattnm = 0
attn_mask = torch.empty(0, dtype=x.dtype, device=x.device)
else:
apply_attn_mask = True
stride_zattnm = attn_mask.stride(0)
# run kernel
M = x.shape[0]
grid = [spdims[0] * spdims[1] * block, M]
_forward[grid](x, scale, lut, rpe, key_padding_mask, attn_mask, is_causal, maxlut, x.stride(0),
stride_zrpe, stride_hrpe, stride_srpe, stride_zkpm, stride_zattnm,
BLOCK=block,
APPLY_SCALE=apply_scale,
APPLY_RPE=apply_rpe,
APPLY_KP_MASK=apply_kp_mask,
APPLY_ATTN_MASK=apply_attn_mask,
KP_MASK_MUL=(kp_mask_mode == 'mul'),
ATTN_MASK_MUL=(attn_mask_mode == 'mul'))
# save to context
ctx.mark_dirty(x)
ctx.save_for_backward(x, lut)
ctx.spdims = spdims
ctx.block = block
ctx.maxlut = maxlut
ctx.scale = scale
ctx.apply_scale = apply_scale
ctx.apply_rpe = apply_rpe
ctx.apply_kp_mask = apply_kp_mask
ctx.apply_attn_mask = apply_attn_mask
ctx.kp_mask_mode = kp_mask_mode
ctx.attn_mask_mode = attn_mask_mode
return x
@staticmethod
def backward(ctx, dx):
# retrieve from context
x, lut = ctx.saved_tensors
# run kernel
M = x.shape[0]
grid = lambda opt: [ctx.spdims[0] * ctx.spdims[1] * ctx.block, M]
_backward[grid](x, ctx.scale, dx, lut, ctx.maxlut, x.stride(0), dx.stride(0), BLOCK=ctx.block)
return dx, None, None, None, None, None, None, None, None, None, None, None, None, None, None
class softmax:
def make_lut(self, device):
key = (device, )
if key not in self.lut_cache:
self.lut_cache[key] = _softmax.make_lut(self.layout, self.block, device)
return self.lut_cache[key]
def __init__(self, layout, block):
self.spdims = layout.shape
self.layout = layout
self.block = block
self.lut_cache = dict()
def __call__(
self, x, scale=1., rpe=None,
key_padding_mask=None, attn_mask=None,
key_padding_mask_mode='add', attn_mask_mode='add',
is_causal=False
):
if rpe is not None and rpe.dtype != x.dtype:
raise ValueError('relative position embedding must be %s' % x.dtype)
if attn_mask is not None and attn_mask.dtype != x.dtype:
raise ValueError('Attention mask must be %s' % x.dtype)
if key_padding_mask is not None and key_padding_mask.dtype != x.dtype:
raise ValueError('Key padding mask must be %s' % x.dtype)
lut, maxlut = self.make_lut(x.device)
x = _softmax.apply(
x, scale, rpe,
key_padding_mask, attn_mask,
key_padding_mask_mode, attn_mask_mode,
is_causal,
self.spdims, self.block,
lut, maxlut
)
return x

Просмотреть файл

@ -0,0 +1,6 @@
def layout_full_zero_check(input_layout):
row_check = input_layout.sum(dim=2).eq(0) # (H, L // block)
col_check = input_layout.sum(dim=1).eq(0)
row_answer = bool(row_check.any())
col_answer = bool(col_check.any())
return row_answer or col_answer, row_answer, col_answer, row_check, col_check

Просмотреть файл

@ -0,0 +1,129 @@
from typing import Optional
from dataclasses import dataclass
def get_item(schemes, idx):
if schemes is None:
return None
len_schemes = len(schemes)
if len_schemes > idx:
return schemes[idx]
if len_schemes == 1:
return schemes[0]
raise IndexError('idx (%d) is out of the length (%d) of the schemes' % (idx, len_schemes), schemes)
@dataclass(frozen=True)
class HeadAttentionScheme:
head_sum2sum: Optional[tuple]
head_sum2sum_self: str
head_sum2con: Optional[tuple]
head_sum2con_self: str
head_con2sum: Optional[tuple]
head_con2sum_self: str
head_con2con: Optional[tuple]
head_con2con_self: str
head_con2con_causal: bool
head_pref2pref_mode: str
head_sum2pref_mode: str
head_pref2sum_mode: str
head_con2pref_mode: str
head_pref2con_mode: str
class LayerAttentionScheme(object):
def __init__(
self,
layer_sum2sum, layer_sum2sum_self,
layer_sum2con, layer_sum2con_self,
layer_con2sum, layer_con2sum_self,
layer_con2con, layer_con2con_self,
layer_con2con_causal,
layer_pref2pref_mode,
layer_sum2pref_mode,
layer_pref2sum_mode,
layer_con2pref_mode,
layer_pref2con_mode,
num_layer_heads,
):
self.num_layer_heads = num_layer_heads
self.heads_schemes = []
for idx in range(self.num_layer_heads):
self.heads_schemes.append(
HeadAttentionScheme(
get_item(layer_sum2sum, idx), layer_sum2sum_self,
get_item(layer_sum2con, idx), layer_sum2con_self,
get_item(layer_con2sum, idx), layer_con2sum_self,
get_item(layer_con2con, idx), layer_con2con_self,
layer_con2con_causal,
layer_pref2pref_mode,
layer_sum2pref_mode,
layer_pref2sum_mode,
layer_con2pref_mode,
layer_pref2con_mode
)
)
self.heads_schemes = tuple(self.heads_schemes)
self.same_for_all_heads = len(set(self.heads_schemes)) == 1
def __hash__(self):
return hash(self.heads_schemes)
def __eq__(self, other):
return self.heads_schemes == other.heads_schemes
def __getitem__(self, idx):
return self.heads_schemes[idx]
def __len__(self):
return self.num_layer_heads
class AttentionScheme(object):
def __init__(
self,
sum2sum, sum2sum_self,
sum2con, sum2con_self,
con2sum, con2sum_self,
con2con, con2con_self,
con2con_causal,
pref2pref_mode,
sum2pref_mode,
pref2sum_mode,
con2pref_mode,
pref2con_mode,
num_layers, num_layers_heads,
):
self.num_layers = num_layers
self.layers_schemes = []
for idx in range(self.num_layers):
self.layers_schemes.append(
LayerAttentionScheme(
get_item(sum2sum, idx), sum2sum_self,
get_item(sum2con, idx), sum2con_self,
get_item(con2sum, idx), con2sum_self,
get_item(con2con, idx), con2con_self,
con2con_causal,
pref2pref_mode,
sum2pref_mode,
pref2sum_mode,
con2pref_mode,
pref2con_mode,
get_item(num_layers_heads, idx)
)
)
self.layers_schemes = tuple(self.layers_schemes)
def __hash__(self):
return hash(self.layers_schemes)
def __eq__(self, other):
return self.layers_schemes == other.layers_schemes
def __getitem__(self, idx):
return self.layers_schemes[idx]
def __len__(self):
return self.num_layers

Просмотреть файл

@ -0,0 +1,6 @@
from museformer.tools.singleton_decorator import singleton
@singleton
class FourDimPocket(dict):
pass

Просмотреть файл

@ -0,0 +1,87 @@
import logging
import os
import numpy as np
from fairseq.data import data_utils
from fairseq.data.base_wrapper_dataset import BaseWrapperDataset
from fairseq.data.indexed_dataset import MMapIndexedDatasetBuilder
from ..dictionary.compound_dictionary import CompoundDictionary
logger = logging.getLogger(__name__)
def get_beat_ids(src_tokens, dictionary, ts_instead_of_tempo=False):
change = 's' if ts_instead_of_tempo else 't'
beat_mask = src_tokens.ge(dictionary.type_dict['o'][0]) & src_tokens.lt(dictionary.type_dict['o'][1])
change_mask = src_tokens.ge(dictionary.type_dict[change][0]) & src_tokens.lt(dictionary.type_dict[change][1])
bar_mask = src_tokens.ge(dictionary.type_dict['b'][0]) & src_tokens.lt(dictionary.type_dict['b'][1])
special_mask = src_tokens.lt(4)
no_beat_mask = change_mask | bar_mask | special_mask
del change_mask, bar_mask
src_tokens = src_tokens.clone() - (dictionary.type_dict['o'][0] - 1)
cur_beat = 0
for idx, (token, beat_token) in enumerate(zip(src_tokens, beat_mask)):
if beat_token:
cur_beat = token
else:
src_tokens[idx] = cur_beat
src_tokens.masked_fill_(no_beat_mask, 0)
return src_tokens
class AddBeatDataset(BaseWrapperDataset):
def __init__(
self,
dataset, dictionary: CompoundDictionary,
cache_data_label='default', dataset_name=None,
mask_ts_instead_of_tempo=False,
):
super().__init__(dataset)
self.dictionary = dictionary
self.cache_data_label = cache_data_label
self.dataset_name = dataset_name
self.mask_ts_instead_of_tempo = mask_ts_instead_of_tempo
self.cache_list = None
cache_name = 'add-beat-dataset_%s_%s' % (self.cache_data_label, self.dataset_name)
cache_dataset_dir = 'cached_datasets/'
if cache_data_label is not None:
cache_dataset_dir = os.path.join(cache_dataset_dir, cache_data_label)
cache_path = os.path.join(cache_dataset_dir, cache_name)
if all([os.path.isfile(cache_path + suffix) for suffix in ('.beat_ids.bin', '.beat_ids.idx')]):
pass
else:
logger.info('Building up beat_ids dataset for %s ...' % dataset_name)
os.makedirs(cache_dataset_dir, exist_ok=True)
self.beat_ids_builder = MMapIndexedDatasetBuilder(
cache_path + '.beat_ids.bin', dtype=np.int32
)
self.__prepare_dataset()
self.beat_ids_builder.finalize(cache_path + '.beat_ids.idx')
del self.beat_ids_builder
self.beat_ids_dataset = data_utils.load_indexed_dataset(cache_path + '.beat_ids',
dictionary=None, dataset_impl='mmap')
assert len(self.beat_ids_dataset) == len(self.dataset)
for idx, beat_ids in enumerate(self.beat_ids_dataset):
assert len(beat_ids) == self.dataset.size(idx), (idx, self.dataset.size(idx), len(beat_ids))
logger.info('Checked the cached beat_ids dataset.')
def __prepare_dataset(self):
for sample in self.dataset:
src_tokens = sample[0]
beat_ids = get_beat_ids(src_tokens, self.dictionary, ts_instead_of_tempo=self.mask_ts_instead_of_tempo)
self.beat_ids_builder.add_item(beat_ids)
def __getitem__(self, idx):
beat_ids = self.beat_ids_dataset[idx]
sample = self.dataset[idx]
return (*sample, beat_ids)
def collater(self, samples):
raise NotImplementedError

Просмотреть файл

@ -0,0 +1,129 @@
# import os
# from copy import deepcopy
# import pickle
import logging
# import numpy as np
import torch
# from fairseq.data import data_utils, FairseqDataset
from fairseq.data.base_wrapper_dataset import BaseWrapperDataset
# from fairseq.data.indexed_dataset import MMapIndexedDatasetBuilder
logger = logging.getLogger(__name__)
def get_bar_chunk_points(seq: torch.Tensor, eob_index, begin_idx=0, take_bos_as_bar=False, bos_index=None):
# seq: (seq_len,)
# eob_index: int
is_complete_bar = seq[-1] == eob_index
indices = seq.eq(eob_index).nonzero(as_tuple=False).squeeze(1) # (num_bars,)
indices = indices + 1
indices = torch.cat(
(indices.new_tensor([begin_idx]), indices), dim=0
)
len_seq = len(seq)
if not is_complete_bar and len_seq > begin_idx:
indices = torch.cat(
(indices, indices.new_tensor([len_seq])), dim=0
)
if take_bos_as_bar:
assert seq[0] == bos_index
assert begin_idx == 1
indices = torch.cat(
(torch.tensor([0]), indices), dim=0
)
return indices, is_complete_bar
class BarChunkSequenceDataset(BaseWrapperDataset):
def __init__(
self,
dataset, src_dict, eob,
eos_appended=True,
offset=1,
take_bos_as_bar=True,
bos_index=None,
):
super().__init__(dataset)
self.src_dict = src_dict
self.eos_appended = eos_appended
self.eob = eob
self.offset = offset
self.take_bos_as_bar = take_bos_as_bar
self.bos_index = bos_index
self.cache = [None] * len(self.dataset)
def __iter__(self):
len_dataset = len(self)
for idx in range(len_dataset):
yield self[idx]
def __getitem__(self, index):
src, tgt = self.dataset[index] # all include eoc
chunk_points = self.cache[index]
if chunk_points is None:
chunk_points, complete = get_bar_chunk_points(
src, self.eob, begin_idx=self.offset,
take_bos_as_bar=self.take_bos_as_bar, bos_index=self.bos_index
)
assert complete
self.cache[index] = chunk_points
return src, tgt, chunk_points
def collater(self, samples):
raise NotImplementedError("Dataset class %s is not designed for collating samples." % self.__class__.__name__)
class FixedChunkingLengthDataset(BaseWrapperDataset):
def __init__(
self,
dataset, fixed_chunking_length
):
assert fixed_chunking_length is not None
super().__init__(dataset)
self.fixed_chunking_length = fixed_chunking_length
def __iter__(self):
len_dataset = len(self)
for idx in range(len_dataset):
yield self[idx]
def __getitem__(self, index):
src, tgt = self.dataset[index] # all include eoc
sample_len = len(src)
chunk_points = torch.arange(0, sample_len, self.fixed_chunking_length)
chunk_points = torch.cat((chunk_points, chunk_points.new_tensor([sample_len])), dim=0)
return src, tgt, chunk_points
def collater(self, samples):
raise NotImplementedError("Dataset class %s is not designed for collating samples." % self.__class__.__name__)
def ChunkSequenceDataset(
dataset, src_dict,
eob, eoc,
chunking_scheme='bar_aware',
chunking_length=None,
dataset_name=None,
cache_data_label=None,
cache_sequence=None,
offset=0,
take_bos_as_bar=False, bos_index=None
):
if chunking_scheme == 'bar_aware':
return BarChunkSequenceDataset(
dataset, src_dict, eob,
offset=offset,
take_bos_as_bar=take_bos_as_bar,
bos_index=bos_index
)
elif chunking_scheme == 'fixed':
return FixedChunkingLengthDataset(
dataset, chunking_length
)
raise NotImplementedError(chunking_scheme)

Просмотреть файл

@ -0,0 +1,43 @@
from fairseq.data import FairseqDataset
from torch.utils.data.dataloader import default_collate
class ExtendedWrapperDataset(FairseqDataset):
def __init__(self, dataset):
super().__init__()
self.dataset = dataset
def __getitem__(self, index):
return self.dataset[index]
def __len__(self):
return len(self.dataset)
def collater(self, samples):
if hasattr(self.dataset, "collater"):
return self.dataset.collater(samples)
else:
return default_collate(samples)
@property
def sizes(self):
return self.dataset.sizes
def num_tokens(self, index):
return self.dataset.sizes[index]
def size(self, index):
return self.dataset.sizes[index]
@property
def supports_prefetch(self):
return getattr(self.dataset, "supports_prefetch", False)
def attr(self, attr: str, index: int):
return self.dataset.attr(attr, index)
def set_epoch(self, epoch):
super().set_epoch(epoch)
if hasattr(self.dataset, "set_epoch"):
self.dataset.set_epoch(epoch)

Просмотреть файл

@ -0,0 +1,14 @@
import torch
from fairseq.data.base_wrapper_dataset import BaseWrapperDataset
class MusicMonolingualDataset(BaseWrapperDataset):
def __int__(self, dataset):
super().__init__(dataset)
def __getitem__(self, index):
sample = self.dataset[index] # (len,)
return torch.cat((sample[-1:], sample[:-1]), dim=0), sample
def collater(self, samples):
raise NotImplementedError

Просмотреть файл

@ -0,0 +1,185 @@
import logging
import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
from fairseq.data import data_utils
from fairseq.data.base_wrapper_dataset import BaseWrapperDataset
logger = logging.getLogger(__name__)
class PostProcessDataset(BaseWrapperDataset):
def __init__(self, dataset):
super().__init__(dataset)
def __len__(self):
return len(self.dataset)
def __iter__(self):
for idx in range(len(self)):
yield self[idx]
def __getitem__(self, index):
src, tgt, chunk_points, beat_ids, num_chunks, num_complete_chunks, num_prefix = self.dataset[index]
seq_len = src.shape[0]
new_sample = {
'id': index,
'src_tokens': src,
'src_length': seq_len,
'target': tgt,
'chunk_points': chunk_points,
'num_chunks': num_chunks,
'num_complete_chunks': num_complete_chunks,
'num_pref': num_prefix,
'beat_ids': beat_ids,
}
return new_sample
@property
def sizes(self):
return self.dataset.sizes
def size(self, index):
return self.dataset.size(index)
def num_tokens(self, index):
return self.dataset.size(index)
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
return np.arange(len(self), dtype=np.int64)
def collater(self, samples):
if len(samples) == 0:
return {}
bsz = len(samples)
sample_id = torch.tensor([s['id'] for s in samples])
src_tokens = [s['src_tokens'] for s in samples]
src_lengths = [s['src_length'] for s in samples]
target = [s['target'] for s in samples]
chunk_points = [s['chunk_points'] for s in samples]
num_chunks = [s['num_chunks'] for s in samples]
num_complete_chunks = [s['num_complete_chunks'] for s in samples]
num_prefix = [s['num_pref'] for s in samples]
beat_ids = [s['beat_ids'] for s in samples]
ntokens = sum(src_lengths)
src_tokens = pad_sequence(src_tokens, batch_first=True, padding_value=0)
src_lengths = torch.tensor(src_lengths, dtype=torch.long)
target = pad_sequence(target, batch_first=True, padding_value=0)
chunk_points = data_utils.collate_tokens(
chunk_points, 0
)
num_chunks = torch.tensor(num_chunks, dtype=torch.long)
num_complete_chunks = torch.tensor(num_complete_chunks, dtype=torch.long)
num_prefix = torch.tensor(num_prefix, dtype=torch.long)
beat_ids = pad_sequence(beat_ids, batch_first=True, padding_value=0)
batch = {
'id': sample_id,
'nsentences': bsz,
'ntokens': ntokens,
'net_input': {
'src_tokens': src_tokens,
'src_lengths': src_lengths,
'chunk_points': chunk_points,
'num_chunks': num_chunks,
'num_complete_chunks': num_complete_chunks,
'num_pref': num_prefix,
'beat_ids': beat_ids,
},
'target': target,
}
return batch
def filter_indices_by_size(self, indices, max_sizes):
"""
Filter a list of sample indices. Remove those that are longer than
specified in *max_sizes*.
WARNING: don't update, override method in child classes
Args:
indices (np.array): original array of sample indices
max_sizes (int or list[int] or tuple[int]): max sample size,
can be defined separately for src and tgt (then list or tuple)
Returns:
np.array: filtered sample array
list: list of removed indices
"""
# print(indices)
if isinstance(max_sizes, float) or isinstance(max_sizes, int):
if hasattr(self, "sizes") and isinstance(self.sizes, np.ndarray):
ignored = indices[self.sizes[indices] > max_sizes].tolist()
indices = indices[self.sizes[indices] <= max_sizes]
elif (
hasattr(self, "sizes")
and isinstance(self.sizes, list)
and len(self.sizes) == 1
):
ignored = indices[self.sizes[0][indices] > max_sizes].tolist()
indices = indices[self.sizes[0][indices] <= max_sizes]
else:
indices, ignored = data_utils._filter_by_size_dynamic(
indices, self.size, max_sizes
)
else:
indices, ignored = data_utils._filter_by_size_dynamic(
indices, self.size, max_sizes
)
if len(ignored) > 0:
print(self.sizes)
print(ignored)
print(max_sizes)
return indices, ignored
def batch_by_size(
self,
indices,
max_tokens=None,
max_sentences=None,
required_batch_size_multiple=1,
):
"""
Given an ordered set of indices, return batches according to
*max_tokens*, *max_sentences* and *required_batch_size_multiple*.
"""
from fairseq.data import data_utils
fixed_shapes = self.get_batch_shapes()
if fixed_shapes is not None:
def adjust_bsz(bsz, num_tokens):
if bsz is None:
assert max_tokens is not None, "Must specify --max-tokens"
bsz = max_tokens // num_tokens
if max_sentences is not None:
bsz = min(bsz, max_sentences)
elif (
bsz >= required_batch_size_multiple
and bsz % required_batch_size_multiple != 0
):
bsz -= bsz % required_batch_size_multiple
return bsz
fixed_shapes = np.array(
[
[adjust_bsz(bsz, num_tokens), num_tokens]
for (bsz, num_tokens) in fixed_shapes
]
)
return data_utils.batch_by_size(
indices,
num_tokens_fn=self.num_tokens,
max_tokens=max_tokens,
max_sentences=max_sentences,
required_batch_size_multiple=required_batch_size_multiple,
fixed_shapes=fixed_shapes,
)

Просмотреть файл

@ -0,0 +1,15 @@
from fairseq.data.base_wrapper_dataset import BaseWrapperDataset
class PrefixTagDataset(BaseWrapperDataset):
def __init__(self, dataset, num_prefix):
super().__init__(dataset)
self.dataset = dataset
self.num_prefix = num_prefix
def __getitem__(self, idx):
sample = self.dataset[idx]
return (*sample, self.num_prefix)
def collater(self, samples):
raise NotImplementedError("Dataset class %s is not designed for collating samples." % self.__class__.__name__)

Просмотреть файл

@ -0,0 +1,53 @@
# Author: Botao Yu
import numpy as np
from fairseq.data.base_wrapper_dataset import BaseWrapperDataset
class RemoveLongSizeSamplesDataset(BaseWrapperDataset):
def __init__(self, dataset, max_size):
super().__init__(dataset)
self.max_size = max_size
max_token_select = self.dataset.sizes <= self.max_size
final_select = max_token_select
self.selected_index = np.nonzero(final_select)[0]
def __getitem__(self, index):
origin_index = self.selected_index[index]
return self.dataset[origin_index]
def __len__(self):
return len(self.selected_index)
def __iter__(self):
len_dataset = len(self)
for idx in range(len_dataset):
yield self[idx]
@property
def sizes(self):
return self.dataset.sizes[self.selected_index]
def size(self, index):
origin_index = self.selected_index[index]
return self.dataset.size(origin_index)
def num_tokens(self, index):
origin_index = self.selected_index[index]
return self.dataset.num_tokens(origin_index)
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
return np.arange(len(self), dtype=np.int64)
def MayRemoveLongSizeSamplesDataset(dataset, max_size):
if max_size is None:
return dataset
new_dataset = RemoveLongSizeSamplesDataset(dataset, max_size)
if len(new_dataset) == len(dataset):
del new_dataset
return dataset
return new_dataset

Просмотреть файл

@ -0,0 +1,61 @@
# Author: Botao Yu
import numpy as np
from fairseq.data import FairseqDataset
class RemoveShortSizeSamplesDataset(FairseqDataset):
def __init__(self, dataset, min_size):
super().__init__()
# st = time.time()
self.dataset = dataset
self.min_size = min_size
min_token_select = self.dataset.sizes >= self.min_size
final_select = min_token_select
self.selected_index = np.nonzero(final_select)[0]
# et = time.time()
# print('%s dataset: %.2fs' % (self.__class__.__name__, et - st))
def __getitem__(self, index):
origin_index = self.selected_index[index]
return self.dataset[origin_index]
def __len__(self):
return len(self.selected_index)
def __iter__(self):
len_dataset = len(self)
for idx in range(len_dataset):
yield self[idx]
@property
def sizes(self):
return self.dataset.sizes[self.selected_index]
def size(self, index):
origin_index = self.selected_index[index]
return self.dataset.size(origin_index)
def num_tokens(self, index):
origin_index = self.selected_index[index]
return self.dataset.num_tokens(origin_index)
def ordered_indices(self):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
return np.arange(len(self), dtype=np.int64)
def MayRemoveShortSizeSamplesDataset(dataset, min_size):
if min_size is None or min_size == 0:
return dataset
new_dataset = RemoveShortSizeSamplesDataset(dataset, min_size)
if len(new_dataset) == len(dataset):
del new_dataset
return dataset
return new_dataset

Просмотреть файл

@ -0,0 +1,44 @@
import torch
from fairseq.data.base_wrapper_dataset import BaseWrapperDataset
class TruncateMusicDataset(BaseWrapperDataset):
def __init__(self, dataset, truncated_length):
super().__init__(dataset)
assert truncated_length is None or truncated_length > 1
self.truncated_length = truncated_length
self._sizes = self.dataset.sizes.copy()
self._sizes[self._sizes > self.truncated_length] = self.truncated_length
def __getitem__(self, idx):
src, tgt, chunk_points, beat_ids = self.dataset[idx]
num_chunks = len(chunk_points) - 1
if self.truncated_length is None or self.dataset.size(idx) <= self.truncated_length:
return src, tgt, chunk_points, beat_ids, num_chunks, num_chunks # num_complete_chunks
src = src[:self.truncated_length]
tgt = tgt[:self.truncated_length]
beat_ids = beat_ids[:self.truncated_length]
chunk_points = chunk_points[chunk_points.le(self.truncated_length)]
if chunk_points[-1] == self.truncated_length:
num_chunks = len(chunk_points) - 1
num_complete_chunks = num_chunks
else:
num_chunks = len(chunk_points)
num_complete_chunks = num_chunks - 1
chunk_points = torch.cat((chunk_points, chunk_points.new_tensor([self.truncated_length])))
return src, tgt, chunk_points, beat_ids, num_chunks, num_complete_chunks
@property
def sizes(self):
return self._sizes
def size(self, index):
return self._sizes[index]
def num_tokens(self, index):
return self._sizes[index]
def collater(self, samples):
raise NotImplementedError

Просмотреть файл

@ -0,0 +1,51 @@
import re
from fairseq.data import Dictionary
class CompoundDictionary(Dictionary):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.type_dict = None
self.special_pad_word = None
self.special_pad_index = None
@classmethod
def load(cls, f):
d = cls()
d.add_from_file(f)
d.construct_types()
return d
def construct_types(self):
type_dict = {}
type_token = None
current_begin = None
considered_type_token = set()
for idx, symbol in enumerate(self.symbols):
match = re.fullmatch('([a-z]+)\-\d+', symbol)
if match is None:
if current_begin is not None:
type_dict[type_token] = (current_begin, idx)
type_token = None
current_begin = None
else:
now_type_token = match.group(1)
if current_begin is not None:
if type_token == now_type_token:
continue
else:
type_dict[type_token] = (current_begin, idx)
assert now_type_token not in considered_type_token
type_token = now_type_token
considered_type_token.add(type_token)
current_begin = idx
else:
assert now_type_token not in considered_type_token
type_token = now_type_token
considered_type_token.add(type_token)
current_begin = idx
if current_begin is not None:
type_dict[type_token] = (current_begin, len(self.symbols))
self.type_dict = type_dict
return type_dict

Просмотреть файл

@ -0,0 +1,6 @@
import torch.nn as nn
class EmbeddingLayer(nn.Module):
def __init__(self):
pass

Просмотреть файл

@ -0,0 +1,25 @@
import torch
import torch.nn as nn
def generate_sinusoid_embeddings(max_len, dim):
inv_freq = 1 / (10000 ** (torch.arange(0.0, dim, 2.0) / dim))
position = torch.arange(max_len) # (max_len, 1)
sinusoid_inp = torch.ger(position, inv_freq)
pe = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
return pe
def generate_sinusoid_position_embedding_with_padding(num_positions, embed_dim):
embedding_weight = generate_sinusoid_embeddings(num_positions, embed_dim)
embedding_weight = torch.cat(
(torch.zeros(1, embed_dim), embedding_weight), dim=0
)
return embedding_weight
def generate_randomly_initialized_position_embedding_with_padding(num_positions, embed_dim):
embedding_weight = torch.empty(num_positions + 1, embed_dim)
nn.init.normal_(embedding_weight)
nn.init.zeros_(embedding_weight[0])
return embedding_weight

Просмотреть файл

@ -0,0 +1,71 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class SinusoidalEmbedding(nn.Module):
def __init__(self, embedding_dim, padding_idx, init_size=1024):
super().__init__()
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.num_embedding = init_size
weights = self.get_embedding(
init_size, embedding_dim, padding_idx
)
self.register_buffer('weights', weights, persistent=False)
@staticmethod
def get_embedding(
num_embeddings: int, embedding_dim: int, padding_idx=None
):
"""Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(
1
) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(
num_embeddings, -1
)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
def forward(self, x):
if x.numel() > 0:
x_max = int(x.max())
if x_max >= self.num_embedding:
self.num_embedding = x_max + 32
weights = self.get_embedding(
self.num_embedding, self.embedding_dim, self.padding_idx
)
self.weights = weights.to(self.weights)
r = F.embedding(x, self.weights, padding_idx=self.padding_idx)
return r
def TaoEmbedding(
num_embeddings: int,
embedding_dim: int,
padding_idx: int,
learned: bool = False,
):
if learned:
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
if padding_idx is not None:
nn.init.constant_(m.weight[padding_idx], 0)
else:
m = SinusoidalEmbedding(
embedding_dim,
padding_idx,
init_size=num_embeddings,
)
return m

Просмотреть файл

@ -0,0 +1 @@
from .main import block_fill_

Просмотреть файл

@ -0,0 +1,32 @@
// #pragma once
#include <torch/extension.h>
// CUDA函数声明
int cudaForwardLauncher(
const at::Tensor&, const at::Tensor&,
const long, const long, const long, const long, const long, const long, const bool, at::Tensor&
);
// C++函数包装
int cuda_forward(const at::Tensor& block_ranges,
const at::Tensor& head_masks,
const long num_lines,
const long max_query,
const long max_key,
const long num_heads,
const long seq_len_1,
const long seq_len_2,
const bool fill_value,
at::Tensor& out) {
at::DeviceGuard guard(block_ranges.device());
cudaForwardLauncher(block_ranges, head_masks, num_lines, max_query, max_key, num_heads,
seq_len_1, seq_len_2, fill_value, out);
return 0;
}
// 绑定
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
m.def("cuda_forward", &cuda_forward, "cuda_forward");
}

Просмотреть файл

@ -0,0 +1,107 @@
//#include <cstdio>
#include <torch/torch.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCAtomics.cuh>
const long THREADS_PER_BLOCK = 1024;
const long MAX_GRID_NUM = 2147483647;
inline long GET_BLOCKS(const long N) {
long optimal_block_num = (N + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
return long(min(optimal_block_num, MAX_GRID_NUM));
}
template <typename scalar_t>
__global__ void cudaForward(
const scalar_t * block_ranges,
const bool * head_masks,
const long input_size,
const long num_lines, const long max_query, const long max_key, const long num_heads,
const long seq_len_2,
const long line_stride, const long query_stride, const long key_stride, const long seq_len_square,
const bool fill_value,
bool * out) {
// block_ranges: (num_lines, 4)
// head_masks: (num_lines, num_heads)
// input_size == num_lines * max_query * max_key * num_heads
// line_stride == max_query * max_key * num_heads
// query_stride == max_key * num_heads
// key_stride == num_heads
// out: (num_heads, seq_len_1, seq_len_2)
const long index = long(blockIdx.x) * long(blockDim.x) + long(threadIdx.x);
if (index >= input_size) return;
const long line_index = index / line_stride;
assert (line_index < num_lines);
const long query_index = (index % line_stride) / query_stride;
assert (query_index < max_query);
const long key_index = (index % query_stride) / key_stride;
assert (key_index < max_key);
const long head_index = (index % key_stride);
const long line_start = line_index * 4;
const long query_begin = block_ranges[line_start];
const long query_end = block_ranges[line_start + 1];
const long query_index_end = query_index + query_begin;
if (!(query_index_end < query_end)) return;
const long key_begin = block_ranges[line_start + 2];
const long key_end = block_ranges[line_start + 3];
const long key_index_end = key_index + key_begin;
if (!(key_index_end < key_end)) return;
if (head_masks[line_index * num_heads + head_index])
out[head_index * seq_len_square + query_index_end * seq_len_2 + key_index_end] = fill_value;
}
int cudaForwardLauncher(
const at::Tensor& block_ranges,
const at::Tensor& head_masks,
const long num_lines,
const long max_query,
const long max_key,
const long num_heads,
const long seq_len_1,
const long seq_len_2,
const bool fill_value,
at::Tensor& out
) {
const long input_size = num_lines * max_query * max_key * num_heads;
assert (input_size <= THREADS_PER_BLOCK * MAX_GRID_NUM);
const long key_stride = num_heads;
const long query_stride = max_key * key_stride;
const long line_stride = max_query * query_stride;
const long seq_len_square = seq_len_1 * seq_len_2;
AT_DISPATCH_INTEGRAL_TYPES(
block_ranges.type(), "cudaForward",
([&] {
const scalar_t *block_ranges_ = block_ranges.data_ptr<scalar_t>();
const bool *head_masks_ = head_masks.data_ptr<bool>();
bool *out_ = out.data_ptr<bool>();
cudaForward<<<GET_BLOCKS(input_size), THREADS_PER_BLOCK>>>(
block_ranges_, head_masks_, input_size,
num_lines, max_query, max_key, num_heads, seq_len_2,
line_stride, query_stride, key_stride, seq_len_square,
fill_value,
out_
);
}
)
);
THCudaCheck(cudaGetLastError());
return 0;
}

Просмотреть файл

@ -0,0 +1,101 @@
import os
import glob
import logging
import torch
logger = logging.getLogger(__name__)
block_fill_cuda_module = None
def load_cuda_module():
global block_fill_cuda_module
if block_fill_cuda_module is not None:
return block_fill_cuda_module
path = os.path.dirname(os.path.abspath(__file__))
cpps = glob.glob(os.path.join(path, "cuda_src/*.cpp"))
cudas = glob.glob(os.path.join(path, "cuda_src/*.cu"))
sources = list(cpps) + list(cudas)
from torch.utils.cpp_extension import load
module = load(name='block_fill_cuda',
sources=sources,
extra_cflags=['-O2'],
with_cuda=True,
verbose=False)
block_fill_cuda_module = module
return block_fill_cuda_module
def block_fill_(out, block_ranges, head_masks, fill_value, no_cuda_kernel=False):
"""
:param block_ranges: (num_blocks, 4)
:param head_masks: (num_blocks, num_heads)
:param out:
:param fill_value:
:param no_cuda_kernel:
:return:
"""
if block_ranges.shape[0] == 0:
return out
assert isinstance(fill_value, bool)
if block_ranges.is_cuda and not no_cuda_kernel:
return block_fill_cuda(out, block_ranges, head_masks, fill_value)
else:
return block_fill_pytorch(out, block_ranges, head_masks, fill_value)
def _check_input(block_ranges, head_masks, out):
assert head_masks.dtype == torch.bool
assert block_ranges.device == head_masks.device
assert block_ranges.device == out.device
num_blocks, col = block_ranges.shape
num_blocks_2, num_heads = head_masks.shape
assert num_blocks == num_blocks_2
assert col == 4
assert num_heads > 0
return num_blocks, num_heads
def block_fill_cuda(out, block_ranges, head_masks, fill_value):
num_blocks, num_heads = _check_input(block_ranges, head_masks, out)
num_heads_2, seq_len_1, seq_len_2 = out.shape
assert num_heads_2 == num_heads
assert block_ranges.is_contiguous()
assert head_masks.is_contiguous()
assert out.is_contiguous()
max_query = (block_ranges[:, 1] - block_ranges[:, 0]).max().item()
max_key = (block_ranges[:, 3] - block_ranges[:, 2]).max().item()
if max_query <= 0 or max_key <= 0:
return out
module = load_cuda_module()
module.cuda_forward(block_ranges, head_masks, num_blocks, max_query, max_key, num_heads,
seq_len_1, seq_len_2, fill_value, out)
return out
def block_fill_pytorch(out, block_ranges, head_masks, fill_value):
num_blocks, num_heads = _check_input(block_ranges, head_masks, out)
assert out.shape[0] == num_heads
num_heads_2, _, _ = out.shape
assert num_heads_2 == num_heads
for (query_begin, query_end, key_begin, key_end), line_mask in zip(block_ranges, head_masks):
line_head_idx = torch.nonzero(line_mask, as_tuple=False).squeeze(-1) # (num,)
out[line_head_idx, query_begin: query_end, key_begin: key_end] = fill_value
return out

Просмотреть файл

@ -0,0 +1 @@
from .main import range_fill

Просмотреть файл

@ -0,0 +1,27 @@
// #pragma once
#include <torch/extension.h>
// CUDA函数声明
int cudaForwardLauncher(
const at::Tensor&,
const at::Tensor&,
const long,
at::Tensor&
);
// C++函数包装
int cuda_forward(const at::Tensor& ranges,
const at::Tensor& values,
const long num_chunks,
at::Tensor& output) {
at::DeviceGuard guard(ranges.device());
cudaForwardLauncher(ranges, values, num_chunks, output);
return 0;
}
// 绑定
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
m.def("cuda_forward", &cuda_forward, "");
}

Просмотреть файл

@ -0,0 +1,65 @@
#include <torch/torch.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCAtomics.cuh>
const long THREADS_PER_BLOCK = 1024;
const long MAX_GRID_NUM = 2147483647;
inline long GET_BLOCKS(const long N) {
long optimal_block_num = (N + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
return long(min(optimal_block_num, MAX_GRID_NUM));
}
template <typename scalar_t>
__global__ void cudaForward(
const long * ranges,
const scalar_t * values,
const long input_size,
scalar_t * output) {
const long index = long(blockIdx.x) * long(blockDim.x) + long(threadIdx.x);
if (index >= input_size) return;
const long line_start = index * 2;
const long begin_idx = ranges[line_start];
const long end_idx = ranges[line_start + 1];
long value = values[index];
for (long idx = begin_idx; idx < end_idx; idx++) {
output[idx] = value;
}
}
int cudaForwardLauncher(
const at::Tensor& ranges,
const at::Tensor& values,
const long num_chunks,
at::Tensor& output
) {
const long input_size = num_chunks;
assert (input_size <= THREADS_PER_BLOCK * MAX_GRID_NUM);
AT_DISPATCH_INTEGRAL_TYPES(
values.type(), "cudaForward",
([&] {
const long *ranges_ = ranges.data_ptr<long>();
const scalar_t *values_ = values.data_ptr<scalar_t>();
scalar_t *output_ = output.data_ptr<scalar_t>();
cudaForward<<<GET_BLOCKS(input_size), THREADS_PER_BLOCK>>>(
ranges_, values_, input_size, output_
);
}
)
);
THCudaCheck(cudaGetLastError());
return 0;
}

Просмотреть файл

@ -0,0 +1,81 @@
import os
import glob
import logging
import torch
logger = logging.getLogger(__name__)
range_fill_cuda_module = None
def load_cuda_module():
global range_fill_cuda_module
if range_fill_cuda_module is not None:
return range_fill_cuda_module
path = os.path.dirname(os.path.abspath(__file__))
cpps = glob.glob(os.path.join(path, "cuda_src/*.cpp"))
cudas = glob.glob(os.path.join(path, "cuda_src/*.cu"))
sources = list(cpps) + list(cudas)
from torch.utils.cpp_extension import load
module = load(name='range_fill_cuda',
sources=sources,
extra_cflags=['-O2'],
with_cuda=True,
verbose=False)
range_fill_cuda_module = module
return range_fill_cuda_module
def range_fill(
ranges: torch.Tensor,
values: torch.Tensor,
seq_len: int,
pad_value,
dtype=torch.long,
no_cuda_kernel: bool = False
):
"""
Fill a set of values into some ranges of 1D vector.
Note: The ranges cannot be overlapped, or some unexpected behavior may appear. Current code does not check it.
:param ranges: (num_ranges, 2) begin, end
:param values: (num_ranges,)
:param seq_len:
:param pad_value:
:param dtype:
:param no_cuda_kernel:
:return:
"""
num_ranges, dim = ranges.shape
values_shape, = values.shape
assert num_ranges == values_shape
assert dim == 2
assert ranges[:, 0].lt(ranges[:, 1]).all()
if ranges.is_cuda and not no_cuda_kernel:
return range_fill_cuda(ranges, values, seq_len, pad_value, dtype=dtype)
else:
return range_fill_pytorch(ranges, values, seq_len, pad_value, dtype=dtype)
def range_fill_cuda(ranges, values, seq_len, pad_value, dtype=torch.long):
if dtype not in (torch.long,):
raise NotImplementedError
num_ranges = ranges.shape[0]
out = torch.full((seq_len,), pad_value, dtype=dtype, device=values.device)
if num_ranges == 0:
return out
module = load_cuda_module()
module.cuda_forward(ranges, values, num_ranges, out)
return out
def range_fill_pytorch(ranges, values, seq_len, pad_value, dtype=torch.long):
out = torch.full((seq_len,), pad_value, dtype=dtype, device=values.device)
for idx, (begin, end) in enumerate(ranges):
out[begin: end] = values[idx]
return out

Просмотреть файл

@ -0,0 +1 @@
from .main import segment_arange

Просмотреть файл

@ -0,0 +1,27 @@
// #pragma once
#include <torch/extension.h>
// CUDA函数声明
int cudaForwardLauncher(
const at::Tensor&,
const long,
const long,
at::Tensor&
);
// C++函数包装
int cuda_forward(const at::Tensor& ranges,
const long start,
const long num_chunks,
at::Tensor& output) {
at::DeviceGuard guard(ranges.device());
cudaForwardLauncher(ranges, start, num_chunks, output);
return 0;
}
// 绑定
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
m.def("cuda_forward", &cuda_forward, "");
}

Просмотреть файл

@ -0,0 +1,65 @@
#include <torch/torch.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCAtomics.cuh>
const long THREADS_PER_BLOCK = 1024;
const long MAX_GRID_NUM = 2147483647;
inline long GET_BLOCKS(const long N) {
long optimal_block_num = (N + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
return long(min(optimal_block_num, MAX_GRID_NUM));
}
template <typename scalar_t>
__global__ void cudaForward(
const long * ranges,
const long start,
const long input_size,
scalar_t * output) {
const long index = long(blockIdx.x) * long(blockDim.x) + long(threadIdx.x);
if (index >= input_size) return;
const long line_start = index * 2;
const long begin_idx = ranges[line_start];
const long end_idx = ranges[line_start + 1];
long cur = 0;
for (long idx = begin_idx; idx < end_idx; idx++) {
output[idx] = cur + start;
cur++;
}
}
int cudaForwardLauncher(
const at::Tensor& ranges,
const long start,
const long num_chunks,
at::Tensor& output
) {
const long input_size = num_chunks;
assert (input_size <= THREADS_PER_BLOCK * MAX_GRID_NUM);
AT_DISPATCH_INTEGRAL_TYPES(
ranges.type(), "cudaForward",
([&] {
const long *ranges_ = ranges.data_ptr<long>();
scalar_t *output_ = output.data_ptr<scalar_t>();
cudaForward<<<GET_BLOCKS(input_size), THREADS_PER_BLOCK>>>(
ranges_, start, input_size, output_
);
}
)
);
THCudaCheck(cudaGetLastError());
return 0;
}

Просмотреть файл

@ -0,0 +1,80 @@
import os
import glob
import logging
import torch
logger = logging.getLogger(__name__)
segment_arange_cuda_module = None
def load_cuda_module():
global segment_arange_cuda_module
if segment_arange_cuda_module is not None:
return segment_arange_cuda_module
path = os.path.dirname(os.path.abspath(__file__))
cpps = glob.glob(os.path.join(path, "cuda_src/*.cpp"))
cudas = glob.glob(os.path.join(path, "cuda_src/*.cu"))
sources = list(cpps) + list(cudas)
from torch.utils.cpp_extension import load
module = load(name='segment_arange_cuda',
sources=sources,
extra_cflags=['-O2'],
with_cuda=True,
verbose=False)
segment_arange_cuda_module = module
return segment_arange_cuda_module
def segment_arange(
ranges: torch.Tensor,
start: int,
seq_len: int,
pad_value,
dtype=torch.long,
no_cuda_kernel: bool = False
):
"""
Fill a set of values into some ranges of 1D vector.
Note: The ranges cannot be overlapped, or some unexpected behavior may appear. Current code does not check it.
:param ranges: (num_ranges, 2) begin, end
:param start: int
:param seq_len:
:param pad_value:
:param dtype:
:param no_cuda_kernel:
:return:
"""
# Todo: Verify the effect of the segment_arange kernel.
num_ranges, dim = ranges.shape
assert dim == 2
assert ranges[:, 0].le(ranges[:, 1]).all()
if ranges.is_cuda and not no_cuda_kernel:
return segment_arange_cuda(ranges, start, seq_len, pad_value, dtype=dtype)
else:
return segment_arange_pytorch(ranges, start, seq_len, pad_value, dtype=dtype)
def segment_arange_cuda(ranges, start, seq_len, pad_value, dtype=torch.long):
if dtype not in (torch.long,):
raise NotImplementedError
num_ranges = ranges.shape[0]
out = torch.full((seq_len,), pad_value, dtype=dtype, device=ranges.device)
if num_ranges == 0:
return out
module = load_cuda_module()
module.cuda_forward(ranges, start, num_ranges, out)
return out
def segment_arange_pytorch(ranges, start, seq_len, pad_value, dtype=torch.long):
out = torch.full((seq_len,), pad_value, dtype=dtype, device=ranges.device)
for idx, (begin, end) in enumerate(ranges):
out[begin: end] = torch.arange(start, start + (end - begin), dtype=dtype, device=ranges.device)
return out

Просмотреть файл

@ -0,0 +1,794 @@
import math
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from fairseq import utils
from fairseq.models import FairseqDecoder
from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.layer_norm import LayerNorm
from .tools import instant_info_construction as iic
from .tools import arg_tools, computation_tools
from .data_structures.four_dim_pocket import FourDimPocket
from .kernels.range_fill import range_fill
from .kernels.segment_arange import segment_arange
from .museformer_decoder_layer import MuseformerDecoderLayer
from .data_structures.attention_scheme import AttentionScheme
from .embedding.tao_embedding import TaoEmbedding
from .datasets.add_beat_dataset import get_beat_ids
def construct_reg_bar_ids(chunk_ranges, num_chunks, reg_len):
device = chunk_ranges.device
bar_ids = []
for sample_ranges, sample_num_chunk in zip(chunk_ranges, num_chunks):
if sample_num_chunk == 0:
sample_bar_ids = torch.zeros(reg_len, dtype=torch.long, device=device)
else:
sample_bar_ids = range_fill(sample_ranges, torch.arange(1, sample_num_chunk + 1, device=device),
reg_len, pad_value=0)
bar_ids.append(sample_bar_ids)
bar_ids = torch.stack(bar_ids, dim=0)
return bar_ids # (bsz, reg_len)
def construct_reg_token_in_chunk_ids(chunk_ranges, num_chunks, reg_len):
device = chunk_ranges.device
ids = []
for sample_ranges, sample_num_chunk in zip(chunk_ranges, num_chunks):
if sample_num_chunk == 0:
sample_ids = torch.zeros(reg_len, dtype=torch.long, device=device)
else:
sample_ranges = sample_ranges[:sample_num_chunk]
sample_ids = segment_arange(sample_ranges, 1, reg_len, 0, dtype=torch.long, no_cuda_kernel=False)
ids.append(sample_ids)
ids = torch.stack(ids, dim=0)
return ids
class MuseformerDecoder(FairseqDecoder):
_submodules = (MuseformerDecoderLayer,)
@classmethod
def add_args(cls, parser):
# === Implementation ===
parser.add_argument('--attention-impl', choices=('mask', 'triton', 'sparta'))
parser.add_argument('--block-size', type=int, choices=(64, 32))
parser.add_argument('--attention-mode', choices=('v2s1',))
# === Transformer ===
parser.add_argument('--attention-embed-dim', type=int)
parser.add_argument('--num-layers', type=int)
parser.add_argument('--num-attention-heads', type=eval)
parser.add_argument('--normalize-before', type=arg_tools.str_bool_with_default_error)
parser.add_argument('--ffn-embed-dim', type=int)
parser.add_argument('--dropout', type=float)
parser.add_argument('--attention-dropout', type=float)
parser.add_argument('--activation-fn', choices=utils.get_available_activation_fns())
parser.add_argument('--no-final-norm', type=arg_tools.str_bool_with_default_error)
# === Attention Scheme ===
parser.add_argument('--con2con', type=eval)
parser.add_argument('--con2con-self', choices=('default', 'none', 'full'))
parser.add_argument('--con2sum', type=eval)
parser.add_argument('--con2sum-self', choices=('default', 'none', 'full'))
parser.add_argument('--sum2con', type=eval)
parser.add_argument('--sum2con-self', choices=('default', 'none', 'full'))
parser.add_argument('--sum2sum', type=eval)
parser.add_argument('--sum2sum-self', choices=('default', 'none', 'full'))
parser.add_argument('--con2con-causal', type=arg_tools.str_bool_with_default_error)
parser.add_argument('--pref2pref-mode', choices=('full', 'lt', 'none'))
parser.add_argument('--sum2pref-mode', choices=('none',))
parser.add_argument('--pref2sum-mode', choices=('none',))
parser.add_argument('--con2pref-mode', choices=('full',))
parser.add_argument('--pref2con-mode', choices=('none',))
# === Summary Tokens ===
parser.add_argument('--num-summary-tokens-each-chunk', type=int)
parser.add_argument('--max-summary-tokens', type=int)
# === Embedding ===
parser.add_argument('--concat-sum-embeddings', type=arg_tools.str_bool_with_default_error)
parser.add_argument('--concat-reg-embeddings', type=arg_tools.str_bool_with_default_error)
parser.add_argument('--proj-sum-embeddings', type=arg_tools.str_bool_with_default_error)
parser.add_argument('--proj-reg-embeddings', type=arg_tools.str_bool_with_default_error)
parser.add_argument('--share-embedding-projection', type=arg_tools.str_bool_with_default_error)
parser.add_argument('--share-input-output-embed', type=arg_tools.str_bool_with_default_error)
parser.add_argument('--no-scale-embedding', type=arg_tools.str_bool_with_default_error)
parser.add_argument('--layernorm-embedding', type=arg_tools.str_bool_with_default_error)
parser.add_argument('--share-layernorm-embedding', type=arg_tools.str_bool_with_default_error)
# token embedding
parser.add_argument('--token-embed-dim', type=int)
# absolute token in-chunk-position embedding
parser.add_argument('--use-token-in-chunk-abs-pos', type=arg_tools.str_bool_with_default_error)
parser.add_argument('--learned-token-in-chunk-abs-pos', type=arg_tools.str_bool_with_default_error)
parser.add_argument('--max-token-in-chunk-abs-pos', type=int)
parser.add_argument('--token-in-chunk-abs-pos-embed-dim', type=int)
# absolute token position embedding
parser.add_argument('--use-token-abs-pos', type=arg_tools.str_bool_with_default_error)
parser.add_argument('--learned-token-abs-pos', type=arg_tools.str_bool_with_default_error)
parser.add_argument('--max-token-abs-pos', type=int)
parser.add_argument('--token-abs-pos-embed-dim', type=int)
# absolute bar position embedding
parser.add_argument('--use-bar-abs-pos', type=arg_tools.str_bool_with_default_error)
parser.add_argument('--learned-bar-abs-pos', type=arg_tools.str_bool_with_default_error)
parser.add_argument('--max-bar-abs-pos', type=int)
parser.add_argument('--bar-abs-pos-embed-dim', type=int)
parser.add_argument('--valid-parts-for-bar-abs-pos', type=arg_tools.comma_split_tuple_of_specific_type(str))
# absolute beat position embedding
parser.add_argument('--use-beat-abs-pos', type=arg_tools.str_bool_with_default_error)
parser.add_argument('--learned-beat-abs-pos', type=arg_tools.str_bool_with_default_error)
parser.add_argument('--max-beat-abs-pos', type=int)
parser.add_argument('--beat-abs-pos-embed-dim', type=int)
parser.add_argument('--beat-abs-pos-padding-on-sum', type=arg_tools.str_bool_with_default_error)
# === Gradient Checkpointing ===
parser.add_argument('--gradient-checkpointing', type=arg_tools.str_bool_with_default_error)
parser.add_argument('--gradient-checkpointing-every-n-layer', type=int)
parser.add_argument('--gradient-checkpointing-layers',
type=lambda x: tuple([int(item) for item in x.split(',')]))
arg_tools.add_submodule_args(cls, parser)
def __init__(self, args, dictionary):
super().__init__(dictionary)
self.args = args
# === Meta Preparations ===
# pocket to store things that can be used everywhere
self.pocket = FourDimPocket()
self.pocket['constant'] = self.pocket_constant = {}
self.pocket['instant'] = self.pocket_instant = {}
# --- control parameters ---
self.attention_impl = self.args.attention_impl
self.block_size = getattr(self.args, "block_size", None)
if self.attention_impl == 'triton':
if self.block_size is None:
self.block_size = 64
self.attention_mode = getattr(self.args, "attention_mode", "simu")
self.attn_mask_gen_parts = None # which parts are involved for attn_mask
if self.attention_mode == 'v2s1':
self.attn_mask_combination = {'sx': ('ss', 'sr'), 'rx': ('rs', 'rr')}
else:
raise NotImplementedError
self.need_embedding_summary_tokens = True
self.pocket_constant['attention_impl'] = self.attention_impl
self.pocket_constant['block_size'] = self.block_size
self.pocket_constant['attention_mode'] = self.attention_mode
self.pocket_constant['attn_mask_combination'] = self.attn_mask_combination
self.pocket_constant['attn_mask_gen_parts'] = self.attn_mask_gen_parts
self.pocket_constant['need_embedding_summary_tokens'] = self.need_embedding_summary_tokens
# === Basic Parameters ===
self.embed_dim = self.args.attention_embed_dim
self.num_layers = self.args.num_layers
self.num_attention_heads = arg_tools.possibly_extend_tuple(
self.args.num_attention_heads, self.num_layers
) # tuple of size num_layers
self.valid_dictionary_len = getattr(self.args, 'valid_dict_len', len(dictionary))
# === Embedding ===
self.embed_scale = None if args.no_scale_embedding else math.sqrt(self.args.token_embed_dim)
# --- Regular Token Embedding ---
self.embed_regular_tokens = nn.Embedding(
self.valid_dictionary_len, self.args.token_embed_dim, self.dictionary.pad()
)
# --- Summary Token Embedding ---
self.num_summary = self.args.num_summary_tokens_each_chunk
assert self.num_summary >= 0
if self.num_summary == 0 or not self.need_embedding_summary_tokens:
self.embed_summary_tokens = None
else:
self.embed_summary_tokens = nn.Embedding(
getattr(self.args, 'max_summary_tokens', self.num_summary) + 1,
self.args.token_embed_dim,
padding_idx=0
)
# --- Absolute Token In-Chunk Absolute Position Embedding ---
self.use_token_in_chunk_abs_pos = getattr(self.args, 'use_token_in_chunk_abs_pos', False)
if self.use_token_in_chunk_abs_pos:
token_in_chunk_abs_pos_embed_dim = getattr(self.args, "token_in_chunk_abs_pos_embed_dim",
self.args.token_embed_dim)
self.embed_token_in_chunk_abs_pos = TaoEmbedding(
self.args.max_token_in_chunk_abs_pos + 1, token_in_chunk_abs_pos_embed_dim, padding_idx=0,
learned=getattr(self.args, 'learned_token_in_chunk_abs_pos', False)
)
else:
self.embed_token_in_chunk_abs_pos = None
token_in_chunk_abs_pos_embed_dim = 0
# --- Absolute Token Position Embedding ---
self.use_token_abs_pos = getattr(self.args, 'use_token_abs_pos', False)
if self.use_token_abs_pos:
token_abs_pos_embed_dim = getattr(self.args, "token_abs_pos_embed_dim", self.args.token_embed_dim)
if getattr(self.args, 'learned_token_abs_pos', False):
raise NotImplementedError("The support for overly long token absolute position embedding is not done.")
self.embed_token_abs_pos = TaoEmbedding(
self.args.max_token_abs_pos, token_abs_pos_embed_dim, padding_idx=0,
learned=getattr(self.args, 'learned_token_abs_pos', False)
)
else:
self.embed_token_abs_pos = None
token_abs_pos_embed_dim = 0
# --- Absolute Bar Position Embedding ---
self.use_bar_abs_pos = getattr(self.args, 'use_bar_abs_pos', False)
if self.use_bar_abs_pos:
bar_abs_pos_embed_dim = getattr(self.args, "bar_abs_pos_embed_dim", self.args.token_embed_dim)
# if getattr(self.args, 'learned_bar_abs_pos', False):
# raise NotImplementedError("The support for overly long bar absolute position embedding is not done.")
self.embed_bar_abs_pos = TaoEmbedding(
self.args.max_bar_abs_pos + 1, bar_abs_pos_embed_dim, padding_idx=0,
learned=getattr(self.args, 'learned_bar_abs_pos', False)
)
self.valid_parts_for_bar_abs_pos = getattr(self.args, 'valid_parts_for_bar_abs_pos', None)
else:
self.embed_bar_abs_pos = None
bar_abs_pos_embed_dim = 0
self.valid_parts_for_bar_abs_pos = None
# --- Absolute Beat Position Embedding ---
self.use_beat_abs_pos = getattr(self.args, 'use_beat_abs_pos', False)
if self.use_beat_abs_pos:
beat_abs_pos_embed_dim = getattr(self.args, "beat_abs_pos_embed_dim", self.args.token_embed_dim)
# if getattr(self.args, 'learned_beat_abs_pos', False):
# raise NotImplementedError("The support for overly long beat absolute position embedding is not done.")
self.embed_beat_abs_pos = TaoEmbedding(
self.args.max_beat_abs_pos + 1, beat_abs_pos_embed_dim, padding_idx=0,
learned=getattr(self.args, 'learned_beat_abs_pos', False)
)
self.valid_parts_for_beat_abs_pos = getattr(self.args, 'valid_parts_for_beat_abs_pos', None)
else:
self.embed_beat_abs_pos = None
beat_abs_pos_embed_dim = 0
self.valid_parts_for_beat_abs_pos = None
# --- Conclude Embedding ---
self.sum_proj_embeddings = None
self.reg_proj_embeddings = None
self.concat_reg_embeddings = getattr(self.args, "concat_reg_embeddings", False)
if self.concat_reg_embeddings:
if getattr(self.args, 'proj_reg_embeddings', False):
self.reg_proj_embeddings = nn.Linear(
self.args.token_embed_dim + token_in_chunk_abs_pos_embed_dim +
token_abs_pos_embed_dim +
(
bar_abs_pos_embed_dim
if (self.valid_parts_for_bar_abs_pos is None
or 'r' in self.valid_parts_for_bar_abs_pos)
else 0
) + beat_abs_pos_embed_dim,
self.embed_dim
)
self.concat_sum_embeddings = False
self.beat_abs_pos_padding_on_sum = getattr(self.args, 'beat_abs_pos_padding_on_sum', False)
if self.num_summary > 0 and self.need_embedding_summary_tokens:
self.concat_sum_embeddings = getattr(self.args, "concat_sum_embeddings", False)
if getattr(self.args, 'proj_sum_embeddings', False):
if getattr(self.args, 'share_embedding_projection', False):
self.sum_proj_embeddings = self.reg_proj_embeddings
else:
self.sum_proj_embeddings = nn.Linear(
self.args.token_embed_dim + (
bar_abs_pos_embed_dim
if (self.valid_parts_for_bar_abs_pos is None
or 's' in self.valid_parts_for_bar_abs_pos)
else 0
) +
(
beat_abs_pos_embed_dim if self.beat_abs_pos_padding_on_sum else 0
),
self.embed_dim
)
self.reg_layernorm_embedding = None
self.sum_layernorm_embedding = None
if getattr(args, "layernorm_embedding", False):
self.reg_layernorm_embedding = LayerNorm(self.embed_dim)
if self.num_summary > 0 and self.need_embedding_summary_tokens:
if getattr(self.args, "share_layernorm_embedding", False):
self.sum_layernorm_embedding = self.reg_layernorm_embedding
else:
self.sum_layernorm_embedding = LayerNorm(self.embed_dim)
# === Attention Scheme ===
self.attn_scheme = AttentionScheme(
self.args.sum2sum, self.args.sum2sum_self,
self.args.sum2con, self.args.sum2con_self,
self.args.con2sum, self.args.con2sum_self,
self.args.con2con, self.args.con2con_self,
self.args.con2con_causal,
self.args.pref2pref_mode,
self.args.sum2pref_mode,
self.args.pref2sum_mode,
self.args.con2pref_mode,
self.args.pref2con_mode,
self.num_layers, self.num_attention_heads
)
self.layer_to_sv = {}
self.sv_to_layers = {}
attn_scheme_set = set(self.attn_scheme)
for layer_idx, layer_scheme in enumerate(self.attn_scheme):
for sv, unique_scheme in enumerate(attn_scheme_set):
if layer_scheme == unique_scheme:
self.layer_to_sv[layer_idx] = sv
if sv not in self.sv_to_layers:
self.sv_to_layers[sv] = []
self.sv_to_layers[sv].append(layer_idx)
for sv in self.sv_to_layers:
self.sv_to_layers[sv] = set(self.sv_to_layers[sv])
self.pocket_constant['layer_to_sv'] = self.layer_to_sv
self.pocket_constant['sv_to_layers'] = self.sv_to_layers
# === Transformer Blocks ===
self.dropout_module = FairseqDropout(
self.args.dropout, module_name=self.__class__.__name__
)
self.layers = nn.ModuleList([])
self.layers.extend(
[
self.build_decoder_layer(
args, layer_idx, self.num_attention_heads[layer_idx],
self.attn_scheme[layer_idx],
args.attention_dropout
) for layer_idx in range(self.num_layers)
]
)
if args.normalize_before and not getattr(args, "no_final_norm", False):
self.layer_norm = LayerNorm(self.embed_dim)
else:
self.layer_norm = None
# === Output ===
if self.args.share_input_output_embed:
self.output_projection = nn.Linear(
self.embed_regular_tokens.weight.shape[1],
self.embed_regular_tokens.weight.shape[0],
bias=False,
)
self.output_projection.weight = self.embed_regular_tokens.weight
else:
self.output_projection = nn.Linear(
self.embed_dim, self.valid_dictionary_len, bias=False
)
nn.init.normal_(
self.output_projection.weight, mean=0, std=self.embed_dim ** -0.5
)
self.gradient_checkpointing = getattr(self.args, 'gradient_checkpointing', False)
if self.gradient_checkpointing:
checkpointing_layers = getattr(self.args, 'gradient_checkpointing_layers', None)
if checkpointing_layers is None:
gradient_checkpointing_every_n_layer = getattr(self.args, 'gradient_checkpointing_every_n_layer', 1)
checkpointing_layers = tuple(range(0, self.num_layers, gradient_checkpointing_every_n_layer))
self.checkpointing_layers = checkpointing_layers
def build_decoder_layer(
self,
args,
layer_idx,
num_attention_heads,
layer_attention_scheme,
attention_dropout,
**kwargs
):
return MuseformerDecoderLayer(
args,
layer_idx,
num_attention_heads,
layer_attention_scheme,
attention_dropout,
**kwargs
)
def forward(
self,
src_tokens, # (batch, reg_len)
src_lengths=None,
chunk_points=None, # (batch, max_chunk + 1)
num_chunks=None, # (batch,)
num_complete_chunks=None, # (batch,)
num_pref=None, # (batch,)
beat_ids=None, # (batch, reg_len)
last_state_only=False,
features_only=False,
**kwargs
):
# print('Beg', src_tokens.shape)
x, extra = self.extract_features(
src_tokens,
src_lengths=src_lengths,
reg_chunk_points=chunk_points,
num_chunks=num_chunks,
num_complete_chunks=num_complete_chunks,
num_pref=num_pref,
beat_ids=beat_ids,
last_state_only=last_state_only,
**kwargs
)
if not features_only:
x = self.output_layer(x)
return x, extra
def extract_features(
self,
reg_tokens, # (batch, reg_len)
src_lengths=None,
reg_chunk_points=None, # (batch, max_chunk + 1)
num_chunks=None, # (batch,)
num_complete_chunks=None, # (batch,)
num_pref=None, # (batch,)
beat_ids=None,
last_state_only=False,
**kwargs
):
self.pocket_instant.clear()
bsz, reg_len = reg_tokens.shape
device = reg_tokens.device
# ===== Instant Info Construction ===
# Construct the missing input. Only for inference.
if not self.training:
if all([item is None for item in (reg_chunk_points, num_chunks, num_complete_chunks, num_pref)]):
if src_lengths is None:
src_lengths = reg_tokens.ne(self.dictionary.pad()).sum(dim=-1)
reg_chunk_points, num_chunks, num_complete_chunks, num_pref = iic.construct_bar_chunk_info(
reg_tokens, src_lengths,
self.dictionary.index(getattr(self.args, 'eob_token', 'b-1')),
begin_idx=1,
only_bos_prefix=True,
device=device
)
if getattr(self.args, 'take_bos_as_bar', False):
assert reg_chunk_points.shape[0] == 1
reg_chunk_points = torch.cat(
(reg_chunk_points.new_tensor([[0]]), reg_chunk_points), dim=-1
)
num_pref = num_pref.new_zeros(1,)
num_chunks = num_chunks + 1
num_complete_chunks = num_complete_chunks + 1
if getattr(self.args, 'use_beat_abs_pos', False) and beat_ids is None:
beat_ids = get_beat_ids(
reg_tokens[0], self.dictionary,
ts_instead_of_tempo=getattr(self.args, 'beat_mask_ts', False)
)[None]
# ===== Input Checking ===
if num_complete_chunks is None:
num_complete_chunks = num_chunks # (bsz,)
else:
temp_check = num_chunks - num_complete_chunks
assert (temp_check.eq(0) | temp_check.eq(1)).all()
del temp_check
max_chunk = int(num_chunks.max())
assert reg_chunk_points.shape == (bsz, max_chunk + 1)
reg_chunk_ranges = computation_tools.transfer_chunk_points_to_ranges(reg_chunk_points) # (bsz, max_chunk, 2)
assert reg_chunk_ranges.shape == (bsz, max_chunk, 2)
del reg_chunk_points
# ===== Summary Sequence and Embeddings =====
max_comp_chunk = int(num_complete_chunks.max())
sum_len = max_comp_chunk * self.num_summary
# ===== Padding for Blocksparse Computation =====
sum_pad_len = 0
reg_pad_len = 0
real_sum_len = sum_len
real_reg_len = reg_len
if self.block_size is not None:
if sum_len > 0:
sum_pad_len = self.block_size - sum_len % self.block_size
if sum_pad_len == self.block_size:
sum_pad_len = 0
if sum_pad_len > 0:
sum_len = real_sum_len + sum_pad_len
reg_pad_len = self.block_size - reg_len % self.block_size
if reg_pad_len == self.block_size:
reg_pad_len = 0
if reg_pad_len > 0:
reg_len = real_reg_len + reg_pad_len
# ===== Embedding Layer =====
# --- Summary and Regular Token Embeddings ---
sum_x = None
sum_key_padding_mask = None
sum_token_ids = None
if self.num_summary > 0:
sum_key_padding_mask = torch.arange(max_comp_chunk, device=device)[None].expand(
bsz, -1).ge(num_complete_chunks[:, None])[:, :, None].expand(-1, -1, self.num_summary).reshape(
bsz, real_sum_len
) # (bsz, real_sum_len)
if sum_pad_len > 0:
sum_key_padding_mask = torch.cat(
(sum_key_padding_mask, sum_key_padding_mask.new_ones(bsz, sum_pad_len)), dim=1
) # (bsz, sum_len)
sum_token_ids = torch.arange(1, self.num_summary + 1, device=device)[None, None].repeat(
bsz, max_comp_chunk, 1
).reshape(bsz, real_sum_len)
if sum_pad_len > 0:
sum_token_ids = torch.cat(
(sum_token_ids, sum_token_ids.new_zeros(bsz, sum_pad_len)), dim=1
)
sum_token_ids.masked_fill_(sum_key_padding_mask, 0)
if self.need_embedding_summary_tokens:
sum_x = self.embed_summary_tokens(sum_token_ids.transpose(0, 1))
if self.embed_scale is not None:
sum_x.mul_(self.embed_scale)
self.pocket_instant['sum_token_ids'] = sum_token_ids
if reg_pad_len > 0:
reg_tokens = torch.cat((reg_tokens, reg_tokens.new_full((bsz, reg_pad_len), self.dictionary.pad())), dim=1)
beat_ids = torch.cat((beat_ids, beat_ids.new_full((bsz, reg_pad_len), 0)), dim=1)
reg_key_padding_mask = reg_tokens.eq(self.embed_regular_tokens.padding_idx) # (bsz, reg_len)
reg_x = self.embed_regular_tokens(reg_tokens.transpose(0, 1)) # (reg_len, bsz, token_embed_dim)
del reg_tokens
if self.embed_scale is not None:
reg_x.mul_(self.embed_scale)
# --- Absolute Token In-Chunk Position Embedding ---
if self.embed_token_in_chunk_abs_pos is not None:
token_in_chunk_abs_pos = construct_reg_token_in_chunk_ids(reg_chunk_ranges, num_chunks, reg_len)
token_in_chunk_abs_pos_embed = self.embed_token_in_chunk_abs_pos(token_in_chunk_abs_pos).transpose(0, 1) \
# (reg_len, bsz, dim)
del token_in_chunk_abs_pos
else:
token_in_chunk_abs_pos_embed = None
# --- Absolute Beat Position Embedding ---
reg_beat_abs_pos_embed = None
sum_beat_abs_pos_embed = None
if self.embed_beat_abs_pos is not None:
reg_beat_abs_pos_embed = self.embed_beat_abs_pos(beat_ids.transpose(0, 1)) # (l, bsz, dim)
sum_beat_abs_pos_embed = None
if sum_x is not None and self.beat_abs_pos_padding_on_sum:
sum_beat_abs_pos_embed = reg_x.new_zeros(sum_x.shape[0], bsz, self.embed_beat_abs_pos.embedding_dim)
del beat_ids
# --- Absolute Token Position Embedding ---
token_abs_pos_ids = None
if self.use_token_abs_pos:
token_abs_pos_ids = torch.arange(1, reg_len + 1, device=device)[None].repeat(bsz, 1)
token_abs_pos_ids.masked_fill_(reg_key_padding_mask, 0)
if self.embed_token_abs_pos is not None:
token_abs_pos_embed = self.embed_token_abs_pos(token_abs_pos_ids.transpose(0, 1))
else:
token_abs_pos_embed = None
del token_abs_pos_ids
# --- Absolute Bar Position Embedding ---
sum_bar_pos_ids = None
reg_bar_pos_ids = None
if self.use_bar_abs_pos:
reg_bar_pos_ids = construct_reg_bar_ids(reg_chunk_ranges, num_chunks, reg_len) # (bsz, reg_len)
if self.num_summary > 0:
sum_bar_pos_ids = torch.arange(1, max_comp_chunk + 1, device=device) # (max_comp_chunk,)
sum_bar_pos_ids = sum_bar_pos_ids[None, :, None].expand(bsz, -1, self.num_summary).reshape(
bsz, real_sum_len
) # (bsz, sum_len)
sum_bar_pos_ids = torch.cat((sum_bar_pos_ids, sum_bar_pos_ids.new_zeros(bsz, sum_pad_len)), dim=1)
sum_bar_pos_ids.masked_fill_(sum_key_padding_mask, 0)
sum_bar_abs_pos_embed = None
reg_bar_abs_pos_embed = None
if self.embed_bar_abs_pos is not None:
if (
self.num_summary > 0 and self.need_embedding_summary_tokens and
(self.valid_parts_for_bar_abs_pos is None or 's' in self.valid_parts_for_bar_abs_pos)
):
sum_bar_abs_pos_embed = self.embed_bar_abs_pos(sum_bar_pos_ids.transpose(0, 1)) # (l, bsz, dim)
if self.valid_parts_for_bar_abs_pos is None or 'r' in self.valid_parts_for_bar_abs_pos:
reg_bar_abs_pos_embed = self.embed_bar_abs_pos(reg_bar_pos_ids).transpose(0, 1) # (l, bsz, dim)
del sum_bar_pos_ids, reg_bar_pos_ids
# --- Conclude Embeddings ---
sum_x = [item for item in (sum_x, sum_bar_abs_pos_embed, sum_beat_abs_pos_embed) if item is not None]
if len(sum_x) == 0:
sum_x = None
reg_x = [item for item in (
reg_x, token_in_chunk_abs_pos_embed, token_abs_pos_embed, reg_bar_abs_pos_embed, reg_beat_abs_pos_embed
) if item is not None]
del sum_bar_abs_pos_embed
del token_in_chunk_abs_pos_embed, token_abs_pos_embed, reg_bar_abs_pos_embed
if self.concat_reg_embeddings:
reg_x = torch.cat(reg_x, dim=-1)
if self.reg_proj_embeddings is not None:
reg_x = self.reg_proj_embeddings(reg_x)
else:
reg_x = sum(reg_x)
if self.num_summary > 0 and self.need_embedding_summary_tokens:
if self.concat_sum_embeddings:
sum_x = torch.cat(sum_x, dim=-1)
if self.sum_proj_embeddings is not None:
sum_x = self.sum_proj_embeddings(sum_x)
else:
sum_x = sum(sum_x)
if self.sum_layernorm_embedding is not None and sum_x is not None:
sum_x = self.sum_layernorm_embedding(sum_x)
if self.reg_layernorm_embedding is not None:
reg_x = self.reg_layernorm_embedding(reg_x)
if sum_x is not None:
sum_x = self.dropout_module(sum_x)
reg_x = self.dropout_module(reg_x)
key_padding_mask = computation_tools.may_bi_cat(sum_key_padding_mask, reg_key_padding_mask, dim=1)
if key_padding_mask is not None and not key_padding_mask.any():
key_padding_mask = None
del sum_key_padding_mask, reg_key_padding_mask
# with open('meta.bin', 'wb') as f:
# torch.save((sum_len, reg_len), f)
# print('saved meta')
# === Transformer Layers ===
(sum_x, reg_x), inner_states = self.run_layers(
(sum_x, reg_x),
reg_chunk_ranges=reg_chunk_ranges,
num_chunks=num_chunks,
num_complete_chunks=num_complete_chunks,
num_pref=num_pref,
sum_len=sum_len,
reg_len=reg_len,
key_padding_mask=key_padding_mask,
attn_mask=None,
need_weights=False,
need_head_weights=False,
last_state_only=last_state_only,
)
if sum_x is not None:
sum_x = sum_x[:real_sum_len]
sum_len = real_sum_len
sum_x = sum_x.transpose(0, 1)
assert sum_x.shape == (bsz, sum_len, self.embed_dim), (sum_x.shape, (bsz, sum_len, self.embed_dim))
reg_x = reg_x[:real_reg_len]
reg_len = real_reg_len
if self.layer_norm is not None:
reg_x = self.layer_norm(reg_x)
reg_x = reg_x.transpose(0, 1)
assert reg_x.shape == (bsz, reg_len, self.embed_dim), (reg_x.shape, (bsz, reg_len, self.embed_dim))
others = {
# Uncomment if needed
# 'summary': sum_x,
'attn': None,
# 'inner_states': inner_states,
}
return reg_x, others
def run_layers(
self,
x,
reg_chunk_ranges,
num_chunks,
num_complete_chunks,
num_pref,
sum_len,
reg_len,
key_padding_mask,
attn_mask,
need_weights,
need_head_weights,
last_state_only
):
inner_states = []
if not last_state_only:
inner_states.append(x)
for layer in self.layers:
layer_idx = layer.layer_idx
if (
getattr(self.args, "gradient_checkpointing", False) and self.training and
layer_idx in self.checkpointing_layers
):
x, _ = checkpoint(
layer,
x,
reg_chunk_ranges,
num_chunks,
num_complete_chunks,
num_pref,
sum_len,
reg_len,
key_padding_mask,
None if attn_mask is None else attn_mask[layer_idx],
need_weights,
need_head_weights,
)
else:
x, _ = layer(
x,
reg_chunk_ranges,
num_chunks,
num_complete_chunks,
num_pref,
sum_len,
reg_len,
key_padding_mask=key_padding_mask,
attn_mask=None if attn_mask is None else attn_mask[layer_idx],
need_weights=need_weights,
need_head_weights=need_head_weights,
)
if attn_mask is not None:
attn_mask[layer_idx] = None
if not last_state_only:
inner_states.append(x)
if last_state_only:
inner_states = [x]
return x, inner_states
def output_layer(self, features, **kwargs):
"""Project features to the vocabulary size."""
return self.output_projection(features)
def max_positions(self):
"""Maximum output length supported by the decoder."""
if getattr(self.args, 'use_token_abs_pos', False) and getattr(self.args, 'learned_token_abs_pos', False):
return min(self.args.max_target_positions, self.args.max_token_abs_pos)
return self.args.max_target_positions
def get_normalized_probs(self, net_output, log_probs, sample=None):
"""Get normalized probabilities (or log probs) from a net's output."""
# assert sample is not None
logits, others = net_output # (batch, seq_len, vocab_len)
embed_dim = logits.shape[-1]
if sample is not None and 'target_mask' in sample:
target_mask = sample['target_mask']
else:
target_mask = None
if target_mask is not None:
logits = logits.masked_select(target_mask.unsqueeze(-1)).view(-1, embed_dim)
if log_probs:
return utils.log_softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
else:
return utils.softmax(logits, dim=-1, onnx_trace=self.onnx_trace)
def get_targets(self, sample, net_output):
target = sample['target']
_, others = net_output
if 'target_mask' in sample:
target_mask = sample['target_mask']
else:
target_mask = None
if target_mask is not None:
target = target.masked_select(target_mask)
return target

Просмотреть файл

@ -0,0 +1,549 @@
import logging
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from fairseq import utils
from fairseq.modules import FairseqDropout, LayerNorm
from .tools import arg_tools, computation_tools
from . import attention
from .attention_mask_generation.attention_mask_generation import LayerAttentionMaskGeneration, \
combine_part_masks, transfer_attn_mask_to_block_layout, combine_attn_mask_and_key_padding_mask_
from .data_structures.four_dim_pocket import FourDimPocket
from .embedding.embedding_weight_generation import \
generate_sinusoid_position_embedding_with_padding, generate_randomly_initialized_position_embedding_with_padding
logger = logging.getLogger(__name__)
def generate_layer_attn_mask(
layer_gen: LayerAttentionMaskGeneration,
reg_chunk_ranges, num_chunks, num_complete_chunks,
num_summary, num_pref, sum_len, reg_len,
key_padding_mask=None,
highlight_prefix=False,
combine_parts=None,
block_size=None,
avoid_empty_row_or_column=True
):
# all_seq_len = sum_len + reg_len
bsz = reg_chunk_ranges.shape[0]
attn_mask = layer_gen(
reg_chunk_ranges, num_chunks, num_complete_chunks,
num_summary, num_pref, sum_len, reg_len
) # each part: (bsz, 1, part_tgt_len, part_src_len)
# assert (~attn_mask['ss']).any()
# combine key_padding_mask into attn_mask
if key_padding_mask is not None:
sum_key_padding_mask = key_padding_mask[:, :sum_len] # (bsz, sum_len)
if sum_key_padding_mask.any():
for key in ('ss', 'rs'):
if key not in attn_mask:
continue
combine_attn_mask_and_key_padding_mask_(attn_mask[key], sum_key_padding_mask)
reg_key_padding_mask = key_padding_mask[:, sum_len:] # (bsz, reg_len)
if reg_key_padding_mask.any():
for key in ('sr', 'rr'):
if key not in attn_mask:
continue
combine_attn_mask_and_key_padding_mask_(attn_mask[key], reg_key_padding_mask)
del sum_key_padding_mask, reg_key_padding_mask
del key_padding_mask
if highlight_prefix:
raise NotImplementedError
rr_attn_mask = attn_mask['rr']
assert rr_attn_mask.shape[1:] == (1, reg_len, reg_len)
rr_attn_mask = rr_attn_mask.type(torch.uint8) # (bsz, 1, reg_len, reg_len)
num_unique_pref = len(torch.unique(num_pref))
if num_unique_pref == 1:
unique_prefix_number = num_pref[0].item()
change_2_mask = rr_attn_mask[:, :, :, :unique_prefix_number].eq(0)
rr_attn_mask[:, :, :, :unique_prefix_number][change_2_mask] = 2
else:
raise NotImplementedError
attn_mask['rr'] = rr_attn_mask
# for key in attn_mask:
# print(key)
# temp_mask = attn_mask[key]
# if temp_mask is not None:
# print(temp_mask.dtype)
# temp_mask = temp_mask.long()
# temp_mask = temp_mask[0, 0]
#
# for line in temp_mask:
# for col in line:
# print(int(col), end=' ')
# print()
# else:
# print('None')
if combine_parts == 'all':
attn_mask = combine_part_masks(attn_mask) # (bsz, 1, all_seq_len, all_seq_len)
attn_mask = {'all': attn_mask}
elif isinstance(combine_parts, dict):
new_attn_mask = {}
processed_parts = set()
for key in combine_parts:
part1, part2 = combine_parts[key]
assert part1 not in processed_parts
assert part2 not in processed_parts
assert part1[0] == part2[0] # only ss + sr, or rs + rr can be combined
new_attn_mask[key] = computation_tools.may_bi_cat(attn_mask[part1], attn_mask[part2], dim=3)
processed_parts.add(part1)
processed_parts.add(part2)
for key in attn_mask:
if key in processed_parts:
continue
new_attn_mask[key] = attn_mask[key]
attn_mask = new_attn_mask
del processed_parts, new_attn_mask
elif combine_parts is None:
pass
else:
raise ValueError('combine_parts has wrong value:', combine_parts)
if block_size is None:
return attn_mask
for key in attn_mask:
batch_key_attn_mask = []
for sample_idx in range(bsz):
if attn_mask[key] is None:
r = None
else:
r = transfer_attn_mask_to_block_layout(attn_mask[key][sample_idx], block_size,
avoid_empty_row_or_column=avoid_empty_row_or_column)
batch_key_attn_mask.append(r)
attn_mask[key] = batch_key_attn_mask
# real_part_sample_mask: layout, block_mask
# layout: (head, num_tgt_blocks, num_src_blocks)
# block_mask: (-1, block, block)
return attn_mask
class MuseformerDecoderLayer(nn.Module):
_submodules = (attention,)
@classmethod
def add_args(cls, parser):
parser.add_argument('--share-self-attention-layer-norm', type=arg_tools.str_bool_with_default_error)
parser.add_argument('--share-ffn', type=arg_tools.str_bool_with_default_error)
parser.add_argument('--share-final-layer-norm', type=arg_tools.str_bool_with_default_error)
parser.add_argument('--chunk-ffn', type=int)
arg_tools.add_submodule_args(cls, parser)
def __init__(self,
args,
layer_idx,
num_attention_heads,
layer_attention_scheme,
attention_dropout,
**kwargs):
super().__init__()
self.pocket = FourDimPocket()
self.pocket_constant = self.pocket['constant']
self.pocket_instant = self.pocket['instant']
# === Basic Settings ===
self.args = args
self.layer_idx = layer_idx
self.attention_impl = self.pocket_constant['attention_impl']
self.block_size = self.pocket_constant['block_size']
self.attention_mode = self.pocket_constant['attention_mode']
self.attn_mask_combination = self.pocket_constant['attn_mask_combination']
self.attn_mask_gen_parts = self.pocket_constant['attn_mask_gen_parts']
self.need_embedding_summary_tokens = self.pocket_constant['need_embedding_summary_tokens']
self.layer_to_sv = self.pocket_constant['layer_to_sv']
self.sv_to_layers = self.pocket_constant['sv_to_layers']
self.layer_sv = self.layer_to_sv[self.layer_idx]
self.num_attention_heads = num_attention_heads
self.embed_dim = args.attention_embed_dim
self.num_summary = args.num_summary_tokens_each_chunk
self.normalize_before = args.normalize_before
self.chunk_ffn = getattr(args, 'chunk_ffn', None)
if self.chunk_ffn is not None:
if self.chunk_ffn <= 0:
self.chunk_ffn = None
# === Layer Attention Mask Generator ===
self.layer_attention_scheme = layer_attention_scheme
layer_attention_mask_generator_label = ('layer_attention_mask_generator', self.layer_sv)
if layer_attention_mask_generator_label in self.pocket_constant:
self.layer_attention_mask_generator = self.pocket_constant[layer_attention_mask_generator_label]
else:
self.layer_attention_mask_generator = LayerAttentionMaskGeneration(self.layer_attention_scheme,
gen_parts=self.attn_mask_gen_parts)
self.pocket_constant[layer_attention_mask_generator_label] = self.layer_attention_mask_generator
# === Construct Relative Embeddings ===
self.use_token_rel_pos = getattr(self.args, 'use_token_rel_pos', False)
if self.use_token_rel_pos:
self.max_token_rel_pos = args.max_token_rel_pos
token_rel_pos_embed_dim = getattr(self.args, 'token_rel_pos_embed_dim', self.embed_dim)
if getattr(self.args, 'learned_token_rel_pos', False):
token_rel_embed = generate_randomly_initialized_position_embedding_with_padding(
self.max_token_rel_pos + 1, token_rel_pos_embed_dim
)
self.register_parameter('token_rel_embed', nn.Parameter(token_rel_embed, requires_grad=True))
else:
token_rel_embed = generate_sinusoid_position_embedding_with_padding(
self.max_token_rel_pos + 1, token_rel_pos_embed_dim
)
self.register_buffer('token_rel_embed', token_rel_embed, persistent=False)
else:
self.token_rel_embed = None
self.no_token_rel_pos_for_prefix = getattr(self.args, 'no_token_rel_pos_for_prefix', False)
self.use_bar_rel_pos = getattr(self.args, 'use_bar_rel_pos', False)
if self.use_bar_rel_pos:
self.max_bar_rel_pos = args.max_bar_rel_pos
bar_rel_pos_embed_dim = getattr(self.args, 'bar_rel_pos_embed_dim', self.embed_dim)
if getattr(self.args, 'learned_bar_rel_pos', False):
bar_rel_embed = generate_randomly_initialized_position_embedding_with_padding(
self.max_bar_rel_pos + 1, bar_rel_pos_embed_dim
)
self.register_parameter('bar_rel_embed', nn.Parameter(bar_rel_embed, requires_grad=True))
else:
bar_rel_embed = generate_sinusoid_position_embedding_with_padding(
self.max_bar_rel_pos + 1, bar_rel_pos_embed_dim
)
self.register_buffer('bar_rel_embed', bar_rel_embed, persistent=False)
else:
self.bar_rel_embed = None
# === Self Attention ===
self.self_attn = self.build_self_attention(
self.args,
self.embed_dim,
self.num_attention_heads,
attention_dropout,
token_rel_pos_embeddings=self.token_rel_embed,
bar_rel_pos_embeddings=self.bar_rel_embed
)
# === Other Modules ===
self.dropout_module = FairseqDropout(
args.dropout, module_name=self.__class__.__name__
)
self.activation_fn = utils.get_activation_fn(
activation=str(args.activation_fn)
if getattr(args, "activation_fn", None) is not None
else "relu"
)
# activation_dropout_p = getattr(args, "activation_dropout", None)
# if activation_dropout_p is not None:
# assert activation_dropout_p > 0
# self.activation_dropout_module = FairseqDropout(
# float(activation_dropout_p), module_name=self.__class__.__name__
# )
# else:
# self.activation_dropout_module = None
self.reg_self_attn_layer_norm = LayerNorm(self.embed_dim, export=False)
if self.need_embedding_summary_tokens:
if getattr(self.args, 'share_self_attention_layer_norm', False):
self.sum_self_attn_layer_norm = self.reg_self_attn_layer_norm
else:
self.sum_self_attn_layer_norm = LayerNorm(self.embed_dim, export=False)
else:
self.sum_self_attn_layer_norm = None
self.reg_fc1 = nn.Linear(self.embed_dim, args.ffn_embed_dim)
self.reg_fc2 = nn.Linear(args.ffn_embed_dim, self.embed_dim)
share_ffn = getattr(self.args, 'share_ffn', False)
if self.need_embedding_summary_tokens:
if share_ffn:
self.sum_fc1 = self.reg_fc1
self.sum_fc2 = self.reg_fc2
else:
self.sum_fc1 = nn.Linear(self.embed_dim, args.ffn_embed_dim)
self.sum_fc2 = nn.Linear(args.ffn_embed_dim, self.embed_dim)
else:
self.sum_fc1 = None
self.sum_fc2 = None
self.reg_final_layer_norm = LayerNorm(self.embed_dim, export=False)
if self.need_embedding_summary_tokens:
if getattr(self.args, 'share_final_layer_norm', False):
self.sum_final_layer_norm = self.reg_final_layer_norm
else:
self.sum_final_layer_norm = LayerNorm(self.embed_dim, export=False)
else:
self.sum_final_layer_norm = None
def build_self_attention(
self,
args,
embed_dim,
num_attention_heads,
dropout,
token_rel_pos_embeddings,
bar_rel_pos_embeddings,
**kwargs,
):
rel_embeddings = []
no_rel_projections = []
if self.use_token_rel_pos:
rel_embeddings.append(token_rel_pos_embeddings)
no_rel_projections.append(getattr(args, 'no_token_rel_pos_proj', False))
if self.use_bar_rel_pos:
rel_embeddings.append(bar_rel_pos_embeddings)
no_rel_projections.append(getattr(args, 'no_bar_rel_pos_proj', False))
return attention.create_attention(
implementation=self.attention_impl,
attention_mode=self.attention_mode,
block_size=self.block_size,
embed_dim=embed_dim,
num_heads=num_attention_heads,
num_summary=self.args.num_summary_tokens_each_chunk,
rel_embeddings=rel_embeddings,
layer_idx=self.layer_idx,
dropout=dropout,
query_proj_bias=getattr(args, 'attn_query_proj_bias', False),
key_proj_bias=getattr(args, 'attn_key_proj_bias', False),
value_proj_bias=getattr(args, 'attn_value_proj_bias', False),
out_proj_bias=getattr(args, 'attn_out_proj_bias', False),
no_rel_proj=no_rel_projections,
rel_proj_bias=getattr(args, 'rel_proj_bias', False),
single_head_masks=self.layer_attention_mask_generator.single_head,
# For v1, v2, v2.1
add_different_kqv_bias_for_sum_and_reg=getattr(args, 'add_different_kqv_bias_for_sum_and_reg', False),
add_different_out_bias_for_sum_and_reg=getattr(args, 'add_different_out_bias_for_sum_and_reg', False),
# For v2, v2.1
sum_key2_proj_bias=getattr(args, 'attn_sum_key2_proj_bias', False),
sum_value2_proj_bias=getattr(args, 'attn_sum_value2_proj_bias', False),
share_query_proj=getattr(args, 'attn_share_query_proj', False),
share_key_proj=getattr(args, 'attn_share_key_proj', False),
share_value_proj=getattr(args, 'attn_share_value_proj', False),
share_out_proj=getattr(args, 'attn_share_out_proj', False),
share_key2_value2_proj_weight=getattr(args, 'attn_share_key2_value2_proj_weight', False),
no_sum_out=((self.layer_idx == self.args.num_layers - 1)
if getattr(args, 'add_different_out_bias_for_sum_and_reg', False) else False
), # to make compatible with previous checkpoints
# For v5 (sum_then_reg_3)
share_reg_kv_proj=getattr(args, 'share_reg_kv_proj', False),
#
# key_rel_proj_bias=getattr(args, 'attn_key_rel_proj_bias', True), # renamed to rel_proj_bias
# add_global_rel_bias=getattr(args, 'attn_add_global_rel_bias', True),
)
def forward(
self,
x, # tuple of (Optional[Tensor], Tensor) # (sum_len, bsz, embed_dim) (reg_len, bsz, embed_dim)
reg_chunk_ranges, # (bsz, max_chunk, 2)
num_chunks, # (bsz,)
num_complete_chunks, # (bsz,)
num_pref, # (batch,)
sum_len,
reg_len,
key_padding_mask=None, # (bsz, all_seq_len)
attn_mask=None, # (bsz, num_heads, all_seq_len, all_seq_len)
need_weights=False,
need_head_weights=False,
):
if need_head_weights:
need_weights = True
if attn_mask is None:
attn_mask_label = 'attn_mask_sv%d' % self.layer_sv
if attn_mask_label in self.pocket_instant:
attn_mask = self.pocket_instant[attn_mask_label]
else:
attn_mask = generate_layer_attn_mask(
self.layer_attention_mask_generator,
reg_chunk_ranges, num_chunks, num_complete_chunks,
self.args.num_summary_tokens_each_chunk,
num_pref, sum_len, reg_len,
key_padding_mask=key_padding_mask,
highlight_prefix=False,
combine_parts=self.attn_mask_combination,
block_size=self.block_size if self.attention_impl == 'triton' else None,
avoid_empty_row_or_column=True
)
self.pocket_instant[attn_mask_label] = attn_mask
else:
raise NotImplementedError('Passing attn_mask into a Museformer layer is not supported yet.')
key_padding_mask = None # key_padding_mask is useless, after combined into attn_mask
token_rel_indices = None
bar_rel_indices = None
residual = x
if self.normalize_before:
x = computation_tools.may_bi_op(
self.sum_self_attn_layer_norm, self.reg_self_attn_layer_norm,
x, sum_len, reg_len, self.embed_dim, as_tuple=True
)
# print(attn_mask)
# print(token_rel_indices)
# print(bar_rel_indices)
# st = time.time()
# try:
x, attn = self.run_self_attn(
x, x, x,
sum_len, reg_len,
key_padding_mask=key_padding_mask, attn_mask=attn_mask,
incremental_state=None,
need_weights=need_weights, need_head_weights=need_head_weights,
token_rel_indices=token_rel_indices, bar_rel_indices=bar_rel_indices,
)
# except RuntimeError:
# print('x:', x.shape)
# print('num_heads:', self.num_attention_heads)
# print('attn_mask:', attn_mask.shape)
# print('key_padding_mask:', key_padding_mask.shape)
# raise
# et = time.time()
# if debug:
# logger.info('Run Self-Attn: %f' % (et - st))
x = computation_tools.may_bi_op(self.dropout_module, self.dropout_module, x,
sum_len, reg_len, self.embed_dim, as_tuple=True)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = computation_tools.may_bi_op(
self.sum_self_attn_layer_norm, self.reg_self_attn_layer_norm,
x, sum_len, reg_len, self.embed_dim, as_tuple=True
)
residual = x
if self.normalize_before:
x = computation_tools.may_bi_op(
self.sum_final_layer_norm, self.reg_final_layer_norm,
x, sum_len, reg_len, self.embed_dim, as_tuple=True
)
x = self.run_ffn(x, sum_len, reg_len)
x = computation_tools.may_bi_op(self.dropout_module, self.dropout_module, x,
sum_len, reg_len, self.embed_dim, as_tuple=True)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = computation_tools.may_bi_op(
self.sum_final_layer_norm, self.reg_final_layer_norm,
x, sum_len, reg_len, self.embed_dim, as_tuple=True
)
return x, attn
@staticmethod
def residual_connection(x, residual):
return residual[0] + x[0] if (residual[0] is not None and x[0] is not None) else None, residual[1] + x[1]
def run_self_attn(
self,
query, # (all_seq_len, bsz, embed_dim)
key, # (all_seq_len, bsz, embed_dim)
value, # (all_seq_len, bsz, embed_dim)
sum_len,
reg_len,
key_padding_mask=None,
attn_mask=None,
incremental_state=None,
need_weights=True,
need_head_weights=False,
token_rel_indices=None,
bar_rel_indices=None,
**kwargs,
):
assert incremental_state is None
rel_indices = []
if self.use_token_rel_pos:
rel_indices.append(token_rel_indices)
if self.use_bar_rel_pos:
rel_indices.append(bar_rel_indices)
r, weight = self.self_attn(
query,
self.pocket_instant['sum_token_ids'],
sum_len, reg_len,
rel_indices,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
need_weights=need_weights,
need_head_weights=need_head_weights,
**kwargs,
)
return r, weight
def run_ffn(self, x, sum_len, reg_len):
sum_x, reg_x = x
del x
def reg_ffn(input_x):
input_x = self.reg_fc1(input_x)
input_x = self.activation_fn(input_x)
input_x = self.reg_fc2(input_x)
return input_x
if sum_x is None:
sum_ffn = None
else:
if getattr(self.args, 'share_ffn', False):
sum_ffn = reg_ffn
else:
def sum_ffn(input_x):
input_x = self.sum_fc1(input_x)
input_x = self.activation_fn(input_x)
input_x = self.sum_fc2(input_x)
return input_x
if self.chunk_ffn is None:
if sum_x is not None:
sum_x = sum_ffn(sum_x)
reg_x = reg_ffn(reg_x)
else:
def do_chunk_ffn(input_x, ffn, split_size):
input_x = torch.split(input_x, split_size, dim=0)
result = []
for chunk_x in input_x:
chunk_x = ffn(chunk_x)
result.append(chunk_x)
del input_x
result = torch.cat(result, dim=0)
return result
if reg_x.requires_grad:
if id(reg_ffn) == id(sum_ffn):
def do_multi_chunk_ffn(input_x, ffn, split_size):
return [do_chunk_ffn(one_x, ffn, split_size) for one_x in input_x]
sum_x, reg_x = do_multi_chunk_ffn((sum_x, reg_x), reg_ffn, self.chunk_ffn)
else:
reg_x = checkpoint(do_chunk_ffn, reg_x, reg_ffn, self.chunk_ffn)
if sum_x is not None:
sum_x = do_chunk_ffn(sum_x, sum_ffn, self.chunk_ffn)
else:
reg_x = do_chunk_ffn(reg_x, reg_ffn, self.chunk_ffn)
if sum_x is not None:
sum_x = do_chunk_ffn(sum_x, sum_ffn, self.chunk_ffn)
return sum_x, reg_x

Просмотреть файл

@ -0,0 +1,158 @@
from dataclasses import dataclass
from typing import Optional
from omegaconf import II
from fairseq.dataclass import FairseqDataclass
from fairseq.models import FairseqLanguageModel, register_model, register_model_architecture
from .tools import arg_tools
from .museformer_decoder import MuseformerDecoder
DEFAULT_MAX_TARGET_POSITIONS = 100000
@dataclass
class MuseformerLanguageModelConfig(FairseqDataclass):
tokens_per_sample: int = II("task.tokens_per_sample")
max_target_positions: Optional[int] = II("task.max_target_positions")
eob_token: str = II("task.eob_token")
eoc_token: str = II("task.eoc_token")
chunking_scheme: str = II("task.chunking_scheme")
fixed_chunking_length = int = II("task.fixed_chunking_length")
@register_model('museformer_lm', dataclass=MuseformerLanguageModelConfig)
class MuseformerLanguageModel(FairseqLanguageModel):
_submodules = (MuseformerDecoder,)
@classmethod
def add_args(cls, parser):
super(MuseformerLanguageModel, cls).add_args(parser)
arg_tools.add_submodule_args(cls, parser)
def __init__(self, decoder):
super().__init__(decoder)
@classmethod
def build_model(cls, args, task):
decoder = MuseformerDecoder(args, task.target_dictionary)
print(args)
return cls(decoder)
def get_targets(self, sample, net_output):
return self.decoder.get_targets(sample, net_output)
def base_lm_scheme(args):
args.pref2pref_mode = getattr(args, "pref2pref_mode", 'lt')
args.sum2pref_mode = getattr(args, "sum2pref_mode", 'none')
args.pref2sum_mode = getattr(args, "pref2sum_mode", "none")
args.con2pref_mode = getattr(args, "con2pref_mode", 'full')
args.pref2con_mode = getattr(args, "pref2con_mode", 'none')
def b1248_lm_scheme(args):
args.con2con = getattr(args, "con2con", ((((-2, 0), -4, -8, -12, -16, -24, -32),),))
args.con2con_self = getattr(args, "con2con_self", "full")
args.con2con_causal = getattr(args, "con2con_causal", True)
args.con2sum = getattr(args, "con2sum",
((((None, -32), (-31, -24), (-23, -16), (-15, -12), (-11, -8), (-7, -4), -3,),),))
args.con2sum_self = getattr(args, "con2sum_self", "none")
args.sum2con = getattr(args, "sum2con", ((None,),))
args.sum2con_self = getattr(args, "sum2con_self", "full")
args.sum2sum = getattr(args, "sum2sum", ((None,),))
args.sum2sum_self = getattr(args, "sum2sum_self", "full")
base_lm_scheme(args)
def b1234_lm_scheme(args):
args.con2con = getattr(args, "con2con", ((((-8, 0),),),))
args.con2con_self = getattr(args, "con2con_self", "full")
args.con2con_causal = getattr(args, "con2con_causal", True)
args.con2sum = getattr(args, "con2sum", ((((None, -8),),),))
args.con2sum_self = getattr(args, "con2sum_self", "none")
args.sum2con = getattr(args, "sum2con", ((None,),))
args.sum2con_self = getattr(args, "sum2con_self", "full")
args.sum2sum = getattr(args, "sum2sum", ((None,),))
args.sum2sum_self = getattr(args, "sum2sum_self", "full")
base_lm_scheme(args)
def share_sum_reg_params(args):
args.share_layernorm_embedding = getattr(args, 'share_layernorm_embedding', True)
args.attn_share_query_proj = getattr(args, 'attn_share_query_proj', True)
args.attn_share_key_proj = getattr(args, 'attn_share_key_proj', True)
args.attn_share_value_proj = getattr(args, 'attn_share_value_proj', True)
args.attn_share_out_proj = getattr(args, 'attn_share_out_proj', True)
args.share_self_attention_layer_norm = getattr(args, 'share_self_attention_layer_norm', True)
args.share_ffn = getattr(args, 'share_ffn', True)
args.share_final_layer_norm = getattr(args, 'share_final_layer_norm', True)
def base_lm_architecture(args):
args.attention_embed_dim = getattr(args, "attention_embed_dim", 512)
args.num_layers = getattr(args, "num_layers", 4)
args.num_attention_heads = getattr(args, "num_attention_heads", (8,))
args.normalize_before = getattr(args, "normalize_before", True)
args.ffn_embed_dim = getattr(args, "ffn_embed_dim", 2048)
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
args.activation_fn = getattr(args, "activation_fn", "gelu")
args.no_final_norm = getattr(args, "no_final_norm", False)
args.take_bos_as_bar = getattr(args, 'take_bos_as_bar', True)
args.num_summary_tokens_each_chunk = getattr(args, "num_summary_tokens_each_chunk", 1)
args.share_input_output_embed = getattr(args, "share_input_output_embed", False)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
args.token_embed_dim = getattr(args, "token_embed_dim", args.attention_embed_dim)
if getattr(args, "max_target_positions", None) is None:
args.max_target_positions = getattr(
args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS
)
if args.num_summary_tokens_each_chunk <= 0:
args.num_summary_tokens_each_chunk = 0
args.con2sum = ((None,),)
args.sum2con = ((None,),)
args.sum2sum = ((None,),)
args.attention_impl = getattr(args, "attention_impl", "triton")
if args.attention_impl == 'triton':
args.block_size = getattr(args, "block_size", 64)
@register_model_architecture("museformer_lm", "museformer_lm")
def museformer_lm_v2s1(args):
args.attention_mode = getattr(args, 'attention_mode', 'v2s1')
b1248_lm_scheme(args)
share_sum_reg_params(args)
args.tokens_per_sample = getattr(args, 'tokens_per_sample', 100000)
args.use_token_in_chunk_abs_pos = getattr(args, 'use_token_in_chunk_abs_pos', False)
args.use_token_abs_pos = getattr(args, 'use_token_abs_pos', False)
args.use_bar_abs_pos = getattr(args, 'use_bar_abs_pos', True)
args.max_bar_abs_pos = getattr(args, 'max_bar_abs_pos', 512)
args.bar_abs_pos_embed_dim = getattr(args, 'bar_abs_pos_embed_dim', 256)
args.use_beat_abs_pos = getattr(args, 'use_beat_abs_pos', True)
args.max_beat_abs_pos = getattr(args, 'max_beat_abs_pos', 64)
args.beat_abs_pos_embed_dim = getattr(args, 'beat_abs_pos_embed_dim', 128)
args.concat_reg_embeddings = getattr(args, 'concat_reg_embeddings', True)
args.proj_reg_embeddings = getattr(args, 'proj_reg_embeddings', True)
args.concat_sum_embeddings = getattr(args, 'concat_sum_embeddings', True)
args.proj_sum_embeddings = getattr(args, 'proj_sum_embeddings', True)
args.attn_query_proj_bias = getattr(args, 'attn_query_proj_bias', True)
args.attn_key_proj_bias = getattr(args, 'attn_key_proj_bias', True)
args.sum_key2_proj_bias = getattr(args, 'attn_sum_key2_proj_bias', True)
args.attn_value_proj_bias = getattr(args, 'attn_value_proj_bias', True)
args.sum_value2_proj_bias = getattr(args, 'attn_sum_value2_proj_bias', True)
args.attn_out_proj_bias = getattr(args, 'attn_out_proj_bias', True)
args.add_different_kqv_bias_for_sum_and_reg = getattr(args, 'add_different_kqv_bias_for_sum_and_reg', True)
base_lm_architecture(args)

Просмотреть файл

@ -0,0 +1,159 @@
import os
import logging
from dataclasses import dataclass, field
from typing import Optional
from fairseq import utils
from fairseq.tasks.language_modeling import LanguageModelingTask
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from omegaconf import II
from fairseq.data import data_utils
from fairseq.data.indexed_dataset import get_available_dataset_impl
from fairseq.tasks import register_task
from .dictionary.compound_dictionary import CompoundDictionary
from .tools import arg_tools
from .datasets.extended_wrapper_dataset import ExtendedWrapperDataset
from .datasets.remove_short_size_samples_dataset import MayRemoveShortSizeSamplesDataset
from .datasets.chunk_sequence_dataset_2 import ChunkSequenceDataset as ChunkSequenceDataset2
from .datasets.prefix_tag_dataset import PrefixTagDataset
from .datasets.music_monolingual_dataset_2 import MusicMonolingualDataset as MusicMonolingualDataset2
from .datasets.remove_long_size_samples_dataset import MayRemoveLongSizeSamplesDataset
from .datasets.truncate_music_dataset_2 import TruncateMusicDataset as TruncateMusicDataset2
from .datasets.post_process_dataset_2 import PostProcessDataset as PostProcessDataset2
from .datasets.add_beat_dataset import AddBeatDataset
from .sequence_generator import MuseformerSequenceGenerator
logger = logging.getLogger(__name__)
@dataclass
class MuseformerLanguageModelingConfig(FairseqDataclass):
data: Optional[str] = field(
default=None, metadata={"help": "path to data directory"}
)
tokens_per_sample: int = field(
default=1024,
metadata={"help": "max number of tokens per sample for LM dataset"},
)
max_target_positions: Optional[int] = field(
default=None, metadata={"help": "max number of tokens in the target sequence"}
)
output_dictionary_size: int = field(
default=-1, metadata={"help": "limit the size of output dictionary"}
)
seed: int = II("params.common.seed")
dataset_impl: Optional[ChoiceEnum(get_available_dataset_impl())] = II(
"params.dataset.dataset_impl"
)
@register_task('museformer_language_modeling', dataclass=MuseformerLanguageModelingConfig)
class MuseformerLanguageModelingTask(LanguageModelingTask):
@classmethod
def add_args(cls, parser):
super(MuseformerLanguageModelingTask, cls).add_args(parser)
# Basic
parser.add_argument('--eob-token', default='b-1')
parser.add_argument('--eoc-token', type=arg_tools.str_to_type_with_specific_word_as_none(str, 'None'),
default='e-1')
parser.add_argument('--chunking-scheme', choices=('bar_aware', 'fixed'), default='bar_aware')
parser.add_argument('--fixed-chunking-length', type=int, default=None)
parser.add_argument('--max-size-train', type=int)
parser.add_argument('--max-size-valid', type=int)
parser.add_argument('--max-size-test', type=int)
parser.add_argument('--truncate-train', type=int)
parser.add_argument('--truncate-valid', type=int)
parser.add_argument('--truncate-test', type=int)
parser.add_argument('--take-bos-as-bar', type=arg_tools.str_bool_with_default_error, default=False)
parser.add_argument('--beat-mask-ts', type=arg_tools.str_bool_with_default_error, default=False)
@classmethod
def setup_dictionary(cls, args, **kwargs):
dictionary = None
output_dictionary = None
if args.data:
paths = utils.split_paths(args.data)
assert len(paths) > 0
dictionary = CompoundDictionary.load(os.path.join(paths[0], "dict.txt"))
logger.info("dictionary: {} types".format(len(dictionary)))
output_dictionary = dictionary
if args.output_dictionary_size >= 0:
raise NotImplementedError
# output_dictionary = TruncatedDictionary(
# dictionary, args.output_dictionary_size
# )
return (dictionary, output_dictionary)
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
paths = utils.split_paths(self.args.data)
assert len(paths) > 0
data_path = paths[(epoch - 1) % len(paths)]
split_path = os.path.join(data_path, split)
dataset = data_utils.load_indexed_dataset(
split_path, self.dictionary, self.args.dataset_impl, combine=combine
)
if dataset is None:
raise FileNotFoundError(
"Dataset not found: {} ({})".format(split, split_path)
)
dataset = ExtendedWrapperDataset(dataset)
dataset = MayRemoveShortSizeSamplesDataset(dataset, 2) # Delete empty samples
assert self.args.eob_token in self.dictionary
eob_index = self.dictionary.index(self.args.eob_token)
eoc_index = None
data_name = self.args.data.replace('\\', '/')
data_name = data_name[:-1] if self.args.data.endswith('/') else data_name
data_name = data_name.split('/')[-1]
take_bos_as_bar = getattr(self.args, 'take_bos_as_bar', False)
dataset = MusicMonolingualDataset2(dataset)
dataset = ChunkSequenceDataset2(
dataset, self.source_dictionary,
eob_index, eoc_index,
chunking_scheme=self.args.chunking_scheme,
chunking_length=getattr(self.args, 'fixed_chunking_length', None),
dataset_name=split,
cache_data_label=data_name,
cache_sequence=True,
offset=1 if take_bos_as_bar else 0,
take_bos_as_bar=take_bos_as_bar,
bos_index=self.source_dictionary.eos_index
)
dataset = AddBeatDataset(dataset, self.source_dictionary, cache_data_label=data_name, dataset_name=split,
mask_ts_instead_of_tempo=self.args.beat_mask_ts)
max_size_split = getattr(self.args, 'max_size_%s' % split, None)
dataset = MayRemoveLongSizeSamplesDataset(dataset, max_size_split)
truncate_length = getattr(self.args, 'truncate_%s' % split, None)
dataset = TruncateMusicDataset2(dataset, truncate_length)
dataset = PrefixTagDataset(dataset, 0 if take_bos_as_bar or self.args.chunking_scheme != 'bar_aware' else 1)
dataset = PostProcessDataset2(dataset)
self.datasets[split] = dataset
logger.info('loaded %d samples for %s' % (len(self.datasets[split]), split))
def build_generator(
self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None
):
if seq_gen_cls is None:
seq_gen_cls = MuseformerSequenceGenerator
return super().build_generator(models, args, seq_gen_cls=seq_gen_cls, extra_gen_cls_kwargs=extra_gen_cls_kwargs)

Просмотреть файл

@ -0,0 +1,8 @@
from fairseq.sequence_generator import SequenceGenerator
class MuseformerSequenceGenerator(SequenceGenerator):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.vocab_size = self.model.single_model.decoder.valid_dictionary_len
self.beam_size = min(self.beam_size, self.vocab_size - 1)

Просмотреть файл

@ -0,0 +1,48 @@
def add_submodule_args(cls, parser, submodules_attr_name='_submodules'):
submodules = getattr(cls, submodules_attr_name, None)
if submodules is None:
return
for submodule in submodules:
if hasattr(submodule, 'add_args'):
submodule.add_args(parser)
def str_bool(c):
c_lower = c.lower()
if c_lower in ('true', 'yes', 'y'):
return True
elif c_lower in ('false', 'no', 'n'):
return False
else:
return None
def str_bool_with_default_error(c):
r = str_bool(c)
if r is None:
raise ValueError('Value "%s" is not valid.' % c)
else:
return r
def str_to_type_with_specific_word_as_none(type_func, none_word):
def f(x):
return None if x == none_word else type_func(x)
return f
def comma_split_tuple_of_specific_type(type_func):
def inner(x):
return tuple([type_func(item) for item in x.split(',')])
return inner
def possibly_extend_tuple(c, n):
assert isinstance(c, tuple)
if len(c) == 1 and n > 1:
c = c * n
else:
assert len(c) == n, \
"%s for %d layers, len(c) == %d, type(c) == %s" % \
(str(c), n, len(c), str(type(c)))
return c

Просмотреть файл

@ -0,0 +1,158 @@
from typing import Optional
import torch
def may_bi_op(
op1,
op2,
x: (torch.Tensor, tuple),
seq_len_1: int,
seq_len_2: int,
out_dim: Optional[int],
as_tuple=False,
):
"""
Do two different projections over two different parts of x.
:param op1:
:param op2:
:param x: input, either tensor or tuple of two tensors
:param seq_len_1: int
:param seq_len_2: int
:param out_dim: int
:param as_tuple: bool
:return:
"""
if isinstance(x, torch.Tensor):
seq_len, bsz, _ = x.shape
assert seq_len == seq_len_1 + seq_len_2
if id(op1) == id(op2):
if as_tuple:
r = op1(x)
return r[:seq_len_1], r[seq_len_1:]
return op1(x)
x1 = x[:seq_len_1]
x2 = x[seq_len_1:]
assert x2.shape[0] == seq_len_2
else:
x1, x2 = x
assert x1 is None or x1.shape[0] == seq_len_1
assert x2 is None or x2.shape[0] == seq_len_2
bsz = x2.shape[1]
if as_tuple:
return op1(x1) if op1 is not None and (x1 is not None and seq_len_1 > 0) else x1, \
op2(x2) if op2 is not None and x2 is not None else x2
out = x1.new_empty(seq_len_1 + seq_len_2, bsz, out_dim)
if seq_len_1 > 0:
out[:seq_len_1] = op1(x1) if op1 is not None else x1
if seq_len_2 > 0:
out[seq_len_1:] = op2(x2) if op2 is not None else x2
return out
def bi_projection(
proj1: torch.nn.Linear,
proj2: torch.nn.Linear,
x: (torch.Tensor, tuple),
seq_len_1: int,
seq_len_2: int,
as_tuple=False,
):
"""
Do two different projections over two different parts of x.
:param proj1: Linear instance
:param proj2: Linear instance
:param x: input, either tensor or tuple of two tensors
:param seq_len_1: int
:param seq_len_2: int
:param as_tuple:
:return:
"""
out_dim = proj2.weight.shape[0]
return may_bi_op(proj1, proj2, x, seq_len_1, seq_len_2, out_dim, as_tuple=as_tuple)
def may_bi_add(x, add1, add2, seq_len_1, seq_len_2):
if add1 is None and add2 is None:
return x
if isinstance(x, torch.Tensor):
seq_len, bsz = x.shape[:2]
assert seq_len == seq_len_1 + seq_len_2
x1 = x[:seq_len_1]
x2 = x[seq_len_1:]
assert x2.shape[0] == seq_len_2
else:
x1, x2 = x
assert x1.shape[0] == seq_len_1
assert x2.shape[0] == seq_len_2
bsz = x2.shape[1]
out = x1.new_empty(seq_len_1 + seq_len_2, bsz, *x1.shape[2:])
if add1 is None:
out[:seq_len_1] = x1
else:
out[:seq_len_1] = x1 + add1
if add2 is None:
out[seq_len_1:] = x2
else:
out[seq_len_1:] = x2 + add2
return out
def may_bi_cat(x1, x2, dim=0):
if x1 is None or x1.shape[dim] == 0:
return x2
if x2 is None or x1.shape[dim] == 0:
return x1
return torch.cat((x1, x2), dim=dim)
def pad_embed_first_dim(x, pad_len):
if pad_len is None or pad_len == 0:
return x
shape = x.shape
first_dim = shape[0]
new_shape = ((first_dim + pad_len,) + shape[1:])
r = x.new_zeros(*new_shape)
r[:first_dim] = x
return r
def pad_2d(x, pad_len_1, pad_len_2, pad_value):
if pad_len_1 is None:
pad_len_1 = 0
if pad_len_2 is None:
pad_len_2 = 0
if pad_len_1 == 0 and pad_len_2 == 0:
return x
bsz, len_1, len_2 = x.shape
r = x.new_full((bsz, len_1 + pad_len_1, len_2 + pad_len_2), pad_value)
r[:, :len_1, :len_2] = x
return r
def projection_and_pad(proj: torch.nn.Linear, x: torch.Tensor, padded_len: int):
"""
Do projection and pad to a specific length. Combine together to save memory.
:param proj:
:param x: (seq_len, bsz, in_dim)
:param padded_len: target seq_len
:return:
"""
seq_len, bsz = x.shape[:2]
if seq_len == padded_len:
return proj(x)
assert padded_len > seq_len
out_dim = proj.weight.shape[0]
out = x.new_zeros((padded_len, bsz, out_dim))
out[:seq_len] = proj(x)
return out
def transfer_chunk_points_to_ranges(chunk_points):
"""
:param chunk_points: (bsz, max_chunk + 1)
:return:
"""
return torch.stack((chunk_points[:, :-1], chunk_points[:, 1:]), dim=-1) # (bsz, max_chunk, 2)

Просмотреть файл

@ -0,0 +1,41 @@
import torch
from fairseq.data import data_utils
from ..datasets.chunk_sequence_dataset_2 import get_bar_chunk_points
def construct_bar_chunk_info(reg_tokens, src_lengths, eob, begin_idx=0, only_bos_prefix=True, device='cpu'):
"""
:param reg_tokens: (bsz, seq_len)
:param src_lengths: (bsz,)
:param eob:
:param begin_idx:
:param only_bos_prefix:
:param device:
:return:
"""
assert only_bos_prefix
chunk_points = []
num_chunks = []
should_minus_one = []
bsz = len(reg_tokens)
for idx, (sample_reg_tokens, sample_length) in enumerate(zip(reg_tokens, src_lengths)):
sample_reg_tokens = sample_reg_tokens[:sample_length]
sample_chunk_points, sample_is_complete_bar = get_bar_chunk_points(
sample_reg_tokens,
eob, begin_idx=begin_idx
) #
chunk_points.append(sample_chunk_points)
sample_num_chunks = len(sample_chunk_points) - 1
num_chunks.append(sample_num_chunks)
should_minus_one.append(not sample_is_complete_bar and sample_num_chunks > 0)
chunk_points = data_utils.collate_tokens(
chunk_points, 0
).to(device) # (bsz, max_chunk + 1)
num_chunks = torch.tensor(num_chunks, device=device)
should_minus_one = torch.tensor(should_minus_one, device=device).long()
num_complete_chunks = num_chunks - should_minus_one
num_pref = torch.tensor((1,) * bsz, device=device)
return chunk_points, num_chunks, num_complete_chunks, num_pref

Просмотреть файл

@ -0,0 +1,8 @@
def singleton(cls):
_instance = {}
def inner(*args, **kwargs):
if cls not in _instance:
_instance[cls] = cls(*args, **kwargs)
return _instance[cls]
return inner

Просмотреть файл

@ -0,0 +1,15 @@
DATA_DIR=data-bin/lmd6remi
MODEL_NAME=mf-lmd6remi-$1
OMP_NUM_THREADS=$(cat /proc/cpuinfo| grep "processor"| wc -l)
NUM_WORKERS=$OMP_NUM_THREADS
fairseq-interactive $DATA_DIR \
--path checkpoints/$MODEL_NAME/$2 \
--user-dir museformer \
--task museformer_language_modeling \
--sampling --sampling-topk 8 --beam 1 --nbest 1 \
--min-len 8192 \
--max-len-b 20480 \
--num-workers $NUM_WORKERS \
--seed $3

Просмотреть файл

@ -0,0 +1,82 @@
import os
import argparse
import re
def load_midi_processor():
if midi_processor_cls is not None:
return midi_processor_cls
import midiprocessor
midi_processor_cls = midiprocessor
return midi_processor_cls
class GenerationLogExtractor(object):
@classmethod
def add_args(cls, parser):
parser.add_argument('--base_name', default='random')
parser.add_argument('--start_idx', type=int, default=1)
parser.add_argument('--process-token-str-method', default='default')
@classmethod
def build(cls, args):
return cls(base_name=args.base_name, start_idx=args.start_idx,
process_token_str_method=args.process_token_str_method)
def __init__(self, base_name='random', start_idx=1, process_token_str_method='default'):
self.base_name = base_name
self.start_idx = start_idx
self.process_token_str_method = process_token_str_method
def do(self, log_path, token_output_dir, base_name=None, start_idx=None, process_token_str_method=None):
if base_name is None:
base_name = self.base_name
if start_idx is None:
start_idx = self.start_idx
if process_token_str_method is None:
process_token_str_method = self.process_token_str_method
process_token_str_func = self.get_process_token_str_func(process_token_str_method)
return self.extract_midi_tokens_from_output_log(
log_path, token_output_dir, base_name, start_idx, process_token_str_func
)
@classmethod
def get_process_token_str_func(cls, method):
if method == 'default':
return cls.default_process_token_str
else:
raise ValueError(method)
@staticmethod
def default_process_token_str(token_str):
return token_str.strip()
@staticmethod
def extract_midi_tokens_from_output_log(log_path, token_output_dir, base_name, start_idx, process_token_str):
with open(log_path, 'r') as f:
s = f.read()
r = re.findall('D-\d+?\t.+?\t(.+?)\n', s)
os.makedirs(token_output_dir, exist_ok=True)
for idx, token_str in enumerate(r, start=start_idx):
token_str = process_token_str(token_str)
with open(os.path.join(token_output_dir, '%s-%d.txt') % (base_name, idx), 'w') as f:
f.write(token_str)
num_songs = len(r)
print('Extract %d songs from log. (%s-%d ~ %s-%d)' %
(num_songs, base_name, start_idx, base_name, start_idx + num_songs - 1))
def main():
parser = argparse.ArgumentParser()
GenerationLogExtractor.add_args(parser)
parser.add_argument('log_path')
parser.add_argument('token_output_dir')
args = parser.parse_args()
generation_log_extractor = GenerationLogExtractor.build(args)
generation_log_extractor.do(args.log_path, args.token_output_dir)
if __name__ == '__main__':
main()

Просмотреть файл

@ -0,0 +1,102 @@
import os
import argparse
import midiprocessor as mp
class MidiGenerator(object):
@classmethod
def add_args(cls, parser):
parser.add_argument('--encoding-method', required=True)
parser.add_argument('--process-token-str-list-method', default='default')
parser.add_argument('--input-dir', required=True)
parser.add_argument('--output-dir')
parser.add_argument('--suffix', default='.txt')
parser.add_argument('--skip-error', type=lambda x: x.lower() == 'true', default=True)
@classmethod
def build(cls, args):
return cls(args.encoding_method, process_token_str_list_method=args.process_token_str_list_method)
def __init__(self, encoding_method, process_token_str_list_method='default'):
self.encoding_method = encoding_method
self.process_token_str_list_method = process_token_str_list_method
self.process_token_str_list_func = self.get_process_token_str_list_func(self.process_token_str_list_method)
self.midi_decoder = mp.MidiDecoder(self.encoding_method)
def do(self, token_text_path, output_midi_path):
with open(token_text_path, 'r') as f:
x = f.read()
line = self._process_token_str(x)
midi_obj = self.midi_decoder.decode_from_token_str_list(line)
dir_name = os.path.dirname(output_midi_path)
if dir_name not in ('', '.'):
os.makedirs(dir_name, exist_ok=True)
midi_obj.dump(output_midi_path)
def do_batch(self, token_output_dir, midi_output_dir, suffix='.txt', skip_error=True):
count = 0
error_count = 0
for root_dir, dirs, files in os.walk(token_output_dir):
for file_name in files:
if not file_name.endswith(suffix):
continue
file_path = os.path.join(root_dir, file_name)
relative_file_path = os.path.relpath(file_path, token_output_dir)
base_name = file_name
save_dir = os.path.dirname(os.path.join(midi_output_dir, relative_file_path))
save_path = os.path.join(save_dir, base_name + '.mid')
print('parsing', file_path)
# noinspection PyBroadException
try:
self.do(file_path, save_path)
except KeyboardInterrupt:
raise
except:
print('Error:', file_path)
error_count += 1
import traceback
traceback.print_exc()
if skip_error:
continue
else:
count += 1
return count, error_count
def _process_token_str(self, x):
x = x.split('\n')[0]
x = x.strip().split(' ')
x = self.process_token_str_list_func(x)
return x
@classmethod
def get_process_token_str_list_func(cls, method):
if method == 'default':
return cls.default_process_token_str_list
else:
raise ValueError(method)
@staticmethod
def default_process_token_str_list(x):
return x
def main():
parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS)
MidiGenerator.add_args(parser)
args = parser.parse_args()
midi_generator = MidiGenerator.build(args)
count, error_count = midi_generator.do_batch(
args.input_dir, getattr(args, 'output_dir', args.input_dir),
suffix=args.suffix, skip_error=args.skip_error
)
print('Done. %d succeed! %d failed.' % (count, error_count))
if __name__ == '__main__':
main()

Просмотреть файл

@ -0,0 +1,46 @@
import os
import argparse
from tqdm import tqdm
def read_file_list(file_list):
path_list = []
with open(file_list, 'r') as f:
for line in f:
line = line.strip()
if line == '':
continue
path_list.append(line)
return path_list
def main(file_list, token_dir, save_dir, file_list_suffix='.txt'):
file_list_base = os.path.basename(file_list)[:-len(file_list_suffix)]
file_list = read_file_list(file_list)
# print('processing %d files...' % len(file_list))
os.makedirs(save_dir, exist_ok=True)
with open(os.path.join(save_dir, file_list_base + '.data'), 'w') as save_f:
for file_path in tqdm(file_list):
file_name = os.path.basename(file_path)
token_path = os.path.join(token_dir, file_name + '.txt')
with open(token_path, 'r') as f:
for line in f:
line = line.strip()
if line == '':
continue
save_f.write(line + '\n')
break
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('file_list')
parser.add_argument('token_dir')
parser.add_argument('save_dir')
parser.add_argument('--file_list_suffix', default='.txt')
args = parser.parse_args()
main(args.file_list, args.token_dir, args.save_dir, args.file_list_suffix)
print('Done')

Просмотреть файл

@ -0,0 +1,44 @@
MODEL_NAME='mf-lmd6remi-1'
DATA_DIR=data-bin/lmd6remi # Data dir
# 4 GPUs
UPDATE_FREQ=1
PEAK_LR=5e-4 # Peak learning rate, adjust as needed
WARMUP_UPDATES=16000 # Warmup the learning rate over this many updates
OMP_NUM_THREADS=$(cat /proc/cpuinfo| grep "processor"| wc -l)
ulimit -n 4096
mkdir -p log
fairseq-train \
$DATA_DIR \
--user-dir museformer \
--task museformer_language_modeling \
--arch museformer_lm \
--num-layers 4 \
--truncate-train 15360 \
--truncate-valid 10240 \
--batch-size 1 \
--update-freq $UPDATE_FREQ \
--optimizer adam \
--adam-betas '(0.9, 0.98)' \
--adam-eps 1e-9 \
--weight-decay 0.01 \
--lr $PEAK_LR \
--lr-scheduler inverse_sqrt \
--warmup-updates $WARMUP_UPDATES \
--max-update 1000000 \
--validate-interval 1000000000 \
--save-interval 1000000000 \
--save-interval-updates 5000 \
--fp16 \
--log-format simple \
--log-interval 10 \
--tensorboard-logdir tb_log/$MODEL_NAME \
--num-workers "$OMP_NUM_THREADS" \
--save-dir checkpoints/$MODEL_NAME \
| tee log/${MODEL_NAME}.log

Просмотреть файл

@ -0,0 +1,13 @@
DATA_DIR=data-bin/lmd6remi # Data dir
OMP_NUM_THREADS=$(cat /proc/cpuinfo| grep "processor"| wc -l)
fairseq-validate \
$DATA_DIR \
--user-dir museformer \
--task museformer_language_modeling \
--path checkpoints/mf-lmd6remi-$1/$2 \
--batch-size 1 \
--truncate-test $3 \
--valid-subset test \
--num-workers $OMP_NUM_THREADS