зеркало из https://github.com/microsoft/muzic.git
emogen
This commit is contained in:
Родитель
1fb05b46f4
Коммит
e93f6117de
|
@ -0,0 +1,40 @@
|
|||
# models parameters
|
||||
datasets_name="Piano"
|
||||
control_mode="embedding_v2"
|
||||
feature_num=100
|
||||
bucket_num=2
|
||||
command_subset="emotion_rank_100"
|
||||
rank=0
|
||||
|
||||
gpus=1
|
||||
|
||||
model_name="${datasets_name}"
|
||||
MODEL_NAME="${model_name}"
|
||||
|
||||
|
||||
tgt=${1}
|
||||
|
||||
checkpoint_name="checkpoint_best"
|
||||
checkpoint_path="checkpoints/${model_name}/${checkpoint_name}.pt"
|
||||
command_path="data/infer_input/inference_command.npy"
|
||||
save_root="generation/${model_name}-${checkpoint_name}/Q${tgt}"
|
||||
mkdir -p ${save_root}
|
||||
export CUDA_VISIBLE_DEVICES=${rank}
|
||||
echo "generating from ${checkpoint_path} with emotion Q${tgt}!"
|
||||
python interactive.py \
|
||||
data/${datasets_name}/data-bin \
|
||||
--task language_modeling_control \
|
||||
--path $checkpoint_path \
|
||||
--ctrl_command_path $command_path \
|
||||
--save_root $save_root \
|
||||
--tgt_emotion $tgt \
|
||||
--need_num 2 \
|
||||
--max-len-b 1280 \
|
||||
--min-len 512 \
|
||||
--sampling \
|
||||
--beam 1 \
|
||||
--sampling-topp 0.9 \
|
||||
--temperature 1.0 \
|
||||
--buffer-size 2 \
|
||||
--batch-size 2
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
# models parameters
|
||||
datasets_name="Piano"
|
||||
control_mode="embedding_v2"
|
||||
feature_num=100
|
||||
bucket_num=2
|
||||
command_subset="emotion_rank_100"
|
||||
rank=0
|
||||
|
||||
gpus=1
|
||||
WARMUP_UPDATES=16000 # Warmup the learning rate over this many updates
|
||||
PEAK_LR=1e-4 # Peak learning rate, adjust as needed
|
||||
model_name="${datasets_name}-${control_mode}-bucket${bucket_num}-${command_subset}"
|
||||
MODEL_NAME="${model_name}"
|
||||
|
||||
|
||||
|
||||
echo "training the model: ${MODEL_NAME}!"
|
||||
|
||||
BATCH_SIZE=8 # BATCH
|
||||
UPDATE_FREQ=1
|
||||
|
||||
DATA_DIR=data/${datasets_name} # Data dir
|
||||
|
||||
export MKL_THREADING_LAYER=GNU
|
||||
|
||||
|
||||
|
||||
OMP_NUM_THREADS=$(cat /proc/cpuinfo| grep "processor"| wc -l)
|
||||
let "port_rank=$rank+666"
|
||||
export CUDA_VISIBLE_DEVICES=${rank}
|
||||
#python -m torch.distributed.launch \
|
||||
#--nproc_per_node=${gpus} \
|
||||
#--master_port=${port_rank} \
|
||||
python train.py \
|
||||
$DATA_DIR/data-bin \
|
||||
--truncated_length 1280 \
|
||||
--task language_modeling_control \
|
||||
--arch linear_transformer_lm_std \
|
||||
--control_mode $control_mode \
|
||||
--command_path $DATA_DIR \
|
||||
--feature_num $feature_num \
|
||||
--bucket_num $bucket_num \
|
||||
--sample-break-mode eos \
|
||||
--tokens-per-sample 10000000 \
|
||||
--max-tokens 10000000 \
|
||||
--batch-size $BATCH_SIZE \
|
||||
--batch-size-valid $BATCH_SIZE \
|
||||
--update-freq $UPDATE_FREQ \
|
||||
--optimizer adam \
|
||||
--adam-betas '(0.9, 0.98)' \
|
||||
--adam-eps 1e-9 \
|
||||
--weight-decay 0.01 \
|
||||
--lr $PEAK_LR \
|
||||
--lr-scheduler inverse_sqrt \
|
||||
--warmup-updates $WARMUP_UPDATES \
|
||||
--log-format simple \
|
||||
--log-interval 10 \
|
||||
--tensorboard-logdir tb_log/$MODEL_NAME \
|
||||
--num-workers "$OMP_NUM_THREADS" \
|
||||
--max-update 600000 \
|
||||
--validate-interval 1 \
|
||||
--save-interval-updates 2000 \
|
||||
--save-dir checkpoints/$MODEL_NAME \
|
||||
--no-epoch-checkpoints \
|
||||
--find-unused-parameters \
|
||||
--patience 20
|
|
@ -0,0 +1,39 @@
|
|||
# models parameters
|
||||
datasets_name="TopMAGD"
|
||||
control_mode="embedding_v2"
|
||||
bucket_num=2
|
||||
command_subset="emotion_rank_100"
|
||||
rank=0
|
||||
|
||||
gpus=1
|
||||
|
||||
model_name="${datasets_name}"
|
||||
MODEL_NAME="${model_name}"
|
||||
|
||||
|
||||
tgt=${1}
|
||||
|
||||
checkpoint_name="checkpoint_best"
|
||||
checkpoint_path="checkpoints/${model_name}/${checkpoint_name}.pt"
|
||||
command_path="data/infer_input/inference_command.npy"
|
||||
save_root="generation/${model_name}-${checkpoint_name}/Q${tgt}"
|
||||
mkdir -p ${save_root}
|
||||
export CUDA_VISIBLE_DEVICES=${rank}
|
||||
echo "generating from ${checkpoint_path} with emotion Q${tgt}!"
|
||||
python interactive.py \
|
||||
data/${datasets_name}/data-bin \
|
||||
--task language_modeling_control \
|
||||
--path $checkpoint_path \
|
||||
--ctrl_command_path $command_path \
|
||||
--save_root $save_root \
|
||||
--tgt_emotion $tgt \
|
||||
--need_num 2 \
|
||||
--max-len-b 2560 \
|
||||
--min-len 512 \
|
||||
--sampling \
|
||||
--beam 1 \
|
||||
--sampling-topp 0.9 \
|
||||
--temperature 1.0 \
|
||||
--buffer-size 2 \
|
||||
--batch-size 2
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
# models parameters
|
||||
datasets_name="TopMAGD"
|
||||
control_mode="embedding_v2"
|
||||
feature_num=100
|
||||
bucket_num=2
|
||||
command_subset="emotion_rank_100"
|
||||
rank=0
|
||||
|
||||
gpus=1
|
||||
|
||||
WARMUP_UPDATES=16000 # Warmup the learning rate over this many updates
|
||||
PEAK_LR=1e-4 # Peak learning rate, adjust as needed
|
||||
|
||||
model_name="${datasets_name}-${control_mode}-bucket${bucket_num}-${command_subset}"
|
||||
MODEL_NAME="${model_name}"
|
||||
|
||||
|
||||
echo "training the model: ${MODEL_NAME}!"
|
||||
|
||||
BATCH_SIZE=1 # BATCH
|
||||
UPDATE_FREQ=1
|
||||
|
||||
DATA_DIR=data/${datasets_name} # Data dir
|
||||
|
||||
export MKL_THREADING_LAYER=GNU # for "import numpy" "import torch" order bug
|
||||
|
||||
|
||||
|
||||
OMP_NUM_THREADS=$(cat /proc/cpuinfo| grep "processor"| wc -l)
|
||||
let "port_rank=$rank+666"
|
||||
#python -m torch.distributed.launch \
|
||||
#--nproc_per_node=${gpus} \
|
||||
#--master_port=${port_rank} \
|
||||
python train.py \
|
||||
$DATA_DIR/data-bin \
|
||||
--truncated_length 2560 \
|
||||
--task language_modeling_control \
|
||||
--arch linear_transformer_lm_std \
|
||||
--control_mode $control_mode \
|
||||
--command_path $DATA_DIR \
|
||||
--feature_num $feature_num \
|
||||
--bucket_num $bucket_num \
|
||||
--sample-break-mode eos \
|
||||
--tokens-per-sample 10000000 \
|
||||
--max-tokens 10000000 \
|
||||
--batch-size $BATCH_SIZE \
|
||||
--batch-size-valid $BATCH_SIZE \
|
||||
--update-freq $UPDATE_FREQ \
|
||||
--optimizer adam \
|
||||
--adam-betas '(0.9, 0.98)' \
|
||||
--adam-eps 1e-9 \
|
||||
--weight-decay 0.01 \
|
||||
--lr $PEAK_LR \
|
||||
--lr-scheduler inverse_sqrt \
|
||||
--warmup-updates $WARMUP_UPDATES \
|
||||
--log-format simple \
|
||||
--log-interval 10 \
|
||||
--tensorboard-logdir tb_log/$MODEL_NAME \
|
||||
--num-workers "$OMP_NUM_THREADS" \
|
||||
--max-update 600000 \
|
||||
--validate-interval 2 \
|
||||
--save-interval-updates 2000 \
|
||||
--save-dir checkpoints/$MODEL_NAME \
|
||||
--no-epoch-checkpoints \
|
||||
--find-unused-parameters \
|
||||
--patience 20
|
|
@ -0,0 +1,458 @@
|
|||
b-1 1
|
||||
d-0 1
|
||||
d-1 1
|
||||
d-2 1
|
||||
d-3 1
|
||||
d-4 1
|
||||
d-5 1
|
||||
d-6 1
|
||||
d-7 1
|
||||
d-8 1
|
||||
d-9 1
|
||||
d-10 1
|
||||
d-11 1
|
||||
d-12 1
|
||||
d-13 1
|
||||
d-14 1
|
||||
d-15 1
|
||||
d-16 1
|
||||
d-17 1
|
||||
d-18 1
|
||||
d-19 1
|
||||
d-20 1
|
||||
d-21 1
|
||||
d-22 1
|
||||
d-23 1
|
||||
d-24 1
|
||||
d-25 1
|
||||
d-26 1
|
||||
d-27 1
|
||||
d-28 1
|
||||
d-29 1
|
||||
d-30 1
|
||||
d-31 1
|
||||
d-32 1
|
||||
d-33 1
|
||||
d-34 1
|
||||
d-35 1
|
||||
d-36 1
|
||||
d-37 1
|
||||
d-38 1
|
||||
d-39 1
|
||||
d-40 1
|
||||
d-41 1
|
||||
d-42 1
|
||||
d-43 1
|
||||
d-44 1
|
||||
d-45 1
|
||||
d-46 1
|
||||
d-47 1
|
||||
d-48 1
|
||||
d-49 1
|
||||
d-50 1
|
||||
d-51 1
|
||||
d-52 1
|
||||
d-53 1
|
||||
d-54 1
|
||||
d-55 1
|
||||
d-56 1
|
||||
d-57 1
|
||||
d-58 1
|
||||
d-59 1
|
||||
d-60 1
|
||||
d-61 1
|
||||
d-62 1
|
||||
d-63 1
|
||||
d-64 1
|
||||
d-65 1
|
||||
d-66 1
|
||||
d-67 1
|
||||
d-68 1
|
||||
d-69 1
|
||||
d-70 1
|
||||
d-71 1
|
||||
d-72 1
|
||||
d-73 1
|
||||
d-74 1
|
||||
d-75 1
|
||||
d-76 1
|
||||
d-77 1
|
||||
d-78 1
|
||||
d-79 1
|
||||
d-80 1
|
||||
d-81 1
|
||||
d-82 1
|
||||
d-83 1
|
||||
d-84 1
|
||||
d-85 1
|
||||
d-86 1
|
||||
d-87 1
|
||||
d-88 1
|
||||
d-89 1
|
||||
d-90 1
|
||||
d-91 1
|
||||
d-92 1
|
||||
d-93 1
|
||||
d-94 1
|
||||
d-95 1
|
||||
i-0 1
|
||||
i-1 1
|
||||
o-0 1
|
||||
o-1 1
|
||||
o-2 1
|
||||
o-3 1
|
||||
o-4 1
|
||||
o-5 1
|
||||
o-6 1
|
||||
o-7 1
|
||||
o-8 1
|
||||
o-9 1
|
||||
o-10 1
|
||||
o-11 1
|
||||
o-12 1
|
||||
o-13 1
|
||||
o-14 1
|
||||
o-15 1
|
||||
o-16 1
|
||||
o-17 1
|
||||
o-18 1
|
||||
o-19 1
|
||||
o-20 1
|
||||
o-21 1
|
||||
o-22 1
|
||||
o-23 1
|
||||
o-24 1
|
||||
o-25 1
|
||||
o-26 1
|
||||
o-27 1
|
||||
o-28 1
|
||||
o-29 1
|
||||
o-30 1
|
||||
o-31 1
|
||||
o-32 1
|
||||
o-33 1
|
||||
o-34 1
|
||||
o-35 1
|
||||
o-36 1
|
||||
o-37 1
|
||||
o-38 1
|
||||
o-39 1
|
||||
o-40 1
|
||||
o-41 1
|
||||
o-42 1
|
||||
o-43 1
|
||||
o-44 1
|
||||
o-45 1
|
||||
o-46 1
|
||||
o-47 1
|
||||
o-48 1
|
||||
o-49 1
|
||||
o-50 1
|
||||
o-51 1
|
||||
o-52 1
|
||||
o-53 1
|
||||
o-54 1
|
||||
o-55 1
|
||||
o-56 1
|
||||
o-57 1
|
||||
o-58 1
|
||||
o-59 1
|
||||
o-60 1
|
||||
o-61 1
|
||||
o-62 1
|
||||
o-63 1
|
||||
o-64 1
|
||||
o-65 1
|
||||
o-66 1
|
||||
o-67 1
|
||||
o-68 1
|
||||
o-69 1
|
||||
o-70 1
|
||||
o-71 1
|
||||
o-72 1
|
||||
o-73 1
|
||||
o-74 1
|
||||
o-75 1
|
||||
o-76 1
|
||||
o-77 1
|
||||
o-78 1
|
||||
o-79 1
|
||||
o-80 1
|
||||
o-81 1
|
||||
o-82 1
|
||||
o-83 1
|
||||
o-84 1
|
||||
o-85 1
|
||||
o-86 1
|
||||
o-87 1
|
||||
o-88 1
|
||||
o-89 1
|
||||
o-90 1
|
||||
o-91 1
|
||||
o-92 1
|
||||
o-93 1
|
||||
o-94 1
|
||||
o-95 1
|
||||
p-0 1
|
||||
p-1 1
|
||||
p-2 1
|
||||
p-3 1
|
||||
p-4 1
|
||||
p-5 1
|
||||
p-6 1
|
||||
p-7 1
|
||||
p-8 1
|
||||
p-9 1
|
||||
p-10 1
|
||||
p-11 1
|
||||
p-12 1
|
||||
p-13 1
|
||||
p-14 1
|
||||
p-15 1
|
||||
p-16 1
|
||||
p-17 1
|
||||
p-18 1
|
||||
p-19 1
|
||||
p-20 1
|
||||
p-21 1
|
||||
p-22 1
|
||||
p-23 1
|
||||
p-24 1
|
||||
p-25 1
|
||||
p-26 1
|
||||
p-27 1
|
||||
p-28 1
|
||||
p-29 1
|
||||
p-30 1
|
||||
p-31 1
|
||||
p-32 1
|
||||
p-33 1
|
||||
p-34 1
|
||||
p-35 1
|
||||
p-36 1
|
||||
p-37 1
|
||||
p-38 1
|
||||
p-39 1
|
||||
p-40 1
|
||||
p-41 1
|
||||
p-42 1
|
||||
p-43 1
|
||||
p-44 1
|
||||
p-45 1
|
||||
p-46 1
|
||||
p-47 1
|
||||
p-48 1
|
||||
p-49 1
|
||||
p-50 1
|
||||
p-51 1
|
||||
p-52 1
|
||||
p-53 1
|
||||
p-54 1
|
||||
p-55 1
|
||||
p-56 1
|
||||
p-57 1
|
||||
p-58 1
|
||||
p-59 1
|
||||
p-60 1
|
||||
p-61 1
|
||||
p-62 1
|
||||
p-63 1
|
||||
p-64 1
|
||||
p-65 1
|
||||
p-66 1
|
||||
p-67 1
|
||||
p-68 1
|
||||
p-69 1
|
||||
p-70 1
|
||||
p-71 1
|
||||
p-72 1
|
||||
p-73 1
|
||||
p-74 1
|
||||
p-75 1
|
||||
p-76 1
|
||||
p-77 1
|
||||
p-78 1
|
||||
p-79 1
|
||||
p-80 1
|
||||
p-81 1
|
||||
p-82 1
|
||||
p-83 1
|
||||
p-84 1
|
||||
p-85 1
|
||||
p-86 1
|
||||
p-87 1
|
||||
p-88 1
|
||||
p-89 1
|
||||
p-90 1
|
||||
p-91 1
|
||||
p-92 1
|
||||
p-93 1
|
||||
p-94 1
|
||||
p-95 1
|
||||
p-96 1
|
||||
p-97 1
|
||||
p-98 1
|
||||
p-99 1
|
||||
p-100 1
|
||||
p-101 1
|
||||
p-102 1
|
||||
p-103 1
|
||||
p-104 1
|
||||
p-105 1
|
||||
p-106 1
|
||||
p-107 1
|
||||
p-108 1
|
||||
p-109 1
|
||||
p-110 1
|
||||
p-111 1
|
||||
p-112 1
|
||||
p-113 1
|
||||
p-114 1
|
||||
p-115 1
|
||||
p-116 1
|
||||
p-117 1
|
||||
p-118 1
|
||||
p-119 1
|
||||
p-120 1
|
||||
p-122 1
|
||||
p-123 1
|
||||
p-124 1
|
||||
p-125 1
|
||||
p-126 1
|
||||
p-127 1
|
||||
s-1 1
|
||||
s-2 1
|
||||
s-3 1
|
||||
s-4 1
|
||||
s-5 1
|
||||
s-6 1
|
||||
s-7 1
|
||||
s-8 1
|
||||
s-9 1
|
||||
s-10 1
|
||||
s-11 1
|
||||
s-12 1
|
||||
s-13 1
|
||||
s-14 1
|
||||
s-15 1
|
||||
s-16 1
|
||||
s-17 1
|
||||
s-18 1
|
||||
s-19 1
|
||||
s-20 1
|
||||
s-21 1
|
||||
s-22 1
|
||||
s-23 1
|
||||
s-24 1
|
||||
s-25 1
|
||||
s-27 1
|
||||
s-28 1
|
||||
s-29 1
|
||||
s-30 1
|
||||
s-31 1
|
||||
s-32 1
|
||||
s-33 1
|
||||
s-34 1
|
||||
s-35 1
|
||||
s-36 1
|
||||
s-38 1
|
||||
s-39 1
|
||||
s-40 1
|
||||
s-41 1
|
||||
s-42 1
|
||||
s-44 1
|
||||
s-45 1
|
||||
s-46 1
|
||||
s-48 1
|
||||
s-49 1
|
||||
s-50 1
|
||||
s-52 1
|
||||
s-57 1
|
||||
s-62 1
|
||||
s-70 1
|
||||
s-74 1
|
||||
s-76 1
|
||||
s-78 1
|
||||
s-82 1
|
||||
s-127 1
|
||||
t-0 1
|
||||
t-1 1
|
||||
t-2 1
|
||||
t-3 1
|
||||
t-4 1
|
||||
t-5 1
|
||||
t-6 1
|
||||
t-7 1
|
||||
t-8 1
|
||||
t-9 1
|
||||
t-10 1
|
||||
t-11 1
|
||||
t-12 1
|
||||
t-13 1
|
||||
t-14 1
|
||||
t-15 1
|
||||
t-16 1
|
||||
t-17 1
|
||||
t-18 1
|
||||
t-19 1
|
||||
t-20 1
|
||||
t-21 1
|
||||
t-22 1
|
||||
t-23 1
|
||||
t-24 1
|
||||
t-25 1
|
||||
t-26 1
|
||||
t-27 1
|
||||
t-28 1
|
||||
t-29 1
|
||||
t-30 1
|
||||
t-31 1
|
||||
t-32 1
|
||||
t-33 1
|
||||
t-34 1
|
||||
t-35 1
|
||||
t-36 1
|
||||
t-37 1
|
||||
t-38 1
|
||||
t-39 1
|
||||
t-40 1
|
||||
t-41 1
|
||||
t-42 1
|
||||
t-43 1
|
||||
t-44 1
|
||||
t-45 1
|
||||
t-46 1
|
||||
t-47 1
|
||||
t-48 1
|
||||
v-0 1
|
||||
v-1 1
|
||||
v-2 1
|
||||
v-3 1
|
||||
v-4 1
|
||||
v-5 1
|
||||
v-6 1
|
||||
v-7 1
|
||||
v-8 1
|
||||
v-9 1
|
||||
v-10 1
|
||||
v-11 1
|
||||
v-12 1
|
||||
v-13 1
|
||||
v-14 1
|
||||
v-15 1
|
||||
v-16 1
|
||||
v-17 1
|
||||
v-18 1
|
||||
v-19 1
|
||||
v-20 1
|
||||
v-21 1
|
||||
v-22 1
|
||||
v-23 1
|
||||
v-24 1
|
||||
v-25 1
|
||||
v-26 1
|
||||
v-27 1
|
||||
v-28 1
|
||||
v-29 1
|
||||
v-30 1
|
||||
v-31 1
|
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -0,0 +1,688 @@
|
|||
b-1 1
|
||||
d-0 1
|
||||
d-1 1
|
||||
d-2 1
|
||||
d-3 1
|
||||
d-4 1
|
||||
d-5 1
|
||||
d-6 1
|
||||
d-7 1
|
||||
d-8 1
|
||||
d-9 1
|
||||
d-10 1
|
||||
d-11 1
|
||||
d-12 1
|
||||
d-13 1
|
||||
d-14 1
|
||||
d-15 1
|
||||
d-16 1
|
||||
d-17 1
|
||||
d-18 1
|
||||
d-19 1
|
||||
d-20 1
|
||||
d-21 1
|
||||
d-22 1
|
||||
d-23 1
|
||||
d-24 1
|
||||
d-25 1
|
||||
d-26 1
|
||||
d-27 1
|
||||
d-28 1
|
||||
d-29 1
|
||||
d-30 1
|
||||
d-31 1
|
||||
d-32 1
|
||||
d-33 1
|
||||
d-34 1
|
||||
d-35 1
|
||||
d-36 1
|
||||
d-37 1
|
||||
d-38 1
|
||||
d-39 1
|
||||
d-40 1
|
||||
d-41 1
|
||||
d-42 1
|
||||
d-43 1
|
||||
d-44 1
|
||||
d-45 1
|
||||
d-46 1
|
||||
d-47 1
|
||||
d-48 1
|
||||
d-49 1
|
||||
d-50 1
|
||||
d-51 1
|
||||
d-52 1
|
||||
d-53 1
|
||||
d-54 1
|
||||
d-55 1
|
||||
d-56 1
|
||||
d-57 1
|
||||
d-58 1
|
||||
d-59 1
|
||||
d-60 1
|
||||
d-61 1
|
||||
d-62 1
|
||||
d-63 1
|
||||
d-64 1
|
||||
d-65 1
|
||||
d-66 1
|
||||
d-67 1
|
||||
d-68 1
|
||||
d-69 1
|
||||
d-70 1
|
||||
d-71 1
|
||||
d-72 1
|
||||
d-73 1
|
||||
d-74 1
|
||||
d-75 1
|
||||
d-76 1
|
||||
d-77 1
|
||||
d-78 1
|
||||
d-79 1
|
||||
d-80 1
|
||||
d-81 1
|
||||
d-82 1
|
||||
d-83 1
|
||||
d-84 1
|
||||
d-85 1
|
||||
d-86 1
|
||||
d-87 1
|
||||
d-88 1
|
||||
d-89 1
|
||||
d-90 1
|
||||
d-91 1
|
||||
d-93 1
|
||||
d-94 1
|
||||
d-95 1
|
||||
i-0 1
|
||||
i-1 1
|
||||
i-2 1
|
||||
i-3 1
|
||||
i-4 1
|
||||
i-5 1
|
||||
i-6 1
|
||||
i-7 1
|
||||
i-8 1
|
||||
i-9 1
|
||||
i-10 1
|
||||
i-11 1
|
||||
i-12 1
|
||||
i-13 1
|
||||
i-14 1
|
||||
i-15 1
|
||||
i-16 1
|
||||
i-17 1
|
||||
i-18 1
|
||||
i-19 1
|
||||
i-20 1
|
||||
i-21 1
|
||||
i-22 1
|
||||
i-23 1
|
||||
i-24 1
|
||||
i-25 1
|
||||
i-26 1
|
||||
i-27 1
|
||||
i-28 1
|
||||
i-29 1
|
||||
i-30 1
|
||||
i-31 1
|
||||
i-32 1
|
||||
i-33 1
|
||||
i-34 1
|
||||
i-35 1
|
||||
i-36 1
|
||||
i-37 1
|
||||
i-38 1
|
||||
i-39 1
|
||||
i-40 1
|
||||
i-41 1
|
||||
i-42 1
|
||||
i-43 1
|
||||
i-44 1
|
||||
i-45 1
|
||||
i-46 1
|
||||
i-47 1
|
||||
i-48 1
|
||||
i-49 1
|
||||
i-50 1
|
||||
i-51 1
|
||||
i-52 1
|
||||
i-53 1
|
||||
i-54 1
|
||||
i-55 1
|
||||
i-56 1
|
||||
i-57 1
|
||||
i-58 1
|
||||
i-59 1
|
||||
i-60 1
|
||||
i-61 1
|
||||
i-62 1
|
||||
i-63 1
|
||||
i-64 1
|
||||
i-65 1
|
||||
i-66 1
|
||||
i-67 1
|
||||
i-68 1
|
||||
i-69 1
|
||||
i-70 1
|
||||
i-71 1
|
||||
i-72 1
|
||||
i-73 1
|
||||
i-74 1
|
||||
i-75 1
|
||||
i-76 1
|
||||
i-77 1
|
||||
i-78 1
|
||||
i-79 1
|
||||
i-80 1
|
||||
i-81 1
|
||||
i-82 1
|
||||
i-83 1
|
||||
i-84 1
|
||||
i-85 1
|
||||
i-86 1
|
||||
i-87 1
|
||||
i-88 1
|
||||
i-89 1
|
||||
i-90 1
|
||||
i-91 1
|
||||
i-92 1
|
||||
i-93 1
|
||||
i-94 1
|
||||
i-95 1
|
||||
i-96 1
|
||||
i-97 1
|
||||
i-98 1
|
||||
i-99 1
|
||||
i-100 1
|
||||
i-101 1
|
||||
i-102 1
|
||||
i-103 1
|
||||
i-104 1
|
||||
i-105 1
|
||||
i-106 1
|
||||
i-107 1
|
||||
i-108 1
|
||||
i-109 1
|
||||
i-110 1
|
||||
i-111 1
|
||||
i-112 1
|
||||
i-113 1
|
||||
i-114 1
|
||||
i-115 1
|
||||
i-116 1
|
||||
i-117 1
|
||||
i-118 1
|
||||
i-119 1
|
||||
i-120 1
|
||||
i-121 1
|
||||
i-122 1
|
||||
i-123 1
|
||||
i-124 1
|
||||
i-125 1
|
||||
i-126 1
|
||||
i-127 1
|
||||
i-128 1
|
||||
o-0 1
|
||||
o-1 1
|
||||
o-2 1
|
||||
o-3 1
|
||||
o-4 1
|
||||
o-5 1
|
||||
o-6 1
|
||||
o-7 1
|
||||
o-8 1
|
||||
o-9 1
|
||||
o-10 1
|
||||
o-11 1
|
||||
o-12 1
|
||||
o-13 1
|
||||
o-14 1
|
||||
o-15 1
|
||||
o-16 1
|
||||
o-17 1
|
||||
o-18 1
|
||||
o-19 1
|
||||
o-20 1
|
||||
o-21 1
|
||||
o-22 1
|
||||
o-23 1
|
||||
o-24 1
|
||||
o-25 1
|
||||
o-26 1
|
||||
o-27 1
|
||||
o-28 1
|
||||
o-29 1
|
||||
o-30 1
|
||||
o-31 1
|
||||
o-32 1
|
||||
o-33 1
|
||||
o-34 1
|
||||
o-35 1
|
||||
o-36 1
|
||||
o-37 1
|
||||
o-38 1
|
||||
o-39 1
|
||||
o-40 1
|
||||
o-41 1
|
||||
o-42 1
|
||||
o-43 1
|
||||
o-44 1
|
||||
o-45 1
|
||||
o-46 1
|
||||
o-47 1
|
||||
o-48 1
|
||||
o-49 1
|
||||
o-50 1
|
||||
o-51 1
|
||||
o-52 1
|
||||
o-53 1
|
||||
o-54 1
|
||||
o-55 1
|
||||
o-56 1
|
||||
o-57 1
|
||||
o-58 1
|
||||
o-59 1
|
||||
o-60 1
|
||||
o-61 1
|
||||
o-62 1
|
||||
o-63 1
|
||||
o-64 1
|
||||
o-65 1
|
||||
o-66 1
|
||||
o-67 1
|
||||
o-68 1
|
||||
o-69 1
|
||||
o-70 1
|
||||
o-71 1
|
||||
o-72 1
|
||||
o-73 1
|
||||
o-74 1
|
||||
o-75 1
|
||||
o-76 1
|
||||
o-77 1
|
||||
o-78 1
|
||||
o-79 1
|
||||
o-80 1
|
||||
o-81 1
|
||||
o-82 1
|
||||
o-83 1
|
||||
o-84 1
|
||||
o-85 1
|
||||
o-86 1
|
||||
o-87 1
|
||||
o-88 1
|
||||
o-89 1
|
||||
o-90 1
|
||||
o-91 1
|
||||
o-92 1
|
||||
o-93 1
|
||||
o-94 1
|
||||
o-95 1
|
||||
p-0 1
|
||||
p-1 1
|
||||
p-2 1
|
||||
p-3 1
|
||||
p-4 1
|
||||
p-5 1
|
||||
p-6 1
|
||||
p-7 1
|
||||
p-8 1
|
||||
p-9 1
|
||||
p-10 1
|
||||
p-11 1
|
||||
p-12 1
|
||||
p-13 1
|
||||
p-14 1
|
||||
p-15 1
|
||||
p-16 1
|
||||
p-17 1
|
||||
p-18 1
|
||||
p-19 1
|
||||
p-20 1
|
||||
p-21 1
|
||||
p-22 1
|
||||
p-23 1
|
||||
p-24 1
|
||||
p-25 1
|
||||
p-26 1
|
||||
p-27 1
|
||||
p-28 1
|
||||
p-29 1
|
||||
p-30 1
|
||||
p-31 1
|
||||
p-32 1
|
||||
p-33 1
|
||||
p-34 1
|
||||
p-35 1
|
||||
p-36 1
|
||||
p-37 1
|
||||
p-38 1
|
||||
p-39 1
|
||||
p-40 1
|
||||
p-41 1
|
||||
p-42 1
|
||||
p-43 1
|
||||
p-44 1
|
||||
p-45 1
|
||||
p-46 1
|
||||
p-47 1
|
||||
p-48 1
|
||||
p-49 1
|
||||
p-50 1
|
||||
p-51 1
|
||||
p-52 1
|
||||
p-53 1
|
||||
p-54 1
|
||||
p-55 1
|
||||
p-56 1
|
||||
p-57 1
|
||||
p-58 1
|
||||
p-59 1
|
||||
p-60 1
|
||||
p-61 1
|
||||
p-62 1
|
||||
p-63 1
|
||||
p-64 1
|
||||
p-65 1
|
||||
p-66 1
|
||||
p-67 1
|
||||
p-68 1
|
||||
p-69 1
|
||||
p-70 1
|
||||
p-71 1
|
||||
p-72 1
|
||||
p-73 1
|
||||
p-74 1
|
||||
p-75 1
|
||||
p-76 1
|
||||
p-77 1
|
||||
p-78 1
|
||||
p-79 1
|
||||
p-80 1
|
||||
p-81 1
|
||||
p-82 1
|
||||
p-83 1
|
||||
p-84 1
|
||||
p-85 1
|
||||
p-86 1
|
||||
p-87 1
|
||||
p-88 1
|
||||
p-89 1
|
||||
p-90 1
|
||||
p-91 1
|
||||
p-92 1
|
||||
p-93 1
|
||||
p-94 1
|
||||
p-95 1
|
||||
p-96 1
|
||||
p-97 1
|
||||
p-98 1
|
||||
p-99 1
|
||||
p-100 1
|
||||
p-101 1
|
||||
p-102 1
|
||||
p-103 1
|
||||
p-104 1
|
||||
p-105 1
|
||||
p-106 1
|
||||
p-107 1
|
||||
p-108 1
|
||||
p-109 1
|
||||
p-110 1
|
||||
p-111 1
|
||||
p-112 1
|
||||
p-113 1
|
||||
p-114 1
|
||||
p-115 1
|
||||
p-116 1
|
||||
p-117 1
|
||||
p-118 1
|
||||
p-119 1
|
||||
p-120 1
|
||||
p-121 1
|
||||
p-122 1
|
||||
p-123 1
|
||||
p-124 1
|
||||
p-125 1
|
||||
p-126 1
|
||||
p-127 1
|
||||
p-128 1
|
||||
p-129 1
|
||||
p-130 1
|
||||
p-131 1
|
||||
p-132 1
|
||||
p-133 1
|
||||
p-134 1
|
||||
p-135 1
|
||||
p-136 1
|
||||
p-137 1
|
||||
p-140 1
|
||||
p-141 1
|
||||
p-142 1
|
||||
p-143 1
|
||||
p-144 1
|
||||
p-145 1
|
||||
p-146 1
|
||||
p-147 1
|
||||
p-148 1
|
||||
p-149 1
|
||||
p-150 1
|
||||
p-151 1
|
||||
p-152 1
|
||||
p-153 1
|
||||
p-154 1
|
||||
p-155 1
|
||||
p-156 1
|
||||
p-157 1
|
||||
p-158 1
|
||||
p-159 1
|
||||
p-160 1
|
||||
p-161 1
|
||||
p-162 1
|
||||
p-163 1
|
||||
p-164 1
|
||||
p-165 1
|
||||
p-166 1
|
||||
p-167 1
|
||||
p-168 1
|
||||
p-169 1
|
||||
p-170 1
|
||||
p-171 1
|
||||
p-172 1
|
||||
p-173 1
|
||||
p-174 1
|
||||
p-175 1
|
||||
p-176 1
|
||||
p-177 1
|
||||
p-178 1
|
||||
p-179 1
|
||||
p-180 1
|
||||
p-181 1
|
||||
p-182 1
|
||||
p-183 1
|
||||
p-184 1
|
||||
p-185 1
|
||||
p-186 1
|
||||
p-187 1
|
||||
p-188 1
|
||||
p-189 1
|
||||
p-190 1
|
||||
p-191 1
|
||||
p-192 1
|
||||
p-193 1
|
||||
p-194 1
|
||||
p-195 1
|
||||
p-196 1
|
||||
p-197 1
|
||||
p-198 1
|
||||
p-199 1
|
||||
p-200 1
|
||||
p-201 1
|
||||
p-202 1
|
||||
p-203 1
|
||||
p-204 1
|
||||
p-205 1
|
||||
p-206 1
|
||||
p-207 1
|
||||
p-208 1
|
||||
p-209 1
|
||||
p-210 1
|
||||
p-211 1
|
||||
p-212 1
|
||||
p-213 1
|
||||
p-214 1
|
||||
p-215 1
|
||||
p-216 1
|
||||
p-217 1
|
||||
p-218 1
|
||||
p-219 1
|
||||
p-220 1
|
||||
p-221 1
|
||||
p-222 1
|
||||
p-223 1
|
||||
p-224 1
|
||||
p-225 1
|
||||
p-226 1
|
||||
p-227 1
|
||||
p-228 1
|
||||
p-230 1
|
||||
p-231 1
|
||||
p-232 1
|
||||
p-233 1
|
||||
p-234 1
|
||||
p-235 1
|
||||
p-236 1
|
||||
p-238 1
|
||||
p-239 1
|
||||
p-242 1
|
||||
p-243 1
|
||||
p-244 1
|
||||
p-248 1
|
||||
p-249 1
|
||||
p-250 1
|
||||
p-251 1
|
||||
p-252 1
|
||||
p-253 1
|
||||
p-254 1
|
||||
p-255 1
|
||||
s-3 1
|
||||
s-4 1
|
||||
s-5 1
|
||||
s-6 1
|
||||
s-7 1
|
||||
s-8 1
|
||||
s-9 1
|
||||
s-10 1
|
||||
s-11 1
|
||||
s-12 1
|
||||
s-13 1
|
||||
s-14 1
|
||||
s-15 1
|
||||
s-16 1
|
||||
s-17 1
|
||||
s-18 1
|
||||
s-19 1
|
||||
s-20 1
|
||||
s-21 1
|
||||
s-22 1
|
||||
s-23 1
|
||||
s-24 1
|
||||
s-25 1
|
||||
s-26 1
|
||||
s-28 1
|
||||
s-31 1
|
||||
s-33 1
|
||||
s-34 1
|
||||
s-38 1
|
||||
s-40 1
|
||||
s-41 1
|
||||
s-44 1
|
||||
s-45 1
|
||||
s-78 1
|
||||
t-0 1
|
||||
t-1 1
|
||||
t-2 1
|
||||
t-3 1
|
||||
t-4 1
|
||||
t-5 1
|
||||
t-6 1
|
||||
t-7 1
|
||||
t-8 1
|
||||
t-9 1
|
||||
t-10 1
|
||||
t-11 1
|
||||
t-12 1
|
||||
t-13 1
|
||||
t-14 1
|
||||
t-15 1
|
||||
t-16 1
|
||||
t-17 1
|
||||
t-18 1
|
||||
t-19 1
|
||||
t-20 1
|
||||
t-21 1
|
||||
t-22 1
|
||||
t-23 1
|
||||
t-24 1
|
||||
t-25 1
|
||||
t-26 1
|
||||
t-27 1
|
||||
t-28 1
|
||||
t-29 1
|
||||
t-30 1
|
||||
t-31 1
|
||||
t-32 1
|
||||
t-33 1
|
||||
t-34 1
|
||||
t-35 1
|
||||
t-36 1
|
||||
t-37 1
|
||||
t-38 1
|
||||
t-39 1
|
||||
t-40 1
|
||||
t-41 1
|
||||
t-42 1
|
||||
t-43 1
|
||||
t-44 1
|
||||
t-45 1
|
||||
t-46 1
|
||||
t-47 1
|
||||
t-48 1
|
||||
v-0 1
|
||||
v-1 1
|
||||
v-2 1
|
||||
v-3 1
|
||||
v-4 1
|
||||
v-5 1
|
||||
v-6 1
|
||||
v-7 1
|
||||
v-8 1
|
||||
v-9 1
|
||||
v-10 1
|
||||
v-11 1
|
||||
v-12 1
|
||||
v-13 1
|
||||
v-14 1
|
||||
v-15 1
|
||||
v-16 1
|
||||
v-17 1
|
||||
v-18 1
|
||||
v-19 1
|
||||
v-20 1
|
||||
v-21 1
|
||||
v-22 1
|
||||
v-23 1
|
||||
v-24 1
|
||||
v-25 1
|
||||
v-26 1
|
||||
v-27 1
|
||||
v-28 1
|
||||
v-29 1
|
||||
v-30 1
|
||||
v-31 1
|
||||
Q1 1
|
||||
Q2 1
|
||||
Q3 1
|
||||
Q4 1
|
||||
None 1
|
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -0,0 +1,85 @@
|
|||
import json
|
||||
import multiprocessing
|
||||
import random
|
||||
import sys
|
||||
sys.path.append("..")
|
||||
import miditoolkit
|
||||
import numpy.random
|
||||
import pandas as pd
|
||||
import os, collections, pickle, shutil
|
||||
import numpy as np
|
||||
from MidiProcessor.midiprocessor import MidiEncoder, midi_utils, MidiDecoder
|
||||
from tqdm import tqdm
|
||||
import gc
|
||||
from multiprocessing import Pool, Manager, Lock
|
||||
import math, random
|
||||
from typing import List, Dict
|
||||
from functools import partial
|
||||
from jSymbolic_lib.jSymbolic_util import read_all_feature
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
import joblib
|
||||
|
||||
random.seed(42)
|
||||
np.random.seed(42)
|
||||
|
||||
|
||||
def binarize_data(path_root):
|
||||
save_root = path_root + "/data-bin"
|
||||
dict_path = save_root + f"/dict.txt"
|
||||
command = f"fairseq-preprocess --only-source --destdir {save_root} --srcdict {dict_path} "\
|
||||
f"--validpref {path_root}/valid.txt --testpref {path_root}/test.txt --trainpref {path_root}/train.txt --workers 4 "
|
||||
text = os.popen(command).read()
|
||||
print(text)
|
||||
|
||||
|
||||
def binarize_command(command, thresholds):
|
||||
discrete_feature = []
|
||||
for k in range(command.shape[0]):
|
||||
thres = thresholds[k]
|
||||
discrete_feature.append(np.searchsorted(thres, command[k]))
|
||||
return discrete_feature
|
||||
|
||||
def gen_split_data(path_root):
|
||||
feature_index = np.load("../data/feature_index.npy", allow_pickle=True)
|
||||
thresholds = np.load("../data/threshold.npy", allow_pickle=True)
|
||||
save_root = path_root
|
||||
os.makedirs(save_root, exist_ok=True)
|
||||
fn_list = os.listdir(path_root + "/remi")
|
||||
random.shuffle(fn_list)
|
||||
for split in ["train", "valid", "test"]:
|
||||
split_command = []
|
||||
if split == "train":
|
||||
s, e = 0, int(len(fn_list)*0.8)
|
||||
elif split == "valid":
|
||||
s,e = int(len(fn_list)*0.8), int(len(fn_list)*0.9)
|
||||
else:
|
||||
s,e = int(len(fn_list)*0.9), len(fn_list)
|
||||
|
||||
with open(path_root + f"/{split}.txt", "w") as split_txt:
|
||||
split_fn_list = []
|
||||
j = 0
|
||||
for i, fn in enumerate(tqdm(fn_list[s:e])):
|
||||
fn_name = fn.split(".")[0]
|
||||
try:
|
||||
jS_feature = read_all_feature(path_root + f"/feature/{fn_name}.xml")
|
||||
except:
|
||||
continue
|
||||
jS_feature = np.array(jS_feature)
|
||||
if len(jS_feature) != 1495:
|
||||
continue
|
||||
jS_feature = np.array(jS_feature)
|
||||
binary_command = binarize_command(jS_feature[feature_index], thresholds)
|
||||
split_command.append(binary_command)
|
||||
split_fn_list.append(fn)
|
||||
with open(path_root + f'/remi/{fn}', "r") as f:
|
||||
remi_tokens = f.read().strip("\n").strip(" ")
|
||||
split_txt.write(remi_tokens + "\n")
|
||||
j += 1
|
||||
split_command = np.array(split_command)
|
||||
np.save(save_root + f"/{split}_fn_list.npy", split_fn_list)
|
||||
np.save(save_root + f'/{split}_command.npy', split_command)
|
||||
assert len(split_fn_list) == len(split_command), "length dismatch!"
|
||||
|
||||
if __name__ == "__main__":
|
||||
gen_split_data("../data/Piano")
|
||||
binarize_data("../data/Piano")
|
|
@ -0,0 +1,70 @@
|
|||
import os
|
||||
import random
|
||||
import shutil
|
||||
import json
|
||||
from multiprocessing import Process, Pool
|
||||
from functools import partial
|
||||
from typing import List
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
import sys
|
||||
sys.path.append("..")
|
||||
from MidiProcessor.midiprocessor import MidiEncoder, MidiDecoder, enc_remigen_utils, const
|
||||
from MidiProcessor.midiprocessor.keys_normalization import get_notes_from_pos_info, get_pitch_shift
|
||||
import math, pickle,miditoolkit
|
||||
import json
|
||||
|
||||
random.seed(2022)
|
||||
np.random.seed(2022)
|
||||
|
||||
|
||||
encoder = MidiEncoder("REMIGEN2")
|
||||
|
||||
|
||||
def midi_encoding(midi_path, save_root, prefix):
|
||||
try:
|
||||
name_split = midi_path.replace("//", "/").replace("\\", "/").split("/")
|
||||
midi_name = name_split[-1][:-4] # skip ".mid"
|
||||
# save_name = "_".join(name_split[-3:-1]) + "_" + f"{midi_name}.txt"
|
||||
midi_name = midi_name.replace(" ", "_")
|
||||
if prefix is not None:
|
||||
save_name = f"{prefix}_{midi_name}.txt"
|
||||
else:
|
||||
save_name = f"{midi_name}.txt"
|
||||
midi_obj = miditoolkit.MidiFile(midi_path)
|
||||
remi_token = encoder.encode_file(file_path = midi_path, midi_obj = midi_obj, remove_empty_bars=True)
|
||||
encoder.dump_token_lists(token_lists=remi_token, file_path=os.path.join(save_root, save_name))
|
||||
return 1
|
||||
except KeyboardInterrupt:
|
||||
sys.exit(1)
|
||||
except BaseException:
|
||||
print(midi_path, "error")
|
||||
return 0
|
||||
|
||||
def midi_encoding_generate(data_path):
|
||||
|
||||
midi_path_list = os.listdir(data_path + f"/midi")
|
||||
for i in range(len(midi_path_list)):
|
||||
midi_path_list[i] = data_path + f"/midi/{midi_path_list[i]}"
|
||||
|
||||
save_root = data_path + f"/remi"
|
||||
if not os.path.exists(save_root):
|
||||
os.mkdir(save_root)
|
||||
with Pool(processes=8) as pool:
|
||||
result = iter(tqdm(pool.imap(partial(midi_encoding, save_root = save_root, prefix = None), midi_path_list), total=len(midi_path_list)))
|
||||
for i in range(len(midi_path_list)):
|
||||
# try:
|
||||
next(result)
|
||||
# except BaseException as e:
|
||||
# if isinstance(e, KeyboardInterrupt):
|
||||
# print("error")
|
||||
# pool.terminate()
|
||||
# else:
|
||||
# print(midi_path_list[i], "error!")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
midi_encoding_generate("../data/Piano")
|
||||
|
|
@ -0,0 +1,633 @@
|
|||
import json
|
||||
import multiprocessing
|
||||
import random
|
||||
import sys
|
||||
sys.path.append("..")
|
||||
import miditoolkit
|
||||
import numpy.random
|
||||
import pandas as pd
|
||||
import os, collections, pickle, shutil
|
||||
import numpy as np
|
||||
from reference.MidiProcessor.midiprocessor import MidiEncoder, midi_utils
|
||||
from tqdm import tqdm
|
||||
import gc
|
||||
from multiprocessing import Pool, Manager, Lock
|
||||
import math, random
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
|
||||
# INSTR_NAME_LIST = [0, 1, 4, 5, 24, 25, 26, 27, 28, 29, 30, 32, 33, 35, 48, 49, 50, 52, 53, 56, 57, 60, 61, 65, 66, 71, 73, 128]
|
||||
INSTR_NAME_LIST = [0, 25, 32, 48, 80, 128]
|
||||
# INSTR_NAME_LIST = [0, 25, 32, 48, 128]
|
||||
#
|
||||
#
|
||||
#
|
||||
# # Banlanced_INST = [1, 4, 5, 24, 25, 26, 27, 28, 29, 30, 32, 33, 35, 48, 49, 50, 52, 53, 56, 57, 60, 61, 65, 66, 71, 73]
|
||||
INSTR_NAME2_index = dict(zip(INSTR_NAME_LIST, range(len(INSTR_NAME_LIST))))
|
||||
sample_INSTR_dist = dict(zip(INSTR_NAME_LIST, [0]*len(INSTR_NAME_LIST)))
|
||||
file_path = './test.mid'
|
||||
# sample_list = dict.fromkeys(INSTR_NAME_LIST, 0)
|
||||
|
||||
|
||||
encoder = MidiEncoder('REMIGEN2')
|
||||
|
||||
ignore_inst = False
|
||||
|
||||
def seed_everything():
|
||||
np.random.seed(42)
|
||||
random.seed(42)
|
||||
|
||||
|
||||
# def gen_encode_data(): # generate command_vector from txt file,这个不用了,产生小样本用的
|
||||
# data_root = "./data/lmd_encode"
|
||||
# start_token, end_token = "<s>", "</s>"
|
||||
# dict = []
|
||||
#
|
||||
# # with open("./data/txt/dict.txt", "a") as f:
|
||||
# # for i in range(129):
|
||||
# # f.write(f"i-{i} 1\n")
|
||||
# # bos = "<s>",
|
||||
# # pad="<pad>",
|
||||
# # eos="</s>",
|
||||
# # dict = [] # use dict from longformer_exps
|
||||
# for target_prefix in ["train", "valid", "test"]:
|
||||
# file_count = 0
|
||||
# target_file = open(f"./data/txt/{target_prefix}.txt", "w")
|
||||
# # cf = open(f"./data/txt/command_{target_prefix}.txt", "w")
|
||||
# cf = []
|
||||
# index_f = open(f"./data/txt/index_{target_prefix}.id", "w")
|
||||
# for file in os.listdir(data_root):
|
||||
# # print(target_prefix, file_count)
|
||||
# with open(os.path.join(data_root, file), "r") as f:
|
||||
# remi_str = f.read().strip("\n")
|
||||
# target_file.write(" ".join([start_token, remi_str, end_token]) + "\n")
|
||||
# command_vector = [0]*len(INSTR_NAME_LIST)
|
||||
# for i in remi_str.split(" "):
|
||||
# if i not in dict:
|
||||
# dict.append(i)
|
||||
# if i[0] == "i":
|
||||
# instru_name_now = eval(i[2:])
|
||||
# assert instru_name_now in INSTR_NAME_LIST, f"{file} illegal!"
|
||||
# command_vector[INSTR_NAME2_index[instru_name_now]] = 1
|
||||
#
|
||||
# cf.append(command_vector)
|
||||
# # cf.write(" ".join(command_vector) + "\n")
|
||||
# index_f.write(file.split(".")[0] + "\n")
|
||||
# file_count += 1
|
||||
# if file_count >= 117:
|
||||
# break
|
||||
# cf = np.array(cf)
|
||||
# np.save(f"./data/data_bin/command_{target_prefix}.npy", cf)
|
||||
# target_file.close()
|
||||
# # dict = sorted(dict, key = lambda x: (x[0], eval(x.split("-")[-1])))
|
||||
# # with open("./data/txt/dict.txt", "w") as f:
|
||||
# # for i in dict:
|
||||
# # f.write(i + " 1\n")
|
||||
# write_dict(dict, "./data/txt/dict.txt")
|
||||
# print("done!")
|
||||
|
||||
def check_ts(midi_obj):
|
||||
for i in midi_obj.time_signature_changes:
|
||||
a, b = i.numerator, i.denominator
|
||||
if not(b == 4 and a in [4,8,16,32,64,128,256]):
|
||||
return False
|
||||
return True
|
||||
|
||||
# def ergodic_a_file(path_root:str, path_list: List[str], dict_list:List[str], locker, process_index: int, processor_num: int, finished_num):
|
||||
#
|
||||
# # for i in range(1000):
|
||||
# # locker.acquire()
|
||||
# # finished_num.value += 1
|
||||
# # locker.release()
|
||||
# # print(process_index, finished_num.value)
|
||||
#
|
||||
# size = math.ceil(len(path_list) / processor_num)
|
||||
# start = size * process_index
|
||||
# end = (process_index + 1) * size if (process_index + 1) * size < len(path_list) else len(path_list)
|
||||
# for path in path_list[start:end]:
|
||||
# with open(os.path.join(path_root, path), "r") as f:
|
||||
# line = f.read().strip("\n").split(" ")
|
||||
# for token in line:
|
||||
# if token not in dict_list:
|
||||
# locker.acquire()
|
||||
# dict_list.append(token)
|
||||
# locker.release()
|
||||
#
|
||||
# # if process_index == 2:
|
||||
# locker.acquire()
|
||||
# finished_num.value += 1
|
||||
# locker.release()
|
||||
# print(process_index, finished_num.value)
|
||||
# # if finished_num.value > 100:
|
||||
# # break
|
||||
#
|
||||
# def generate_dict(dataset_path:str, save_root:str):
|
||||
# # dataset_path: 输入的remi编码的根目录, 文件夹里是remi的txt序列
|
||||
# # save_root: 存储的dict 根目录
|
||||
#
|
||||
# file_name_list = os.listdir(dataset_path)
|
||||
# # dict_list = []
|
||||
# manager = Manager()
|
||||
# dict_list = manager.list()
|
||||
# locker = manager.Lock()
|
||||
# finished_num = manager.Value("i", 0)
|
||||
# processor_num = 4
|
||||
#
|
||||
#
|
||||
# process_list = []
|
||||
# for i in range(processor_num):
|
||||
# # res.append(pools.apply_async(pass_para, args=(file_name_list, i, processor_num, sample_list, finished_num,)))
|
||||
# process_list.append(multiprocessing.Process(target=ergodic_a_file, args=(
|
||||
# dataset_path, file_name_list, dict_list, locker, i, processor_num, finished_num,)))
|
||||
# process_list[-1].start()
|
||||
# print(str(i) + ' processor started !')
|
||||
#
|
||||
# for i in process_list:
|
||||
# i.join()
|
||||
# print("over!")
|
||||
#
|
||||
# # for file_name in tqdm(file_name_list):
|
||||
# # file_path = os.path.join(dataset_path, file_name)
|
||||
# dict_list = sorted(dict_list, key=lambda x: (x[0], eval(x.split("-")[-1])))
|
||||
# with open(os.path.join(save_root, "dict.txt"), "w") as f:
|
||||
# for i in dict_list:
|
||||
# f.write(f"{i} 1\n")
|
||||
# return dict_list
|
||||
|
||||
def midi2REMI(midi_obj, save_path, sample_list = None):
|
||||
if len(midi_obj.instruments) > 0:
|
||||
encoder.encode_file(
|
||||
file_path,
|
||||
midi_obj=midi_obj,
|
||||
remove_empty_bars=True,
|
||||
ignore_inst=ignore_inst,
|
||||
ignore_ts=True,
|
||||
ignore_tempo=False,
|
||||
save_path = save_path
|
||||
)
|
||||
|
||||
def pass_para(file_in, index, p_num, sample_list, finished_num):
|
||||
# if eval(file_in[1]):
|
||||
# return
|
||||
# file_name = file_in[0]
|
||||
# save_name = "./data/lmd_encode/" + file_name.replace("/", "_")[:-4] + ".txt"
|
||||
# midi_obj = miditoolkit.MidiFile("./data/lmd_full/" + file_name)
|
||||
# midi2REMI(midi_obj, save_name)
|
||||
size = math.ceil(len(file_in) / p_num)
|
||||
start = size * index
|
||||
end = (index + 1) * size if (index + 1) * size < len(file_in) else len(file_in)
|
||||
temp_data = file_in[start:end]
|
||||
for j in temp_data:
|
||||
# if eval(j[1]):
|
||||
# continue
|
||||
file_name = j
|
||||
# save_name = r"E:\AIMusic\ControlGeneration\lmd_6_tracks_encode/remi_seq/" + file_name.replace("/", "_")[:-4] + ".txt"
|
||||
save_name = r"D:\ProjectData\ControlGeneration\0813_attributes_generation\lmd_6tracks_clean\full_song\remi_seq/" + file_name.split(".")[0] + ".txt"
|
||||
# if os.path.exists(save_name):
|
||||
# if index == 2:
|
||||
# print("skip")
|
||||
# continue
|
||||
midi_obj = miditoolkit.MidiFile(os.path.join(r"D:\ProjectData\datasets\lmd_6tracks\midi_6tracks", file_name))
|
||||
if check_ts(midi_obj):
|
||||
midi_obj.time_signature_changes = [miditoolkit.TimeSignature(4,4,0)]
|
||||
try:
|
||||
midi2REMI(midi_obj, save_name, sample_list)
|
||||
except:
|
||||
print(file_name, " illegal!!!")
|
||||
finished_num.value += 1
|
||||
if index == 2:
|
||||
print(finished_num.value)
|
||||
|
||||
def mutiprocess_REMI_encode(file_name_list):
|
||||
# ex_item = pd.read_excel("./data/LMD_full_statistics.xlsx", header = [0,1])
|
||||
# file_name_list = np.load("./data/file_path.npy")
|
||||
# for i in tqdm(range(50000)):
|
||||
# if not eval(file_name_list[i,1]):
|
||||
# save_name = "./data/lmd_encode/" + file_name_list[i,0].replace("/", "_")[:-4] + ".txt"
|
||||
# midi_obj = miditoolkit.MidiFile("./data/lmd_full/" + file_name_list[i,0])
|
||||
# midi2REMI(midi_obj, save_name)
|
||||
# del midi_obj
|
||||
# gc.collect()
|
||||
|
||||
manager = Manager()
|
||||
sample_list = manager.dict()
|
||||
finished_num = manager.Value("i", 0)
|
||||
|
||||
processor_num = 4
|
||||
process_list = []
|
||||
for i in range(processor_num):
|
||||
# res.append(pools.apply_async(pass_para, args=(file_name_list, i, processor_num, sample_list, finished_num,)))
|
||||
process_list.append(multiprocessing.Process(target=pass_para, args=(
|
||||
file_name_list, i, processor_num, sample_list, finished_num,)))
|
||||
process_list[-1].start()
|
||||
print(str(i) + ' processor started !')
|
||||
|
||||
for i in process_list:
|
||||
i.join()
|
||||
print("over!")
|
||||
# save_dict = {}
|
||||
# for i in sample_list.items():
|
||||
# print("key:", type(i[0]))
|
||||
# save_dict[int(i[0])] = i[1]
|
||||
# json.dump(save_dict, open("./data/sample_list.json", "w"))
|
||||
|
||||
# def sample_midi_data():
|
||||
# # sample_list = json.load(open("./data/sample_list.json", "r"))
|
||||
# file_name_list = np.load("./data/file_instru.npy", allow_pickle=True)
|
||||
#
|
||||
# chosen_file = []
|
||||
# total_dist = dict(zip(range(129), [0] * 129))
|
||||
# for file in tqdm(file_name_list):
|
||||
# if file[1]:
|
||||
# continue
|
||||
# # midi_obj = miditoolkit.MidiFile("./data/lmd_full/" + file[0])
|
||||
# isSelected = False
|
||||
# for instru in eval(file[2]):
|
||||
# if instru in Banlanced_INST:
|
||||
# if not isSelected:
|
||||
# chosen_file.append(file)
|
||||
# isSelected = True
|
||||
#
|
||||
# total_dist[min(instru, 128)] += 1
|
||||
# if isSelected:
|
||||
# for instru in eval(file[2]):
|
||||
# if instru in INSTR_NAME_LIST or instru >= 128:
|
||||
# sample_INSTR_dist[min(instru, 128)] += 1
|
||||
# np.save("./data/chosen_files.npy", chosen_file)
|
||||
# np.save("./data/selected_dist.npy", sample_INSTR_dist)
|
||||
# # {0: 46250, 1: 15661, 4: 13041, 5: 10634, 24: 23789, 25: 30324, 26: 16262, 27: 21368, 28: 13430, 29: 16726, 30: 16890, 32: 22399, 33: 36293,
|
||||
# # 35: 21158, 48: 37318, 49: 20172, 50: 12922, 52: 20990, 53: 10138, 56: 17569, 57: 13767, 60: 11161, 61: 13343, 65: 14791, 66: 11622, 71: 10262, 73: 16688, 128: 97225}
|
||||
|
||||
# def gen_train_valid_test_txt():
|
||||
# np.random.seed(42)
|
||||
# # chosen_file = np.load("./data/chosen_files.npy", allow_pickle=True)
|
||||
# root = "E:\AIMusic\ControlGeneration\lmd_6_tracks_encode\clean"
|
||||
# chosen_file = os.listdir(root)
|
||||
# split_point = [int(0.88*len(chosen_file)), int(0.95*len(chosen_file))]
|
||||
# chosen_index = np.arange(0,len(chosen_file))
|
||||
# np.random.shuffle(chosen_index)
|
||||
# train_file_index = chosen_index[:split_point[0]]
|
||||
# valid_file_index = chosen_index[split_point[0]:split_point[1]]
|
||||
# test_file_index = chosen_index[split_point[1]:]
|
||||
# # dict = []
|
||||
# # np.save("./data/total_txt/split_index.npy", [train_file_index, valid_file_index, test_file_index])
|
||||
# def get_split_txt(split, indexes):
|
||||
# # command_list = []
|
||||
# f = open(rf"E:\AIMusic\ControlGeneration\instrument_control_bar\attribute_control\txt/{split}.txt", "w")
|
||||
#
|
||||
# num = 0
|
||||
# print(split, " start!")
|
||||
# for i in tqdm(indexes):
|
||||
# file_path = os.path.join(root, chosen_file[i])
|
||||
# s_encode = open(file_path, "r").read()
|
||||
# if s_encode[-1:] != "\n":
|
||||
# s_encode = s_encode + "\n"
|
||||
# f.write(s_encode)
|
||||
# num += 1
|
||||
# # if num > 100:
|
||||
# # break
|
||||
# f.close()
|
||||
# # np.save(f"./data/total_txt/command_{split}.npy", np.array(command_list))
|
||||
# get_split_txt("train", train_file_index)
|
||||
# get_split_txt("valid", valid_file_index)
|
||||
# get_split_txt("test", test_file_index)
|
||||
# # write_dict(dict, "./data/total_txt/dict.txt")
|
||||
# print("done!")
|
||||
#
|
||||
# def check(check_index):
|
||||
# command = np.load("./data/total_txt/command_train.npy")
|
||||
# index = 0
|
||||
# with open("./data/total_txt/train.txt", "r") as f:
|
||||
# line = f.readline().strip("\n")
|
||||
# while 1:
|
||||
# # for i in line.split(" "):
|
||||
# # if i[0] == "i":
|
||||
# # instru_number = eval(i[2:])
|
||||
# # assert command[index][INSTR_NAME2_index[instru_number]] == 1, "error!"
|
||||
#
|
||||
# if check_index == index:
|
||||
# print(line)
|
||||
# break
|
||||
# index += 1
|
||||
# line = f.readline().strip("\n")
|
||||
# if index % 1000 == 0:
|
||||
# print(index)
|
||||
#
|
||||
# # def generate_command_vector(token_list:list):
|
||||
# # command_vector = [0] * (len(INSTR_NAME_LIST))
|
||||
# # length = 0
|
||||
# # for i in token_list:
|
||||
# # if i == "":
|
||||
# # continue
|
||||
# # if i[0] == "i":
|
||||
# # j = eval(i[2:])
|
||||
# # command_vector[INSTR_NAME2_index[min(j, 128)]] = 1
|
||||
# # length += 1
|
||||
# # # command_vector = np.array([command_vector])
|
||||
# # # command_vector = np.repeat(command_vector, length, axis = 0)
|
||||
# # command_vector.append(length + 1) # for b-1 token
|
||||
# # return command_vector
|
||||
#
|
||||
# def get_8bars_train_valid_test_txt():
|
||||
# # 随机截取8个bar长度的token试一试
|
||||
# # chosen_file = np.load("./data/chosen_files.npy", allow_pickle=True)
|
||||
# # data_path = "E:/AIMusic/ControlGeneration/lmd_encode/"
|
||||
# # save_root = "E:/AIMusic/ControlGeneration/8-bars-txt/"
|
||||
# # data_path = "./data/sample_6tracks/remi_seq/"
|
||||
# # save_root = "./data/sample_6tracks/txt/"
|
||||
# data_path = "E:/AIMusic/ControlGeneration/lmd_6_tracks_encode/remi_seq/"
|
||||
# save_root = "E:/AIMusic/ControlGeneration/lmd_6_tracks_encode/txt/"
|
||||
#
|
||||
# chosen_file = [i for i in os.listdir(data_path)]
|
||||
#
|
||||
# def write(code, target_f):
|
||||
# for i in range(len(code)):
|
||||
# if code[i] == "s-9":
|
||||
# continue
|
||||
# target_f.write(code[i])
|
||||
# if i != len(code)-1:
|
||||
# target_f.write(" ")
|
||||
# target_f.write("\n")
|
||||
# def get_bar_split_txt(split, start, end):
|
||||
# command_list = []
|
||||
# split_files = chosen_file[start:end]
|
||||
# f = open(save_root + f"./{split}.txt", "w")
|
||||
# num = 0
|
||||
# print(split, " start!")
|
||||
# max_length = 3000
|
||||
# sample_num = 0
|
||||
# for file_name in tqdm(split_files):
|
||||
# file_path = data_path + file_name.split(".")[0].replace("/", "_") + ".txt"
|
||||
# s_encode = open(file_path, "r").read().strip("\n")
|
||||
# if s_encode[:3] != "s-9":
|
||||
# continue
|
||||
# bar_pos = []
|
||||
# remi_tokens = s_encode.split(" ")
|
||||
# for i, token in enumerate(remi_tokens):
|
||||
# if token == "b-1":
|
||||
# bar_pos.append(i)
|
||||
#
|
||||
# if len(bar_pos) <= 8:
|
||||
# cur_length = len(remi_tokens) -1
|
||||
# if cur_length + 2 <= max_length:
|
||||
# write(remi_tokens[0:], f)
|
||||
# sample_num += 1
|
||||
# continue
|
||||
#
|
||||
# coordinate_pool = []
|
||||
# start = 0
|
||||
# for i, pos in enumerate(bar_pos[:-7]):
|
||||
# cur_length = bar_pos[i+7] + 1 - start
|
||||
# if cur_length -8 + 2 <= max_length:
|
||||
# coordinate_pool.append((i, start, start + cur_length))
|
||||
# start = pos + 1
|
||||
# if len(coordinate_pool) == 0:
|
||||
# continue
|
||||
# elif len(coordinate_pool) == 1:
|
||||
# arr = coordinate_pool[0]
|
||||
# write(remi_tokens[arr[1]:arr[2]], f)
|
||||
# sample_num += 1
|
||||
# else:
|
||||
# chosen_bars = np.random.choice(len(coordinate_pool), 2, replace = False)
|
||||
# for i in chosen_bars:
|
||||
# arr = coordinate_pool[i]
|
||||
# write(remi_tokens[arr[1]:arr[2]], f)
|
||||
# sample_num += 1
|
||||
# print(sample_num)
|
||||
# f.close()
|
||||
# # np.save(f"./data/total_txt/command_{split}.npy", np.array(command_list))
|
||||
# get_bar_split_txt("train", 0, 70000)
|
||||
# get_bar_split_txt("valid", 70000, 75000)
|
||||
# get_bar_split_txt("test", 75000, 80000)
|
||||
# def statics_length():
|
||||
# # 统计token长度
|
||||
# f = open("./data/total_txt/train.txt", "r")
|
||||
# line = f.readline()
|
||||
# length_dict = {}
|
||||
# num = 0
|
||||
# while line:
|
||||
# length = len(line.split(" "))
|
||||
# if length not in length_dict:
|
||||
# length_dict[length] = 1
|
||||
# else:
|
||||
# length_dict[length] += 1
|
||||
# line = f.readline()
|
||||
# num += 1
|
||||
# print("\r", num, end = "")
|
||||
#
|
||||
# import matplotlib.pyplot as plt
|
||||
#
|
||||
# X = length_dict.keys()
|
||||
# Y = length_dict.values()
|
||||
# fig = plt.figure()
|
||||
# plt.bar(X, Y, 0.4, color="green")
|
||||
# plt.xlabel("X-axis")
|
||||
# plt.ylabel("Y-axis")
|
||||
# plt.title("bar chart")
|
||||
# plt.show()
|
||||
# ave = 0
|
||||
# for j in length_dict.items():
|
||||
# ave += j[0]*j[1]/num
|
||||
# ave = 14995.75
|
||||
#
|
||||
# def dict_generation(path):
|
||||
# dict = []
|
||||
# for file in os.listdir(path):
|
||||
# f = open(os.path.join(path, file))
|
||||
# line = f.readline()
|
||||
# while line:
|
||||
# tokens = line.strip("\n").split(" ")
|
||||
# assert len(tokens) <= 3010, "too long token sequence"
|
||||
# for j in tokens:
|
||||
# if j not in dict:
|
||||
# dict.append(j)
|
||||
# line = f.readline()
|
||||
# f.close()
|
||||
# g = open(os.path.join(path, "dict.txt"), "w")
|
||||
# dict = sorted(dict, key=lambda x: (x[0], eval(x.split("-")[-1])))
|
||||
# for i in dict:
|
||||
# g.write(i + " 1\n")
|
||||
# g.close()
|
||||
#
|
||||
# def lmd_6tracks_clean():
|
||||
# # 清理lmd_6tracks的代码
|
||||
# chosen_files = np.load(r"./data/lmd_full_not_duplicate_files.npy", allow_pickle=True)
|
||||
# all_files = np.load(r"E:\AIMusic\ControlGeneration\全局的控制数据-28类乐器\file_path.npy", allow_pickle=True)
|
||||
# all_file_names = []
|
||||
# for i in tqdm(range(len(all_files))):
|
||||
# cur_name = all_files[i][0]
|
||||
# all_file_names.append(cur_name)
|
||||
#
|
||||
# out_of_lmd_full_num = 0
|
||||
# successs_num = 0
|
||||
# error_name_list = []
|
||||
# for tracks6_name in tqdm(os.listdir(r"E:\AIMusic\ControlGeneration\lmd_6_tracks_encode\remi_seq")):
|
||||
# search_name = tracks6_name.split("_")
|
||||
#
|
||||
# search_name = search_name[-1][0] + "/" + search_name[-1].split(".")[0] + ".mid"
|
||||
# if len(search_name) != 38:
|
||||
# error_name_list.append(tracks6_name)
|
||||
# if search_name not in all_file_names:
|
||||
# out_of_lmd_full_num += 1
|
||||
# else:
|
||||
# successs_num += 1
|
||||
# source_path = os.path.join(r"E:\AIMusic\ControlGeneration\lmd_6_tracks_encode\remi_seq", tracks6_name)
|
||||
# with open(source_path, "r") as f:
|
||||
# line = f.readline().strip("\n")
|
||||
# line = line.split(" ")
|
||||
# is4_4 = True
|
||||
# for j in line:
|
||||
# if j[0] == "s" and eval(j[2:]) != 9:
|
||||
# is4_4 = False
|
||||
# break
|
||||
# if is4_4:
|
||||
# target_path = os.path.join(r"E:\AIMusic\ControlGeneration\lmd_6_tracks_encode\clean", tracks6_name)
|
||||
# shutil.copy(source_path, target_path)
|
||||
# print(out_of_lmd_full_num, successs_num)
|
||||
#
|
||||
# length = []
|
||||
# for i in all_file_names:
|
||||
# if len(i) not in length:
|
||||
# length.append((len(i)))
|
||||
#
|
||||
# length_list = {}
|
||||
# length_33 = []
|
||||
# for name in error_name_list:
|
||||
# if len(name)-4 not in length_list.keys():
|
||||
# length_list[len(name)-4] = 1
|
||||
# else:
|
||||
# length_list[len(name)- 4] += 1
|
||||
# if len(name)-4 == 33:
|
||||
# length_33.append(name)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# dict = generate_dict(r"E:\AIMusic\ControlGeneration\lmd_6_tracks_encode\clean_4_4", r"E:\AIMusic\ControlGeneration\lmd_6_tracks_encode")
|
||||
# x = 1
|
||||
# name_list = os.listdir("./data/lmd_6tracks/midi_6tracks/")
|
||||
#
|
||||
# name_list = np.random.choice(name_list, 10000, replace=False)
|
||||
# np.save("./data/chosen_files_6tracks.npy", name_list)
|
||||
|
||||
# file_name_list = np.load("data/not_duplicate_list.npy")
|
||||
# file_name_list = os.listdir(r"D:\ProjectData\datasets\lmd_6tracks\midi_6tracks")
|
||||
# print(len(file_name_list))
|
||||
file_name_list = np.load(r"D:\ProjectData\datasets\lmd_6tracks_clean_4_4\file_name_list.npy")
|
||||
mutiprocess_REMI_encode(file_name_list)
|
||||
|
||||
# gen_train_valid_test_txt()
|
||||
|
||||
# file_name_list = os.listdir("E:\AIMusic\ControlGeneration\lmd_6_tracks_encode\clean_4_4")
|
||||
|
||||
|
||||
# test_mid_obj = miditoolkit.MidiFile("./test.mid")
|
||||
# midi2REMI(test_mid_obj, )
|
||||
# file_name_list = np.load("./data/file_path.npy")
|
||||
# for file_name in tqdm(file_name_list[:,0]):
|
||||
# save_name = r"E:\AIMusic\ControlGeneration\lmd_encode/" + file_name.replace("/", "_")[:-4] + ".txt"
|
||||
# if os.path.exists(save_name):
|
||||
# continue
|
||||
# midi_obj = miditoolkit.MidiFile("./data/lmd_full/" + file_name)
|
||||
# midi2REMI(midi_obj, save_name, {})
|
||||
|
||||
# get_8bars_train_valid_test_txt()
|
||||
# dict_generation(r"E:\AIMusic\ControlGeneration\lmd_6_tracks_encode\txt/")
|
||||
# command_generation(r"E:\AIMusic\ControlGeneration\lmd_6_tracks_encode\txt/")
|
||||
|
||||
|
||||
# path = "./data/sample_6tracks/midi"
|
||||
# save_root = "./data/sample_6tracks/remi_seq/"
|
||||
# for i, file in enumerate(os.listdir(path)):
|
||||
# midi_obj = miditoolkit.MidiFile(os.path.join(path, file))
|
||||
# midi2REMI(midi_obj, os.path.join(save_root, file.split(".")[0] +".txt"))
|
||||
|
||||
|
||||
|
||||
# fairseq bug:
|
||||
# g = open("./data/total_txt/train.txt", "w")
|
||||
# with open("./data/total_txt/train_total.txt", "r") as f:
|
||||
# line = f.readline()
|
||||
# index = 0
|
||||
# while 1:
|
||||
# # for i in line.split(" "):
|
||||
# # if i[0] == "i":
|
||||
# # instru_number = eval(i[2:])
|
||||
# # assert command[index][INSTR_NAME2_index[instru_number]] == 1, "error!"
|
||||
# g.write(line)
|
||||
# index += 1
|
||||
# line = f.readline()
|
||||
# if index % 1000 == 0:
|
||||
# print(index)
|
||||
# if index > 70000:
|
||||
# break
|
||||
# gen_train_valid_test_txt()
|
||||
# check(90000)
|
||||
# def _warmup_mmap_file(path):
|
||||
# with open(path, "rb") as stream:
|
||||
# while stream.read(100 * 1024 * 1024):
|
||||
# pass
|
||||
# import struct
|
||||
# dtypes = {
|
||||
# 1: np.uint8,
|
||||
# 2: np.int8,
|
||||
# 3: np.int16,
|
||||
# 4: np.int32,
|
||||
# 5: np.int64,
|
||||
# 6: np.float,
|
||||
# 7: np.double,
|
||||
# 8: np.uint16,
|
||||
# }
|
||||
# path ="./data/total-data-bin/train.idx"
|
||||
# _HDR_MAGIC = b"MMIDIDX\x00\x00"
|
||||
# with open(path, "rb") as stream:
|
||||
# magic_test = stream.read(9)
|
||||
# assert _HDR_MAGIC == magic_test, (
|
||||
# "Index file doesn't match expected format. "
|
||||
# "Make sure that --dataset-impl is configured properly."
|
||||
# )
|
||||
# version = struct.unpack("<Q", stream.read(8))
|
||||
# assert (1,) == version
|
||||
#
|
||||
# (dtype_code,) = struct.unpack("<B", stream.read(1))
|
||||
# _dtype = dtypes[dtype_code]
|
||||
# _dtype_size = _dtype().itemsize
|
||||
#
|
||||
# _len = struct.unpack("<Q", stream.read(8))[0]
|
||||
# offset = stream.tell()
|
||||
#
|
||||
# _warmup_mmap_file(path)
|
||||
#
|
||||
# _bin_buffer_mmap = np.memmap(path, mode="r", order="C")
|
||||
# _bin_buffer = memoryview(_bin_buffer_mmap)
|
||||
# _sizes = np.frombuffer(
|
||||
# _bin_buffer, dtype=np.int32, count=_len, offset=offset
|
||||
# )
|
||||
# _pointers = np.frombuffer(
|
||||
# _bin_buffer,
|
||||
# dtype=np.int64,
|
||||
# count=_len,
|
||||
# offset=offset + _sizes.nbytes,
|
||||
# )
|
||||
|
||||
# fairseq-preprocess --only-source --srcdict data/total_txt/dict.txt --trainpref data/total_txt/train.txt --validpref data/total_txt/valid.txt --testpref data/total_txt/test.txt --destdir data/total-data-bin --workers 8
|
||||
|
||||
# midi_root_path = r"D:\ProjectData\datasets\lmd_full"
|
||||
# save_root = r"D:\Project\ControlGeneration\StyleCtrl\jSymbolic_lib\datasets\TopMAGD/midi"
|
||||
# midi_to_genre = json.load(open("./midi_genre_map.json", "r"))
|
||||
# top_magd_genre = midi_to_genre["topmagd"]
|
||||
# for key in tqdm(top_magd_genre.keys()):
|
||||
# midi_name = key + ".mid"
|
||||
# midi_src_path = os.path.join(midi_root_path, midi_name[0], midi_name)
|
||||
# midi_save_path = os.path.join(save_root, midi_name)
|
||||
# shutil.copyfile(midi_src_path, midi_save_path)
|
||||
|
||||
# midi_root_path = r"D:\ProjectData\datasets\lmd_full"
|
||||
# save_root = r"D:\Project\ControlGeneration\StyleCtrl\jSymbolic_lib\datasets\MASD/midi"
|
||||
# midi_to_genre = json.load(open("./midi_genre_map.json", "r"))
|
||||
# masd_genre = midi_to_genre["masd"]
|
||||
# for key in tqdm(masd_genre.keys()):
|
||||
# midi_name = key + ".mid"
|
||||
# midi_src_path = os.path.join(midi_root_path, midi_name[0], midi_name)
|
||||
# midi_save_path = os.path.join(save_root, midi_name)
|
||||
# shutil.copyfile(midi_src_path, midi_save_path)
|
|
@ -0,0 +1,360 @@
|
|||
#!/usr/bin/env python3 -u
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""
|
||||
Translate raw text with a trained model. Batches data on-the-fly.
|
||||
"""
|
||||
|
||||
import fileinput
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections import namedtuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils
|
||||
from fairseq.data import encoders
|
||||
from fairseq.token_generation_constraints import pack_constraints, unpack_constraints
|
||||
from fairseq_cli.generate import get_symbols_to_strip_from_output
|
||||
import linear_decoder.controlled_task
|
||||
from MidiProcessor.midiprocessor import MidiDecoder
|
||||
import shutil, random
|
||||
from fairseq_cli import interactive
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
||||
stream=sys.stdout,
|
||||
)
|
||||
logger = logging.getLogger("fairseq_cli.interactive")
|
||||
|
||||
|
||||
Batch = namedtuple("Batch", "ids src_tokens src_lengths constraints")
|
||||
Translation = namedtuple("Translation", "src_str hypos pos_scores alignments")
|
||||
|
||||
|
||||
def seed_everything(seed):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def buffered_read(input, buffer_size):
|
||||
buffer = []
|
||||
with fileinput.input(files=[input], openhook=fileinput.hook_encoded("utf-8")) as h:
|
||||
for src_str in h:
|
||||
buffer.append(src_str.strip())
|
||||
if len(buffer) >= buffer_size:
|
||||
yield buffer
|
||||
buffer = []
|
||||
|
||||
if len(buffer) > 0:
|
||||
yield buffer
|
||||
|
||||
|
||||
def make_batches(lines, args, task, max_positions, encode_fn):
|
||||
def encode_fn_target(x):
|
||||
return encode_fn(x)
|
||||
|
||||
if args.constraints:
|
||||
# Strip (tab-delimited) contraints, if present, from input lines,
|
||||
# store them in batch_constraints
|
||||
batch_constraints = [list() for _ in lines]
|
||||
for i, line in enumerate(lines):
|
||||
if "\t" in line:
|
||||
lines[i], *batch_constraints[i] = line.split("\t")
|
||||
|
||||
# Convert each List[str] to List[Tensor]
|
||||
for i, constraint_list in enumerate(batch_constraints):
|
||||
batch_constraints[i] = [
|
||||
task.target_dictionary.encode_line(
|
||||
encode_fn_target(constraint),
|
||||
append_eos=False,
|
||||
add_if_not_exist=False,
|
||||
)
|
||||
for constraint in constraint_list
|
||||
]
|
||||
|
||||
tokens = [
|
||||
task.source_dictionary.encode_line(
|
||||
encode_fn(src_str), add_if_not_exist=False
|
||||
).long()
|
||||
for src_str in lines
|
||||
]
|
||||
|
||||
if args.constraints:
|
||||
constraints_tensor = pack_constraints(batch_constraints)
|
||||
else:
|
||||
constraints_tensor = None
|
||||
|
||||
lengths = [t.numel() for t in tokens]
|
||||
itr = task.get_batch_iterator(
|
||||
dataset=task.build_dataset_for_inference(
|
||||
tokens, lengths, constraints=constraints_tensor
|
||||
),
|
||||
max_tokens=args.max_tokens,
|
||||
max_sentences=args.batch_size,
|
||||
max_positions=max_positions,
|
||||
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
|
||||
).next_epoch_itr(shuffle=False)
|
||||
for batch in itr:
|
||||
ids = batch["id"]
|
||||
src_tokens = batch["net_input"]["src_tokens"]
|
||||
src_lengths = batch["net_input"]["src_lengths"]
|
||||
constraints = batch.get("constraints", None)
|
||||
|
||||
yield Batch(
|
||||
ids=ids,
|
||||
src_tokens=src_tokens,
|
||||
src_lengths=src_lengths,
|
||||
constraints=constraints,
|
||||
)
|
||||
|
||||
def attribute_generate(args):
|
||||
assert args.tgt_emotion in [1, 2, 3, 4], f"Error emotion index for {args.tgt_emotion}."
|
||||
start_time = time.time()
|
||||
total_translate_time = 0
|
||||
|
||||
utils.import_user_module(args)
|
||||
|
||||
if args.buffer_size < 1:
|
||||
args.buffer_size = 1
|
||||
if args.max_tokens is None and args.batch_size is None:
|
||||
args.batch_size = 1
|
||||
|
||||
assert (
|
||||
not args.sampling or args.nbest == args.beam
|
||||
), "--sampling requires --nbest to be equal to --beam"
|
||||
assert (
|
||||
not args.batch_size or args.batch_size <= args.buffer_size
|
||||
), "--batch-size cannot be larger than --buffer-size"
|
||||
|
||||
logger.info(args)
|
||||
|
||||
# Fix seed for stochastic decoding
|
||||
if args.seed is not None and not args.no_seed_provided:
|
||||
np.random.seed(args.seed)
|
||||
utils.set_torch_seed(args.seed)
|
||||
|
||||
use_cuda = torch.cuda.is_available() and not args.cpu
|
||||
|
||||
# Setup task, e.g., translation_control
|
||||
task = tasks.setup_task(args)
|
||||
|
||||
# Load ensemble
|
||||
logger.info("loading model(s) from {}".format(args.path))
|
||||
models, _model_args = checkpoint_utils.load_model_ensemble(
|
||||
args.path.split(os.pathsep),
|
||||
arg_overrides=eval(args.model_overrides),
|
||||
task=task,
|
||||
suffix=getattr(args, "checkpoint_suffix", ""),
|
||||
strict=(args.checkpoint_shard_count == 1),
|
||||
num_shards=args.checkpoint_shard_count,
|
||||
)
|
||||
|
||||
# Set dictionaries
|
||||
src_dict = task.source_dictionary
|
||||
tgt_dict = task.target_dictionary
|
||||
|
||||
# Optimize ensemble for generation
|
||||
for model in models:
|
||||
if args.fp16:
|
||||
model.half()
|
||||
if use_cuda and not args.pipeline_model_parallel:
|
||||
model.cuda()
|
||||
model.prepare_for_inference_(args)
|
||||
|
||||
# Initialize generator
|
||||
generator = task.build_generator(models, args)
|
||||
|
||||
# Handle tokenization and BPE
|
||||
tokenizer = encoders.build_tokenizer(args)
|
||||
bpe = encoders.build_bpe(args)
|
||||
|
||||
def encode_fn(x):
|
||||
if tokenizer is not None:
|
||||
x = tokenizer.encode(x)
|
||||
if bpe is not None:
|
||||
x = bpe.encode(x)
|
||||
return x
|
||||
|
||||
def decode_fn(x):
|
||||
if bpe is not None:
|
||||
x = bpe.decode(x)
|
||||
if tokenizer is not None:
|
||||
x = tokenizer.decode(x)
|
||||
return x
|
||||
|
||||
# Load alignment dictionary for unknown word replacement
|
||||
# (None if no unknown word replacement, empty if no path to align dictionary)
|
||||
align_dict = utils.load_align_dict(args.replace_unk)
|
||||
|
||||
max_positions = utils.resolve_max_positions(
|
||||
task.max_positions(), *[model.max_positions() for model in models]
|
||||
)
|
||||
|
||||
if args.constraints:
|
||||
logger.warning(
|
||||
"NOTE: Constrained decoding currently assumes a shared subword vocabulary."
|
||||
)
|
||||
|
||||
if args.buffer_size > 1:
|
||||
logger.info("Sentence buffer size: %s", args.buffer_size)
|
||||
logger.info("NOTE: hypothesis and token scores are output in base 2")
|
||||
logger.info("Type the input sentence and press return:")
|
||||
start_id = 0
|
||||
command_list = np.load(args.ctrl_command_path)
|
||||
|
||||
# for inputs in buffered_read(args.input, args.buffer_size):
|
||||
save_root = args.save_root
|
||||
os.makedirs(save_root, exist_ok=True)
|
||||
midi_decoder = MidiDecoder("REMIGEN2")
|
||||
|
||||
|
||||
for command_index in [args.tgt_emotion - 1]:
|
||||
os.makedirs(save_root + f"/midi", exist_ok=True)
|
||||
os.makedirs(save_root + f"/remi", exist_ok=True)
|
||||
sample_scores = {}
|
||||
for gen_times in range(1000):
|
||||
if len(os.listdir(save_root + f"/midi")) > args.need_num:
|
||||
break
|
||||
if os.path.exists(save_root + f"/remi/{gen_times*args.batch_size}.txt"):
|
||||
print(f"command_id: {command_index} sample:{gen_times*args.batch_size} already exists. Skip this batch!")
|
||||
continue
|
||||
start_tokens = [""]
|
||||
for inputs in [start_tokens*args.batch_size]: # "" for none prefix input
|
||||
results = []
|
||||
for batch in make_batches(inputs, args, task, max_positions, encode_fn):
|
||||
bsz = batch.src_tokens.size(0)
|
||||
src_tokens = batch.src_tokens
|
||||
src_lengths = batch.src_lengths
|
||||
constraints = batch.constraints
|
||||
command_input = []
|
||||
for i in range(args.batch_size):
|
||||
command_input.append(command_list[command_index])
|
||||
command_input = torch.tensor(command_input)
|
||||
|
||||
if use_cuda:
|
||||
src_tokens = src_tokens.cuda()
|
||||
src_lengths = src_lengths.cuda()
|
||||
command_input = command_input.cuda()
|
||||
if constraints is not None:
|
||||
constraints = constraints.cuda()
|
||||
|
||||
sample = {
|
||||
"net_input": {
|
||||
"src_tokens": src_tokens,
|
||||
"src_lengths": src_lengths,
|
||||
"command_input":command_input,
|
||||
},
|
||||
}
|
||||
translate_start_time = time.time()
|
||||
translations = task.inference_step(
|
||||
generator, models, sample, constraints=constraints
|
||||
)
|
||||
translate_time = time.time() - translate_start_time
|
||||
total_translate_time += translate_time
|
||||
list_constraints = [[] for _ in range(bsz)]
|
||||
if args.constraints:
|
||||
list_constraints = [unpack_constraints(c) for c in constraints]
|
||||
for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
|
||||
src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
|
||||
constraints = list_constraints[i]
|
||||
results.append(
|
||||
(
|
||||
start_id + id,
|
||||
src_tokens_i,
|
||||
hypos,
|
||||
{
|
||||
"constraints": constraints,
|
||||
"time": translate_time / len(translations),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# sort output to match input order
|
||||
for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]):
|
||||
if src_dict is not None:
|
||||
src_str = src_dict.string(src_tokens, args.remove_bpe)
|
||||
# print("S-{}\t{}".format(id_, src_str))
|
||||
# print("W-{}\t{:.3f}\tseconds".format(id_, info["time"]))
|
||||
# for constraint in info["constraints"]:
|
||||
# print(
|
||||
# "C-{}\t{}".format(
|
||||
# id_, tgt_dict.string(constraint, args.remove_bpe)
|
||||
# )
|
||||
# )
|
||||
|
||||
# Process top predictions
|
||||
for hypo in hypos[: min(len(hypos), args.nbest)]:
|
||||
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
|
||||
hypo_tokens=hypo["tokens"].int().cpu(),
|
||||
src_str=src_str,
|
||||
alignment=hypo["alignment"],
|
||||
align_dict=align_dict,
|
||||
tgt_dict=tgt_dict,
|
||||
remove_bpe=args.remove_bpe,
|
||||
extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator),
|
||||
)
|
||||
detok_hypo_str = decode_fn(hypo_str)
|
||||
score = hypo["score"] / math.log(2) # convert to base 2
|
||||
command_id = command_index
|
||||
save_id = id_ + gen_times*args.batch_size
|
||||
sample_scores[save_id] = score.detach().cpu().numpy().item()
|
||||
print(f"command_id: {command_id} sample:{save_id} over with length {len(hypo_str.split(' '))}")
|
||||
with open(save_root + f"/remi/{save_id}.txt", "w") as f:
|
||||
f.write(hypo_str)
|
||||
remi_token = hypo_str.split(" ")
|
||||
# try:
|
||||
midi_obj = midi_decoder.decode_from_token_str_list(remi_token)
|
||||
midi_obj.dump(save_root + f"/midi/{save_id}.mid")
|
||||
# except:
|
||||
# pass
|
||||
|
||||
def cli_main():
|
||||
parser = options.get_interactive_generation_parser()
|
||||
parser.add_argument("--ctrl_command_path", type=str)
|
||||
parser.add_argument("--save_root", type=str)
|
||||
parser.add_argument("--need_num", type=int, default=50)
|
||||
parser.add_argument("--tgt_emotion", type=int)
|
||||
args = options.parse_args_and_arch(parser)
|
||||
attribute_generate(args)
|
||||
|
||||
# if args.ctrl_command_path != "None":
|
||||
# attributes(args)
|
||||
# else:
|
||||
# label_embedding(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
seed_everything(2023)
|
||||
cli_main()
|
||||
'''
|
||||
../StyleCtrlData/train_data/1026_EMO/data-bin
|
||||
--task
|
||||
language_modeling_control
|
||||
--path
|
||||
checkpoints/controlled/checkpoint_best.pt
|
||||
--ctrl_command_path
|
||||
../StyleCtrlData/train_data/1026_EMO/bucket2_command_thres_EMO/emotion_rank_100/EMOPIA_inference_nearest.npy
|
||||
--save_root
|
||||
generation
|
||||
--max-len-b
|
||||
500
|
||||
--sampling
|
||||
--beam
|
||||
1
|
||||
--sampling-topk
|
||||
8
|
||||
--buffer-size
|
||||
8
|
||||
--batch-size
|
||||
8
|
||||
'''
|
Двоичный файл не отображается.
|
@ -0,0 +1,64 @@
|
|||
import shutil
|
||||
|
||||
import pandas as pd
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
import xmltodict
|
||||
import random
|
||||
import numpy as np
|
||||
from multiprocessing import Process, Pool
|
||||
import json
|
||||
from jSymbolic_util import read_pitch_feature, read_all_feature
|
||||
import xmltodict
|
||||
import subprocess
|
||||
from functools import partial
|
||||
import argparse
|
||||
|
||||
command_prefix = "java -Xmx6g -jar ./jSymbolic_2_2_user/jSymbolic2.jar -configrun ./jSymbolic_2_2_user/jSymbolicDefaultConfigs.txt"
|
||||
|
||||
|
||||
|
||||
def rename_midi_path(root):
|
||||
midi_name_list = os.listdir(root)
|
||||
for midi_name in midi_name_list:
|
||||
os.rename(root + "/" + midi_name, root + f"/{midi_name.replace(' ', '_')}")
|
||||
|
||||
def get_jSymbolic_feature(file_name, root):
|
||||
midi_name = file_name[:-4]
|
||||
|
||||
midi_path = root + f"/midi/{midi_name}.mid"
|
||||
path = os.path.join(root, "feature/" + f"{midi_name.replace(' ', '_')}.xml")
|
||||
# print(midi_path)
|
||||
if os.path.exists(path):
|
||||
return 0
|
||||
if not os.path.exists(path):
|
||||
new_command = " ".join([command_prefix, midi_path, path,
|
||||
"./test_def.xml"])
|
||||
os.system(new_command)
|
||||
return 0
|
||||
|
||||
np.random.seed(42)
|
||||
random.seed(42)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
data_path = "../data/Piano"
|
||||
midi_sampled = os.listdir(data_path + f"/midi")
|
||||
os.makedirs(data_path +f"/feature", exist_ok=True)
|
||||
with Pool(processes=8) as pool:
|
||||
result = iter(tqdm(pool.imap(partial(get_jSymbolic_feature, root = data_path), midi_sampled),
|
||||
total=len(midi_sampled)))
|
||||
for i in range(len(midi_sampled)):
|
||||
# try:
|
||||
next(result)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,99 @@
|
|||
import os
|
||||
import xmltodict
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
nan_feature_list = {}
|
||||
split_name = ["Melodic Interval Histogram", "Vertical Interval Histogram"]
|
||||
end_at = "Initial Time Signature"
|
||||
# split_pos = [['Melodic Interval Histogram', 190], ['Vertical Interval Histogram', 342]]
|
||||
|
||||
feature_name = []
|
||||
histogram_feature_name = []
|
||||
def read_pitch_feature_without_long_histogram(path):
|
||||
data = xmltodict.parse(open(path, "r").read())
|
||||
data = data['feature_vector_file']["data_set"]["feature"]
|
||||
ret = []
|
||||
histogram_ret = []
|
||||
for f in data:
|
||||
if f["name"] == end_at:
|
||||
break
|
||||
if "histogram" not in f["name"].lower():
|
||||
if f["v"] == "NaN":
|
||||
ret.append(0)
|
||||
if f["name"] not in nan_feature_list.keys():
|
||||
nan_feature_list[f["name"]] = 1
|
||||
else:
|
||||
nan_feature_list[f["name"]] += 1
|
||||
elif isinstance(f["v"], list):
|
||||
ret.extend([eval(i) for i in f["v"]])
|
||||
else:
|
||||
ret.append(eval(f["v"]))
|
||||
# feature_name.append(f["name"])
|
||||
else:
|
||||
if len(f["v"]) < 20:
|
||||
# histogram_feature_name.append(f["name"])
|
||||
histogram_ret.extend([eval(i) for i in f["v"]])
|
||||
|
||||
return ret, histogram_ret
|
||||
|
||||
def read_pitch_feature(path):
|
||||
data = xmltodict.parse(open(path, "r").read())
|
||||
data = data['feature_vector_file']["data_set"]["feature"]
|
||||
ret = []
|
||||
for f in data:
|
||||
# print(f["name"])
|
||||
# if f["name"] in split_name:
|
||||
# split_pos.append([f["name"], len(ret)])
|
||||
if f["name"] == end_at:
|
||||
break
|
||||
if "histogram" not in f["name"].lower():
|
||||
if f["v"] == "NaN":
|
||||
ret.append(0)
|
||||
if f["name"] not in nan_feature_list.keys():
|
||||
nan_feature_list[f["name"]] = 1
|
||||
else:
|
||||
nan_feature_list[f["name"]] += 1
|
||||
elif isinstance(f["v"], list):
|
||||
ret.extend([eval(i) for i in f["v"]])
|
||||
else:
|
||||
ret.append(eval(f["v"]))
|
||||
else:
|
||||
ret.extend([eval(i) for i in f["v"]])
|
||||
|
||||
return ret
|
||||
|
||||
def read_all_feature(path, need_name = False):
|
||||
data = open(path, "r").read()
|
||||
if len(data) == 0:
|
||||
return None
|
||||
data = xmltodict.parse(open(path, "r").read())
|
||||
data = data['feature_vector_file']["data_set"]["feature"]
|
||||
ret = []
|
||||
feature_names = []
|
||||
for f in data:
|
||||
if "histogram" not in f["name"].lower():
|
||||
if f["v"] == "NaN":
|
||||
ret.append(0)
|
||||
if need_name:
|
||||
feature_names.append(f["name"])
|
||||
if f["name"] not in nan_feature_list.keys():
|
||||
nan_feature_list[f["name"]] = 1
|
||||
else:
|
||||
nan_feature_list[f["name"]] += 1
|
||||
elif isinstance(f["v"], list):
|
||||
ret.extend([eval(i) for i in f["v"]])
|
||||
if need_name:
|
||||
feature_names.extend(f["name"] + f"_{i}" for i in range(len(f["v"])))
|
||||
else:
|
||||
ret.append(eval(f["v"]))
|
||||
if need_name:
|
||||
feature_names.append(f["name"])
|
||||
else:
|
||||
ret.extend([eval(i) for i in f["v"]])
|
||||
if need_name:
|
||||
feature_names.extend(f["name"] + f"_{i}" for i in range(len(f["v"])))
|
||||
if need_name:
|
||||
assert len(ret) == len(feature_names)
|
||||
return ret, feature_names
|
||||
else:
|
||||
return ret
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,211 @@
|
|||
import random
|
||||
|
||||
from sklearn.cluster import KMeans
|
||||
from sklearn.preprocessing import StandardScaler, MinMaxScaler
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import pandas as pd
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from sklearn.manifold import TSNE
|
||||
from scipy.stats import pearsonr
|
||||
from sklearn.feature_selection import SelectKBest, f_classif, VarianceThreshold
|
||||
# Import libraries
|
||||
from sklearn import preprocessing
|
||||
from sklearn.preprocessing import LabelEncoder
|
||||
from sklearn.model_selection import train_test_split, StratifiedKFold
|
||||
from sklearn import tree
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.metrics import f1_score, roc_auc_score
|
||||
import joblib
|
||||
import lightgbm as lgb
|
||||
|
||||
# 从数据集中选取feature的阈值
|
||||
def EMOPIA_threshold():
|
||||
feature_table = np.load(r"../../StyleCtrlData/jSymbolic_lib\datasets\EMOPIA\data\feature_table.npy")
|
||||
X, Y = feature_table[:, :-1], feature_table[:, -1]
|
||||
|
||||
# for i in range(X.shape[1]):
|
||||
# medium_value = np.percentile(X[:, i], 50)
|
||||
# thresholds.append(medium_value)
|
||||
# X_discrete.append(np.searchsorted([medium_value], X[:, i])[:, np.newaxis])
|
||||
for bucket_num in [3, 5, 8, 12, 16]:
|
||||
thresholds = []
|
||||
for i in range(X.shape[1]):
|
||||
thres = []
|
||||
for j in range(1, bucket_num):
|
||||
thres.append(np.percentile(X[:, i], 100.0/bucket_num*j))
|
||||
thresholds.append(thres)
|
||||
# X_discrete.append(np.searchsorted(thres, X[:, i])[:, np.newaxis])
|
||||
thresholds = np.array(thresholds)
|
||||
# X_discrete = np.concatenate(X_discrete, axis = 1)
|
||||
# cv_folds = 5
|
||||
# skf = StratifiedKFold(n_splits=cv_folds, random_state=2022, shuffle=True)
|
||||
# y_preds = np.zeros(Y.shape) - 1
|
||||
# for train_index, test_index in skf.split(X_discrete, Y):
|
||||
# X_train, y_train = X_discrete[train_index], Y[train_index]
|
||||
# X_test, y_test = X_discrete[test_index], Y[test_index]
|
||||
# clf = RandomForestClassifier(n_estimators=100, random_state=2022)
|
||||
# clf.fit(X_train, y_train)
|
||||
# y_pred_score = clf.predict_proba(X_test)
|
||||
# y_pred = np.argmax(y_pred_score, axis=1)
|
||||
# y_pred += 1 # 预测第几个象限
|
||||
# y_preds[test_index] = y_pred
|
||||
# print(f"离散化后的{X_discrete.shape[1]}维特征:", f1_score(Y, y_preds, average="micro"))
|
||||
np.save(f"../../StyleCtrlData/jSymbolic_lib/datasets/EMOPIA/data/threshold_{bucket_num}.npy", thresholds)
|
||||
# 二值化之后的分类精度为0.6762523191094619
|
||||
|
||||
# feature_selected = np.load("./data/selected_feature/emotion_And_select_17.npy")
|
||||
# feature_name = np.load("./data/feature_name.npy")
|
||||
# name2index = dict(zip(feature_name, range(len(feature_name))))
|
||||
# feature_selected_index = [name2index[i] for i in feature_selected]
|
||||
# X_discrete = X_discrete[:, feature_selected_index]
|
||||
# y_preds = np.zeros(Y.shape) - 1
|
||||
# for train_index, test_index in skf.split(X_discrete, Y):
|
||||
# X_train, y_train = X_discrete[train_index], Y[train_index]
|
||||
# X_test, y_test = X_discrete[test_index], Y[test_index]
|
||||
# clf = RandomForestClassifier(n_estimators=100, random_state=2022)
|
||||
# clf.fit(X_train, y_train)
|
||||
# y_pred_score = clf.predict_proba(X_test)
|
||||
# y_pred = np.argmax(y_pred_score, axis=1)
|
||||
# y_pred += 1 # 预测第几个象限
|
||||
# y_preds[test_index] = y_pred
|
||||
# print(f"离散化后的{X_discrete.shape[1]}维特征:", f1_score(Y, y_preds, average="micro"))
|
||||
# 分为两个桶:
|
||||
# 离散化后的1495维特征: 0.6762523191094619
|
||||
# 离散化后的39维特征: 0.62430426716141
|
||||
# 离散化后的17维特征: 0.5092764378478665
|
||||
# 分为三个桶
|
||||
# 离散化后的1495维特征: 0.6725417439703154
|
||||
# 离散化后的39维特征: 0.6447124304267161
|
||||
# 离散化后的17维特征: 0.5445269016697588
|
||||
def VGMIDI_feedback():
|
||||
threshold = np.load("./datasets/EMOPIA/data/threshold_2.npy")
|
||||
feature_table = np.load("./datasets/VGMIDI/data/feature_table.npy")
|
||||
X = feature_table[:, :-1]
|
||||
Y = feature_table[:, -1]
|
||||
X_discrete = []
|
||||
for i in range(X.shape[1]):
|
||||
X_discrete.append(np.searchsorted([threshold[i]], X[:,i])[:, np.newaxis])
|
||||
X_discrete = np.concatenate(X_discrete, axis = 1)
|
||||
cv_folds = 5
|
||||
skf = StratifiedKFold(n_splits=cv_folds, random_state=2022, shuffle=True)
|
||||
y_preds_continuous = np.zeros(Y.shape) - 1
|
||||
for train_index, test_index in skf.split(X, Y):
|
||||
X_train, y_train = X[train_index], Y[train_index]
|
||||
X_test, y_test = X[test_index], Y[test_index]
|
||||
clf = RandomForestClassifier(n_estimators=100, random_state=2022)
|
||||
clf.fit(X_train, y_train)
|
||||
y_pred_score = clf.predict_proba(X_test)
|
||||
y_pred = np.argmax(y_pred_score, axis=1)
|
||||
y_pred += 1 # 预测第几个象限
|
||||
y_preds_continuous[test_index] = y_pred
|
||||
print(f1_score(Y, y_preds_continuous, average="micro"))
|
||||
|
||||
def TopMAGD_threshold():
|
||||
feature_table = []
|
||||
labels = []
|
||||
path_root = r"../..\StyleCtrlData\data\1004_TopMAGD\truncated_2560\split_data"
|
||||
for split in ["train","valid", "test"]:
|
||||
feature_table.append(np.load(path_root + f"/1495/{split}_raw_command_1495.npy"))
|
||||
labels.append(np.load(path_root + f"/{split}_style_labels.npy"))
|
||||
feature_table = np.vstack(feature_table)
|
||||
labels = np.vstack(labels)
|
||||
X = feature_table
|
||||
Y = labels
|
||||
|
||||
feature_selected = np.load(r"E:\Music\Project\StyleCtrl\StyleCtrlData\jSymbolic_lib\data\style_select_feature\style_or_10_select_103.npy")
|
||||
feature_names = np.load(r"E:\Music\Project\StyleCtrl\StyleCtrlData\jSymbolic_lib\data\feature_name.npy")
|
||||
feature2id = dict(zip(feature_names, range(len(feature_names))))
|
||||
select_features = [feature2id[i] for i in feature_selected]
|
||||
id2genre = np.load(r"E:\Music\Project\StyleCtrl\StyleCtrlData\data\1004_TopMAGD\truncated_2560\split_data\id2genre.npy")
|
||||
X = X[:, select_features]
|
||||
|
||||
# 不离散化的feature
|
||||
cv_folds = 5
|
||||
skf = StratifiedKFold(n_splits=cv_folds, random_state=2022, shuffle=True)
|
||||
y_preds = np.zeros(Y.shape) - 1
|
||||
for index, style_name in enumerate(id2genre):
|
||||
Y_style_label = Y[:, index]
|
||||
for train_index, test_index in skf.split(X, Y_style_label):
|
||||
X_train, y_train = X[train_index], Y_style_label[train_index]
|
||||
X_test, y_test = X[test_index], Y_style_label[test_index]
|
||||
clf = RandomForestClassifier(n_estimators=100, random_state=2022)
|
||||
clf.fit(X_train, y_train)
|
||||
y_pred_score = clf.predict_proba(X_test)
|
||||
y_pred = np.argmax(y_pred_score, axis=1)
|
||||
y_pred += 1 # 预测第几个象限
|
||||
y_preds[test_index] = y_pred
|
||||
print(f"连续的{X.shape[1]}维特征:", style_name, f1_score(Y_style_label, y_preds, average="micro"), roc_auc_score(Y_style_label, y_preds))
|
||||
|
||||
X_discrete = []
|
||||
for bucket_num in [2, 3, 5, 8, 12, 16]:
|
||||
thresholds = []
|
||||
for i in range(X.shape[1]):
|
||||
thres = []
|
||||
for j in range(1, bucket_num):
|
||||
thres.append(np.percentile(X[:, i], 100.0 / bucket_num * j))
|
||||
thresholds.append(thres)
|
||||
X_discrete.append(np.searchsorted(thres, X[:, i])[:, np.newaxis])
|
||||
# thresholds = np.array(thresholds)
|
||||
X_discrete = np.concatenate(X_discrete, axis = 1)
|
||||
cv_folds = 5
|
||||
skf = StratifiedKFold(n_splits=cv_folds, random_state=2022, shuffle=True)
|
||||
y_preds = np.zeros(Y.shape) - 1
|
||||
for train_index, test_index in skf.split(X_discrete, Y):
|
||||
X_train, y_train = X_discrete[train_index], Y[train_index]
|
||||
X_test, y_test = X_discrete[test_index], Y[test_index]
|
||||
clf = RandomForestClassifier(n_estimators=100, random_state=2022)
|
||||
clf.fit(X_train, y_train)
|
||||
y_pred_score = clf.predict_proba(X_test)
|
||||
y_pred = np.argmax(y_pred_score, axis=1)
|
||||
y_pred += 1 # 预测第几个象限
|
||||
y_preds[test_index] = y_pred
|
||||
print(f"离散化后的{X_discrete.shape[1]}维特征:", f1_score(Y, y_preds, average="micro"), roc_auc_score(Y, y_preds))
|
||||
# np.save(path_root + f"/threshold_{bucket_num}.npy", thresholds)
|
||||
# 二值化之后的分类精度为0.6762523191094619
|
||||
|
||||
# feature_selected = np.load("./data/selected_feature/emotion_And_select_17.npy")
|
||||
# feature_name = np.load("./data/feature_name.npy")
|
||||
# name2index = dict(zip(feature_name, range(len(feature_name))))
|
||||
# feature_selected_index = [name2index[i] for i in feature_selected]
|
||||
# X_discrete = X_discrete[:, feature_selected_index]
|
||||
# y_preds = np.zeros(Y.shape) - 1
|
||||
# for train_index, test_index in skf.split(X_discrete, Y):
|
||||
# X_train, y_train = X_discrete[train_index], Y[train_index]
|
||||
# X_test, y_test = X_discrete[test_index], Y[test_index]
|
||||
# clf = RandomForestClassifier(n_estimators=100, random_state=2022)
|
||||
# clf.fit(X_train, y_train)
|
||||
# y_pred_score = clf.predict_proba(X_test)
|
||||
# y_pred = np.argmax(y_pred_score, axis=1)
|
||||
# y_pred += 1 # 预测第几个象限
|
||||
# y_preds[test_index] = y_pred
|
||||
# print(f"离散化后的{X_discrete.shape[1]}维特征:", f1_score(Y, y_preds, average="micro"))
|
||||
|
||||
def YMDB_threshold():
|
||||
feature_table = np.load(r"../../StyleCtrlData/jSymbolic_lib/datasets/1224_YMDB/data/fea_table.npy")
|
||||
X = feature_table
|
||||
# X, Y = feature_table[:, :-1], feature_table[:, -1]
|
||||
|
||||
# for i in range(X.shape[1]):
|
||||
# medium_value = np.percentile(X[:, i], 50)
|
||||
# thresholds.append(medium_value)
|
||||
# X_discrete.append(np.searchsorted([medium_value], X[:, i])[:, np.newaxis])
|
||||
for bucket_num in [2, 3, 5, 8, 12, 16]:
|
||||
thresholds = []
|
||||
for i in range(X.shape[1]):
|
||||
thres = []
|
||||
for j in range(1, bucket_num):
|
||||
thres.append(np.percentile(X[:, i], 100.0 / bucket_num * j))
|
||||
thresholds.append(thres)
|
||||
# X_discrete.append(np.searchsorted(thres, X[:, i])[:, np.newaxis])
|
||||
thresholds = np.array(thresholds)
|
||||
np.save(f"../../StyleCtrlData/jSymbolic_lib/datasets/1224_YMDB/data/threshold_{bucket_num}.npy", thresholds)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# EMOPIA_threshold()
|
||||
# TopMAGD_threshold()
|
||||
YMDB_threshold()
|
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -0,0 +1,913 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from fairseq import search, utils
|
||||
from fairseq.data import data_utils
|
||||
from fairseq.models import FairseqIncrementalDecoder
|
||||
from fairseq.models.fairseq_encoder import EncoderOut
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
|
||||
class CommandSequenceGenerator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
models,
|
||||
tgt_dict,
|
||||
beam_size=1,
|
||||
max_len_a=0,
|
||||
max_len_b=200,
|
||||
min_len=1,
|
||||
normalize_scores=True,
|
||||
len_penalty=1.0,
|
||||
unk_penalty=0.0,
|
||||
temperature=1.0,
|
||||
match_source_len=False,
|
||||
no_repeat_ngram_size=0,
|
||||
search_strategy=None,
|
||||
eos=None,
|
||||
symbols_to_strip_from_output=None,
|
||||
lm_model=None,
|
||||
lm_weight=1.0,
|
||||
):
|
||||
"""Generates translations of a given source sentence.
|
||||
|
||||
Args:
|
||||
models (List[~fairseq.models.FairseqModel]): ensemble of models,
|
||||
currently support fairseq.models.TransformerModel for scripting
|
||||
beam_size (int, optional): beam width (default: 1)
|
||||
max_len_a/b (int, optional): generate sequences of maximum length
|
||||
ax + b, where x is the source length
|
||||
min_len (int, optional): the minimum length of the generated output
|
||||
(not including end-of-sentence)
|
||||
normalize_scores (bool, optional): normalize scores by the length
|
||||
of the output (default: True)
|
||||
len_penalty (float, optional): length penalty, where <1.0 favors
|
||||
shorter, >1.0 favors longer sentences (default: 1.0)
|
||||
unk_penalty (float, optional): unknown word penalty, where <0
|
||||
produces more unks, >0 produces fewer (default: 0.0)
|
||||
temperature (float, optional): temperature, where values
|
||||
>1.0 produce more uniform samples and values <1.0 produce
|
||||
sharper samples (default: 1.0)
|
||||
match_source_len (bool, optional): outputs should match the source
|
||||
length (default: False)
|
||||
"""
|
||||
super().__init__()
|
||||
if isinstance(models, EnsembleModel):
|
||||
self.model = models
|
||||
else:
|
||||
self.model = EnsembleModel(models)
|
||||
self.tgt_dict = tgt_dict
|
||||
self.pad = tgt_dict.pad()
|
||||
self.unk = tgt_dict.unk()
|
||||
self.eos = tgt_dict.eos() if eos is None else eos
|
||||
self.symbols_to_strip_from_output = (
|
||||
symbols_to_strip_from_output.union({self.eos})
|
||||
if symbols_to_strip_from_output is not None
|
||||
else {self.eos}
|
||||
)
|
||||
self.vocab_size = len(tgt_dict)
|
||||
self.beam_size = beam_size
|
||||
# the max beam size is the dictionary size - 1, since we never select pad
|
||||
self.beam_size = min(beam_size, self.vocab_size - 1)
|
||||
self.max_len_a = max_len_a
|
||||
self.max_len_b = max_len_b
|
||||
self.min_len = min_len
|
||||
|
||||
self.normalize_scores = normalize_scores
|
||||
self.len_penalty = len_penalty
|
||||
self.unk_penalty = unk_penalty
|
||||
self.temperature = temperature
|
||||
self.match_source_len = match_source_len
|
||||
self.no_repeat_ngram_size = no_repeat_ngram_size
|
||||
assert temperature > 0, "--temperature must be greater than 0"
|
||||
|
||||
self.search = (
|
||||
search.BeamSearch(tgt_dict) if search_strategy is None else search_strategy
|
||||
)
|
||||
# We only need to set src_lengths in LengthConstrainedBeamSearch.
|
||||
# As a module attribute, setting it would break in multithread
|
||||
# settings when the model is shared.
|
||||
self.should_set_src_lengths = (
|
||||
hasattr(self.search, "needs_src_lengths") and self.search.needs_src_lengths
|
||||
)
|
||||
|
||||
self.model.eval()
|
||||
|
||||
self.lm_model = lm_model
|
||||
self.lm_weight = lm_weight
|
||||
if self.lm_model is not None:
|
||||
self.lm_model.eval()
|
||||
|
||||
def cuda(self):
|
||||
self.model.cuda()
|
||||
return self
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
sample: Dict[str, Dict[str, Tensor]],
|
||||
prefix_tokens: Optional[Tensor] = None,
|
||||
bos_token: Optional[int] = None,
|
||||
):
|
||||
"""Generate a batch of translations.
|
||||
|
||||
Args:
|
||||
sample (dict): batch
|
||||
prefix_tokens (torch.LongTensor, optional): force decoder to begin
|
||||
with these tokens
|
||||
bos_token (int, optional): beginning of sentence token
|
||||
(default: self.eos)
|
||||
"""
|
||||
return self._generate(sample, prefix_tokens, bos_token=bos_token)
|
||||
|
||||
# TODO(myleott): unused, deprecate after pytorch-translate migration
|
||||
def generate_batched_itr(self, data_itr, beam_size=None, cuda=False, timer=None):
|
||||
"""Iterate over a batched dataset and yield individual translations.
|
||||
Args:
|
||||
cuda (bool, optional): use GPU for generation
|
||||
timer (StopwatchMeter, optional): time generations
|
||||
"""
|
||||
for sample in data_itr:
|
||||
s = utils.move_to_cuda(sample) if cuda else sample
|
||||
if "net_input" not in s:
|
||||
continue
|
||||
input = s["net_input"]
|
||||
# model.forward normally channels prev_output_tokens into the decoder
|
||||
# separately, but SequenceGenerator directly calls model.encoder
|
||||
encoder_input = {
|
||||
k: v for k, v in input.items() if k != "prev_output_tokens"
|
||||
}
|
||||
if timer is not None:
|
||||
timer.start()
|
||||
with torch.no_grad():
|
||||
hypos = self.generate(encoder_input)
|
||||
if timer is not None:
|
||||
timer.stop(sum(len(h[0]["tokens"]) for h in hypos))
|
||||
for i, id in enumerate(s["id"].data):
|
||||
# remove padding
|
||||
src = utils.strip_pad(input["src_tokens"].data[i, :], self.pad)
|
||||
ref = (
|
||||
utils.strip_pad(s["target"].data[i, :], self.pad)
|
||||
if s["target"] is not None
|
||||
else None
|
||||
)
|
||||
yield id, src, ref, hypos[i]
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs):
|
||||
"""Generate translations. Match the api of other fairseq generators.
|
||||
|
||||
Args:
|
||||
models (List[~fairseq.models.FairseqModel]): ensemble of models
|
||||
sample (dict): batch
|
||||
prefix_tokens (torch.LongTensor, optional): force decoder to begin
|
||||
with these tokens
|
||||
constraints (torch.LongTensor, optional): force decoder to include
|
||||
the list of constraints
|
||||
bos_token (int, optional): beginning of sentence token
|
||||
(default: self.eos)
|
||||
"""
|
||||
return self._generate(sample, **kwargs)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
sample: Dict[str, Dict[str, Tensor]],
|
||||
prefix_tokens: Optional[Tensor] = None,
|
||||
constraints: Optional[Tensor] = None,
|
||||
bos_token: Optional[int] = None,
|
||||
):
|
||||
incremental_states = torch.jit.annotate(
|
||||
List[Dict[str, Dict[str, Optional[Tensor]]]],
|
||||
[
|
||||
torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
|
||||
for i in range(self.model.models_size)
|
||||
],
|
||||
)
|
||||
net_input = sample["net_input"]
|
||||
command_input = net_input["command_input"]
|
||||
if "src_tokens" in net_input:
|
||||
src_tokens = net_input["src_tokens"]
|
||||
# length of the source text being the character length except EndOfSentence and pad
|
||||
src_lengths = (
|
||||
(src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
|
||||
)
|
||||
elif "source" in net_input:
|
||||
src_tokens = net_input["source"]
|
||||
src_lengths = (
|
||||
net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
|
||||
if net_input["padding_mask"] is not None
|
||||
else torch.tensor(src_tokens.size(-1)).to(src_tokens)
|
||||
)
|
||||
else:
|
||||
raise Exception("expected src_tokens or source in net input")
|
||||
|
||||
# bsz: total number of sentences in beam
|
||||
# Note that src_tokens may have more than 2 dimenions (i.e. audio features)
|
||||
bsz, src_len = src_tokens.size()[:2]
|
||||
beam_size = self.beam_size
|
||||
|
||||
if constraints is not None and not self.search.supports_constraints:
|
||||
raise NotImplementedError(
|
||||
"Target-side constraints were provided, but search method doesn't support them"
|
||||
)
|
||||
|
||||
# Initialize constraints, when active
|
||||
self.search.init_constraints(constraints, beam_size)
|
||||
|
||||
max_len: int = -1
|
||||
if self.match_source_len:
|
||||
max_len = src_lengths.max().item()
|
||||
else:
|
||||
max_len = min(
|
||||
int(self.max_len_a * src_len + self.max_len_b),
|
||||
# exclude the EOS marker
|
||||
self.model.max_decoder_positions() - 1,
|
||||
)
|
||||
assert (
|
||||
self.min_len <= max_len
|
||||
), "min_len cannot be larger than max_len, please adjust these!"
|
||||
# compute the encoder output for each beam
|
||||
encoder_outs = self.model.forward_encoder(net_input)
|
||||
|
||||
# placeholder of indices for bsz * beam_size to hold tokens and accumulative scores
|
||||
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
|
||||
new_order = new_order.to(src_tokens.device).long()
|
||||
encoder_outs = self.model.reorder_encoder_out(encoder_outs, new_order)
|
||||
# ensure encoder_outs is a List.
|
||||
assert encoder_outs is not None
|
||||
|
||||
# initialize buffers
|
||||
scores = (
|
||||
torch.zeros(bsz * beam_size, max_len + 1).to(src_tokens).float()
|
||||
) # +1 for eos; pad is never chosen for scoring
|
||||
tokens = (
|
||||
torch.zeros(bsz * beam_size, max_len + 2)
|
||||
.to(src_tokens)
|
||||
.long()
|
||||
.fill_(self.pad)
|
||||
) # +2 for eos and pad
|
||||
tokens[:, 0] = self.eos if bos_token is None else bos_token
|
||||
attn: Optional[Tensor] = None
|
||||
|
||||
# A list that indicates candidates that should be ignored.
|
||||
# For example, suppose we're sampling and have already finalized 2/5
|
||||
# samples. Then cands_to_ignore would mark 2 positions as being ignored,
|
||||
# so that we only finalize the remaining 3 samples.
|
||||
cands_to_ignore = (
|
||||
torch.zeros(bsz, beam_size).to(src_tokens).eq(-1)
|
||||
) # forward and backward-compatible False mask
|
||||
|
||||
# list of completed sentences
|
||||
finalized = torch.jit.annotate(
|
||||
List[List[Dict[str, Tensor]]],
|
||||
[torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)],
|
||||
) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step
|
||||
|
||||
finished = [
|
||||
False for i in range(bsz)
|
||||
] # a boolean array indicating if the sentence at the index is finished or not
|
||||
num_remaining_sent = bsz # number of sentences remaining
|
||||
|
||||
# number of candidate hypos per step
|
||||
cand_size = 2 * beam_size # 2 x beam size in case half are EOS
|
||||
|
||||
# offset arrays for converting between different indexing schemes
|
||||
bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens)
|
||||
cand_offsets = torch.arange(0, cand_size).type_as(tokens)
|
||||
|
||||
reorder_state: Optional[Tensor] = None
|
||||
batch_idxs: Optional[Tensor] = None
|
||||
|
||||
original_batch_idxs: Optional[Tensor] = None
|
||||
if "id" in sample and isinstance(sample["id"], Tensor):
|
||||
original_batch_idxs = sample["id"]
|
||||
else:
|
||||
original_batch_idxs = torch.arange(0, bsz).type_as(tokens)
|
||||
|
||||
for step in range(max_len + 1): # one extra step for EOS marker
|
||||
# reorder decoder internal states based on the prev choice of beams
|
||||
# print(f'step: {step}')
|
||||
if reorder_state is not None:
|
||||
if batch_idxs is not None:
|
||||
# update beam indices to take into account removed sentences
|
||||
corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(
|
||||
batch_idxs
|
||||
)
|
||||
reorder_state.view(-1, beam_size).add_(
|
||||
corr.unsqueeze(-1) * beam_size
|
||||
)
|
||||
original_batch_idxs = original_batch_idxs[batch_idxs]
|
||||
self.model.reorder_incremental_state(incremental_states, reorder_state)
|
||||
encoder_outs = self.model.reorder_encoder_out(
|
||||
encoder_outs, reorder_state
|
||||
)
|
||||
|
||||
lprobs, avg_attn_scores = self.model.forward_decoder(
|
||||
tokens[:, : step + 1],
|
||||
command_input,
|
||||
encoder_outs,
|
||||
incremental_states,
|
||||
self.temperature,
|
||||
)
|
||||
|
||||
if self.lm_model is not None:
|
||||
lm_out = self.lm_model(tokens[:, : step + 1])
|
||||
probs = self.lm_model.get_normalized_probs(
|
||||
lm_out, log_probs=True, sample=None
|
||||
)
|
||||
probs = probs[:, -1, :] * self.lm_weight
|
||||
lprobs += probs
|
||||
|
||||
lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)
|
||||
|
||||
lprobs[:, self.pad] = -math.inf # never select pad
|
||||
lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
|
||||
|
||||
# handle max length constraint
|
||||
if step >= max_len:
|
||||
lprobs[:, : self.eos] = -math.inf
|
||||
lprobs[:, self.eos + 1 :] = -math.inf
|
||||
|
||||
# handle prefix tokens (possibly with different lengths)
|
||||
if (
|
||||
prefix_tokens is not None
|
||||
and step < prefix_tokens.size(1)
|
||||
and step < max_len
|
||||
):
|
||||
lprobs, tokens, scores = self._prefix_tokens(
|
||||
step, lprobs, scores, tokens, prefix_tokens, beam_size
|
||||
)
|
||||
elif step < self.min_len:
|
||||
# minimum length constraint (does not apply if using prefix_tokens)
|
||||
lprobs[:, self.eos] = -math.inf
|
||||
|
||||
# Record attention scores, only support avg_attn_scores is a Tensor
|
||||
if avg_attn_scores is not None:
|
||||
if attn is None:
|
||||
attn = torch.empty(
|
||||
bsz * beam_size, avg_attn_scores.size(1), max_len + 2
|
||||
).to(scores)
|
||||
attn[:, :, step + 1].copy_(avg_attn_scores)
|
||||
|
||||
scores = scores.type_as(lprobs)
|
||||
eos_bbsz_idx = torch.empty(0).to(
|
||||
tokens
|
||||
) # indices of hypothesis ending with eos (finished sentences)
|
||||
eos_scores = torch.empty(0).to(
|
||||
scores
|
||||
) # scores of hypothesis ending with eos (finished sentences)
|
||||
|
||||
if self.should_set_src_lengths:
|
||||
self.search.set_src_lengths(src_lengths)
|
||||
|
||||
if self.no_repeat_ngram_size > 0:
|
||||
lprobs = self._no_repeat_ngram(tokens, lprobs, bsz, beam_size, step)
|
||||
|
||||
# Shape: (batch, cand_size)
|
||||
cand_scores, cand_indices, cand_beams = self.search.step(
|
||||
step,
|
||||
lprobs.view(bsz, -1, self.vocab_size),
|
||||
scores.view(bsz, beam_size, -1)[:, :, :step],
|
||||
tokens[:, : step + 1],
|
||||
original_batch_idxs,
|
||||
)
|
||||
|
||||
# cand_bbsz_idx contains beam indices for the top candidate
|
||||
# hypotheses, with a range of values: [0, bsz*beam_size),
|
||||
# and dimensions: [bsz, cand_size]
|
||||
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
|
||||
|
||||
# finalize hypotheses that end in eos
|
||||
# Shape of eos_mask: (batch size, beam size)
|
||||
eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf)
|
||||
eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask)
|
||||
|
||||
# only consider eos when it's among the top beam_size indices
|
||||
# Now we know what beam item(s) to finish
|
||||
# Shape: 1d list of absolute-numbered
|
||||
eos_bbsz_idx = torch.masked_select(
|
||||
cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]
|
||||
)
|
||||
|
||||
finalized_sents: List[int] = []
|
||||
if eos_bbsz_idx.numel() > 0:
|
||||
eos_scores = torch.masked_select(
|
||||
cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size]
|
||||
)
|
||||
|
||||
finalized_sents = self.finalize_hypos(
|
||||
step,
|
||||
eos_bbsz_idx,
|
||||
eos_scores,
|
||||
tokens,
|
||||
scores,
|
||||
finalized,
|
||||
finished,
|
||||
beam_size,
|
||||
attn,
|
||||
src_lengths,
|
||||
max_len,
|
||||
)
|
||||
num_remaining_sent -= len(finalized_sents)
|
||||
|
||||
assert num_remaining_sent >= 0
|
||||
if num_remaining_sent == 0:
|
||||
break
|
||||
if self.search.stop_on_max_len and step >= max_len:
|
||||
break
|
||||
assert step < max_len
|
||||
|
||||
# Remove finalized sentences (ones for which {beam_size}
|
||||
# finished hypotheses have been generated) from the batch.
|
||||
if len(finalized_sents) > 0:
|
||||
new_bsz = bsz - len(finalized_sents)
|
||||
|
||||
# construct batch_idxs which holds indices of batches to keep for the next pass
|
||||
batch_mask = torch.ones(
|
||||
bsz, dtype=torch.bool, device=cand_indices.device
|
||||
)
|
||||
batch_mask[finalized_sents] = False
|
||||
# TODO replace `nonzero(as_tuple=False)` after TorchScript supports it
|
||||
batch_idxs = torch.arange(
|
||||
bsz, device=cand_indices.device
|
||||
).masked_select(batch_mask)
|
||||
|
||||
# Choose the subset of the hypothesized constraints that will continue
|
||||
self.search.prune_sentences(batch_idxs)
|
||||
|
||||
eos_mask = eos_mask[batch_idxs]
|
||||
cand_beams = cand_beams[batch_idxs]
|
||||
bbsz_offsets.resize_(new_bsz, 1)
|
||||
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
|
||||
cand_scores = cand_scores[batch_idxs]
|
||||
cand_indices = cand_indices[batch_idxs]
|
||||
|
||||
command_input = command_input[batch_idxs]
|
||||
|
||||
if prefix_tokens is not None:
|
||||
prefix_tokens = prefix_tokens[batch_idxs]
|
||||
src_lengths = src_lengths[batch_idxs]
|
||||
cands_to_ignore = cands_to_ignore[batch_idxs]
|
||||
|
||||
scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
|
||||
tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
|
||||
if attn is not None:
|
||||
attn = attn.view(bsz, -1)[batch_idxs].view(
|
||||
new_bsz * beam_size, attn.size(1), -1
|
||||
)
|
||||
bsz = new_bsz
|
||||
else:
|
||||
batch_idxs = None
|
||||
|
||||
# Set active_mask so that values > cand_size indicate eos hypos
|
||||
# and values < cand_size indicate candidate active hypos.
|
||||
# After, the min values per row are the top candidate active hypos
|
||||
|
||||
# Rewrite the operator since the element wise or is not supported in torchscript.
|
||||
|
||||
eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size]))
|
||||
active_mask = torch.add(
|
||||
eos_mask.type_as(cand_offsets) * cand_size,
|
||||
cand_offsets[: eos_mask.size(1)],
|
||||
)
|
||||
|
||||
# get the top beam_size active hypotheses, which are just
|
||||
# the hypos with the smallest values in active_mask.
|
||||
# {active_hypos} indicates which {beam_size} hypotheses
|
||||
# from the list of {2 * beam_size} candidates were
|
||||
# selected. Shapes: (batch size, beam size)
|
||||
new_cands_to_ignore, active_hypos = torch.topk(
|
||||
active_mask, k=beam_size, dim=1, largest=False
|
||||
)
|
||||
|
||||
# update cands_to_ignore to ignore any finalized hypos.
|
||||
cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size]
|
||||
# Make sure there is at least one active item for each sentence in the batch.
|
||||
assert (~cands_to_ignore).any(dim=1).all()
|
||||
|
||||
# update cands_to_ignore to ignore any finalized hypos
|
||||
|
||||
# {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam
|
||||
# can be selected more than once).
|
||||
active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos)
|
||||
active_scores = torch.gather(cand_scores, dim=1, index=active_hypos)
|
||||
|
||||
active_bbsz_idx = active_bbsz_idx.view(-1)
|
||||
active_scores = active_scores.view(-1)
|
||||
|
||||
# copy tokens and scores for active hypotheses
|
||||
|
||||
# Set the tokens for each beam (can select the same row more than once)
|
||||
tokens[:, : step + 1] = torch.index_select(
|
||||
tokens[:, : step + 1], dim=0, index=active_bbsz_idx
|
||||
)
|
||||
# Select the next token for each of them
|
||||
tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather(
|
||||
cand_indices, dim=1, index=active_hypos
|
||||
)
|
||||
if step > 0:
|
||||
scores[:, :step] = torch.index_select(
|
||||
scores[:, :step], dim=0, index=active_bbsz_idx
|
||||
)
|
||||
scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather(
|
||||
cand_scores, dim=1, index=active_hypos
|
||||
)
|
||||
|
||||
# Update constraints based on which candidates were selected for the next beam
|
||||
self.search.update_constraints(active_hypos)
|
||||
|
||||
# copy attention for active hypotheses
|
||||
if attn is not None:
|
||||
attn[:, :, : step + 2] = torch.index_select(
|
||||
attn[:, :, : step + 2], dim=0, index=active_bbsz_idx
|
||||
)
|
||||
|
||||
# reorder incremental state in decoder
|
||||
reorder_state = active_bbsz_idx
|
||||
|
||||
# sort by score descending
|
||||
for sent in range(len(finalized)):
|
||||
scores = torch.tensor(
|
||||
[float(elem["score"].item()) for elem in finalized[sent]]
|
||||
)
|
||||
_, sorted_scores_indices = torch.sort(scores, descending=True)
|
||||
finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices]
|
||||
finalized[sent] = torch.jit.annotate(
|
||||
List[Dict[str, Tensor]], finalized[sent]
|
||||
)
|
||||
return finalized
|
||||
|
||||
def _prefix_tokens(
|
||||
self, step: int, lprobs, scores, tokens, prefix_tokens, beam_size: int
|
||||
):
|
||||
"""Handle prefix tokens"""
|
||||
prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1)
|
||||
prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1))
|
||||
prefix_mask = prefix_toks.ne(self.pad)
|
||||
lprobs[prefix_mask] = torch.tensor(-math.inf).to(lprobs)
|
||||
lprobs[prefix_mask] = lprobs[prefix_mask].scatter(
|
||||
-1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask]
|
||||
)
|
||||
# if prefix includes eos, then we should make sure tokens and
|
||||
# scores are the same across all beams
|
||||
eos_mask = prefix_toks.eq(self.eos)
|
||||
if eos_mask.any():
|
||||
# validate that the first beam matches the prefix
|
||||
first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[
|
||||
:, 0, 1 : step + 1
|
||||
]
|
||||
eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0]
|
||||
target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step]
|
||||
assert (first_beam == target_prefix).all()
|
||||
|
||||
# copy tokens, scores and lprobs from the first beam to all beams
|
||||
tokens = self.replicate_first_beam(tokens, eos_mask_batch_dim, beam_size)
|
||||
scores = self.replicate_first_beam(scores, eos_mask_batch_dim, beam_size)
|
||||
lprobs = self.replicate_first_beam(lprobs, eos_mask_batch_dim, beam_size)
|
||||
return lprobs, tokens, scores
|
||||
|
||||
def replicate_first_beam(self, tensor, mask, beam_size: int):
|
||||
tensor = tensor.view(-1, beam_size, tensor.size(-1))
|
||||
tensor[mask] = tensor[mask][:, :1, :]
|
||||
return tensor.view(-1, tensor.size(-1))
|
||||
|
||||
def finalize_hypos(
|
||||
self,
|
||||
step: int,
|
||||
bbsz_idx,
|
||||
eos_scores,
|
||||
tokens,
|
||||
scores,
|
||||
finalized: List[List[Dict[str, Tensor]]],
|
||||
finished: List[bool],
|
||||
beam_size: int,
|
||||
attn: Optional[Tensor],
|
||||
src_lengths,
|
||||
max_len: int,
|
||||
):
|
||||
"""Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly.
|
||||
A sentence is finalized when {beam_size} finished items have been collected for it.
|
||||
|
||||
Returns number of sentences (not beam items) being finalized.
|
||||
These will be removed from the batch and not processed further.
|
||||
Args:
|
||||
bbsz_idx (Tensor):
|
||||
"""
|
||||
assert bbsz_idx.numel() == eos_scores.numel()
|
||||
|
||||
# clone relevant token and attention tensors.
|
||||
# tokens is (batch * beam, max_len). So the index_select
|
||||
# gets the newly EOS rows, then selects cols 1..{step + 2}
|
||||
tokens_clone = tokens.index_select(0, bbsz_idx)[
|
||||
:, 1 : step + 2
|
||||
] # skip the first index, which is EOS
|
||||
|
||||
tokens_clone[:, step] = self.eos
|
||||
attn_clone = (
|
||||
attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2]
|
||||
if attn is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# compute scores per token position
|
||||
pos_scores = scores.index_select(0, bbsz_idx)[:, : step + 1]
|
||||
pos_scores[:, step] = eos_scores
|
||||
# convert from cumulative to per-position scores
|
||||
pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]
|
||||
|
||||
# normalize sentence-level scores
|
||||
if self.normalize_scores:
|
||||
eos_scores /= (step + 1) ** self.len_penalty
|
||||
|
||||
# cum_unfin records which sentences in the batch are finished.
|
||||
# It helps match indexing between (a) the original sentences
|
||||
# in the batch and (b) the current, possibly-reduced set of
|
||||
# sentences.
|
||||
cum_unfin: List[int] = []
|
||||
prev = 0
|
||||
for f in finished:
|
||||
if f:
|
||||
prev += 1
|
||||
else:
|
||||
cum_unfin.append(prev)
|
||||
|
||||
# set() is not supported in script export
|
||||
|
||||
# The keys here are of the form "{sent}_{unfin_idx}", where
|
||||
# "unfin_idx" is the index in the current (possibly reduced)
|
||||
# list of sentences, and "sent" is the index in the original,
|
||||
# unreduced batch
|
||||
sents_seen: Dict[str, Optional[Tensor]] = {}
|
||||
|
||||
# For every finished beam item
|
||||
for i in range(bbsz_idx.size()[0]):
|
||||
idx = bbsz_idx[i]
|
||||
score = eos_scores[i]
|
||||
# sentence index in the current (possibly reduced) batch
|
||||
unfin_idx = idx // beam_size
|
||||
# sentence index in the original (unreduced) batch
|
||||
sent = unfin_idx + cum_unfin[unfin_idx]
|
||||
# print(f"{step} FINISHED {idx} {score} {sent}={unfin_idx} {cum_unfin}")
|
||||
# Cannot create dict for key type '(int, int)' in torchscript.
|
||||
# The workaround is to cast int to string
|
||||
seen = str(sent.item()) + "_" + str(unfin_idx.item())
|
||||
if seen not in sents_seen:
|
||||
sents_seen[seen] = None
|
||||
|
||||
if self.match_source_len and step > src_lengths[unfin_idx]:
|
||||
score = torch.tensor(-math.inf).to(score)
|
||||
|
||||
# An input sentence (among those in a batch) is finished when
|
||||
# beam_size hypotheses have been collected for it
|
||||
if len(finalized[sent]) < beam_size:
|
||||
if attn_clone is not None:
|
||||
# remove padding tokens from attn scores
|
||||
hypo_attn = attn_clone[i]
|
||||
else:
|
||||
hypo_attn = torch.empty(0)
|
||||
|
||||
finalized[sent].append(
|
||||
{
|
||||
"tokens": tokens_clone[i],
|
||||
"score": score,
|
||||
"attention": hypo_attn, # src_len x tgt_len
|
||||
"alignment": torch.empty(0),
|
||||
"positional_scores": pos_scores[i],
|
||||
}
|
||||
)
|
||||
|
||||
newly_finished: List[int] = []
|
||||
|
||||
for seen in sents_seen.keys():
|
||||
# check termination conditions for this sentence
|
||||
sent: int = int(float(seen.split("_")[0]))
|
||||
unfin_idx: int = int(float(seen.split("_")[1]))
|
||||
|
||||
if not finished[sent] and self.is_finished(
|
||||
step, unfin_idx, max_len, len(finalized[sent]), beam_size
|
||||
):
|
||||
finished[sent] = True
|
||||
newly_finished.append(unfin_idx)
|
||||
|
||||
return newly_finished
|
||||
|
||||
def is_finished(
|
||||
self,
|
||||
step: int,
|
||||
unfin_idx: int,
|
||||
max_len: int,
|
||||
finalized_sent_len: int,
|
||||
beam_size: int,
|
||||
):
|
||||
"""
|
||||
Check whether decoding for a sentence is finished, which
|
||||
occurs when the list of finalized sentences has reached the
|
||||
beam size, or when we reach the maximum length.
|
||||
"""
|
||||
assert finalized_sent_len <= beam_size
|
||||
if finalized_sent_len == beam_size or step == max_len:
|
||||
return True
|
||||
return False
|
||||
|
||||
def calculate_banned_tokens(
|
||||
self,
|
||||
tokens,
|
||||
step: int,
|
||||
gen_ngrams: List[Dict[str, List[int]]],
|
||||
no_repeat_ngram_size: int,
|
||||
bbsz_idx: int,
|
||||
):
|
||||
tokens_list: List[int] = tokens[
|
||||
bbsz_idx, step + 2 - no_repeat_ngram_size : step + 1
|
||||
].tolist()
|
||||
# before decoding the next token, prevent decoding of ngrams that have already appeared
|
||||
ngram_index = ",".join([str(x) for x in tokens_list])
|
||||
return gen_ngrams[bbsz_idx].get(ngram_index, torch.jit.annotate(List[int], []))
|
||||
|
||||
def transpose_list(self, l: List[List[int]]):
|
||||
# GeneratorExp aren't supported in TS so ignoring the lint
|
||||
min_len = min([len(x) for x in l]) # noqa
|
||||
l2 = [[row[i] for row in l] for i in range(min_len)]
|
||||
return l2
|
||||
|
||||
def _no_repeat_ngram(self, tokens, lprobs, bsz: int, beam_size: int, step: int):
|
||||
# for each beam and batch sentence, generate a list of previous ngrams
|
||||
gen_ngrams: List[Dict[str, List[int]]] = [
|
||||
torch.jit.annotate(Dict[str, List[int]], {})
|
||||
for bbsz_idx in range(bsz * beam_size)
|
||||
]
|
||||
cpu_tokens = tokens.cpu()
|
||||
for bbsz_idx in range(bsz * beam_size):
|
||||
gen_tokens: List[int] = cpu_tokens[bbsz_idx].tolist()
|
||||
for ngram in self.transpose_list(
|
||||
[gen_tokens[i:] for i in range(self.no_repeat_ngram_size)]
|
||||
):
|
||||
key = ",".join([str(x) for x in ngram[:-1]])
|
||||
gen_ngrams[bbsz_idx][key] = gen_ngrams[bbsz_idx].get(
|
||||
key, torch.jit.annotate(List[int], [])
|
||||
) + [ngram[-1]]
|
||||
|
||||
if step + 2 - self.no_repeat_ngram_size >= 0:
|
||||
# no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
|
||||
banned_tokens = [
|
||||
self.calculate_banned_tokens(
|
||||
tokens, step, gen_ngrams, self.no_repeat_ngram_size, bbsz_idx
|
||||
)
|
||||
for bbsz_idx in range(bsz * beam_size)
|
||||
]
|
||||
else:
|
||||
banned_tokens = [
|
||||
torch.jit.annotate(List[int], []) for bbsz_idx in range(bsz * beam_size)
|
||||
]
|
||||
for bbsz_idx in range(bsz * beam_size):
|
||||
lprobs[bbsz_idx][
|
||||
torch.tensor(banned_tokens[bbsz_idx]).long()
|
||||
] = torch.tensor(-math.inf).to(lprobs)
|
||||
return lprobs
|
||||
|
||||
class EnsembleModel(nn.Module):
|
||||
"""A wrapper around an ensemble of models."""
|
||||
|
||||
def __init__(self, models):
|
||||
super().__init__()
|
||||
self.models_size = len(models)
|
||||
# method '__len__' is not supported in ModuleList for torch script
|
||||
self.single_model = models[0]
|
||||
self.models = nn.ModuleList(models)
|
||||
|
||||
self.has_incremental: bool = False
|
||||
if all(
|
||||
hasattr(m, "decoder") and isinstance(m.decoder, FairseqIncrementalDecoder)
|
||||
for m in models
|
||||
):
|
||||
self.has_incremental = True
|
||||
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
def has_encoder(self):
|
||||
return hasattr(self.single_model, "encoder")
|
||||
|
||||
def has_incremental_states(self):
|
||||
return self.has_incremental
|
||||
|
||||
def max_decoder_positions(self):
|
||||
return min([m.max_decoder_positions() for m in self.models])
|
||||
|
||||
@torch.jit.export
|
||||
def forward_encoder(self, net_input: Dict[str, Tensor]):
|
||||
if not self.has_encoder():
|
||||
return None
|
||||
return [model.encoder.forward_torchscript(net_input) for model in self.models]
|
||||
|
||||
@torch.jit.export
|
||||
def forward_decoder(
|
||||
self,
|
||||
tokens,
|
||||
command_input,
|
||||
encoder_outs: List[EncoderOut],
|
||||
incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
|
||||
temperature: float = 1.0,
|
||||
):
|
||||
log_probs = []
|
||||
avg_attn: Optional[Tensor] = None
|
||||
encoder_out: Optional[EncoderOut] = None
|
||||
for i, model in enumerate(self.models):
|
||||
if self.has_encoder():
|
||||
encoder_out = encoder_outs[i]
|
||||
# decode each model
|
||||
if self.has_incremental_states():
|
||||
decoder_out = model.decoder.forward(
|
||||
tokens,
|
||||
command_input,
|
||||
encoder_out=encoder_out,
|
||||
incremental_state=incremental_states[i],
|
||||
)
|
||||
else:
|
||||
decoder_out = model.decoder.forward(tokens, command_input, encoder_out=encoder_out)
|
||||
|
||||
attn: Optional[Tensor] = None
|
||||
decoder_len = len(decoder_out)
|
||||
if decoder_len > 1 and decoder_out[1] is not None:
|
||||
if isinstance(decoder_out[1], Tensor):
|
||||
attn = decoder_out[1]
|
||||
else:
|
||||
attn_holder = decoder_out[1]["attn"]
|
||||
if isinstance(attn_holder, Tensor):
|
||||
attn = attn_holder
|
||||
elif attn_holder is not None:
|
||||
attn = attn_holder[0]
|
||||
if attn is not None:
|
||||
attn = attn[:, -1, :]
|
||||
|
||||
decoder_out_tuple = (
|
||||
decoder_out[0][:, -1:, :].div_(temperature),
|
||||
None if decoder_len <= 1 else decoder_out[1],
|
||||
)
|
||||
|
||||
probs = model.get_normalized_probs(
|
||||
decoder_out_tuple, log_probs=True, sample=None
|
||||
)
|
||||
probs = probs[:, -1, :]
|
||||
if self.models_size == 1:
|
||||
return probs, attn
|
||||
|
||||
log_probs.append(probs)
|
||||
if attn is not None:
|
||||
if avg_attn is None:
|
||||
avg_attn = attn
|
||||
else:
|
||||
avg_attn.add_(attn)
|
||||
|
||||
avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log(
|
||||
self.models_size
|
||||
)
|
||||
|
||||
if avg_attn is not None:
|
||||
avg_attn.div_(self.models_size)
|
||||
return avg_probs, avg_attn
|
||||
|
||||
@torch.jit.export
|
||||
def reorder_encoder_out(self, encoder_outs: Optional[List[EncoderOut]], new_order):
|
||||
"""
|
||||
Reorder encoder output according to *new_order*.
|
||||
|
||||
Args:
|
||||
encoder_out: output from the ``forward()`` method
|
||||
new_order (LongTensor): desired order
|
||||
|
||||
Returns:
|
||||
*encoder_out* rearranged according to *new_order*
|
||||
"""
|
||||
new_outs: List[EncoderOut] = []
|
||||
if not self.has_encoder():
|
||||
return new_outs
|
||||
for i, model in enumerate(self.models):
|
||||
assert encoder_outs is not None
|
||||
new_outs.append(
|
||||
model.encoder.reorder_encoder_out(encoder_outs[i], new_order)
|
||||
)
|
||||
return new_outs
|
||||
|
||||
@torch.jit.export
|
||||
def reorder_incremental_state(
|
||||
self,
|
||||
incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
|
||||
new_order,
|
||||
):
|
||||
if not self.has_incremental_states():
|
||||
return
|
||||
for i, model in enumerate(self.models):
|
||||
model.decoder.reorder_incremental_state_scripting(
|
||||
incremental_states[i], new_order
|
||||
)
|
||||
|
||||
|
|
@ -0,0 +1,330 @@
|
|||
from fairseq.data.base_wrapper_dataset import BaseWrapperDataset
|
||||
import numpy as np
|
||||
from fairseq.data import data_utils
|
||||
from fairseq.tasks.language_modeling import LanguageModelingTask, LanguageModelingConfig
|
||||
from fairseq.tasks import register_task
|
||||
|
||||
import logging
|
||||
from .linear import transformer_lm
|
||||
|
||||
from fairseq import search
|
||||
|
||||
from fairseq.models.transformer_lm import (
|
||||
TransformerLanguageModel,
|
||||
TransformerLanguageModelConfig,
|
||||
base_lm_architecture,
|
||||
transformer_lm_gpt,
|
||||
DEFAULT_MAX_TARGET_POSITIONS
|
||||
)
|
||||
|
||||
from fairseq.models import (
|
||||
register_model,
|
||||
register_model_architecture,
|
||||
)
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
|
||||
class CommandDataset(BaseWrapperDataset):
|
||||
def __init__(self, dataset, command_data, args = None):
|
||||
super().__init__(dataset)
|
||||
self._sizes = self.dataset.sizes.copy()
|
||||
self.command_data = command_data # need change to np.mmap
|
||||
self.args = args
|
||||
|
||||
def __getitem__(self, index):
|
||||
sample = self.dataset[index]
|
||||
assert len(sample["source"]) <= self.args.truncated_length + 2, f"The maximum length exceeds {self.args.truncated_length}. Please resample the dataset."
|
||||
return {
|
||||
"id": index,
|
||||
"source": sample["source"],
|
||||
"target": sample["target"],
|
||||
"command": torch.from_numpy(np.array(self.command_data[index])).to(sample["source"].device)
|
||||
}
|
||||
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self._sizes
|
||||
|
||||
def size(self, index):
|
||||
return self._sizes[index]
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self._sizes[index]
|
||||
|
||||
def filter_indices_by_size(self, indices, max_sizes):
|
||||
"""
|
||||
Filter a list of sample indices. Remove those that are longer than
|
||||
specified in *max_sizes*.
|
||||
|
||||
WARNING: don't update, override method in child classes
|
||||
|
||||
Args:
|
||||
indices (np.array): original array of sample indices
|
||||
max_sizes (int or list[int] or tuple[int]): max sample size,
|
||||
can be defined separately for src and tgt (then list or tuple)
|
||||
|
||||
Returns:
|
||||
np.array: filtered sample array
|
||||
list: list of removed indices
|
||||
"""
|
||||
if isinstance(max_sizes, float) or isinstance(max_sizes, int):
|
||||
if hasattr(self, "sizes") and isinstance(self.sizes, np.ndarray):
|
||||
ignored = indices[self.sizes[indices] > max_sizes].tolist()
|
||||
indices = indices[self.sizes[indices] <= max_sizes]
|
||||
elif (
|
||||
hasattr(self, "sizes")
|
||||
and isinstance(self.sizes, list)
|
||||
and len(self.sizes) == 1
|
||||
):
|
||||
ignored = indices[self.sizes[0][indices] > max_sizes].tolist()
|
||||
indices = indices[self.sizes[0][indices] <= max_sizes]
|
||||
else:
|
||||
indices, ignored = data_utils._filter_by_size_dynamic(
|
||||
indices, self.size, max_sizes
|
||||
)
|
||||
else:
|
||||
indices, ignored = data_utils._filter_by_size_dynamic(
|
||||
indices, self.size, max_sizes
|
||||
)
|
||||
if len(ignored) > 0:
|
||||
print(self.sizes)
|
||||
print(ignored)
|
||||
print(max_sizes)
|
||||
return indices, ignored
|
||||
|
||||
def collater(self, samples):
|
||||
# samples: list->dict, just as __get_item__(index) rets
|
||||
# print("samples type:", type(samples))
|
||||
|
||||
return self.collate_helper(samples, self.dataset.vocab.pad(), self.dataset.vocab.eos())
|
||||
def collate_helper(self, samples, pad_idx, eos_idx):
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
def merge(key, is_list=False):
|
||||
if is_list:
|
||||
res = []
|
||||
for i in range(len(samples[0][key])):
|
||||
res.append(
|
||||
data_utils.collate_tokens(
|
||||
[s[key][i] for s in samples],
|
||||
pad_idx,
|
||||
eos_idx,
|
||||
left_pad=False,
|
||||
)
|
||||
)
|
||||
return res
|
||||
else:
|
||||
return data_utils.collate_tokens(
|
||||
[s[key] for s in samples],
|
||||
pad_idx,
|
||||
eos_idx,
|
||||
left_pad=False,
|
||||
)
|
||||
|
||||
src_tokens = merge("source")
|
||||
if samples[0]["command"] is not None:
|
||||
command_tokens = merge("command")
|
||||
else:
|
||||
command_tokens = None
|
||||
if samples[0]["target"] is not None:
|
||||
is_target_list = isinstance(samples[0]["target"], list)
|
||||
target = merge("target", is_target_list)
|
||||
else:
|
||||
target = src_tokens
|
||||
|
||||
return {
|
||||
"id": torch.LongTensor([s["id"] for s in samples]),
|
||||
"nsentences": len(samples),
|
||||
"ntokens": sum(len(s["source"]) for s in samples),
|
||||
"net_input": {
|
||||
"src_tokens": src_tokens,
|
||||
"command_input": command_tokens,
|
||||
"src_lengths": torch.LongTensor([s["source"].numel() for s in samples]),
|
||||
},
|
||||
"target": target,
|
||||
}
|
||||
|
||||
|
||||
@register_task("language_modeling_control", dataclass=LanguageModelingConfig)
|
||||
class LanguageModelingTaskWithControl(LanguageModelingTask):
|
||||
@classmethod
|
||||
def add_args(cls, parser):
|
||||
super().add_args(parser)
|
||||
parser.add_argument("--command_in_dim", type=int)
|
||||
parser.add_argument("--command_out_dim", type=int)
|
||||
parser.add_argument("--truncated_length", type=int, default=8192)
|
||||
parser.add_argument("--feature_num", type=int, default=3)
|
||||
parser.add_argument("--control_mode", type=str)
|
||||
parser.add_argument("--command_path", type=str)
|
||||
parser.add_argument("--bucket_num", type=int)
|
||||
|
||||
|
||||
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
||||
super().load_dataset(split, epoch=epoch, combine=combine, **kwargs)
|
||||
command_dataset = np.load(f"{self.args.command_path}/{split}_command.npy", mmap_mode="r")
|
||||
assert command_dataset.shape[0] == len(self.datasets[split]), f"error command sample num for {split}!"
|
||||
assert command_dataset.shape[1] == self.args.feature_num, "command feature_num isn't the same as args feature_num"
|
||||
logger.info(f'Load CommandSourceTargetDataset for {split} from {self.args.command_path}, truncated length: {self.args.truncated_length}')
|
||||
self.datasets[split] = CommandDataset(self.datasets[split], command_dataset, self.args)
|
||||
|
||||
|
||||
def get_batch_iterator(
|
||||
self,
|
||||
dataset,
|
||||
max_tokens=None,
|
||||
max_sentences=None,
|
||||
max_positions=None,
|
||||
ignore_invalid_inputs=False,
|
||||
required_batch_size_multiple=1,
|
||||
seed=1,
|
||||
num_shards=1,
|
||||
shard_id=0,
|
||||
num_workers=0,
|
||||
epoch=1,
|
||||
data_buffer_size=0,
|
||||
disable_iterator_cache=False,
|
||||
):
|
||||
split = None
|
||||
if 'train' in self.datasets and dataset == self.datasets['train']:
|
||||
split = 'train'
|
||||
elif 'valid' in self.datasets and dataset == self.datasets['valid']:
|
||||
split = 'valid'
|
||||
elif 'test' in self.datasets and dataset == self.datasets['test']:
|
||||
split = 'test'
|
||||
|
||||
max_positions_split = getattr(self.args, 'max_positions_%s' % split, None)
|
||||
if max_positions_split is None:
|
||||
max_positions_split = getattr(self.args, 'truncate_%s' % split, None)
|
||||
if max_positions_split is not None:
|
||||
max_positions = max_positions_split
|
||||
logger.info('Using max_positions limit (%d) for %s' % (max_positions,
|
||||
split if split is not None else 'unknown'))
|
||||
|
||||
return super().get_batch_iterator(
|
||||
dataset,
|
||||
max_tokens=max_tokens,
|
||||
max_sentences=max_sentences,
|
||||
max_positions=max_positions,
|
||||
ignore_invalid_inputs=ignore_invalid_inputs,
|
||||
required_batch_size_multiple=required_batch_size_multiple,
|
||||
seed=seed,
|
||||
num_shards=num_shards,
|
||||
shard_id=shard_id,
|
||||
num_workers=num_workers,
|
||||
epoch=epoch,
|
||||
data_buffer_size=data_buffer_size,
|
||||
disable_iterator_cache=disable_iterator_cache
|
||||
)
|
||||
def build_generator(
|
||||
self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None
|
||||
):
|
||||
if getattr(args, "score_reference", False):
|
||||
from fairseq.sequence_scorer import SequenceScorer
|
||||
|
||||
return SequenceScorer(
|
||||
self.target_dictionary,
|
||||
compute_alignment=getattr(args, "print_alignment", False),
|
||||
)
|
||||
|
||||
from .command_seq_generator import CommandSequenceGenerator
|
||||
|
||||
# Choose search strategy. Defaults to Beam Search.
|
||||
sampling = getattr(args, "sampling", False)
|
||||
sampling_topk = getattr(args, "sampling_topk", -1)
|
||||
sampling_topp = getattr(args, "sampling_topp", -1.0)
|
||||
diverse_beam_groups = getattr(args, "diverse_beam_groups", -1)
|
||||
diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5)
|
||||
match_source_len = getattr(args, "match_source_len", False)
|
||||
diversity_rate = getattr(args, "diversity_rate", -1)
|
||||
constrained = getattr(args, "constraints", False)
|
||||
prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None)
|
||||
if (
|
||||
sum(
|
||||
int(cond)
|
||||
for cond in [
|
||||
sampling,
|
||||
diverse_beam_groups > 0,
|
||||
match_source_len,
|
||||
diversity_rate > 0,
|
||||
]
|
||||
)
|
||||
> 1
|
||||
):
|
||||
raise ValueError("Provided Search parameters are mutually exclusive.")
|
||||
assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling"
|
||||
assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling"
|
||||
|
||||
if sampling:
|
||||
search_strategy = search.Sampling(
|
||||
self.target_dictionary, sampling_topk, sampling_topp
|
||||
)
|
||||
elif diverse_beam_groups > 0:
|
||||
search_strategy = search.DiverseBeamSearch(
|
||||
self.target_dictionary, diverse_beam_groups, diverse_beam_strength
|
||||
)
|
||||
elif match_source_len:
|
||||
# this is useful for tagging applications where the output
|
||||
# length should match the input length, so we hardcode the
|
||||
# length constraints for simplicity
|
||||
search_strategy = search.LengthConstrainedBeamSearch(
|
||||
self.target_dictionary,
|
||||
min_len_a=1,
|
||||
min_len_b=0,
|
||||
max_len_a=1,
|
||||
max_len_b=0,
|
||||
)
|
||||
elif diversity_rate > -1:
|
||||
search_strategy = search.DiverseSiblingsSearch(
|
||||
self.target_dictionary, diversity_rate
|
||||
)
|
||||
elif constrained:
|
||||
search_strategy = search.LexicallyConstrainedBeamSearch(
|
||||
self.target_dictionary, args.constraints
|
||||
)
|
||||
elif prefix_allowed_tokens_fn:
|
||||
search_strategy = search.PrefixConstrainedBeamSearch(
|
||||
self.target_dictionary, prefix_allowed_tokens_fn
|
||||
)
|
||||
else:
|
||||
search_strategy = search.BeamSearch(self.target_dictionary)
|
||||
|
||||
if seq_gen_cls is None:
|
||||
if getattr(args, "print_alignment", False):
|
||||
raise ImportError("SequenceGeneratorWithAlignment is not allowed!")
|
||||
# seq_gen_cls = SequenceGeneratorWithAlignment
|
||||
else:
|
||||
seq_gen_cls = CommandSequenceGenerator
|
||||
extra_gen_cls_kwargs = extra_gen_cls_kwargs or {}
|
||||
return seq_gen_cls(
|
||||
models,
|
||||
self.target_dictionary,
|
||||
beam_size=getattr(args, "beam", 5),
|
||||
max_len_a=getattr(args, "max_len_a", 0),
|
||||
max_len_b=getattr(args, "max_len_b", 200),
|
||||
min_len=getattr(args, "min_len", 1),
|
||||
normalize_scores=(not getattr(args, "unnormalized", False)),
|
||||
len_penalty=getattr(args, "lenpen", 1),
|
||||
unk_penalty=getattr(args, "unkpen", 0),
|
||||
temperature=getattr(args, "temperature", 1.0),
|
||||
match_source_len=getattr(args, "match_source_len", False),
|
||||
no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
|
||||
search_strategy=search_strategy,
|
||||
**extra_gen_cls_kwargs,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
Двоичный файл не отображается.
Двоичные данные
emogen/linear_decoder/linear/__pycache__/causal_linear_attention.cpython-38.pyc
Normal file
Двоичные данные
emogen/linear_decoder/linear/__pycache__/causal_linear_attention.cpython-38.pyc
Normal file
Двоичный файл не отображается.
Двоичный файл не отображается.
Двоичные данные
emogen/linear_decoder/linear/__pycache__/transformer_layer.cpython-38.pyc
Normal file
Двоичные данные
emogen/linear_decoder/linear/__pycache__/transformer_layer.cpython-38.pyc
Normal file
Двоичный файл не отображается.
Двоичный файл не отображается.
|
@ -0,0 +1,123 @@
|
|||
#
|
||||
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
||||
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
||||
# Apoorv Vyas <avyas@idiap.ch>
|
||||
#
|
||||
|
||||
"""The base attention layer performs all the query key value projections and
|
||||
output projections leaving the implementation of the attention to the inner
|
||||
attention module.
|
||||
|
||||
The transformer layers, however, are agnostic of the attention implementation
|
||||
and any layer that implements the same interface can substitute for the
|
||||
attention layer.
|
||||
"""
|
||||
import math
|
||||
import torch.nn as nn
|
||||
from torch.nn import Linear, Module
|
||||
|
||||
|
||||
class AttentionLayer(Module):
|
||||
"""Implement the attention layer. Namely project the inputs to multi-head
|
||||
queries, keys and values, call the attention implementation and then
|
||||
reproject the output.
|
||||
|
||||
It can be thought of as a decorator (see decorator design patter) of an
|
||||
attention layer.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
attention: Specific inner attention implementation that just computes a
|
||||
weighted average of values given a similarity of queries and
|
||||
keys.
|
||||
d_model: The input feature dimensionality
|
||||
n_heads: The number of heads for the multi head attention
|
||||
d_keys: The dimensionality of the keys/queries
|
||||
(default: d_model/n_heads)
|
||||
d_values: The dimensionality of the values (default: d_model/n_heads)
|
||||
event_dispatcher: str or EventDispatcher instance to be used by this
|
||||
module for dispatching events (default: the default
|
||||
global dispatcher)
|
||||
"""
|
||||
def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None):
|
||||
super(AttentionLayer, self).__init__()
|
||||
|
||||
# Fill d_keys and d_values
|
||||
d_keys = d_keys or (d_model//n_heads)
|
||||
d_values = d_values or (d_model//n_heads)
|
||||
|
||||
self.inner_attention = attention
|
||||
self.q_proj = Linear(d_model, d_keys * n_heads)
|
||||
self.k_proj = Linear(d_model, d_keys * n_heads)
|
||||
self.v_proj = Linear(d_model, d_values * n_heads)
|
||||
self.out_proj = Linear(d_values * n_heads, d_model)
|
||||
self.n_heads = n_heads
|
||||
|
||||
self.qkv_same_dim = d_keys == d_values == d_model // n_heads
|
||||
|
||||
def reset_parameters(self):
|
||||
if self.qkv_same_dim:
|
||||
# Empirically observed the convergence to be much better with
|
||||
# the scaled initialization
|
||||
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
||||
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
||||
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
||||
else:
|
||||
nn.init.xavier_uniform_(self.k_proj.weight)
|
||||
nn.init.xavier_uniform_(self.v_proj.weight)
|
||||
nn.init.xavier_uniform_(self.q_proj.weight)
|
||||
|
||||
nn.init.xavier_uniform_(self.out_proj.weight)
|
||||
if self.out_proj.bias is not None:
|
||||
nn.init.constant_(self.out_proj.bias, 0.0)
|
||||
|
||||
def forward(self, queries, keys, values, attn_mask=None, key_padding_mask=None):
|
||||
"""Apply attention to the passed in queries/keys/values after
|
||||
projecting them to multiple heads.
|
||||
|
||||
In the argument description we make use of the following sizes
|
||||
|
||||
- N: the batch size
|
||||
- L: The maximum length of the queries
|
||||
- S: The maximum length of the keys (the actual length per sequence
|
||||
is given by the length mask)
|
||||
- D: The input feature dimensionality passed in the constructor as
|
||||
'd_model'
|
||||
|
||||
Arguments
|
||||
---------
|
||||
queries: (N, L, D) The tensor containing the queries
|
||||
keys: (N, S, D) The tensor containing the keys
|
||||
values: (N, S, D) The tensor containing the values
|
||||
attn_mask: An implementation of BaseMask that encodes where each
|
||||
query can attend to
|
||||
query_lengths: An implementation of BaseMask that encodes how
|
||||
many queries each sequence in the batch consists of
|
||||
key_lengths: An implementation of BaseMask that encodes how
|
||||
many queries each sequence in the batch consists of
|
||||
|
||||
Returns
|
||||
-------
|
||||
The new value for each query as a tensor of shape (N, L, D).
|
||||
"""
|
||||
# Extract the dimensions into local variables
|
||||
N, L, _ = queries.shape
|
||||
_, S, _ = keys.shape
|
||||
H = self.n_heads
|
||||
|
||||
# Project the queries/keys/values
|
||||
queries = self.q_proj(queries).view(N, L, H, -1)
|
||||
keys = self.k_proj(keys).view(N, S, H, -1)
|
||||
values = self.v_proj(values).view(N, S, H, -1)
|
||||
|
||||
# Compute the attention
|
||||
new_values = self.inner_attention(
|
||||
queries,
|
||||
keys,
|
||||
values,
|
||||
attn_mask,
|
||||
key_padding_mask
|
||||
).view(N, L, -1)
|
||||
|
||||
# Project the output and return
|
||||
return self.out_proj(new_values)
|
|
@ -0,0 +1,76 @@
|
|||
#
|
||||
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
|
||||
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
|
||||
# Apoorv Vyas <avyas@idiap.ch>
|
||||
#
|
||||
|
||||
"""Implement causally masked linear attention."""
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
|
||||
from fast_transformers.causal_product import causal_dot_product
|
||||
from fast_transformers.feature_maps import elu_feature_map
|
||||
|
||||
|
||||
|
||||
|
||||
def causal_linear(Q, K, V):
|
||||
dtype = Q.dtype
|
||||
Q = Q.permute(0,2,1,3).float().contiguous() # [bs, n_head, seq_len, d_hidden]
|
||||
K = K.permute(0,2,1,3).float().contiguous()
|
||||
V = V.permute(0,2,1,3).float().contiguous()
|
||||
V_new = causal_dot_product(Q, K, V)
|
||||
return V_new.permute(0,2,1,3).type(dtype).contiguous() # [bs, seq_len, n_head, d_hidden]
|
||||
|
||||
|
||||
class CausalLinearAttention(Module):
|
||||
"""Implement causally masked attention using dot product of feature maps in
|
||||
O(N D^2) complexity.
|
||||
|
||||
See fast_transformers.attention.linear_attention.LinearAttention for the
|
||||
general concept of replacing the softmax with feature maps. In addition to
|
||||
that, we also make use of the fact that causal masking is a triangular mask
|
||||
which allows us to apply the masking and still compute the attention in O(N
|
||||
D^2) complexity.
|
||||
|
||||
Arguments
|
||||
---------
|
||||
feature_map: callable, a callable that applies the feature map to the
|
||||
last dimension of a tensor (default: elu(x)+1)
|
||||
eps: float, a small number to ensure the numerical stability of the
|
||||
denominator (default: 1e-6)
|
||||
event_dispatcher: str or EventDispatcher instance to be used by this
|
||||
module for dispatching events (default: the default
|
||||
global dispatcher)
|
||||
"""
|
||||
def __init__(self, query_dimensions, feature_map=None, eps=1e-6):
|
||||
super(CausalLinearAttention, self).__init__()
|
||||
self.feature_map = (
|
||||
feature_map(query_dimensions) if feature_map else
|
||||
elu_feature_map(query_dimensions)
|
||||
)
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, queries, keys, values, attn_mask=None, key_padding_mask=None):
|
||||
# Apply the feature map to the queries and keys
|
||||
self.feature_map.new_feature_map(queries.device)
|
||||
Q = self.feature_map.forward_queries(queries) # [bs, seq_len, n_head, d_hidden]
|
||||
K = self.feature_map.forward_keys(keys)
|
||||
|
||||
assert attn_mask is None, "Cannot assign attn_mask for %s" % self.__class__.__name__
|
||||
|
||||
if key_padding_mask is not None:
|
||||
K = K * key_padding_mask.type(queries.dtype)[:, :, None, None]
|
||||
|
||||
# Compute the normalizers
|
||||
Z = 1/(torch.einsum("nlhi,nlhi->nlh", Q, K.cumsum(1)) + self.eps)
|
||||
|
||||
# Compute the unnormalized result
|
||||
V = causal_linear(
|
||||
Q,
|
||||
K,
|
||||
values
|
||||
)
|
||||
|
||||
return V * Z[:, :, :, None]
|
|
@ -0,0 +1,477 @@
|
|||
from fairseq.models.transformer import Linear
|
||||
from fairseq.models import FairseqDecoder
|
||||
|
||||
import math, gc
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from fairseq import utils
|
||||
from fairseq.models.fairseq_encoder import EncoderOut
|
||||
from fairseq.modules import (
|
||||
AdaptiveSoftmax,
|
||||
FairseqDropout,
|
||||
LayerDropModuleList,
|
||||
LayerNorm,
|
||||
PositionalEmbedding,
|
||||
SinusoidalPositionalEmbedding,
|
||||
)
|
||||
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
|
||||
from torch import Tensor
|
||||
|
||||
from .transformer_layer import LinearTransformerDecoderLayer
|
||||
import numpy as np
|
||||
|
||||
|
||||
class LinearTransformerDecoder(FairseqDecoder):
|
||||
"""
|
||||
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
|
||||
is a :class:`TransformerDecoderLayer`.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): parsed command-line arguments
|
||||
dictionary (~fairseq.data.Dictionary): decoding dictionary
|
||||
embed_tokens (torch.nn.Embedding): output embedding
|
||||
no_encoder_attn (bool, optional): whether to attend to encoder outputs
|
||||
(default: False).
|
||||
"""
|
||||
|
||||
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
|
||||
self.args = args
|
||||
super().__init__(dictionary)
|
||||
self.register_buffer("version", torch.Tensor([3]))
|
||||
self._future_mask = torch.empty(0)
|
||||
|
||||
self.dropout_module = FairseqDropout(
|
||||
args.dropout, module_name=self.__class__.__name__
|
||||
)
|
||||
self.decoder_layerdrop = args.decoder_layerdrop
|
||||
self.share_input_output_embed = args.share_decoder_input_output_embed
|
||||
|
||||
input_embed_dim = embed_tokens.embedding_dim
|
||||
embed_dim = args.decoder_embed_dim
|
||||
self.embed_dim = embed_dim
|
||||
self.output_embed_dim = args.decoder_output_dim
|
||||
|
||||
self.padding_idx = embed_tokens.padding_idx
|
||||
self.max_target_positions = args.max_target_positions
|
||||
|
||||
self.embed_tokens = embed_tokens
|
||||
|
||||
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
|
||||
|
||||
# add new layers:
|
||||
if self.args.control_mode == "embedding_v2":
|
||||
self.feature_embedding_layer_list = nn.ModuleList()
|
||||
for i in range(self.args.feature_num):
|
||||
if self.args.bucket_num > 0:
|
||||
self.feature_embedding_layer_list.append(
|
||||
nn.Embedding(self.args.bucket_num, self.args.command_embed_dim))
|
||||
else:
|
||||
self.feature_embedding_layer_list.append(
|
||||
nn.Linear(1, self.args.command_embed_dim))
|
||||
self.command_project = nn.Linear(self.args.command_embed_dim * self.args.feature_num, input_embed_dim)
|
||||
else:
|
||||
raise ValueError("unknown control mode, please chosen from [embedding_v2]")
|
||||
|
||||
|
||||
if not args.adaptive_input and args.quant_noise_pq > 0:
|
||||
self.quant_noise = apply_quant_noise_(
|
||||
nn.Linear(embed_dim, embed_dim, bias=False),
|
||||
args.quant_noise_pq,
|
||||
args.quant_noise_pq_block_size,
|
||||
)
|
||||
else:
|
||||
self.quant_noise = None
|
||||
|
||||
self.project_in_dim = (
|
||||
Linear(input_embed_dim, embed_dim, bias=False)
|
||||
if embed_dim != input_embed_dim
|
||||
else None
|
||||
)
|
||||
|
||||
self.embed_positions = (
|
||||
PositionalEmbedding(
|
||||
args.truncated_length,
|
||||
embed_dim,
|
||||
self.padding_idx,
|
||||
learned=args.decoder_learned_pos,
|
||||
)
|
||||
if not args.no_token_positional_embeddings
|
||||
else None
|
||||
)
|
||||
|
||||
if getattr(args, "layernorm_embedding", False):
|
||||
self.layernorm_embedding = LayerNorm(embed_dim)
|
||||
else:
|
||||
self.layernorm_embedding = None
|
||||
|
||||
self.cross_self_attention = getattr(args, "cross_self_attention", False)
|
||||
|
||||
if self.decoder_layerdrop > 0.0:
|
||||
self.layers = LayerDropModuleList(p=self.decoder_layerdrop)
|
||||
else:
|
||||
self.layers = nn.ModuleList([])
|
||||
self.layers.extend(
|
||||
[
|
||||
self.build_decoder_layer(args, no_encoder_attn)
|
||||
for _ in range(args.decoder_layers)
|
||||
]
|
||||
)
|
||||
self.num_layers = len(self.layers)
|
||||
|
||||
if args.decoder_normalize_before and not getattr(
|
||||
args, "no_decoder_final_norm", False
|
||||
):
|
||||
self.layer_norm = LayerNorm(embed_dim)
|
||||
else:
|
||||
self.layer_norm = None
|
||||
|
||||
self.project_out_dim = (
|
||||
Linear(embed_dim, self.output_embed_dim, bias=False)
|
||||
if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights
|
||||
else None
|
||||
)
|
||||
|
||||
self.adaptive_softmax = None
|
||||
self.output_projection = None
|
||||
if args.adaptive_softmax_cutoff is not None:
|
||||
self.adaptive_softmax = AdaptiveSoftmax(
|
||||
len(dictionary),
|
||||
self.output_embed_dim,
|
||||
utils.eval_str_list(args.adaptive_softmax_cutoff, type=int),
|
||||
dropout=args.adaptive_softmax_dropout,
|
||||
adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None,
|
||||
factor=args.adaptive_softmax_factor,
|
||||
tie_proj=args.tie_adaptive_proj,
|
||||
)
|
||||
elif self.share_input_output_embed:
|
||||
self.output_projection = nn.Linear(
|
||||
self.embed_tokens.weight.shape[1],
|
||||
self.embed_tokens.weight.shape[0],
|
||||
bias=False,
|
||||
)
|
||||
self.output_projection.weight = self.embed_tokens.weight
|
||||
else:
|
||||
self.output_projection = nn.Linear(
|
||||
self.output_embed_dim, len(dictionary), bias=False
|
||||
)
|
||||
nn.init.normal_(
|
||||
self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = getattr(self.args, 'gradient_checkpointing', False)
|
||||
if self.gradient_checkpointing:
|
||||
checkpointing_layers = getattr(self.args, 'gradient_checkpointing_layers', None)
|
||||
if checkpointing_layers is None:
|
||||
gradient_checkpointing_every_n_layer = getattr(self.args, 'gradient_checkpointing_every_n_layer', 1)
|
||||
checkpointing_layers = tuple(range(0, self.num_layers, gradient_checkpointing_every_n_layer))
|
||||
self.checkpointing_layers = checkpointing_layers
|
||||
|
||||
|
||||
def build_decoder_layer(self, args, no_encoder_attn=False):
|
||||
return LinearTransformerDecoderLayer(args, no_encoder_attn)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
prev_output_tokens,
|
||||
command_input,
|
||||
encoder_out: Optional[EncoderOut] = None,
|
||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
||||
features_only: bool = False,
|
||||
full_context_alignment: bool = False,
|
||||
alignment_layer: Optional[int] = None,
|
||||
alignment_heads: Optional[int] = None,
|
||||
src_lengths: Optional[Any] = None,
|
||||
return_all_hiddens: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
prev_output_tokens (LongTensor): previous decoder outputs of shape
|
||||
`(batch, tgt_len)`, for teacher forcing
|
||||
encoder_out (optional): output from the encoder, used for
|
||||
encoder-side attention
|
||||
incremental_state (dict): dictionary used for storing state during
|
||||
:ref:`Incremental decoding`
|
||||
features_only (bool, optional): only return features without
|
||||
applying output layer (default: False).
|
||||
full_context_alignment (bool, optional): don't apply
|
||||
auto-regressive mask to self-attention (default: False).
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
- the decoder's output of shape `(batch, tgt_len, vocab)`
|
||||
- a dictionary with any model-specific outputs
|
||||
"""
|
||||
x, extra = self.extract_features(
|
||||
prev_output_tokens,
|
||||
command_input,
|
||||
encoder_out=encoder_out,
|
||||
incremental_state=incremental_state,
|
||||
full_context_alignment=full_context_alignment,
|
||||
alignment_layer=alignment_layer,
|
||||
alignment_heads=alignment_heads,
|
||||
)
|
||||
if not features_only:
|
||||
x = self.output_layer(x)
|
||||
return x, extra
|
||||
|
||||
def extract_features(
|
||||
self,
|
||||
prev_output_tokens,
|
||||
command_input,
|
||||
encoder_out: Optional[EncoderOut] = None,
|
||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
||||
full_context_alignment: bool = False,
|
||||
alignment_layer: Optional[int] = None,
|
||||
alignment_heads: Optional[int] = None,
|
||||
):
|
||||
return self.extract_features_scriptable(
|
||||
prev_output_tokens,
|
||||
command_input,
|
||||
encoder_out,
|
||||
incremental_state,
|
||||
full_context_alignment,
|
||||
alignment_layer,
|
||||
alignment_heads,
|
||||
)
|
||||
|
||||
"""
|
||||
A scriptable subclass of this class has an extract_features method and calls
|
||||
super().extract_features, but super() is not supported in torchscript. Aa copy of
|
||||
this function is made to be used in the subclass instead.
|
||||
"""
|
||||
|
||||
def extract_features_scriptable(
|
||||
self,
|
||||
prev_output_tokens,
|
||||
command_input,
|
||||
encoder_out: Optional[EncoderOut] = None,
|
||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
||||
full_context_alignment: bool = False,
|
||||
alignment_layer: Optional[int] = None,
|
||||
alignment_heads: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Similar to *forward* but only return features.
|
||||
|
||||
Includes several features from "Jointly Learning to Align and
|
||||
Translate with Transformer Models" (Garg et al., EMNLP 2019).
|
||||
|
||||
Args:
|
||||
full_context_alignment (bool, optional): don't apply
|
||||
auto-regressive mask to self-attention (default: False).
|
||||
alignment_layer (int, optional): return mean alignment over
|
||||
heads at this layer (default: last layer).
|
||||
alignment_heads (int, optional): only average alignment over
|
||||
this many heads (default: all heads).
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
|
||||
- a dictionary with any model-specific outputs
|
||||
"""
|
||||
if alignment_layer is None:
|
||||
alignment_layer = self.num_layers - 1
|
||||
|
||||
# embed positions
|
||||
positions = (
|
||||
self.embed_positions(
|
||||
prev_output_tokens, incremental_state=incremental_state
|
||||
)
|
||||
if self.embed_positions is not None
|
||||
else None
|
||||
)
|
||||
|
||||
if incremental_state is not None:
|
||||
prev_output_tokens = prev_output_tokens[:, -1:]
|
||||
if positions is not None:
|
||||
positions = positions[:, -1:]
|
||||
|
||||
# embed tokens and positions
|
||||
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
|
||||
|
||||
|
||||
command_emb_list = []
|
||||
|
||||
command_input = command_input.to(torch.int64)
|
||||
for i in range(self.args.feature_num):
|
||||
if command_input[0, i] != -1:
|
||||
command_emb_list.append(
|
||||
self.feature_embedding_layer_list[i](command_input[:, i])[:, None, :]) # [bs, 1, emb_size]
|
||||
else:
|
||||
command_emb_list.append(torch.zeros(command_emb_list[0].shape).to(prev_output_tokens.device))
|
||||
|
||||
command_emb = torch.cat(command_emb_list, dim=1)[:, None, :, :] # [bs, seq_len=1, feature_num, emb_size]
|
||||
seq_len = prev_output_tokens.shape[1]
|
||||
gc.collect()
|
||||
batch_size = command_emb.shape[0]
|
||||
|
||||
command_emb = torch.repeat_interleave(command_emb, repeats=seq_len,
|
||||
dim=1) # [bs, seq_len, feature_num, emb_size]
|
||||
bs, seq_len, feature_num, emb_size = command_emb.shape
|
||||
|
||||
command_emb = command_emb.reshape([bs, seq_len, feature_num * emb_size])
|
||||
command_emb = self.command_project(command_emb)
|
||||
x += command_emb
|
||||
|
||||
if self.quant_noise is not None:
|
||||
x = self.quant_noise(x)
|
||||
|
||||
if self.project_in_dim is not None:
|
||||
x = self.project_in_dim(x)
|
||||
|
||||
if positions is not None:
|
||||
x += positions
|
||||
|
||||
if self.layernorm_embedding is not None:
|
||||
x = self.layernorm_embedding(x)
|
||||
|
||||
x = self.dropout_module(x)
|
||||
|
||||
# B x T x C -> T x B x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
self_attn_padding_mask: Optional[Tensor] = None
|
||||
if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
|
||||
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
|
||||
|
||||
# decoder layers
|
||||
attn: Optional[Tensor] = None
|
||||
inner_states: List[Optional[Tensor]] = [x]
|
||||
|
||||
gradient_checkpointing_every_n_layer = getattr(self.args, "gradient_checkpointing_every_n_layer", 1)
|
||||
|
||||
for idx, layer in enumerate(self.layers):
|
||||
# if incremental_state is None and not full_context_alignment:
|
||||
# self_attn_mask = self.buffered_future_mask(x)
|
||||
# else:
|
||||
# self_attn_mask = None
|
||||
|
||||
self_attn_mask = None # Casual Linear Attention does not need this
|
||||
|
||||
if (
|
||||
getattr(self.args, "gradient_checkpointing", False) and self.training and
|
||||
idx in self.checkpointing_layers
|
||||
):
|
||||
x, layer_attn, _ = checkpoint(
|
||||
layer,
|
||||
x,
|
||||
encoder_out.encoder_out if encoder_out is not None else None,
|
||||
encoder_out.encoder_padding_mask if encoder_out is not None else None,
|
||||
incremental_state,
|
||||
None,
|
||||
None,
|
||||
self_attn_mask,
|
||||
self_attn_padding_mask,
|
||||
bool((idx == alignment_layer)),
|
||||
bool((idx == alignment_layer)),
|
||||
)
|
||||
else:
|
||||
x, layer_attn, _ = layer(
|
||||
x,
|
||||
encoder_out.encoder_out if encoder_out is not None else None,
|
||||
encoder_out.encoder_padding_mask if encoder_out is not None else None,
|
||||
incremental_state,
|
||||
self_attn_mask=self_attn_mask,
|
||||
self_attn_padding_mask=self_attn_padding_mask,
|
||||
need_attn=bool((idx == alignment_layer)),
|
||||
need_head_weights=bool((idx == alignment_layer)),
|
||||
)
|
||||
inner_states.append(x)
|
||||
if layer_attn is not None and idx == alignment_layer:
|
||||
attn = layer_attn.float().to(x)
|
||||
|
||||
if attn is not None:
|
||||
if alignment_heads is not None:
|
||||
attn = attn[:alignment_heads]
|
||||
|
||||
# average probabilities over heads
|
||||
attn = attn.mean(dim=0)
|
||||
|
||||
if self.layer_norm is not None:
|
||||
x = self.layer_norm(x)
|
||||
|
||||
# T x B x C -> B x T x C
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
if self.project_out_dim is not None:
|
||||
x = self.project_out_dim(x)
|
||||
|
||||
return x, {"attn": [attn], "inner_states": inner_states}
|
||||
|
||||
def output_layer(self, features, **kwargs):
|
||||
"""Project features to the vocabulary size."""
|
||||
if self.adaptive_softmax is None:
|
||||
# project back to size of vocabulary
|
||||
return self.output_projection(features, **kwargs)
|
||||
else:
|
||||
return features
|
||||
|
||||
def max_positions(self):
|
||||
"""Maximum output length supported by the decoder."""
|
||||
if self.embed_positions is None:
|
||||
return self.max_target_positions
|
||||
return min(self.max_target_positions, self.embed_positions.max_positions)
|
||||
|
||||
def buffered_future_mask(self, tensor):
|
||||
dim = tensor.size(0)
|
||||
# self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround.
|
||||
if (
|
||||
self._future_mask.size(0) == 0
|
||||
or (not self._future_mask.device == tensor.device)
|
||||
or self._future_mask.size(0) < dim
|
||||
):
|
||||
self._future_mask = torch.triu(
|
||||
utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1
|
||||
)
|
||||
self._future_mask = self._future_mask.to(tensor)
|
||||
return self._future_mask[:dim, :dim]
|
||||
|
||||
def upgrade_state_dict_named(self, state_dict, name):
|
||||
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
|
||||
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
|
||||
weights_key = "{}.embed_positions.weights".format(name)
|
||||
if weights_key in state_dict:
|
||||
del state_dict[weights_key]
|
||||
state_dict[
|
||||
"{}.embed_positions._float_tensor".format(name)
|
||||
] = torch.FloatTensor(1)
|
||||
|
||||
if f"{name}.output_projection.weight" not in state_dict:
|
||||
if self.share_input_output_embed:
|
||||
embed_out_key = f"{name}.embed_tokens.weight"
|
||||
else:
|
||||
embed_out_key = f"{name}.embed_out"
|
||||
if embed_out_key in state_dict:
|
||||
state_dict[f"{name}.output_projection.weight"] = state_dict[
|
||||
embed_out_key
|
||||
]
|
||||
if not self.share_input_output_embed:
|
||||
del state_dict[embed_out_key]
|
||||
|
||||
for i in range(self.num_layers):
|
||||
# update layer norms
|
||||
layer_norm_map = {
|
||||
"0": "self_attn_layer_norm",
|
||||
"1": "encoder_attn_layer_norm",
|
||||
"2": "final_layer_norm",
|
||||
}
|
||||
for old, new in layer_norm_map.items():
|
||||
for m in ("weight", "bias"):
|
||||
k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m)
|
||||
if k in state_dict:
|
||||
state_dict[
|
||||
"{}.layers.{}.{}.{}".format(name, i, new, m)
|
||||
] = state_dict[k]
|
||||
del state_dict[k]
|
||||
|
||||
version_key = "{}.version".format(name)
|
||||
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
|
||||
# earlier checkpoints did not normalize after the stack of layers
|
||||
self.layer_norm = None
|
||||
self.normalize = False
|
||||
state_dict[version_key] = torch.Tensor([1])
|
||||
|
||||
return state_dict
|
|
@ -0,0 +1,191 @@
|
|||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from fairseq.modules.transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer
|
||||
# from fast_transformers.attention.attention_layer import AttentionLayer
|
||||
# from fast_transformers.attention.causal_linear_attention import CausalLinearAttention
|
||||
# from fast_transformers.masking import TriangularCausalMask, LengthMask, FullMask
|
||||
|
||||
from .causal_linear_attention import CausalLinearAttention
|
||||
from .attention_layer import AttentionLayer
|
||||
# from fast_transformers.transformers import TransformerEncoderLayer
|
||||
from fast_transformers.attention import LinearAttention
|
||||
|
||||
class LinearTransformerDecoderLayer(TransformerDecoderLayer):
|
||||
def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False):
|
||||
super().__init__(args, no_encoder_attn=no_encoder_attn, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn)
|
||||
self.decoder_attention_heads = args.decoder_attention_heads
|
||||
|
||||
def build_self_attention(
|
||||
self, embed_dim, args, add_bias_kv=False, add_zero_attn=False
|
||||
):
|
||||
causal_linear_attention = CausalLinearAttention(embed_dim)
|
||||
linear_attention_layer = AttentionLayer(causal_linear_attention,
|
||||
embed_dim, args.decoder_attention_heads)
|
||||
return linear_attention_layer
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
encoder_out: Optional[torch.Tensor] = None,
|
||||
encoder_padding_mask: Optional[torch.Tensor] = None,
|
||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
||||
prev_self_attn_state: Optional[List[torch.Tensor]] = None,
|
||||
prev_attn_state: Optional[List[torch.Tensor]] = None,
|
||||
self_attn_mask: Optional[torch.Tensor] = None,
|
||||
self_attn_padding_mask: Optional[torch.Tensor] = None,
|
||||
need_attn: bool = False,
|
||||
need_head_weights: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
|
||||
encoder_padding_mask (ByteTensor, optional): binary
|
||||
ByteTensor of shape `(batch, src_len)` where padding
|
||||
elements are indicated by ``1``.
|
||||
need_attn (bool, optional): return attention weights
|
||||
need_head_weights (bool, optional): return attention weights
|
||||
for each head (default: return average over heads).
|
||||
|
||||
Returns:
|
||||
encoded output of shape `(seq_len, batch, embed_dim)`
|
||||
"""
|
||||
if need_head_weights:
|
||||
need_attn = True
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
# if prev_self_attn_state is not None:
|
||||
# prev_key, prev_value = prev_self_attn_state[:2]
|
||||
# saved_state: Dict[str, Optional[Tensor]] = {
|
||||
# "prev_key": prev_key,
|
||||
# "prev_value": prev_value,
|
||||
# }
|
||||
# if len(prev_self_attn_state) >= 3:
|
||||
# saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
|
||||
# assert incremental_state is not None
|
||||
# self.self_attn._set_input_buffer(incremental_state, saved_state)
|
||||
# _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
|
||||
# if self.cross_self_attention and not (
|
||||
# incremental_state is not None
|
||||
# and _self_attn_input_buffer is not None
|
||||
# and "prev_key" in _self_attn_input_buffer
|
||||
# ):
|
||||
# if self_attn_mask is not None:
|
||||
# assert encoder_out is not None
|
||||
# self_attn_mask = torch.cat(
|
||||
# (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1
|
||||
# )
|
||||
# if self_attn_padding_mask is not None:
|
||||
# if encoder_padding_mask is None:
|
||||
# assert encoder_out is not None
|
||||
# encoder_padding_mask = self_attn_padding_mask.new_zeros(
|
||||
# encoder_out.size(1), encoder_out.size(0)
|
||||
# )
|
||||
# self_attn_padding_mask = torch.cat(
|
||||
# (encoder_padding_mask, self_attn_padding_mask), dim=1
|
||||
# )
|
||||
# assert encoder_out is not None
|
||||
# y = torch.cat((encoder_out, x), dim=0)
|
||||
# else:
|
||||
y = x
|
||||
|
||||
x, attn = self.run_self_attn(
|
||||
query=x,
|
||||
key=y,
|
||||
value=y,
|
||||
key_padding_mask=self_attn_padding_mask,
|
||||
incremental_state=incremental_state,
|
||||
need_weights=False,
|
||||
attn_mask=self_attn_mask,
|
||||
)
|
||||
x = self.dropout_module(x)
|
||||
x = self.residual_connection(x, residual)
|
||||
if not self.normalize_before:
|
||||
x = self.self_attn_layer_norm(x)
|
||||
|
||||
assert self.encoder_attn is None and encoder_out is None
|
||||
# if self.encoder_attn is not None and encoder_out is not None:
|
||||
# residual = x
|
||||
# if self.normalize_before:
|
||||
# x = self.encoder_attn_layer_norm(x)
|
||||
# if prev_attn_state is not None:
|
||||
# prev_key, prev_value = prev_attn_state[:2]
|
||||
# saved_state: Dict[str, Optional[Tensor]] = {
|
||||
# "prev_key": prev_key,
|
||||
# "prev_value": prev_value,
|
||||
# }
|
||||
# if len(prev_attn_state) >= 3:
|
||||
# saved_state["prev_key_padding_mask"] = prev_attn_state[2]
|
||||
# assert incremental_state is not None
|
||||
# self.encoder_attn._set_input_buffer(incremental_state, saved_state)
|
||||
#
|
||||
# x, attn = self.encoder_attn(
|
||||
# query=x,
|
||||
# key=encoder_out,
|
||||
# value=encoder_out,
|
||||
# key_padding_mask=encoder_padding_mask,
|
||||
# incremental_state=incremental_state,
|
||||
# static_kv=True,
|
||||
# need_weights=need_attn or (not self.training and self.need_attn),
|
||||
# need_head_weights=need_head_weights,
|
||||
# )
|
||||
# x = self.dropout_module(x)
|
||||
# x = self.residual_connection(x, residual)
|
||||
# if not self.normalize_before:
|
||||
# x = self.encoder_attn_layer_norm(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.final_layer_norm(x)
|
||||
|
||||
x = self.activation_fn(self.fc1(x))
|
||||
x = self.activation_dropout_module(x)
|
||||
x = self.fc2(x)
|
||||
x = self.dropout_module(x)
|
||||
x = self.residual_connection(x, residual)
|
||||
if not self.normalize_before:
|
||||
x = self.final_layer_norm(x)
|
||||
if self.onnx_trace and incremental_state is not None:
|
||||
raise NotImplementedError
|
||||
saved_state = self.self_attn._get_input_buffer(incremental_state)
|
||||
assert saved_state is not None
|
||||
if self_attn_padding_mask is not None:
|
||||
self_attn_state = [
|
||||
saved_state["prev_key"],
|
||||
saved_state["prev_value"],
|
||||
saved_state["prev_key_padding_mask"],
|
||||
]
|
||||
else:
|
||||
self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]]
|
||||
return x, attn, self_attn_state
|
||||
return x, attn, None
|
||||
|
||||
def run_self_attn(self, query, key_padding_mask, incremental_state, need_weights, **kwargs):
|
||||
if incremental_state is not None:
|
||||
raise NotImplementedError
|
||||
if need_weights:
|
||||
raise NotImplementedError
|
||||
|
||||
# tgt_len, bsz, embed_dim = query.shape
|
||||
# src_len = tgt_len
|
||||
# num_heads = self.decoder_attention_heads
|
||||
# head_dim = self.embed_dim // num_heads
|
||||
|
||||
# query = query.transpose(0, 1).reshape(bsz, tgt_len, num_heads, head_dim)
|
||||
query = query.transpose(0, 1)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = ~key_padding_mask
|
||||
|
||||
r = self.self_attn(query, query, query, attn_mask=None, key_padding_mask=key_padding_mask)
|
||||
|
||||
r = r.transpose(0, 1)
|
||||
|
||||
return r, None
|
||||
|
||||
# class LinearTransformerEncoderLayer(TransformerEncoderLayer):
|
||||
|
|
@ -0,0 +1,101 @@
|
|||
from fairseq.models.transformer_lm import TransformerLanguageModel, TransformerLanguageModelConfig, \
|
||||
DEFAULT_MAX_TARGET_POSITIONS, transformer_lm_gpt, base_lm_architecture
|
||||
from fairseq import options
|
||||
from fairseq.modules import AdaptiveInput, CharacterTokenEmbedder
|
||||
from fairseq.models import register_model, register_model_architecture
|
||||
|
||||
from .transformer import LinearTransformerDecoder
|
||||
|
||||
@register_model("linear_transformer_lm", dataclass=TransformerLanguageModelConfig)
|
||||
class LinearTransformerLanguageModel(TransformerLanguageModel):
|
||||
@classmethod
|
||||
def add_args(cls, parser):
|
||||
super().add_args(parser)
|
||||
parser.add_argument('--gradient-checkpointing', type=lambda x: x.lower() == 'true', default=False)
|
||||
parser.add_argument('--gradient-checkpointing-every-n-layer', type=int, default=1)
|
||||
parser.add_argument('--gradient-checkpointing-layers',
|
||||
type=lambda x: tuple([int(item) for item in x.split(',')]), default=None)
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, task):
|
||||
"""Build a new model instance."""
|
||||
|
||||
# make sure all arguments are present in older models
|
||||
base_lm_architecture(args)
|
||||
|
||||
if args.decoder_layers_to_keep:
|
||||
args.decoder_layers = len(args.decoder_layers_to_keep.split(","))
|
||||
|
||||
if getattr(args, "max_target_positions", None) is None:
|
||||
args.max_target_positions = getattr(
|
||||
args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS
|
||||
)
|
||||
|
||||
if args.character_embeddings:
|
||||
embed_tokens = CharacterTokenEmbedder(
|
||||
task.source_dictionary,
|
||||
eval(args.character_filters),
|
||||
args.character_embedding_dim,
|
||||
args.decoder_embed_dim,
|
||||
args.char_embedder_highway_layers,
|
||||
)
|
||||
elif args.adaptive_input:
|
||||
embed_tokens = AdaptiveInput(
|
||||
len(task.source_dictionary),
|
||||
task.source_dictionary.pad(),
|
||||
args.decoder_input_dim,
|
||||
args.adaptive_input_factor,
|
||||
args.decoder_embed_dim,
|
||||
options.eval_str_list(args.adaptive_input_cutoff, type=int),
|
||||
args.quant_noise_pq,
|
||||
args.quant_noise_pq_block_size,
|
||||
)
|
||||
else:
|
||||
embed_tokens = cls.build_embedding(
|
||||
args, task.source_dictionary, args.decoder_input_dim
|
||||
)
|
||||
|
||||
if args.tie_adaptive_weights:
|
||||
assert args.adaptive_input
|
||||
assert args.adaptive_input_factor == args.adaptive_softmax_factor
|
||||
assert (
|
||||
args.adaptive_softmax_cutoff == args.adaptive_input_cutoff
|
||||
), "{} != {}".format(
|
||||
args.adaptive_softmax_cutoff, args.adaptive_input_cutoff
|
||||
)
|
||||
assert args.decoder_input_dim == args.decoder_output_dim
|
||||
|
||||
decoder = LinearTransformerDecoder(
|
||||
args, task.target_dictionary, embed_tokens, no_encoder_attn=True
|
||||
)
|
||||
return cls(decoder)
|
||||
|
||||
|
||||
@register_model_architecture("linear_transformer_lm", "linear_transformer_lm_gpt")
|
||||
def linear_transformer_lm_gpt_architecture(args):
|
||||
transformer_lm_gpt(args)
|
||||
|
||||
|
||||
@register_model_architecture("linear_transformer_lm", "linear_transformer_lm_std")
|
||||
def std_linear_transformer_lm_architecture(args):
|
||||
args.command_embed_dim = 16
|
||||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
|
||||
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048)
|
||||
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
||||
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
|
||||
args.dropout = getattr(args, "dropout", 0.1)
|
||||
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
||||
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
||||
base_lm_architecture(args)
|
||||
|
||||
@register_model_architecture("linear_transformer_lm", "linear_transformer_lm_debug")
|
||||
def std_linear_transformer_lm_architecture(args):
|
||||
args.command_embed_dim = 16
|
||||
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 32)
|
||||
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 32)
|
||||
args.decoder_layers = getattr(args, "decoder_layers", 2)
|
||||
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
|
||||
args.dropout = getattr(args, "dropout", 0.1)
|
||||
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
||||
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
||||
base_lm_architecture(args)
|
|
@ -0,0 +1,90 @@
|
|||
# EmoGen
|
||||
|
||||
EmoGen: Eliminating Subjective Bias in Emotional Music Generation, by Chenfei Kang, Peiling Lu, Botao Yu, Xu Tan, Wei Ye, Shikun Zhang, Jiang Bian, is an emotional music generation system that leverages a set of emotion-related music attributes as the bridge between emotion and music, and divides the generation into two stages: emotion-to-attribute mapping with supervised clustering, and attribute-to-music generation with self-supervised learning. Both stages are beneficial: in the first stage, the attribute values around the clustering center represent the general emotions of these samples, which help eliminate the impacts of the subjective bias of emotion labels; in the second stage, the generation is completely disentangled from emotion labels and thus free from the subjective bias. Both subjective and objective evaluations show that EMOGEN outperforms previous methods on emotion control accuracy and music quality respectively, which demonstrate our superiority in generating emotional music.
|
||||
|
||||
demo: [link](https://emo-gen.github.io/)
|
||||
|
||||
The following content includes the steps for EMOGEN training and inference.
|
||||
|
||||
### 1. Environment
|
||||
|
||||
- Hardware environment: We recommend Nvidia V100 16GB/32GB.
|
||||
|
||||
- Software environment:
|
||||
|
||||
Please make sure you have `python 3.8` installed. Run the following command to install necessary packages:
|
||||
|
||||
```sh
|
||||
bash setup.sh
|
||||
```
|
||||
|
||||
Also, please follow the instructions in [Installation](https://jmir.sourceforge.net/manuals/jSymbolic_manual/installation_files/installation.html) to install `Java`.
|
||||
|
||||
### 2. Dataset
|
||||
|
||||
We use three datasets: One emotion-labeled dataset namely EMOPIA([link](https://annahung31.github.io/EMOPIA/)) and two unlabeled datasets namely Pop1k7 ([link](https://github.com/YatingMusic/compound-word-transformer/blob/main/dataset/Dataset.md)) and LMD-Piano, where LMD-Piano is con-
|
||||
structed by using the samples that only contain piano tracks from the Lakh MIDI (LMD) dataset ([link](https://colinraffel.com/projects/lmd/)). To evaluate EMOGEN’s ability to generate emotional music on the arbitrary dataset, we also train EmoGen on TopMAGD([link](http://www.ifs.tuwien.ac.at/mir/msd/download.html)), which is a multi-instrument dataset.
|
||||
|
||||
In principle, you can use any MIDI dataset to train the attribute-to-music generation model. Taking the 'Piano' dataset under the 'data/' folder as an example. Please follow the steps below to process the data.
|
||||
|
||||
1. First, put all MIDI files in 'data/Piano/midi'.
|
||||
|
||||
2. Run the following command to encode MIDI files.
|
||||
|
||||
```shell
|
||||
cd data_process
|
||||
python midi_encoding.py
|
||||
```
|
||||
|
||||
3. Run the following command to extract attributes by jSymbolic.
|
||||
You are required to download the package jSymbolic_2_2_user.zip from https://sourceforge.net/projects/jmir/files/jSymbolic/jSymbolic%202.2/, and extract it into ./jSymbolic_lib.
|
||||
|
||||
```shell
|
||||
cd ../jSymbolic_lib
|
||||
python jSymbolic_feature.py
|
||||
```
|
||||
|
||||
4. Run the following script to prepare train/validation/test dataset.
|
||||
|
||||
```shell
|
||||
cd ../data_process
|
||||
python gen_data.py
|
||||
```
|
||||
|
||||
### 3.Train
|
||||
|
||||
- Emotion-to-attribute mapping
|
||||
|
||||
In this stage, we map four emotion quadrants in Russel's 4Q model to four different attributes. We first compute attribute centers in four quadrants on the EMOPIA dataset and selected the closest attribute vector to the center in each quadrant as the mapping result. The mapped results are stored in `data/infer_input/inference_command.npy`. The emotion quadrants corresponding to these attribute vectors are shown in the following table:
|
||||
|
||||
| Index | Emotion |
|
||||
| -------------------- | ------- |
|
||||
| inference_command[0] | Q1 |
|
||||
| inference_command[1] | Q2 |
|
||||
| inference_command[2] | Q3 |
|
||||
| inference_command[3] | Q4 |
|
||||
|
||||
- Attribute-to-music generation
|
||||
|
||||
Run the following command to train a 6-layer Linear Transformer model on the dataset `data/Piano`:
|
||||
|
||||
```shell
|
||||
bash Piano_train.sh
|
||||
```
|
||||
|
||||
### 4. Inference
|
||||
|
||||
Please put the checkpoints under the folder `checkpoints/`. To generate piano songs, run the following command:
|
||||
|
||||
```shell
|
||||
# usage: bash Piano_gen.sh target_emotion
|
||||
# 1-Q1 2-Q2 3-Q3 4-Q4
|
||||
bash Piano_gen.sh 3 # generate songs with emotion "Q3"
|
||||
```
|
||||
|
||||
Also, to generate multi-instrument songs, please run the following command:
|
||||
|
||||
```shell
|
||||
bash TopMAGD_gen.sh 1 # generate songs with emotion "Q1"
|
||||
```
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
pip install fairseq==0.10.2
|
||||
pip install miditoolkit
|
||||
pip install matplotlib
|
||||
pip install pytorch-fast-transformers
|
||||
pip install chorder
|
||||
pip install plotext
|
||||
pip install scikit-learn
|
||||
pip install scipy
|
||||
git clone https://github.com/btyu/MidiProcessor.git
|
||||
cd MidiProcessor
|
||||
pip install .
|
|
@ -0,0 +1,554 @@
|
|||
#!/usr/bin/env python3 -u
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""
|
||||
Train a new model on one or across multiple GPUs.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairseq import (
|
||||
checkpoint_utils,
|
||||
distributed_utils,
|
||||
options,
|
||||
quantization_utils,
|
||||
tasks,
|
||||
utils,
|
||||
)
|
||||
from fairseq.data import iterators
|
||||
from fairseq.logging import meters, metrics, progress_bar
|
||||
from fairseq.model_parallel.megatron_trainer import MegatronTrainer
|
||||
from fairseq.trainer import Trainer
|
||||
|
||||
import linear_decoder.controlled_task
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
||||
stream=sys.stdout,
|
||||
)
|
||||
logger = logging.getLogger("fairseq_cli.train")
|
||||
|
||||
def main(args):
|
||||
|
||||
utils.import_user_module(args)
|
||||
|
||||
assert (
|
||||
args.max_tokens is not None or args.batch_size is not None
|
||||
), "Must specify batch size either with --max-tokens or --batch-size"
|
||||
|
||||
metrics.reset()
|
||||
|
||||
np.random.seed(args.seed)
|
||||
utils.set_torch_seed(args.seed)
|
||||
|
||||
if distributed_utils.is_master(args):
|
||||
checkpoint_utils.verify_checkpoint_directory(args.save_dir)
|
||||
|
||||
# Print args
|
||||
logger.info(args)
|
||||
|
||||
# Setup task, e.g., translation_control, language modeling, etc.
|
||||
task = tasks.setup_task(args)
|
||||
|
||||
# Load valid dataset (we load training data below, based on the latest checkpoint)
|
||||
for valid_sub_split in args.valid_subset.split(","):
|
||||
task.load_dataset(valid_sub_split, combine=False, epoch=1)
|
||||
|
||||
|
||||
model = task.build_model(args)
|
||||
criterion = task.build_criterion(args)
|
||||
logger.info(model)
|
||||
logger.info("task: {} ({})".format(args.task, task.__class__.__name__))
|
||||
logger.info("model: {} ({})".format(args.arch, model.__class__.__name__))
|
||||
logger.info(
|
||||
"criterion: {} ({})".format(args.criterion, criterion.__class__.__name__)
|
||||
)
|
||||
logger.info(
|
||||
"num. model params: {} (num. trained: {})".format(
|
||||
sum(p.numel() for p in model.parameters()),
|
||||
sum(p.numel() for p in model.parameters() if p.requires_grad),
|
||||
)
|
||||
)
|
||||
|
||||
# (optionally) Configure quantization
|
||||
if args.quantization_config_path is not None:
|
||||
quantizer = quantization_utils.Quantizer(
|
||||
config_path=args.quantization_config_path,
|
||||
max_epoch=args.max_epoch,
|
||||
max_update=args.max_update,
|
||||
)
|
||||
else:
|
||||
quantizer = None
|
||||
|
||||
# Build trainer
|
||||
if args.model_parallel_size == 1:
|
||||
trainer = Trainer(args, task, model, criterion, quantizer)
|
||||
else:
|
||||
trainer = MegatronTrainer(args, task, model, criterion)
|
||||
|
||||
logger.info(
|
||||
"training on {} devices (GPUs/TPUs)".format(args.distributed_world_size)
|
||||
)
|
||||
logger.info(
|
||||
"max tokens per GPU = {} and max sentences per GPU = {}".format(
|
||||
args.max_tokens, args.batch_size
|
||||
)
|
||||
)
|
||||
|
||||
# Load the latest checkpoint if one is available and restore the
|
||||
# corresponding train iterator
|
||||
extra_state, epoch_itr = checkpoint_utils.load_checkpoint(
|
||||
args,
|
||||
trainer,
|
||||
# don't cache epoch iterators for sharded datasets
|
||||
disable_iterator_cache=task.has_sharded_data("train"),
|
||||
)
|
||||
|
||||
# Train until the learning rate gets too small
|
||||
max_epoch = args.max_epoch or math.inf
|
||||
lr = trainer.get_lr()
|
||||
train_meter = meters.StopwatchMeter()
|
||||
train_meter.start()
|
||||
|
||||
while lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch:
|
||||
# train for one epoch
|
||||
valid_losses, should_stop = train(args, trainer, task, epoch_itr)
|
||||
if should_stop:
|
||||
break
|
||||
|
||||
# only use first validation loss to update the learning rate
|
||||
lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
|
||||
|
||||
epoch_itr = trainer.get_train_iterator(
|
||||
epoch_itr.next_epoch_idx,
|
||||
# sharded data: get train iterator for next epoch
|
||||
load_dataset=task.has_sharded_data("train"),
|
||||
# don't cache epoch iterators for sharded datasets
|
||||
disable_iterator_cache=task.has_sharded_data("train"),
|
||||
)
|
||||
train_meter.stop()
|
||||
logger.info("done training in {:.1f} seconds".format(train_meter.sum))
|
||||
|
||||
|
||||
def should_stop_early(args, valid_loss):
|
||||
# skip check if no validation was done in the current epoch
|
||||
if valid_loss is None:
|
||||
return False
|
||||
if args.patience <= 0:
|
||||
return False
|
||||
|
||||
def is_better(a, b):
|
||||
return a > b if args.maximize_best_checkpoint_metric else a < b
|
||||
|
||||
prev_best = getattr(should_stop_early, "best", None)
|
||||
if prev_best is None or is_better(valid_loss, prev_best):
|
||||
should_stop_early.best = valid_loss
|
||||
should_stop_early.num_runs = 0
|
||||
return False
|
||||
else:
|
||||
should_stop_early.num_runs += 1
|
||||
if should_stop_early.num_runs >= args.patience:
|
||||
logger.info(
|
||||
"early stop since valid performance hasn't improved for last {} runs".format(
|
||||
args.patience
|
||||
)
|
||||
)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
@metrics.aggregate("train")
|
||||
def train(args, trainer, task, epoch_itr):
|
||||
"""Train the model for one epoch and return validation losses."""
|
||||
# Initialize data iterator
|
||||
itr = epoch_itr.next_epoch_itr(
|
||||
fix_batches_to_gpus=args.fix_batches_to_gpus,
|
||||
shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
|
||||
)
|
||||
update_freq = (
|
||||
args.update_freq[epoch_itr.epoch - 1]
|
||||
if epoch_itr.epoch <= len(args.update_freq)
|
||||
else args.update_freq[-1]
|
||||
)
|
||||
itr = iterators.GroupedIterator(itr, update_freq)
|
||||
if getattr(args, "tpu", False):
|
||||
itr = utils.tpu_data_loader(itr)
|
||||
progress = progress_bar.progress_bar(
|
||||
itr,
|
||||
log_format=args.log_format,
|
||||
log_interval=args.log_interval,
|
||||
epoch=epoch_itr.epoch,
|
||||
tensorboard_logdir=(
|
||||
args.tensorboard_logdir if distributed_utils.is_master(args) else None
|
||||
),
|
||||
default_log_format=("tqdm" if not args.no_progress_bar else "simple"),
|
||||
)
|
||||
|
||||
trainer.begin_epoch(epoch_itr.epoch)
|
||||
|
||||
valid_losses = [None]
|
||||
valid_subsets = args.valid_subset.split(",")
|
||||
should_stop = False
|
||||
num_updates = trainer.get_num_updates()
|
||||
for i, samples in enumerate(progress):
|
||||
with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function(
|
||||
"train_step-%d" % i
|
||||
):
|
||||
log_output = trainer.train_step(samples)
|
||||
|
||||
if log_output is not None: # not OOM, overflow, ...
|
||||
# log mid-epoch stats
|
||||
num_updates = trainer.get_num_updates()
|
||||
if num_updates % args.log_interval == 0:
|
||||
stats = get_training_stats(metrics.get_smoothed_values("train_inner"))
|
||||
progress.log(stats, tag="train_inner", step=num_updates)
|
||||
|
||||
# reset mid-epoch stats after each log interval
|
||||
# the end-of-epoch stats will still be preserved
|
||||
metrics.reset_meters("train_inner")
|
||||
|
||||
end_of_epoch = not itr.has_next()
|
||||
valid_losses, should_stop = validate_and_save(
|
||||
args, trainer, task, epoch_itr, valid_subsets, end_of_epoch
|
||||
)
|
||||
|
||||
if should_stop:
|
||||
break
|
||||
|
||||
# log end-of-epoch stats
|
||||
logger.info("end of epoch {} (average epoch stats below)".format(epoch_itr.epoch))
|
||||
stats = get_training_stats(metrics.get_smoothed_values("train"))
|
||||
progress.print(stats, tag="train", step=num_updates)
|
||||
|
||||
# reset epoch-level meters
|
||||
metrics.reset_meters("train")
|
||||
return valid_losses, should_stop
|
||||
|
||||
|
||||
def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoch):
|
||||
num_updates = trainer.get_num_updates()
|
||||
max_update = args.max_update or math.inf
|
||||
do_save = (
|
||||
(end_of_epoch and epoch_itr.epoch % args.save_interval == 0)
|
||||
or num_updates >= max_update
|
||||
or (
|
||||
args.save_interval_updates > 0
|
||||
and num_updates > 0
|
||||
and num_updates % args.save_interval_updates == 0
|
||||
and num_updates >= args.validate_after_updates
|
||||
)
|
||||
)
|
||||
do_validate = (
|
||||
(not end_of_epoch and do_save) # validate during mid-epoch saves
|
||||
or (end_of_epoch and epoch_itr.epoch % args.validate_interval == 0)
|
||||
or num_updates >= max_update
|
||||
or (
|
||||
args.validate_interval_updates > 0
|
||||
and num_updates > 0
|
||||
and num_updates % args.validate_interval_updates == 0
|
||||
)
|
||||
) and not args.disable_validation
|
||||
|
||||
# Validate
|
||||
valid_losses = [None]
|
||||
if do_validate:
|
||||
valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
|
||||
|
||||
# Stopping conditions
|
||||
should_stop = (
|
||||
should_stop_early(args, valid_losses[0])
|
||||
or num_updates >= max_update
|
||||
or (
|
||||
args.stop_time_hours > 0
|
||||
and trainer.cumulative_training_time() / (60 * 60) > args.stop_time_hours
|
||||
)
|
||||
)
|
||||
|
||||
# Save checkpoint
|
||||
if do_save or should_stop:
|
||||
logger.info("begin save checkpoint")
|
||||
checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
|
||||
|
||||
return valid_losses, should_stop
|
||||
|
||||
|
||||
def get_training_stats(stats):
|
||||
stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0)
|
||||
return stats
|
||||
|
||||
|
||||
def validate(args, trainer, task, epoch_itr, subsets):
|
||||
"""Evaluate the model on the validation set(s) and return the losses."""
|
||||
|
||||
if args.fixed_validation_seed is not None:
|
||||
# set fixed seed for every validation
|
||||
utils.set_torch_seed(args.fixed_validation_seed)
|
||||
|
||||
trainer.begin_valid_epoch(epoch_itr.epoch)
|
||||
valid_losses = []
|
||||
for subset in subsets:
|
||||
logger.info('begin validation on "{}" subset'.format(subset))
|
||||
|
||||
# Initialize data iterator
|
||||
itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False)
|
||||
if getattr(args, "tpu", False):
|
||||
itr = utils.tpu_data_loader(itr)
|
||||
progress = progress_bar.progress_bar(
|
||||
itr,
|
||||
log_format=args.log_format,
|
||||
log_interval=args.log_interval,
|
||||
epoch=epoch_itr.epoch,
|
||||
prefix=f"valid on '{subset}' subset",
|
||||
tensorboard_logdir=(
|
||||
args.tensorboard_logdir if distributed_utils.is_master(args) else None
|
||||
),
|
||||
default_log_format=("tqdm" if not args.no_progress_bar else "simple"),
|
||||
)
|
||||
|
||||
# create a new root metrics aggregator so validation metrics
|
||||
# don't pollute other aggregators (e.g., train meters)
|
||||
with metrics.aggregate(new_root=True) as agg:
|
||||
for sample in progress:
|
||||
trainer.valid_step(sample)
|
||||
|
||||
# log validation stats
|
||||
stats = get_valid_stats(args, trainer, agg.get_smoothed_values())
|
||||
progress.print(stats, tag=subset, step=trainer.get_num_updates())
|
||||
|
||||
valid_losses.append(stats[args.best_checkpoint_metric])
|
||||
return valid_losses
|
||||
|
||||
|
||||
def get_valid_stats(args, trainer, stats):
|
||||
stats["num_updates"] = trainer.get_num_updates()
|
||||
if hasattr(checkpoint_utils.save_checkpoint, "best"):
|
||||
key = "best_{0}".format(args.best_checkpoint_metric)
|
||||
best_function = max if args.maximize_best_checkpoint_metric else min
|
||||
stats[key] = best_function(
|
||||
checkpoint_utils.save_checkpoint.best, stats[args.best_checkpoint_metric]
|
||||
)
|
||||
return stats
|
||||
|
||||
|
||||
def cli_main(modify_parser=None):
|
||||
parser = options.get_training_parser()
|
||||
args = options.parse_args_and_arch(parser, modify_parser=modify_parser)
|
||||
if args.profile:
|
||||
with torch.cuda.profiler.profile():
|
||||
with torch.autograd.profiler.emit_nvtx():
|
||||
distributed_utils.call_main(args, main)
|
||||
else:
|
||||
distributed_utils.call_main(args, main)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli_main()
|
||||
|
||||
'''
|
||||
../StyleCtrlData/train_data/1026_EMO/data-bin
|
||||
--task
|
||||
language_modeling_control
|
||||
--arch
|
||||
style_model_std_v2
|
||||
--control_mode
|
||||
embedding_v2_bos
|
||||
--command_path
|
||||
../StyleCtrlData/train_data/1026_EMO/bucket2_command_thres_EMO/emotion_reselect_29
|
||||
--feature_num
|
||||
29
|
||||
--bucket_num
|
||||
2
|
||||
--mask_prob
|
||||
-1
|
||||
--sample-break-mode
|
||||
eos
|
||||
--tokens-per-sample
|
||||
10000000
|
||||
--max-tokens
|
||||
10000000
|
||||
--batch-size
|
||||
2
|
||||
--batch-size-valid
|
||||
1
|
||||
--update-freq
|
||||
1
|
||||
--optimizer
|
||||
adam
|
||||
--adam-betas
|
||||
(0.9,0.98)
|
||||
--adam-eps
|
||||
1e-9
|
||||
--weight-decay
|
||||
0.01
|
||||
--lr
|
||||
0.0001
|
||||
--lr-scheduler
|
||||
inverse_sqrt
|
||||
--warmup-updates
|
||||
1000
|
||||
--log-format
|
||||
simple
|
||||
--log-interval
|
||||
10
|
||||
--tensorboard-logdir
|
||||
tb_log/controlled-1
|
||||
--num-workers
|
||||
0
|
||||
--max-update
|
||||
600000
|
||||
--validate-interval
|
||||
1000000000
|
||||
--save-interval-updates
|
||||
5000
|
||||
--save-dir
|
||||
checkpoints/controlled
|
||||
--no-epoch-checkpoints
|
||||
--cpu
|
||||
'''
|
||||
# 39169531
|
||||
|
||||
|
||||
'''
|
||||
../StyleCtrlData/train_data/1024_EMO/bucket2_attributes29/data-bin
|
||||
--task
|
||||
language_modeling_control
|
||||
--arch
|
||||
style_model_debug
|
||||
--control_mode
|
||||
prefix_token
|
||||
--command_path
|
||||
None
|
||||
--feature_num
|
||||
29
|
||||
--bucket_num
|
||||
2
|
||||
--mask_prob
|
||||
-1
|
||||
--add_emotion_token
|
||||
0
|
||||
--sample-break-mode
|
||||
eos
|
||||
--tokens-per-sample
|
||||
10000000
|
||||
--max-tokens
|
||||
10000000
|
||||
--batch-size
|
||||
2
|
||||
--batch-size-valid
|
||||
1
|
||||
--update-freq
|
||||
1
|
||||
--optimizer
|
||||
adam
|
||||
--adam-betas
|
||||
(0.9,0.98)
|
||||
--adam-eps
|
||||
1e-9
|
||||
--weight-decay
|
||||
0.01
|
||||
--lr
|
||||
0.0001
|
||||
--lr-scheduler
|
||||
inverse_sqrt
|
||||
--warmup-updates
|
||||
1000
|
||||
--log-format
|
||||
simple
|
||||
--log-interval
|
||||
10
|
||||
--tensorboard-logdir
|
||||
tb_log/controlled-1
|
||||
--num-workers
|
||||
0
|
||||
--max-update
|
||||
600000
|
||||
--validate-interval
|
||||
1000000000
|
||||
--save-interval-updates
|
||||
5000
|
||||
--save-dir
|
||||
checkpoints/controlled
|
||||
--no-epoch-checkpoints
|
||||
--cpu
|
||||
'''
|
||||
|
||||
|
||||
|
||||
'''
|
||||
../StyleCtrlData/train_data/1026_EMO/data-bin
|
||||
--task
|
||||
language_modeling_control
|
||||
--arch
|
||||
linear_transformer_lm_debug
|
||||
--control_mode
|
||||
embedding_v2
|
||||
--command_path
|
||||
../StyleCtrlData/train_data/1026_EMO/bucket2_command_thres_EMO/emotion_rank_100
|
||||
--feature_num
|
||||
100
|
||||
--bucket_num
|
||||
2
|
||||
--mask_prob
|
||||
-1
|
||||
--add_emotion_token
|
||||
0
|
||||
--sample-break-mode
|
||||
eos
|
||||
--tokens-per-sample
|
||||
10000000
|
||||
--max-tokens
|
||||
10000000
|
||||
--batch-size
|
||||
2
|
||||
--batch-size-valid
|
||||
1
|
||||
--update-freq
|
||||
1
|
||||
--optimizer
|
||||
adam
|
||||
--adam-betas
|
||||
(0.9,0.98)
|
||||
--adam-eps
|
||||
1e-9
|
||||
--weight-decay
|
||||
0.01
|
||||
--lr
|
||||
0.0001
|
||||
--lr-scheduler
|
||||
inverse_sqrt
|
||||
--warmup-updates
|
||||
1000
|
||||
--log-format
|
||||
simple
|
||||
--log-interval
|
||||
10
|
||||
--tensorboard-logdir
|
||||
tb_log/controlled-1
|
||||
--num-workers
|
||||
1
|
||||
--max-update
|
||||
20
|
||||
--validate-interval
|
||||
1000000000
|
||||
--save-interval-updates
|
||||
10
|
||||
--save-dir
|
||||
checkpoints/controlled
|
||||
--patience
|
||||
10
|
||||
--no-epoch-checkpoints
|
||||
'''
|
||||
|
||||
|
Загрузка…
Ссылка в новой задаче