This commit is contained in:
peillu 2023-06-29 15:14:58 +08:00
Родитель 1fb05b46f4
Коммит e93f6117de
51 изменённых файлов: 7469 добавлений и 0 удалений

40
emogen/Piano_gen.sh Normal file
Просмотреть файл

@ -0,0 +1,40 @@
# models parameters
datasets_name="Piano"
control_mode="embedding_v2"
feature_num=100
bucket_num=2
command_subset="emotion_rank_100"
rank=0
gpus=1
model_name="${datasets_name}"
MODEL_NAME="${model_name}"
tgt=${1}
checkpoint_name="checkpoint_best"
checkpoint_path="checkpoints/${model_name}/${checkpoint_name}.pt"
command_path="data/infer_input/inference_command.npy"
save_root="generation/${model_name}-${checkpoint_name}/Q${tgt}"
mkdir -p ${save_root}
export CUDA_VISIBLE_DEVICES=${rank}
echo "generating from ${checkpoint_path} with emotion Q${tgt}!"
python interactive.py \
data/${datasets_name}/data-bin \
--task language_modeling_control \
--path $checkpoint_path \
--ctrl_command_path $command_path \
--save_root $save_root \
--tgt_emotion $tgt \
--need_num 2 \
--max-len-b 1280 \
--min-len 512 \
--sampling \
--beam 1 \
--sampling-topp 0.9 \
--temperature 1.0 \
--buffer-size 2 \
--batch-size 2

66
emogen/Piano_train.sh Normal file
Просмотреть файл

@ -0,0 +1,66 @@
# models parameters
datasets_name="Piano"
control_mode="embedding_v2"
feature_num=100
bucket_num=2
command_subset="emotion_rank_100"
rank=0
gpus=1
WARMUP_UPDATES=16000 # Warmup the learning rate over this many updates
PEAK_LR=1e-4 # Peak learning rate, adjust as needed
model_name="${datasets_name}-${control_mode}-bucket${bucket_num}-${command_subset}"
MODEL_NAME="${model_name}"
echo "training the model: ${MODEL_NAME}!"
BATCH_SIZE=8 # BATCH
UPDATE_FREQ=1
DATA_DIR=data/${datasets_name} # Data dir
export MKL_THREADING_LAYER=GNU
OMP_NUM_THREADS=$(cat /proc/cpuinfo| grep "processor"| wc -l)
let "port_rank=$rank+666"
export CUDA_VISIBLE_DEVICES=${rank}
#python -m torch.distributed.launch \
#--nproc_per_node=${gpus} \
#--master_port=${port_rank} \
python train.py \
$DATA_DIR/data-bin \
--truncated_length 1280 \
--task language_modeling_control \
--arch linear_transformer_lm_std \
--control_mode $control_mode \
--command_path $DATA_DIR \
--feature_num $feature_num \
--bucket_num $bucket_num \
--sample-break-mode eos \
--tokens-per-sample 10000000 \
--max-tokens 10000000 \
--batch-size $BATCH_SIZE \
--batch-size-valid $BATCH_SIZE \
--update-freq $UPDATE_FREQ \
--optimizer adam \
--adam-betas '(0.9, 0.98)' \
--adam-eps 1e-9 \
--weight-decay 0.01 \
--lr $PEAK_LR \
--lr-scheduler inverse_sqrt \
--warmup-updates $WARMUP_UPDATES \
--log-format simple \
--log-interval 10 \
--tensorboard-logdir tb_log/$MODEL_NAME \
--num-workers "$OMP_NUM_THREADS" \
--max-update 600000 \
--validate-interval 1 \
--save-interval-updates 2000 \
--save-dir checkpoints/$MODEL_NAME \
--no-epoch-checkpoints \
--find-unused-parameters \
--patience 20

39
emogen/TopMAGD_gen.sh Normal file
Просмотреть файл

@ -0,0 +1,39 @@
# models parameters
datasets_name="TopMAGD"
control_mode="embedding_v2"
bucket_num=2
command_subset="emotion_rank_100"
rank=0
gpus=1
model_name="${datasets_name}"
MODEL_NAME="${model_name}"
tgt=${1}
checkpoint_name="checkpoint_best"
checkpoint_path="checkpoints/${model_name}/${checkpoint_name}.pt"
command_path="data/infer_input/inference_command.npy"
save_root="generation/${model_name}-${checkpoint_name}/Q${tgt}"
mkdir -p ${save_root}
export CUDA_VISIBLE_DEVICES=${rank}
echo "generating from ${checkpoint_path} with emotion Q${tgt}!"
python interactive.py \
data/${datasets_name}/data-bin \
--task language_modeling_control \
--path $checkpoint_path \
--ctrl_command_path $command_path \
--save_root $save_root \
--tgt_emotion $tgt \
--need_num 2 \
--max-len-b 2560 \
--min-len 512 \
--sampling \
--beam 1 \
--sampling-topp 0.9 \
--temperature 1.0 \
--buffer-size 2 \
--batch-size 2

66
emogen/TopMAGD_train.sh Normal file
Просмотреть файл

@ -0,0 +1,66 @@
# models parameters
datasets_name="TopMAGD"
control_mode="embedding_v2"
feature_num=100
bucket_num=2
command_subset="emotion_rank_100"
rank=0
gpus=1
WARMUP_UPDATES=16000 # Warmup the learning rate over this many updates
PEAK_LR=1e-4 # Peak learning rate, adjust as needed
model_name="${datasets_name}-${control_mode}-bucket${bucket_num}-${command_subset}"
MODEL_NAME="${model_name}"
echo "training the model: ${MODEL_NAME}!"
BATCH_SIZE=1 # BATCH
UPDATE_FREQ=1
DATA_DIR=data/${datasets_name} # Data dir
export MKL_THREADING_LAYER=GNU # for "import numpy" "import torch" order bug
OMP_NUM_THREADS=$(cat /proc/cpuinfo| grep "processor"| wc -l)
let "port_rank=$rank+666"
#python -m torch.distributed.launch \
#--nproc_per_node=${gpus} \
#--master_port=${port_rank} \
python train.py \
$DATA_DIR/data-bin \
--truncated_length 2560 \
--task language_modeling_control \
--arch linear_transformer_lm_std \
--control_mode $control_mode \
--command_path $DATA_DIR \
--feature_num $feature_num \
--bucket_num $bucket_num \
--sample-break-mode eos \
--tokens-per-sample 10000000 \
--max-tokens 10000000 \
--batch-size $BATCH_SIZE \
--batch-size-valid $BATCH_SIZE \
--update-freq $UPDATE_FREQ \
--optimizer adam \
--adam-betas '(0.9, 0.98)' \
--adam-eps 1e-9 \
--weight-decay 0.01 \
--lr $PEAK_LR \
--lr-scheduler inverse_sqrt \
--warmup-updates $WARMUP_UPDATES \
--log-format simple \
--log-interval 10 \
--tensorboard-logdir tb_log/$MODEL_NAME \
--num-workers "$OMP_NUM_THREADS" \
--max-update 600000 \
--validate-interval 2 \
--save-interval-updates 2000 \
--save-dir checkpoints/$MODEL_NAME \
--no-epoch-checkpoints \
--find-unused-parameters \
--patience 20

Просмотреть файл

@ -0,0 +1,458 @@
b-1 1
d-0 1
d-1 1
d-2 1
d-3 1
d-4 1
d-5 1
d-6 1
d-7 1
d-8 1
d-9 1
d-10 1
d-11 1
d-12 1
d-13 1
d-14 1
d-15 1
d-16 1
d-17 1
d-18 1
d-19 1
d-20 1
d-21 1
d-22 1
d-23 1
d-24 1
d-25 1
d-26 1
d-27 1
d-28 1
d-29 1
d-30 1
d-31 1
d-32 1
d-33 1
d-34 1
d-35 1
d-36 1
d-37 1
d-38 1
d-39 1
d-40 1
d-41 1
d-42 1
d-43 1
d-44 1
d-45 1
d-46 1
d-47 1
d-48 1
d-49 1
d-50 1
d-51 1
d-52 1
d-53 1
d-54 1
d-55 1
d-56 1
d-57 1
d-58 1
d-59 1
d-60 1
d-61 1
d-62 1
d-63 1
d-64 1
d-65 1
d-66 1
d-67 1
d-68 1
d-69 1
d-70 1
d-71 1
d-72 1
d-73 1
d-74 1
d-75 1
d-76 1
d-77 1
d-78 1
d-79 1
d-80 1
d-81 1
d-82 1
d-83 1
d-84 1
d-85 1
d-86 1
d-87 1
d-88 1
d-89 1
d-90 1
d-91 1
d-92 1
d-93 1
d-94 1
d-95 1
i-0 1
i-1 1
o-0 1
o-1 1
o-2 1
o-3 1
o-4 1
o-5 1
o-6 1
o-7 1
o-8 1
o-9 1
o-10 1
o-11 1
o-12 1
o-13 1
o-14 1
o-15 1
o-16 1
o-17 1
o-18 1
o-19 1
o-20 1
o-21 1
o-22 1
o-23 1
o-24 1
o-25 1
o-26 1
o-27 1
o-28 1
o-29 1
o-30 1
o-31 1
o-32 1
o-33 1
o-34 1
o-35 1
o-36 1
o-37 1
o-38 1
o-39 1
o-40 1
o-41 1
o-42 1
o-43 1
o-44 1
o-45 1
o-46 1
o-47 1
o-48 1
o-49 1
o-50 1
o-51 1
o-52 1
o-53 1
o-54 1
o-55 1
o-56 1
o-57 1
o-58 1
o-59 1
o-60 1
o-61 1
o-62 1
o-63 1
o-64 1
o-65 1
o-66 1
o-67 1
o-68 1
o-69 1
o-70 1
o-71 1
o-72 1
o-73 1
o-74 1
o-75 1
o-76 1
o-77 1
o-78 1
o-79 1
o-80 1
o-81 1
o-82 1
o-83 1
o-84 1
o-85 1
o-86 1
o-87 1
o-88 1
o-89 1
o-90 1
o-91 1
o-92 1
o-93 1
o-94 1
o-95 1
p-0 1
p-1 1
p-2 1
p-3 1
p-4 1
p-5 1
p-6 1
p-7 1
p-8 1
p-9 1
p-10 1
p-11 1
p-12 1
p-13 1
p-14 1
p-15 1
p-16 1
p-17 1
p-18 1
p-19 1
p-20 1
p-21 1
p-22 1
p-23 1
p-24 1
p-25 1
p-26 1
p-27 1
p-28 1
p-29 1
p-30 1
p-31 1
p-32 1
p-33 1
p-34 1
p-35 1
p-36 1
p-37 1
p-38 1
p-39 1
p-40 1
p-41 1
p-42 1
p-43 1
p-44 1
p-45 1
p-46 1
p-47 1
p-48 1
p-49 1
p-50 1
p-51 1
p-52 1
p-53 1
p-54 1
p-55 1
p-56 1
p-57 1
p-58 1
p-59 1
p-60 1
p-61 1
p-62 1
p-63 1
p-64 1
p-65 1
p-66 1
p-67 1
p-68 1
p-69 1
p-70 1
p-71 1
p-72 1
p-73 1
p-74 1
p-75 1
p-76 1
p-77 1
p-78 1
p-79 1
p-80 1
p-81 1
p-82 1
p-83 1
p-84 1
p-85 1
p-86 1
p-87 1
p-88 1
p-89 1
p-90 1
p-91 1
p-92 1
p-93 1
p-94 1
p-95 1
p-96 1
p-97 1
p-98 1
p-99 1
p-100 1
p-101 1
p-102 1
p-103 1
p-104 1
p-105 1
p-106 1
p-107 1
p-108 1
p-109 1
p-110 1
p-111 1
p-112 1
p-113 1
p-114 1
p-115 1
p-116 1
p-117 1
p-118 1
p-119 1
p-120 1
p-122 1
p-123 1
p-124 1
p-125 1
p-126 1
p-127 1
s-1 1
s-2 1
s-3 1
s-4 1
s-5 1
s-6 1
s-7 1
s-8 1
s-9 1
s-10 1
s-11 1
s-12 1
s-13 1
s-14 1
s-15 1
s-16 1
s-17 1
s-18 1
s-19 1
s-20 1
s-21 1
s-22 1
s-23 1
s-24 1
s-25 1
s-27 1
s-28 1
s-29 1
s-30 1
s-31 1
s-32 1
s-33 1
s-34 1
s-35 1
s-36 1
s-38 1
s-39 1
s-40 1
s-41 1
s-42 1
s-44 1
s-45 1
s-46 1
s-48 1
s-49 1
s-50 1
s-52 1
s-57 1
s-62 1
s-70 1
s-74 1
s-76 1
s-78 1
s-82 1
s-127 1
t-0 1
t-1 1
t-2 1
t-3 1
t-4 1
t-5 1
t-6 1
t-7 1
t-8 1
t-9 1
t-10 1
t-11 1
t-12 1
t-13 1
t-14 1
t-15 1
t-16 1
t-17 1
t-18 1
t-19 1
t-20 1
t-21 1
t-22 1
t-23 1
t-24 1
t-25 1
t-26 1
t-27 1
t-28 1
t-29 1
t-30 1
t-31 1
t-32 1
t-33 1
t-34 1
t-35 1
t-36 1
t-37 1
t-38 1
t-39 1
t-40 1
t-41 1
t-42 1
t-43 1
t-44 1
t-45 1
t-46 1
t-47 1
t-48 1
v-0 1
v-1 1
v-2 1
v-3 1
v-4 1
v-5 1
v-6 1
v-7 1
v-8 1
v-9 1
v-10 1
v-11 1
v-12 1
v-13 1
v-14 1
v-15 1
v-16 1
v-17 1
v-18 1
v-19 1
v-20 1
v-21 1
v-22 1
v-23 1
v-24 1
v-25 1
v-26 1
v-27 1
v-28 1
v-29 1
v-30 1
v-31 1

Двоичные данные
emogen/data/Piano/midi/0.mid Normal file

Двоичный файл не отображается.

Двоичные данные
emogen/data/Piano/midi/1.mid Normal file

Двоичный файл не отображается.

Двоичные данные
emogen/data/Piano/midi/10.mid Normal file

Двоичный файл не отображается.

Двоичные данные
emogen/data/Piano/midi/11.mid Normal file

Двоичный файл не отображается.

Двоичные данные
emogen/data/Piano/midi/2.mid Normal file

Двоичный файл не отображается.

Двоичные данные
emogen/data/Piano/midi/3.mid Normal file

Двоичный файл не отображается.

Двоичные данные
emogen/data/Piano/midi/4.mid Normal file

Двоичный файл не отображается.

Двоичные данные
emogen/data/Piano/midi/5.mid Normal file

Двоичный файл не отображается.

Двоичные данные
emogen/data/Piano/midi/6.mid Normal file

Двоичный файл не отображается.

Двоичные данные
emogen/data/Piano/midi/7.mid Normal file

Двоичный файл не отображается.

Двоичные данные
emogen/data/Piano/midi/8.mid Normal file

Двоичный файл не отображается.

Двоичные данные
emogen/data/Piano/midi/9.mid Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,688 @@
b-1 1
d-0 1
d-1 1
d-2 1
d-3 1
d-4 1
d-5 1
d-6 1
d-7 1
d-8 1
d-9 1
d-10 1
d-11 1
d-12 1
d-13 1
d-14 1
d-15 1
d-16 1
d-17 1
d-18 1
d-19 1
d-20 1
d-21 1
d-22 1
d-23 1
d-24 1
d-25 1
d-26 1
d-27 1
d-28 1
d-29 1
d-30 1
d-31 1
d-32 1
d-33 1
d-34 1
d-35 1
d-36 1
d-37 1
d-38 1
d-39 1
d-40 1
d-41 1
d-42 1
d-43 1
d-44 1
d-45 1
d-46 1
d-47 1
d-48 1
d-49 1
d-50 1
d-51 1
d-52 1
d-53 1
d-54 1
d-55 1
d-56 1
d-57 1
d-58 1
d-59 1
d-60 1
d-61 1
d-62 1
d-63 1
d-64 1
d-65 1
d-66 1
d-67 1
d-68 1
d-69 1
d-70 1
d-71 1
d-72 1
d-73 1
d-74 1
d-75 1
d-76 1
d-77 1
d-78 1
d-79 1
d-80 1
d-81 1
d-82 1
d-83 1
d-84 1
d-85 1
d-86 1
d-87 1
d-88 1
d-89 1
d-90 1
d-91 1
d-93 1
d-94 1
d-95 1
i-0 1
i-1 1
i-2 1
i-3 1
i-4 1
i-5 1
i-6 1
i-7 1
i-8 1
i-9 1
i-10 1
i-11 1
i-12 1
i-13 1
i-14 1
i-15 1
i-16 1
i-17 1
i-18 1
i-19 1
i-20 1
i-21 1
i-22 1
i-23 1
i-24 1
i-25 1
i-26 1
i-27 1
i-28 1
i-29 1
i-30 1
i-31 1
i-32 1
i-33 1
i-34 1
i-35 1
i-36 1
i-37 1
i-38 1
i-39 1
i-40 1
i-41 1
i-42 1
i-43 1
i-44 1
i-45 1
i-46 1
i-47 1
i-48 1
i-49 1
i-50 1
i-51 1
i-52 1
i-53 1
i-54 1
i-55 1
i-56 1
i-57 1
i-58 1
i-59 1
i-60 1
i-61 1
i-62 1
i-63 1
i-64 1
i-65 1
i-66 1
i-67 1
i-68 1
i-69 1
i-70 1
i-71 1
i-72 1
i-73 1
i-74 1
i-75 1
i-76 1
i-77 1
i-78 1
i-79 1
i-80 1
i-81 1
i-82 1
i-83 1
i-84 1
i-85 1
i-86 1
i-87 1
i-88 1
i-89 1
i-90 1
i-91 1
i-92 1
i-93 1
i-94 1
i-95 1
i-96 1
i-97 1
i-98 1
i-99 1
i-100 1
i-101 1
i-102 1
i-103 1
i-104 1
i-105 1
i-106 1
i-107 1
i-108 1
i-109 1
i-110 1
i-111 1
i-112 1
i-113 1
i-114 1
i-115 1
i-116 1
i-117 1
i-118 1
i-119 1
i-120 1
i-121 1
i-122 1
i-123 1
i-124 1
i-125 1
i-126 1
i-127 1
i-128 1
o-0 1
o-1 1
o-2 1
o-3 1
o-4 1
o-5 1
o-6 1
o-7 1
o-8 1
o-9 1
o-10 1
o-11 1
o-12 1
o-13 1
o-14 1
o-15 1
o-16 1
o-17 1
o-18 1
o-19 1
o-20 1
o-21 1
o-22 1
o-23 1
o-24 1
o-25 1
o-26 1
o-27 1
o-28 1
o-29 1
o-30 1
o-31 1
o-32 1
o-33 1
o-34 1
o-35 1
o-36 1
o-37 1
o-38 1
o-39 1
o-40 1
o-41 1
o-42 1
o-43 1
o-44 1
o-45 1
o-46 1
o-47 1
o-48 1
o-49 1
o-50 1
o-51 1
o-52 1
o-53 1
o-54 1
o-55 1
o-56 1
o-57 1
o-58 1
o-59 1
o-60 1
o-61 1
o-62 1
o-63 1
o-64 1
o-65 1
o-66 1
o-67 1
o-68 1
o-69 1
o-70 1
o-71 1
o-72 1
o-73 1
o-74 1
o-75 1
o-76 1
o-77 1
o-78 1
o-79 1
o-80 1
o-81 1
o-82 1
o-83 1
o-84 1
o-85 1
o-86 1
o-87 1
o-88 1
o-89 1
o-90 1
o-91 1
o-92 1
o-93 1
o-94 1
o-95 1
p-0 1
p-1 1
p-2 1
p-3 1
p-4 1
p-5 1
p-6 1
p-7 1
p-8 1
p-9 1
p-10 1
p-11 1
p-12 1
p-13 1
p-14 1
p-15 1
p-16 1
p-17 1
p-18 1
p-19 1
p-20 1
p-21 1
p-22 1
p-23 1
p-24 1
p-25 1
p-26 1
p-27 1
p-28 1
p-29 1
p-30 1
p-31 1
p-32 1
p-33 1
p-34 1
p-35 1
p-36 1
p-37 1
p-38 1
p-39 1
p-40 1
p-41 1
p-42 1
p-43 1
p-44 1
p-45 1
p-46 1
p-47 1
p-48 1
p-49 1
p-50 1
p-51 1
p-52 1
p-53 1
p-54 1
p-55 1
p-56 1
p-57 1
p-58 1
p-59 1
p-60 1
p-61 1
p-62 1
p-63 1
p-64 1
p-65 1
p-66 1
p-67 1
p-68 1
p-69 1
p-70 1
p-71 1
p-72 1
p-73 1
p-74 1
p-75 1
p-76 1
p-77 1
p-78 1
p-79 1
p-80 1
p-81 1
p-82 1
p-83 1
p-84 1
p-85 1
p-86 1
p-87 1
p-88 1
p-89 1
p-90 1
p-91 1
p-92 1
p-93 1
p-94 1
p-95 1
p-96 1
p-97 1
p-98 1
p-99 1
p-100 1
p-101 1
p-102 1
p-103 1
p-104 1
p-105 1
p-106 1
p-107 1
p-108 1
p-109 1
p-110 1
p-111 1
p-112 1
p-113 1
p-114 1
p-115 1
p-116 1
p-117 1
p-118 1
p-119 1
p-120 1
p-121 1
p-122 1
p-123 1
p-124 1
p-125 1
p-126 1
p-127 1
p-128 1
p-129 1
p-130 1
p-131 1
p-132 1
p-133 1
p-134 1
p-135 1
p-136 1
p-137 1
p-140 1
p-141 1
p-142 1
p-143 1
p-144 1
p-145 1
p-146 1
p-147 1
p-148 1
p-149 1
p-150 1
p-151 1
p-152 1
p-153 1
p-154 1
p-155 1
p-156 1
p-157 1
p-158 1
p-159 1
p-160 1
p-161 1
p-162 1
p-163 1
p-164 1
p-165 1
p-166 1
p-167 1
p-168 1
p-169 1
p-170 1
p-171 1
p-172 1
p-173 1
p-174 1
p-175 1
p-176 1
p-177 1
p-178 1
p-179 1
p-180 1
p-181 1
p-182 1
p-183 1
p-184 1
p-185 1
p-186 1
p-187 1
p-188 1
p-189 1
p-190 1
p-191 1
p-192 1
p-193 1
p-194 1
p-195 1
p-196 1
p-197 1
p-198 1
p-199 1
p-200 1
p-201 1
p-202 1
p-203 1
p-204 1
p-205 1
p-206 1
p-207 1
p-208 1
p-209 1
p-210 1
p-211 1
p-212 1
p-213 1
p-214 1
p-215 1
p-216 1
p-217 1
p-218 1
p-219 1
p-220 1
p-221 1
p-222 1
p-223 1
p-224 1
p-225 1
p-226 1
p-227 1
p-228 1
p-230 1
p-231 1
p-232 1
p-233 1
p-234 1
p-235 1
p-236 1
p-238 1
p-239 1
p-242 1
p-243 1
p-244 1
p-248 1
p-249 1
p-250 1
p-251 1
p-252 1
p-253 1
p-254 1
p-255 1
s-3 1
s-4 1
s-5 1
s-6 1
s-7 1
s-8 1
s-9 1
s-10 1
s-11 1
s-12 1
s-13 1
s-14 1
s-15 1
s-16 1
s-17 1
s-18 1
s-19 1
s-20 1
s-21 1
s-22 1
s-23 1
s-24 1
s-25 1
s-26 1
s-28 1
s-31 1
s-33 1
s-34 1
s-38 1
s-40 1
s-41 1
s-44 1
s-45 1
s-78 1
t-0 1
t-1 1
t-2 1
t-3 1
t-4 1
t-5 1
t-6 1
t-7 1
t-8 1
t-9 1
t-10 1
t-11 1
t-12 1
t-13 1
t-14 1
t-15 1
t-16 1
t-17 1
t-18 1
t-19 1
t-20 1
t-21 1
t-22 1
t-23 1
t-24 1
t-25 1
t-26 1
t-27 1
t-28 1
t-29 1
t-30 1
t-31 1
t-32 1
t-33 1
t-34 1
t-35 1
t-36 1
t-37 1
t-38 1
t-39 1
t-40 1
t-41 1
t-42 1
t-43 1
t-44 1
t-45 1
t-46 1
t-47 1
t-48 1
v-0 1
v-1 1
v-2 1
v-3 1
v-4 1
v-5 1
v-6 1
v-7 1
v-8 1
v-9 1
v-10 1
v-11 1
v-12 1
v-13 1
v-14 1
v-15 1
v-16 1
v-17 1
v-18 1
v-19 1
v-20 1
v-21 1
v-22 1
v-23 1
v-24 1
v-25 1
v-26 1
v-27 1
v-28 1
v-29 1
v-30 1
v-31 1
Q1 1
Q2 1
Q3 1
Q4 1
None 1

Двоичные данные
emogen/data/all_feature_name.npy Normal file

Двоичный файл не отображается.

Двоичные данные
emogen/data/feature_index.npy Normal file

Двоичный файл не отображается.

Двоичные данные
emogen/data/infer_input/inference_command.npy Normal file

Двоичный файл не отображается.

Двоичные данные
emogen/data/threshold.npy Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,85 @@
import json
import multiprocessing
import random
import sys
sys.path.append("..")
import miditoolkit
import numpy.random
import pandas as pd
import os, collections, pickle, shutil
import numpy as np
from MidiProcessor.midiprocessor import MidiEncoder, midi_utils, MidiDecoder
from tqdm import tqdm
import gc
from multiprocessing import Pool, Manager, Lock
import math, random
from typing import List, Dict
from functools import partial
from jSymbolic_lib.jSymbolic_util import read_all_feature
from sklearn.preprocessing import StandardScaler
import joblib
random.seed(42)
np.random.seed(42)
def binarize_data(path_root):
save_root = path_root + "/data-bin"
dict_path = save_root + f"/dict.txt"
command = f"fairseq-preprocess --only-source --destdir {save_root} --srcdict {dict_path} "\
f"--validpref {path_root}/valid.txt --testpref {path_root}/test.txt --trainpref {path_root}/train.txt --workers 4 "
text = os.popen(command).read()
print(text)
def binarize_command(command, thresholds):
discrete_feature = []
for k in range(command.shape[0]):
thres = thresholds[k]
discrete_feature.append(np.searchsorted(thres, command[k]))
return discrete_feature
def gen_split_data(path_root):
feature_index = np.load("../data/feature_index.npy", allow_pickle=True)
thresholds = np.load("../data/threshold.npy", allow_pickle=True)
save_root = path_root
os.makedirs(save_root, exist_ok=True)
fn_list = os.listdir(path_root + "/remi")
random.shuffle(fn_list)
for split in ["train", "valid", "test"]:
split_command = []
if split == "train":
s, e = 0, int(len(fn_list)*0.8)
elif split == "valid":
s,e = int(len(fn_list)*0.8), int(len(fn_list)*0.9)
else:
s,e = int(len(fn_list)*0.9), len(fn_list)
with open(path_root + f"/{split}.txt", "w") as split_txt:
split_fn_list = []
j = 0
for i, fn in enumerate(tqdm(fn_list[s:e])):
fn_name = fn.split(".")[0]
try:
jS_feature = read_all_feature(path_root + f"/feature/{fn_name}.xml")
except:
continue
jS_feature = np.array(jS_feature)
if len(jS_feature) != 1495:
continue
jS_feature = np.array(jS_feature)
binary_command = binarize_command(jS_feature[feature_index], thresholds)
split_command.append(binary_command)
split_fn_list.append(fn)
with open(path_root + f'/remi/{fn}', "r") as f:
remi_tokens = f.read().strip("\n").strip(" ")
split_txt.write(remi_tokens + "\n")
j += 1
split_command = np.array(split_command)
np.save(save_root + f"/{split}_fn_list.npy", split_fn_list)
np.save(save_root + f'/{split}_command.npy', split_command)
assert len(split_fn_list) == len(split_command), "length dismatch!"
if __name__ == "__main__":
gen_split_data("../data/Piano")
binarize_data("../data/Piano")

Просмотреть файл

@ -0,0 +1,70 @@
import os
import random
import shutil
import json
from multiprocessing import Process, Pool
from functools import partial
from typing import List
import numpy as np
import pandas as pd
from tqdm import tqdm
import sys
sys.path.append("..")
from MidiProcessor.midiprocessor import MidiEncoder, MidiDecoder, enc_remigen_utils, const
from MidiProcessor.midiprocessor.keys_normalization import get_notes_from_pos_info, get_pitch_shift
import math, pickle,miditoolkit
import json
random.seed(2022)
np.random.seed(2022)
encoder = MidiEncoder("REMIGEN2")
def midi_encoding(midi_path, save_root, prefix):
try:
name_split = midi_path.replace("//", "/").replace("\\", "/").split("/")
midi_name = name_split[-1][:-4] # skip ".mid"
# save_name = "_".join(name_split[-3:-1]) + "_" + f"{midi_name}.txt"
midi_name = midi_name.replace(" ", "_")
if prefix is not None:
save_name = f"{prefix}_{midi_name}.txt"
else:
save_name = f"{midi_name}.txt"
midi_obj = miditoolkit.MidiFile(midi_path)
remi_token = encoder.encode_file(file_path = midi_path, midi_obj = midi_obj, remove_empty_bars=True)
encoder.dump_token_lists(token_lists=remi_token, file_path=os.path.join(save_root, save_name))
return 1
except KeyboardInterrupt:
sys.exit(1)
except BaseException:
print(midi_path, "error")
return 0
def midi_encoding_generate(data_path):
midi_path_list = os.listdir(data_path + f"/midi")
for i in range(len(midi_path_list)):
midi_path_list[i] = data_path + f"/midi/{midi_path_list[i]}"
save_root = data_path + f"/remi"
if not os.path.exists(save_root):
os.mkdir(save_root)
with Pool(processes=8) as pool:
result = iter(tqdm(pool.imap(partial(midi_encoding, save_root = save_root, prefix = None), midi_path_list), total=len(midi_path_list)))
for i in range(len(midi_path_list)):
# try:
next(result)
# except BaseException as e:
# if isinstance(e, KeyboardInterrupt):
# print("error")
# pool.terminate()
# else:
# print(midi_path_list[i], "error!")
if __name__ == "__main__":
midi_encoding_generate("../data/Piano")

633
emogen/data_process/util.py Normal file
Просмотреть файл

@ -0,0 +1,633 @@
import json
import multiprocessing
import random
import sys
sys.path.append("..")
import miditoolkit
import numpy.random
import pandas as pd
import os, collections, pickle, shutil
import numpy as np
from reference.MidiProcessor.midiprocessor import MidiEncoder, midi_utils
from tqdm import tqdm
import gc
from multiprocessing import Pool, Manager, Lock
import math, random
from typing import List, Dict
# INSTR_NAME_LIST = [0, 1, 4, 5, 24, 25, 26, 27, 28, 29, 30, 32, 33, 35, 48, 49, 50, 52, 53, 56, 57, 60, 61, 65, 66, 71, 73, 128]
INSTR_NAME_LIST = [0, 25, 32, 48, 80, 128]
# INSTR_NAME_LIST = [0, 25, 32, 48, 128]
#
#
#
# # Banlanced_INST = [1, 4, 5, 24, 25, 26, 27, 28, 29, 30, 32, 33, 35, 48, 49, 50, 52, 53, 56, 57, 60, 61, 65, 66, 71, 73]
INSTR_NAME2_index = dict(zip(INSTR_NAME_LIST, range(len(INSTR_NAME_LIST))))
sample_INSTR_dist = dict(zip(INSTR_NAME_LIST, [0]*len(INSTR_NAME_LIST)))
file_path = './test.mid'
# sample_list = dict.fromkeys(INSTR_NAME_LIST, 0)
encoder = MidiEncoder('REMIGEN2')
ignore_inst = False
def seed_everything():
np.random.seed(42)
random.seed(42)
# def gen_encode_data(): # generate command_vector from txt file,这个不用了,产生小样本用的
# data_root = "./data/lmd_encode"
# start_token, end_token = "<s>", "</s>"
# dict = []
#
# # with open("./data/txt/dict.txt", "a") as f:
# # for i in range(129):
# # f.write(f"i-{i} 1\n")
# # bos = "<s>",
# # pad="<pad>",
# # eos="</s>",
# # dict = [] # use dict from longformer_exps
# for target_prefix in ["train", "valid", "test"]:
# file_count = 0
# target_file = open(f"./data/txt/{target_prefix}.txt", "w")
# # cf = open(f"./data/txt/command_{target_prefix}.txt", "w")
# cf = []
# index_f = open(f"./data/txt/index_{target_prefix}.id", "w")
# for file in os.listdir(data_root):
# # print(target_prefix, file_count)
# with open(os.path.join(data_root, file), "r") as f:
# remi_str = f.read().strip("\n")
# target_file.write(" ".join([start_token, remi_str, end_token]) + "\n")
# command_vector = [0]*len(INSTR_NAME_LIST)
# for i in remi_str.split(" "):
# if i not in dict:
# dict.append(i)
# if i[0] == "i":
# instru_name_now = eval(i[2:])
# assert instru_name_now in INSTR_NAME_LIST, f"{file} illegal!"
# command_vector[INSTR_NAME2_index[instru_name_now]] = 1
#
# cf.append(command_vector)
# # cf.write(" ".join(command_vector) + "\n")
# index_f.write(file.split(".")[0] + "\n")
# file_count += 1
# if file_count >= 117:
# break
# cf = np.array(cf)
# np.save(f"./data/data_bin/command_{target_prefix}.npy", cf)
# target_file.close()
# # dict = sorted(dict, key = lambda x: (x[0], eval(x.split("-")[-1])))
# # with open("./data/txt/dict.txt", "w") as f:
# # for i in dict:
# # f.write(i + " 1\n")
# write_dict(dict, "./data/txt/dict.txt")
# print("done!")
def check_ts(midi_obj):
for i in midi_obj.time_signature_changes:
a, b = i.numerator, i.denominator
if not(b == 4 and a in [4,8,16,32,64,128,256]):
return False
return True
# def ergodic_a_file(path_root:str, path_list: List[str], dict_list:List[str], locker, process_index: int, processor_num: int, finished_num):
#
# # for i in range(1000):
# # locker.acquire()
# # finished_num.value += 1
# # locker.release()
# # print(process_index, finished_num.value)
#
# size = math.ceil(len(path_list) / processor_num)
# start = size * process_index
# end = (process_index + 1) * size if (process_index + 1) * size < len(path_list) else len(path_list)
# for path in path_list[start:end]:
# with open(os.path.join(path_root, path), "r") as f:
# line = f.read().strip("\n").split(" ")
# for token in line:
# if token not in dict_list:
# locker.acquire()
# dict_list.append(token)
# locker.release()
#
# # if process_index == 2:
# locker.acquire()
# finished_num.value += 1
# locker.release()
# print(process_index, finished_num.value)
# # if finished_num.value > 100:
# # break
#
# def generate_dict(dataset_path:str, save_root:str):
# # dataset_path: 输入的remi编码的根目录, 文件夹里是remi的txt序列
# # save_root: 存储的dict 根目录
#
# file_name_list = os.listdir(dataset_path)
# # dict_list = []
# manager = Manager()
# dict_list = manager.list()
# locker = manager.Lock()
# finished_num = manager.Value("i", 0)
# processor_num = 4
#
#
# process_list = []
# for i in range(processor_num):
# # res.append(pools.apply_async(pass_para, args=(file_name_list, i, processor_num, sample_list, finished_num,)))
# process_list.append(multiprocessing.Process(target=ergodic_a_file, args=(
# dataset_path, file_name_list, dict_list, locker, i, processor_num, finished_num,)))
# process_list[-1].start()
# print(str(i) + ' processor started !')
#
# for i in process_list:
# i.join()
# print("over!")
#
# # for file_name in tqdm(file_name_list):
# # file_path = os.path.join(dataset_path, file_name)
# dict_list = sorted(dict_list, key=lambda x: (x[0], eval(x.split("-")[-1])))
# with open(os.path.join(save_root, "dict.txt"), "w") as f:
# for i in dict_list:
# f.write(f"{i} 1\n")
# return dict_list
def midi2REMI(midi_obj, save_path, sample_list = None):
if len(midi_obj.instruments) > 0:
encoder.encode_file(
file_path,
midi_obj=midi_obj,
remove_empty_bars=True,
ignore_inst=ignore_inst,
ignore_ts=True,
ignore_tempo=False,
save_path = save_path
)
def pass_para(file_in, index, p_num, sample_list, finished_num):
# if eval(file_in[1]):
# return
# file_name = file_in[0]
# save_name = "./data/lmd_encode/" + file_name.replace("/", "_")[:-4] + ".txt"
# midi_obj = miditoolkit.MidiFile("./data/lmd_full/" + file_name)
# midi2REMI(midi_obj, save_name)
size = math.ceil(len(file_in) / p_num)
start = size * index
end = (index + 1) * size if (index + 1) * size < len(file_in) else len(file_in)
temp_data = file_in[start:end]
for j in temp_data:
# if eval(j[1]):
# continue
file_name = j
# save_name = r"E:\AIMusic\ControlGeneration\lmd_6_tracks_encode/remi_seq/" + file_name.replace("/", "_")[:-4] + ".txt"
save_name = r"D:\ProjectData\ControlGeneration\0813_attributes_generation\lmd_6tracks_clean\full_song\remi_seq/" + file_name.split(".")[0] + ".txt"
# if os.path.exists(save_name):
# if index == 2:
# print("skip")
# continue
midi_obj = miditoolkit.MidiFile(os.path.join(r"D:\ProjectData\datasets\lmd_6tracks\midi_6tracks", file_name))
if check_ts(midi_obj):
midi_obj.time_signature_changes = [miditoolkit.TimeSignature(4,4,0)]
try:
midi2REMI(midi_obj, save_name, sample_list)
except:
print(file_name, " illegal!!!")
finished_num.value += 1
if index == 2:
print(finished_num.value)
def mutiprocess_REMI_encode(file_name_list):
# ex_item = pd.read_excel("./data/LMD_full_statistics.xlsx", header = [0,1])
# file_name_list = np.load("./data/file_path.npy")
# for i in tqdm(range(50000)):
# if not eval(file_name_list[i,1]):
# save_name = "./data/lmd_encode/" + file_name_list[i,0].replace("/", "_")[:-4] + ".txt"
# midi_obj = miditoolkit.MidiFile("./data/lmd_full/" + file_name_list[i,0])
# midi2REMI(midi_obj, save_name)
# del midi_obj
# gc.collect()
manager = Manager()
sample_list = manager.dict()
finished_num = manager.Value("i", 0)
processor_num = 4
process_list = []
for i in range(processor_num):
# res.append(pools.apply_async(pass_para, args=(file_name_list, i, processor_num, sample_list, finished_num,)))
process_list.append(multiprocessing.Process(target=pass_para, args=(
file_name_list, i, processor_num, sample_list, finished_num,)))
process_list[-1].start()
print(str(i) + ' processor started !')
for i in process_list:
i.join()
print("over!")
# save_dict = {}
# for i in sample_list.items():
# print("key:", type(i[0]))
# save_dict[int(i[0])] = i[1]
# json.dump(save_dict, open("./data/sample_list.json", "w"))
# def sample_midi_data():
# # sample_list = json.load(open("./data/sample_list.json", "r"))
# file_name_list = np.load("./data/file_instru.npy", allow_pickle=True)
#
# chosen_file = []
# total_dist = dict(zip(range(129), [0] * 129))
# for file in tqdm(file_name_list):
# if file[1]:
# continue
# # midi_obj = miditoolkit.MidiFile("./data/lmd_full/" + file[0])
# isSelected = False
# for instru in eval(file[2]):
# if instru in Banlanced_INST:
# if not isSelected:
# chosen_file.append(file)
# isSelected = True
#
# total_dist[min(instru, 128)] += 1
# if isSelected:
# for instru in eval(file[2]):
# if instru in INSTR_NAME_LIST or instru >= 128:
# sample_INSTR_dist[min(instru, 128)] += 1
# np.save("./data/chosen_files.npy", chosen_file)
# np.save("./data/selected_dist.npy", sample_INSTR_dist)
# # {0: 46250, 1: 15661, 4: 13041, 5: 10634, 24: 23789, 25: 30324, 26: 16262, 27: 21368, 28: 13430, 29: 16726, 30: 16890, 32: 22399, 33: 36293,
# # 35: 21158, 48: 37318, 49: 20172, 50: 12922, 52: 20990, 53: 10138, 56: 17569, 57: 13767, 60: 11161, 61: 13343, 65: 14791, 66: 11622, 71: 10262, 73: 16688, 128: 97225}
# def gen_train_valid_test_txt():
# np.random.seed(42)
# # chosen_file = np.load("./data/chosen_files.npy", allow_pickle=True)
# root = "E:\AIMusic\ControlGeneration\lmd_6_tracks_encode\clean"
# chosen_file = os.listdir(root)
# split_point = [int(0.88*len(chosen_file)), int(0.95*len(chosen_file))]
# chosen_index = np.arange(0,len(chosen_file))
# np.random.shuffle(chosen_index)
# train_file_index = chosen_index[:split_point[0]]
# valid_file_index = chosen_index[split_point[0]:split_point[1]]
# test_file_index = chosen_index[split_point[1]:]
# # dict = []
# # np.save("./data/total_txt/split_index.npy", [train_file_index, valid_file_index, test_file_index])
# def get_split_txt(split, indexes):
# # command_list = []
# f = open(rf"E:\AIMusic\ControlGeneration\instrument_control_bar\attribute_control\txt/{split}.txt", "w")
#
# num = 0
# print(split, " start!")
# for i in tqdm(indexes):
# file_path = os.path.join(root, chosen_file[i])
# s_encode = open(file_path, "r").read()
# if s_encode[-1:] != "\n":
# s_encode = s_encode + "\n"
# f.write(s_encode)
# num += 1
# # if num > 100:
# # break
# f.close()
# # np.save(f"./data/total_txt/command_{split}.npy", np.array(command_list))
# get_split_txt("train", train_file_index)
# get_split_txt("valid", valid_file_index)
# get_split_txt("test", test_file_index)
# # write_dict(dict, "./data/total_txt/dict.txt")
# print("done!")
#
# def check(check_index):
# command = np.load("./data/total_txt/command_train.npy")
# index = 0
# with open("./data/total_txt/train.txt", "r") as f:
# line = f.readline().strip("\n")
# while 1:
# # for i in line.split(" "):
# # if i[0] == "i":
# # instru_number = eval(i[2:])
# # assert command[index][INSTR_NAME2_index[instru_number]] == 1, "error!"
#
# if check_index == index:
# print(line)
# break
# index += 1
# line = f.readline().strip("\n")
# if index % 1000 == 0:
# print(index)
#
# # def generate_command_vector(token_list:list):
# # command_vector = [0] * (len(INSTR_NAME_LIST))
# # length = 0
# # for i in token_list:
# # if i == "":
# # continue
# # if i[0] == "i":
# # j = eval(i[2:])
# # command_vector[INSTR_NAME2_index[min(j, 128)]] = 1
# # length += 1
# # # command_vector = np.array([command_vector])
# # # command_vector = np.repeat(command_vector, length, axis = 0)
# # command_vector.append(length + 1) # for b-1 token
# # return command_vector
#
# def get_8bars_train_valid_test_txt():
# # 随机截取8个bar长度的token试一试
# # chosen_file = np.load("./data/chosen_files.npy", allow_pickle=True)
# # data_path = "E:/AIMusic/ControlGeneration/lmd_encode/"
# # save_root = "E:/AIMusic/ControlGeneration/8-bars-txt/"
# # data_path = "./data/sample_6tracks/remi_seq/"
# # save_root = "./data/sample_6tracks/txt/"
# data_path = "E:/AIMusic/ControlGeneration/lmd_6_tracks_encode/remi_seq/"
# save_root = "E:/AIMusic/ControlGeneration/lmd_6_tracks_encode/txt/"
#
# chosen_file = [i for i in os.listdir(data_path)]
#
# def write(code, target_f):
# for i in range(len(code)):
# if code[i] == "s-9":
# continue
# target_f.write(code[i])
# if i != len(code)-1:
# target_f.write(" ")
# target_f.write("\n")
# def get_bar_split_txt(split, start, end):
# command_list = []
# split_files = chosen_file[start:end]
# f = open(save_root + f"./{split}.txt", "w")
# num = 0
# print(split, " start!")
# max_length = 3000
# sample_num = 0
# for file_name in tqdm(split_files):
# file_path = data_path + file_name.split(".")[0].replace("/", "_") + ".txt"
# s_encode = open(file_path, "r").read().strip("\n")
# if s_encode[:3] != "s-9":
# continue
# bar_pos = []
# remi_tokens = s_encode.split(" ")
# for i, token in enumerate(remi_tokens):
# if token == "b-1":
# bar_pos.append(i)
#
# if len(bar_pos) <= 8:
# cur_length = len(remi_tokens) -1
# if cur_length + 2 <= max_length:
# write(remi_tokens[0:], f)
# sample_num += 1
# continue
#
# coordinate_pool = []
# start = 0
# for i, pos in enumerate(bar_pos[:-7]):
# cur_length = bar_pos[i+7] + 1 - start
# if cur_length -8 + 2 <= max_length:
# coordinate_pool.append((i, start, start + cur_length))
# start = pos + 1
# if len(coordinate_pool) == 0:
# continue
# elif len(coordinate_pool) == 1:
# arr = coordinate_pool[0]
# write(remi_tokens[arr[1]:arr[2]], f)
# sample_num += 1
# else:
# chosen_bars = np.random.choice(len(coordinate_pool), 2, replace = False)
# for i in chosen_bars:
# arr = coordinate_pool[i]
# write(remi_tokens[arr[1]:arr[2]], f)
# sample_num += 1
# print(sample_num)
# f.close()
# # np.save(f"./data/total_txt/command_{split}.npy", np.array(command_list))
# get_bar_split_txt("train", 0, 70000)
# get_bar_split_txt("valid", 70000, 75000)
# get_bar_split_txt("test", 75000, 80000)
# def statics_length():
# # 统计token长度
# f = open("./data/total_txt/train.txt", "r")
# line = f.readline()
# length_dict = {}
# num = 0
# while line:
# length = len(line.split(" "))
# if length not in length_dict:
# length_dict[length] = 1
# else:
# length_dict[length] += 1
# line = f.readline()
# num += 1
# print("\r", num, end = "")
#
# import matplotlib.pyplot as plt
#
# X = length_dict.keys()
# Y = length_dict.values()
# fig = plt.figure()
# plt.bar(X, Y, 0.4, color="green")
# plt.xlabel("X-axis")
# plt.ylabel("Y-axis")
# plt.title("bar chart")
# plt.show()
# ave = 0
# for j in length_dict.items():
# ave += j[0]*j[1]/num
# ave = 14995.75
#
# def dict_generation(path):
# dict = []
# for file in os.listdir(path):
# f = open(os.path.join(path, file))
# line = f.readline()
# while line:
# tokens = line.strip("\n").split(" ")
# assert len(tokens) <= 3010, "too long token sequence"
# for j in tokens:
# if j not in dict:
# dict.append(j)
# line = f.readline()
# f.close()
# g = open(os.path.join(path, "dict.txt"), "w")
# dict = sorted(dict, key=lambda x: (x[0], eval(x.split("-")[-1])))
# for i in dict:
# g.write(i + " 1\n")
# g.close()
#
# def lmd_6tracks_clean():
# # 清理lmd_6tracks的代码
# chosen_files = np.load(r"./data/lmd_full_not_duplicate_files.npy", allow_pickle=True)
# all_files = np.load(r"E:\AIMusic\ControlGeneration\全局的控制数据-28类乐器\file_path.npy", allow_pickle=True)
# all_file_names = []
# for i in tqdm(range(len(all_files))):
# cur_name = all_files[i][0]
# all_file_names.append(cur_name)
#
# out_of_lmd_full_num = 0
# successs_num = 0
# error_name_list = []
# for tracks6_name in tqdm(os.listdir(r"E:\AIMusic\ControlGeneration\lmd_6_tracks_encode\remi_seq")):
# search_name = tracks6_name.split("_")
#
# search_name = search_name[-1][0] + "/" + search_name[-1].split(".")[0] + ".mid"
# if len(search_name) != 38:
# error_name_list.append(tracks6_name)
# if search_name not in all_file_names:
# out_of_lmd_full_num += 1
# else:
# successs_num += 1
# source_path = os.path.join(r"E:\AIMusic\ControlGeneration\lmd_6_tracks_encode\remi_seq", tracks6_name)
# with open(source_path, "r") as f:
# line = f.readline().strip("\n")
# line = line.split(" ")
# is4_4 = True
# for j in line:
# if j[0] == "s" and eval(j[2:]) != 9:
# is4_4 = False
# break
# if is4_4:
# target_path = os.path.join(r"E:\AIMusic\ControlGeneration\lmd_6_tracks_encode\clean", tracks6_name)
# shutil.copy(source_path, target_path)
# print(out_of_lmd_full_num, successs_num)
#
# length = []
# for i in all_file_names:
# if len(i) not in length:
# length.append((len(i)))
#
# length_list = {}
# length_33 = []
# for name in error_name_list:
# if len(name)-4 not in length_list.keys():
# length_list[len(name)-4] = 1
# else:
# length_list[len(name)- 4] += 1
# if len(name)-4 == 33:
# length_33.append(name)
if __name__ == "__main__":
# dict = generate_dict(r"E:\AIMusic\ControlGeneration\lmd_6_tracks_encode\clean_4_4", r"E:\AIMusic\ControlGeneration\lmd_6_tracks_encode")
# x = 1
# name_list = os.listdir("./data/lmd_6tracks/midi_6tracks/")
#
# name_list = np.random.choice(name_list, 10000, replace=False)
# np.save("./data/chosen_files_6tracks.npy", name_list)
# file_name_list = np.load("data/not_duplicate_list.npy")
# file_name_list = os.listdir(r"D:\ProjectData\datasets\lmd_6tracks\midi_6tracks")
# print(len(file_name_list))
file_name_list = np.load(r"D:\ProjectData\datasets\lmd_6tracks_clean_4_4\file_name_list.npy")
mutiprocess_REMI_encode(file_name_list)
# gen_train_valid_test_txt()
# file_name_list = os.listdir("E:\AIMusic\ControlGeneration\lmd_6_tracks_encode\clean_4_4")
# test_mid_obj = miditoolkit.MidiFile("./test.mid")
# midi2REMI(test_mid_obj, )
# file_name_list = np.load("./data/file_path.npy")
# for file_name in tqdm(file_name_list[:,0]):
# save_name = r"E:\AIMusic\ControlGeneration\lmd_encode/" + file_name.replace("/", "_")[:-4] + ".txt"
# if os.path.exists(save_name):
# continue
# midi_obj = miditoolkit.MidiFile("./data/lmd_full/" + file_name)
# midi2REMI(midi_obj, save_name, {})
# get_8bars_train_valid_test_txt()
# dict_generation(r"E:\AIMusic\ControlGeneration\lmd_6_tracks_encode\txt/")
# command_generation(r"E:\AIMusic\ControlGeneration\lmd_6_tracks_encode\txt/")
# path = "./data/sample_6tracks/midi"
# save_root = "./data/sample_6tracks/remi_seq/"
# for i, file in enumerate(os.listdir(path)):
# midi_obj = miditoolkit.MidiFile(os.path.join(path, file))
# midi2REMI(midi_obj, os.path.join(save_root, file.split(".")[0] +".txt"))
# fairseq bug:
# g = open("./data/total_txt/train.txt", "w")
# with open("./data/total_txt/train_total.txt", "r") as f:
# line = f.readline()
# index = 0
# while 1:
# # for i in line.split(" "):
# # if i[0] == "i":
# # instru_number = eval(i[2:])
# # assert command[index][INSTR_NAME2_index[instru_number]] == 1, "error!"
# g.write(line)
# index += 1
# line = f.readline()
# if index % 1000 == 0:
# print(index)
# if index > 70000:
# break
# gen_train_valid_test_txt()
# check(90000)
# def _warmup_mmap_file(path):
# with open(path, "rb") as stream:
# while stream.read(100 * 1024 * 1024):
# pass
# import struct
# dtypes = {
# 1: np.uint8,
# 2: np.int8,
# 3: np.int16,
# 4: np.int32,
# 5: np.int64,
# 6: np.float,
# 7: np.double,
# 8: np.uint16,
# }
# path ="./data/total-data-bin/train.idx"
# _HDR_MAGIC = b"MMIDIDX\x00\x00"
# with open(path, "rb") as stream:
# magic_test = stream.read(9)
# assert _HDR_MAGIC == magic_test, (
# "Index file doesn't match expected format. "
# "Make sure that --dataset-impl is configured properly."
# )
# version = struct.unpack("<Q", stream.read(8))
# assert (1,) == version
#
# (dtype_code,) = struct.unpack("<B", stream.read(1))
# _dtype = dtypes[dtype_code]
# _dtype_size = _dtype().itemsize
#
# _len = struct.unpack("<Q", stream.read(8))[0]
# offset = stream.tell()
#
# _warmup_mmap_file(path)
#
# _bin_buffer_mmap = np.memmap(path, mode="r", order="C")
# _bin_buffer = memoryview(_bin_buffer_mmap)
# _sizes = np.frombuffer(
# _bin_buffer, dtype=np.int32, count=_len, offset=offset
# )
# _pointers = np.frombuffer(
# _bin_buffer,
# dtype=np.int64,
# count=_len,
# offset=offset + _sizes.nbytes,
# )
# fairseq-preprocess --only-source --srcdict data/total_txt/dict.txt --trainpref data/total_txt/train.txt --validpref data/total_txt/valid.txt --testpref data/total_txt/test.txt --destdir data/total-data-bin --workers 8
# midi_root_path = r"D:\ProjectData\datasets\lmd_full"
# save_root = r"D:\Project\ControlGeneration\StyleCtrl\jSymbolic_lib\datasets\TopMAGD/midi"
# midi_to_genre = json.load(open("./midi_genre_map.json", "r"))
# top_magd_genre = midi_to_genre["topmagd"]
# for key in tqdm(top_magd_genre.keys()):
# midi_name = key + ".mid"
# midi_src_path = os.path.join(midi_root_path, midi_name[0], midi_name)
# midi_save_path = os.path.join(save_root, midi_name)
# shutil.copyfile(midi_src_path, midi_save_path)
# midi_root_path = r"D:\ProjectData\datasets\lmd_full"
# save_root = r"D:\Project\ControlGeneration\StyleCtrl\jSymbolic_lib\datasets\MASD/midi"
# midi_to_genre = json.load(open("./midi_genre_map.json", "r"))
# masd_genre = midi_to_genre["masd"]
# for key in tqdm(masd_genre.keys()):
# midi_name = key + ".mid"
# midi_src_path = os.path.join(midi_root_path, midi_name[0], midi_name)
# midi_save_path = os.path.join(save_root, midi_name)
# shutil.copyfile(midi_src_path, midi_save_path)

360
emogen/interactive.py Normal file
Просмотреть файл

@ -0,0 +1,360 @@
#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Translate raw text with a trained model. Batches data on-the-fly.
"""
import fileinput
import logging
import math
import os
import sys
import time
from collections import namedtuple
import numpy as np
import torch
from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils
from fairseq.data import encoders
from fairseq.token_generation_constraints import pack_constraints, unpack_constraints
from fairseq_cli.generate import get_symbols_to_strip_from_output
import linear_decoder.controlled_task
from MidiProcessor.midiprocessor import MidiDecoder
import shutil, random
from fairseq_cli import interactive
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO").upper(),
stream=sys.stdout,
)
logger = logging.getLogger("fairseq_cli.interactive")
Batch = namedtuple("Batch", "ids src_tokens src_lengths constraints")
Translation = namedtuple("Translation", "src_str hypos pos_scores alignments")
def seed_everything(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def buffered_read(input, buffer_size):
buffer = []
with fileinput.input(files=[input], openhook=fileinput.hook_encoded("utf-8")) as h:
for src_str in h:
buffer.append(src_str.strip())
if len(buffer) >= buffer_size:
yield buffer
buffer = []
if len(buffer) > 0:
yield buffer
def make_batches(lines, args, task, max_positions, encode_fn):
def encode_fn_target(x):
return encode_fn(x)
if args.constraints:
# Strip (tab-delimited) contraints, if present, from input lines,
# store them in batch_constraints
batch_constraints = [list() for _ in lines]
for i, line in enumerate(lines):
if "\t" in line:
lines[i], *batch_constraints[i] = line.split("\t")
# Convert each List[str] to List[Tensor]
for i, constraint_list in enumerate(batch_constraints):
batch_constraints[i] = [
task.target_dictionary.encode_line(
encode_fn_target(constraint),
append_eos=False,
add_if_not_exist=False,
)
for constraint in constraint_list
]
tokens = [
task.source_dictionary.encode_line(
encode_fn(src_str), add_if_not_exist=False
).long()
for src_str in lines
]
if args.constraints:
constraints_tensor = pack_constraints(batch_constraints)
else:
constraints_tensor = None
lengths = [t.numel() for t in tokens]
itr = task.get_batch_iterator(
dataset=task.build_dataset_for_inference(
tokens, lengths, constraints=constraints_tensor
),
max_tokens=args.max_tokens,
max_sentences=args.batch_size,
max_positions=max_positions,
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
).next_epoch_itr(shuffle=False)
for batch in itr:
ids = batch["id"]
src_tokens = batch["net_input"]["src_tokens"]
src_lengths = batch["net_input"]["src_lengths"]
constraints = batch.get("constraints", None)
yield Batch(
ids=ids,
src_tokens=src_tokens,
src_lengths=src_lengths,
constraints=constraints,
)
def attribute_generate(args):
assert args.tgt_emotion in [1, 2, 3, 4], f"Error emotion index for {args.tgt_emotion}."
start_time = time.time()
total_translate_time = 0
utils.import_user_module(args)
if args.buffer_size < 1:
args.buffer_size = 1
if args.max_tokens is None and args.batch_size is None:
args.batch_size = 1
assert (
not args.sampling or args.nbest == args.beam
), "--sampling requires --nbest to be equal to --beam"
assert (
not args.batch_size or args.batch_size <= args.buffer_size
), "--batch-size cannot be larger than --buffer-size"
logger.info(args)
# Fix seed for stochastic decoding
if args.seed is not None and not args.no_seed_provided:
np.random.seed(args.seed)
utils.set_torch_seed(args.seed)
use_cuda = torch.cuda.is_available() and not args.cpu
# Setup task, e.g., translation_control
task = tasks.setup_task(args)
# Load ensemble
logger.info("loading model(s) from {}".format(args.path))
models, _model_args = checkpoint_utils.load_model_ensemble(
args.path.split(os.pathsep),
arg_overrides=eval(args.model_overrides),
task=task,
suffix=getattr(args, "checkpoint_suffix", ""),
strict=(args.checkpoint_shard_count == 1),
num_shards=args.checkpoint_shard_count,
)
# Set dictionaries
src_dict = task.source_dictionary
tgt_dict = task.target_dictionary
# Optimize ensemble for generation
for model in models:
if args.fp16:
model.half()
if use_cuda and not args.pipeline_model_parallel:
model.cuda()
model.prepare_for_inference_(args)
# Initialize generator
generator = task.build_generator(models, args)
# Handle tokenization and BPE
tokenizer = encoders.build_tokenizer(args)
bpe = encoders.build_bpe(args)
def encode_fn(x):
if tokenizer is not None:
x = tokenizer.encode(x)
if bpe is not None:
x = bpe.encode(x)
return x
def decode_fn(x):
if bpe is not None:
x = bpe.decode(x)
if tokenizer is not None:
x = tokenizer.decode(x)
return x
# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(args.replace_unk)
max_positions = utils.resolve_max_positions(
task.max_positions(), *[model.max_positions() for model in models]
)
if args.constraints:
logger.warning(
"NOTE: Constrained decoding currently assumes a shared subword vocabulary."
)
if args.buffer_size > 1:
logger.info("Sentence buffer size: %s", args.buffer_size)
logger.info("NOTE: hypothesis and token scores are output in base 2")
logger.info("Type the input sentence and press return:")
start_id = 0
command_list = np.load(args.ctrl_command_path)
# for inputs in buffered_read(args.input, args.buffer_size):
save_root = args.save_root
os.makedirs(save_root, exist_ok=True)
midi_decoder = MidiDecoder("REMIGEN2")
for command_index in [args.tgt_emotion - 1]:
os.makedirs(save_root + f"/midi", exist_ok=True)
os.makedirs(save_root + f"/remi", exist_ok=True)
sample_scores = {}
for gen_times in range(1000):
if len(os.listdir(save_root + f"/midi")) > args.need_num:
break
if os.path.exists(save_root + f"/remi/{gen_times*args.batch_size}.txt"):
print(f"command_id: {command_index} sample:{gen_times*args.batch_size} already exists. Skip this batch!")
continue
start_tokens = [""]
for inputs in [start_tokens*args.batch_size]: # "" for none prefix input
results = []
for batch in make_batches(inputs, args, task, max_positions, encode_fn):
bsz = batch.src_tokens.size(0)
src_tokens = batch.src_tokens
src_lengths = batch.src_lengths
constraints = batch.constraints
command_input = []
for i in range(args.batch_size):
command_input.append(command_list[command_index])
command_input = torch.tensor(command_input)
if use_cuda:
src_tokens = src_tokens.cuda()
src_lengths = src_lengths.cuda()
command_input = command_input.cuda()
if constraints is not None:
constraints = constraints.cuda()
sample = {
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
"command_input":command_input,
},
}
translate_start_time = time.time()
translations = task.inference_step(
generator, models, sample, constraints=constraints
)
translate_time = time.time() - translate_start_time
total_translate_time += translate_time
list_constraints = [[] for _ in range(bsz)]
if args.constraints:
list_constraints = [unpack_constraints(c) for c in constraints]
for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad())
constraints = list_constraints[i]
results.append(
(
start_id + id,
src_tokens_i,
hypos,
{
"constraints": constraints,
"time": translate_time / len(translations),
},
)
)
# sort output to match input order
for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]):
if src_dict is not None:
src_str = src_dict.string(src_tokens, args.remove_bpe)
# print("S-{}\t{}".format(id_, src_str))
# print("W-{}\t{:.3f}\tseconds".format(id_, info["time"]))
# for constraint in info["constraints"]:
# print(
# "C-{}\t{}".format(
# id_, tgt_dict.string(constraint, args.remove_bpe)
# )
# )
# Process top predictions
for hypo in hypos[: min(len(hypos), args.nbest)]:
hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
hypo_tokens=hypo["tokens"].int().cpu(),
src_str=src_str,
alignment=hypo["alignment"],
align_dict=align_dict,
tgt_dict=tgt_dict,
remove_bpe=args.remove_bpe,
extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator),
)
detok_hypo_str = decode_fn(hypo_str)
score = hypo["score"] / math.log(2) # convert to base 2
command_id = command_index
save_id = id_ + gen_times*args.batch_size
sample_scores[save_id] = score.detach().cpu().numpy().item()
print(f"command_id: {command_id} sample:{save_id} over with length {len(hypo_str.split(' '))}")
with open(save_root + f"/remi/{save_id}.txt", "w") as f:
f.write(hypo_str)
remi_token = hypo_str.split(" ")
# try:
midi_obj = midi_decoder.decode_from_token_str_list(remi_token)
midi_obj.dump(save_root + f"/midi/{save_id}.mid")
# except:
# pass
def cli_main():
parser = options.get_interactive_generation_parser()
parser.add_argument("--ctrl_command_path", type=str)
parser.add_argument("--save_root", type=str)
parser.add_argument("--need_num", type=int, default=50)
parser.add_argument("--tgt_emotion", type=int)
args = options.parse_args_and_arch(parser)
attribute_generate(args)
# if args.ctrl_command_path != "None":
# attributes(args)
# else:
# label_embedding(args)
if __name__ == "__main__":
seed_everything(2023)
cli_main()
'''
../StyleCtrlData/train_data/1026_EMO/data-bin
--task
language_modeling_control
--path
checkpoints/controlled/checkpoint_best.pt
--ctrl_command_path
../StyleCtrlData/train_data/1026_EMO/bucket2_command_thres_EMO/emotion_rank_100/EMOPIA_inference_nearest.npy
--save_root
generation
--max-len-b
500
--sampling
--beam
1
--sampling-topk
8
--buffer-size
8
--batch-size
8
'''

Двоичные данные
emogen/jSymbolic_lib/__pycache__/jSymbolic_util.cpython-38.pyc Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,64 @@
import shutil
import pandas as pd
import json
import os
import subprocess
from tqdm import tqdm
import time
import xmltodict
import random
import numpy as np
from multiprocessing import Process, Pool
import json
from jSymbolic_util import read_pitch_feature, read_all_feature
import xmltodict
import subprocess
from functools import partial
import argparse
command_prefix = "java -Xmx6g -jar ./jSymbolic_2_2_user/jSymbolic2.jar -configrun ./jSymbolic_2_2_user/jSymbolicDefaultConfigs.txt"
def rename_midi_path(root):
midi_name_list = os.listdir(root)
for midi_name in midi_name_list:
os.rename(root + "/" + midi_name, root + f"/{midi_name.replace(' ', '_')}")
def get_jSymbolic_feature(file_name, root):
midi_name = file_name[:-4]
midi_path = root + f"/midi/{midi_name}.mid"
path = os.path.join(root, "feature/" + f"{midi_name.replace(' ', '_')}.xml")
# print(midi_path)
if os.path.exists(path):
return 0
if not os.path.exists(path):
new_command = " ".join([command_prefix, midi_path, path,
"./test_def.xml"])
os.system(new_command)
return 0
np.random.seed(42)
random.seed(42)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
data_path = "../data/Piano"
midi_sampled = os.listdir(data_path + f"/midi")
os.makedirs(data_path +f"/feature", exist_ok=True)
with Pool(processes=8) as pool:
result = iter(tqdm(pool.imap(partial(get_jSymbolic_feature, root = data_path), midi_sampled),
total=len(midi_sampled)))
for i in range(len(midi_sampled)):
# try:
next(result)

Просмотреть файл

@ -0,0 +1,99 @@
import os
import xmltodict
import numpy as np
from tqdm import tqdm
nan_feature_list = {}
split_name = ["Melodic Interval Histogram", "Vertical Interval Histogram"]
end_at = "Initial Time Signature"
# split_pos = [['Melodic Interval Histogram', 190], ['Vertical Interval Histogram', 342]]
feature_name = []
histogram_feature_name = []
def read_pitch_feature_without_long_histogram(path):
data = xmltodict.parse(open(path, "r").read())
data = data['feature_vector_file']["data_set"]["feature"]
ret = []
histogram_ret = []
for f in data:
if f["name"] == end_at:
break
if "histogram" not in f["name"].lower():
if f["v"] == "NaN":
ret.append(0)
if f["name"] not in nan_feature_list.keys():
nan_feature_list[f["name"]] = 1
else:
nan_feature_list[f["name"]] += 1
elif isinstance(f["v"], list):
ret.extend([eval(i) for i in f["v"]])
else:
ret.append(eval(f["v"]))
# feature_name.append(f["name"])
else:
if len(f["v"]) < 20:
# histogram_feature_name.append(f["name"])
histogram_ret.extend([eval(i) for i in f["v"]])
return ret, histogram_ret
def read_pitch_feature(path):
data = xmltodict.parse(open(path, "r").read())
data = data['feature_vector_file']["data_set"]["feature"]
ret = []
for f in data:
# print(f["name"])
# if f["name"] in split_name:
# split_pos.append([f["name"], len(ret)])
if f["name"] == end_at:
break
if "histogram" not in f["name"].lower():
if f["v"] == "NaN":
ret.append(0)
if f["name"] not in nan_feature_list.keys():
nan_feature_list[f["name"]] = 1
else:
nan_feature_list[f["name"]] += 1
elif isinstance(f["v"], list):
ret.extend([eval(i) for i in f["v"]])
else:
ret.append(eval(f["v"]))
else:
ret.extend([eval(i) for i in f["v"]])
return ret
def read_all_feature(path, need_name = False):
data = open(path, "r").read()
if len(data) == 0:
return None
data = xmltodict.parse(open(path, "r").read())
data = data['feature_vector_file']["data_set"]["feature"]
ret = []
feature_names = []
for f in data:
if "histogram" not in f["name"].lower():
if f["v"] == "NaN":
ret.append(0)
if need_name:
feature_names.append(f["name"])
if f["name"] not in nan_feature_list.keys():
nan_feature_list[f["name"]] = 1
else:
nan_feature_list[f["name"]] += 1
elif isinstance(f["v"], list):
ret.extend([eval(i) for i in f["v"]])
if need_name:
feature_names.extend(f["name"] + f"_{i}" for i in range(len(f["v"])))
else:
ret.append(eval(f["v"]))
if need_name:
feature_names.append(f["name"])
else:
ret.extend([eval(i) for i in f["v"]])
if need_name:
feature_names.extend(f["name"] + f"_{i}" for i in range(len(f["v"])))
if need_name:
assert len(ret) == len(feature_names)
return ret, feature_names
else:
return ret

Просмотреть файл

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,211 @@
import random
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler, MinMaxScaler
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os
import warnings
from sklearn.manifold import TSNE
from scipy.stats import pearsonr
from sklearn.feature_selection import SelectKBest, f_classif, VarianceThreshold
# Import libraries
from sklearn import preprocessing
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn import tree
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import f1_score, roc_auc_score
import joblib
import lightgbm as lgb
# 从数据集中选取feature的阈值
def EMOPIA_threshold():
feature_table = np.load(r"../../StyleCtrlData/jSymbolic_lib\datasets\EMOPIA\data\feature_table.npy")
X, Y = feature_table[:, :-1], feature_table[:, -1]
# for i in range(X.shape[1]):
# medium_value = np.percentile(X[:, i], 50)
# thresholds.append(medium_value)
# X_discrete.append(np.searchsorted([medium_value], X[:, i])[:, np.newaxis])
for bucket_num in [3, 5, 8, 12, 16]:
thresholds = []
for i in range(X.shape[1]):
thres = []
for j in range(1, bucket_num):
thres.append(np.percentile(X[:, i], 100.0/bucket_num*j))
thresholds.append(thres)
# X_discrete.append(np.searchsorted(thres, X[:, i])[:, np.newaxis])
thresholds = np.array(thresholds)
# X_discrete = np.concatenate(X_discrete, axis = 1)
# cv_folds = 5
# skf = StratifiedKFold(n_splits=cv_folds, random_state=2022, shuffle=True)
# y_preds = np.zeros(Y.shape) - 1
# for train_index, test_index in skf.split(X_discrete, Y):
# X_train, y_train = X_discrete[train_index], Y[train_index]
# X_test, y_test = X_discrete[test_index], Y[test_index]
# clf = RandomForestClassifier(n_estimators=100, random_state=2022)
# clf.fit(X_train, y_train)
# y_pred_score = clf.predict_proba(X_test)
# y_pred = np.argmax(y_pred_score, axis=1)
# y_pred += 1 # 预测第几个象限
# y_preds[test_index] = y_pred
# print(f"离散化后的{X_discrete.shape[1]}维特征:", f1_score(Y, y_preds, average="micro"))
np.save(f"../../StyleCtrlData/jSymbolic_lib/datasets/EMOPIA/data/threshold_{bucket_num}.npy", thresholds)
# 二值化之后的分类精度为0.6762523191094619
# feature_selected = np.load("./data/selected_feature/emotion_And_select_17.npy")
# feature_name = np.load("./data/feature_name.npy")
# name2index = dict(zip(feature_name, range(len(feature_name))))
# feature_selected_index = [name2index[i] for i in feature_selected]
# X_discrete = X_discrete[:, feature_selected_index]
# y_preds = np.zeros(Y.shape) - 1
# for train_index, test_index in skf.split(X_discrete, Y):
# X_train, y_train = X_discrete[train_index], Y[train_index]
# X_test, y_test = X_discrete[test_index], Y[test_index]
# clf = RandomForestClassifier(n_estimators=100, random_state=2022)
# clf.fit(X_train, y_train)
# y_pred_score = clf.predict_proba(X_test)
# y_pred = np.argmax(y_pred_score, axis=1)
# y_pred += 1 # 预测第几个象限
# y_preds[test_index] = y_pred
# print(f"离散化后的{X_discrete.shape[1]}维特征:", f1_score(Y, y_preds, average="micro"))
# 分为两个桶:
# 离散化后的1495维特征: 0.6762523191094619
# 离散化后的39维特征: 0.62430426716141
# 离散化后的17维特征: 0.5092764378478665
# 分为三个桶
# 离散化后的1495维特征: 0.6725417439703154
# 离散化后的39维特征: 0.6447124304267161
# 离散化后的17维特征: 0.5445269016697588
def VGMIDI_feedback():
threshold = np.load("./datasets/EMOPIA/data/threshold_2.npy")
feature_table = np.load("./datasets/VGMIDI/data/feature_table.npy")
X = feature_table[:, :-1]
Y = feature_table[:, -1]
X_discrete = []
for i in range(X.shape[1]):
X_discrete.append(np.searchsorted([threshold[i]], X[:,i])[:, np.newaxis])
X_discrete = np.concatenate(X_discrete, axis = 1)
cv_folds = 5
skf = StratifiedKFold(n_splits=cv_folds, random_state=2022, shuffle=True)
y_preds_continuous = np.zeros(Y.shape) - 1
for train_index, test_index in skf.split(X, Y):
X_train, y_train = X[train_index], Y[train_index]
X_test, y_test = X[test_index], Y[test_index]
clf = RandomForestClassifier(n_estimators=100, random_state=2022)
clf.fit(X_train, y_train)
y_pred_score = clf.predict_proba(X_test)
y_pred = np.argmax(y_pred_score, axis=1)
y_pred += 1 # 预测第几个象限
y_preds_continuous[test_index] = y_pred
print(f1_score(Y, y_preds_continuous, average="micro"))
def TopMAGD_threshold():
feature_table = []
labels = []
path_root = r"../..\StyleCtrlData\data\1004_TopMAGD\truncated_2560\split_data"
for split in ["train","valid", "test"]:
feature_table.append(np.load(path_root + f"/1495/{split}_raw_command_1495.npy"))
labels.append(np.load(path_root + f"/{split}_style_labels.npy"))
feature_table = np.vstack(feature_table)
labels = np.vstack(labels)
X = feature_table
Y = labels
feature_selected = np.load(r"E:\Music\Project\StyleCtrl\StyleCtrlData\jSymbolic_lib\data\style_select_feature\style_or_10_select_103.npy")
feature_names = np.load(r"E:\Music\Project\StyleCtrl\StyleCtrlData\jSymbolic_lib\data\feature_name.npy")
feature2id = dict(zip(feature_names, range(len(feature_names))))
select_features = [feature2id[i] for i in feature_selected]
id2genre = np.load(r"E:\Music\Project\StyleCtrl\StyleCtrlData\data\1004_TopMAGD\truncated_2560\split_data\id2genre.npy")
X = X[:, select_features]
# 不离散化的feature
cv_folds = 5
skf = StratifiedKFold(n_splits=cv_folds, random_state=2022, shuffle=True)
y_preds = np.zeros(Y.shape) - 1
for index, style_name in enumerate(id2genre):
Y_style_label = Y[:, index]
for train_index, test_index in skf.split(X, Y_style_label):
X_train, y_train = X[train_index], Y_style_label[train_index]
X_test, y_test = X[test_index], Y_style_label[test_index]
clf = RandomForestClassifier(n_estimators=100, random_state=2022)
clf.fit(X_train, y_train)
y_pred_score = clf.predict_proba(X_test)
y_pred = np.argmax(y_pred_score, axis=1)
y_pred += 1 # 预测第几个象限
y_preds[test_index] = y_pred
print(f"连续的{X.shape[1]}维特征:", style_name, f1_score(Y_style_label, y_preds, average="micro"), roc_auc_score(Y_style_label, y_preds))
X_discrete = []
for bucket_num in [2, 3, 5, 8, 12, 16]:
thresholds = []
for i in range(X.shape[1]):
thres = []
for j in range(1, bucket_num):
thres.append(np.percentile(X[:, i], 100.0 / bucket_num * j))
thresholds.append(thres)
X_discrete.append(np.searchsorted(thres, X[:, i])[:, np.newaxis])
# thresholds = np.array(thresholds)
X_discrete = np.concatenate(X_discrete, axis = 1)
cv_folds = 5
skf = StratifiedKFold(n_splits=cv_folds, random_state=2022, shuffle=True)
y_preds = np.zeros(Y.shape) - 1
for train_index, test_index in skf.split(X_discrete, Y):
X_train, y_train = X_discrete[train_index], Y[train_index]
X_test, y_test = X_discrete[test_index], Y[test_index]
clf = RandomForestClassifier(n_estimators=100, random_state=2022)
clf.fit(X_train, y_train)
y_pred_score = clf.predict_proba(X_test)
y_pred = np.argmax(y_pred_score, axis=1)
y_pred += 1 # 预测第几个象限
y_preds[test_index] = y_pred
print(f"离散化后的{X_discrete.shape[1]}维特征:", f1_score(Y, y_preds, average="micro"), roc_auc_score(Y, y_preds))
# np.save(path_root + f"/threshold_{bucket_num}.npy", thresholds)
# 二值化之后的分类精度为0.6762523191094619
# feature_selected = np.load("./data/selected_feature/emotion_And_select_17.npy")
# feature_name = np.load("./data/feature_name.npy")
# name2index = dict(zip(feature_name, range(len(feature_name))))
# feature_selected_index = [name2index[i] for i in feature_selected]
# X_discrete = X_discrete[:, feature_selected_index]
# y_preds = np.zeros(Y.shape) - 1
# for train_index, test_index in skf.split(X_discrete, Y):
# X_train, y_train = X_discrete[train_index], Y[train_index]
# X_test, y_test = X_discrete[test_index], Y[test_index]
# clf = RandomForestClassifier(n_estimators=100, random_state=2022)
# clf.fit(X_train, y_train)
# y_pred_score = clf.predict_proba(X_test)
# y_pred = np.argmax(y_pred_score, axis=1)
# y_pred += 1 # 预测第几个象限
# y_preds[test_index] = y_pred
# print(f"离散化后的{X_discrete.shape[1]}维特征:", f1_score(Y, y_preds, average="micro"))
def YMDB_threshold():
feature_table = np.load(r"../../StyleCtrlData/jSymbolic_lib/datasets/1224_YMDB/data/fea_table.npy")
X = feature_table
# X, Y = feature_table[:, :-1], feature_table[:, -1]
# for i in range(X.shape[1]):
# medium_value = np.percentile(X[:, i], 50)
# thresholds.append(medium_value)
# X_discrete.append(np.searchsorted([medium_value], X[:, i])[:, np.newaxis])
for bucket_num in [2, 3, 5, 8, 12, 16]:
thresholds = []
for i in range(X.shape[1]):
thres = []
for j in range(1, bucket_num):
thres.append(np.percentile(X[:, i], 100.0 / bucket_num * j))
thresholds.append(thres)
# X_discrete.append(np.searchsorted(thres, X[:, i])[:, np.newaxis])
thresholds = np.array(thresholds)
np.save(f"../../StyleCtrlData/jSymbolic_lib/datasets/1224_YMDB/data/threshold_{bucket_num}.npy", thresholds)
if __name__ == "__main__":
# EMOPIA_threshold()
# TopMAGD_threshold()
YMDB_threshold()

Просмотреть файл

Двоичные данные
emogen/linear_decoder/__pycache__/__init__.cpython-38.pyc Normal file

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичные данные
emogen/linear_decoder/__pycache__/controlled_task.cpython-38.pyc Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,913 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Dict, List, Optional
import torch
import torch.nn as nn
from fairseq import search, utils
from fairseq.data import data_utils
from fairseq.models import FairseqIncrementalDecoder
from fairseq.models.fairseq_encoder import EncoderOut
from torch import Tensor
class CommandSequenceGenerator(nn.Module):
def __init__(
self,
models,
tgt_dict,
beam_size=1,
max_len_a=0,
max_len_b=200,
min_len=1,
normalize_scores=True,
len_penalty=1.0,
unk_penalty=0.0,
temperature=1.0,
match_source_len=False,
no_repeat_ngram_size=0,
search_strategy=None,
eos=None,
symbols_to_strip_from_output=None,
lm_model=None,
lm_weight=1.0,
):
"""Generates translations of a given source sentence.
Args:
models (List[~fairseq.models.FairseqModel]): ensemble of models,
currently support fairseq.models.TransformerModel for scripting
beam_size (int, optional): beam width (default: 1)
max_len_a/b (int, optional): generate sequences of maximum length
ax + b, where x is the source length
min_len (int, optional): the minimum length of the generated output
(not including end-of-sentence)
normalize_scores (bool, optional): normalize scores by the length
of the output (default: True)
len_penalty (float, optional): length penalty, where <1.0 favors
shorter, >1.0 favors longer sentences (default: 1.0)
unk_penalty (float, optional): unknown word penalty, where <0
produces more unks, >0 produces fewer (default: 0.0)
temperature (float, optional): temperature, where values
>1.0 produce more uniform samples and values <1.0 produce
sharper samples (default: 1.0)
match_source_len (bool, optional): outputs should match the source
length (default: False)
"""
super().__init__()
if isinstance(models, EnsembleModel):
self.model = models
else:
self.model = EnsembleModel(models)
self.tgt_dict = tgt_dict
self.pad = tgt_dict.pad()
self.unk = tgt_dict.unk()
self.eos = tgt_dict.eos() if eos is None else eos
self.symbols_to_strip_from_output = (
symbols_to_strip_from_output.union({self.eos})
if symbols_to_strip_from_output is not None
else {self.eos}
)
self.vocab_size = len(tgt_dict)
self.beam_size = beam_size
# the max beam size is the dictionary size - 1, since we never select pad
self.beam_size = min(beam_size, self.vocab_size - 1)
self.max_len_a = max_len_a
self.max_len_b = max_len_b
self.min_len = min_len
self.normalize_scores = normalize_scores
self.len_penalty = len_penalty
self.unk_penalty = unk_penalty
self.temperature = temperature
self.match_source_len = match_source_len
self.no_repeat_ngram_size = no_repeat_ngram_size
assert temperature > 0, "--temperature must be greater than 0"
self.search = (
search.BeamSearch(tgt_dict) if search_strategy is None else search_strategy
)
# We only need to set src_lengths in LengthConstrainedBeamSearch.
# As a module attribute, setting it would break in multithread
# settings when the model is shared.
self.should_set_src_lengths = (
hasattr(self.search, "needs_src_lengths") and self.search.needs_src_lengths
)
self.model.eval()
self.lm_model = lm_model
self.lm_weight = lm_weight
if self.lm_model is not None:
self.lm_model.eval()
def cuda(self):
self.model.cuda()
return self
@torch.no_grad()
def forward(
self,
sample: Dict[str, Dict[str, Tensor]],
prefix_tokens: Optional[Tensor] = None,
bos_token: Optional[int] = None,
):
"""Generate a batch of translations.
Args:
sample (dict): batch
prefix_tokens (torch.LongTensor, optional): force decoder to begin
with these tokens
bos_token (int, optional): beginning of sentence token
(default: self.eos)
"""
return self._generate(sample, prefix_tokens, bos_token=bos_token)
# TODO(myleott): unused, deprecate after pytorch-translate migration
def generate_batched_itr(self, data_itr, beam_size=None, cuda=False, timer=None):
"""Iterate over a batched dataset and yield individual translations.
Args:
cuda (bool, optional): use GPU for generation
timer (StopwatchMeter, optional): time generations
"""
for sample in data_itr:
s = utils.move_to_cuda(sample) if cuda else sample
if "net_input" not in s:
continue
input = s["net_input"]
# model.forward normally channels prev_output_tokens into the decoder
# separately, but SequenceGenerator directly calls model.encoder
encoder_input = {
k: v for k, v in input.items() if k != "prev_output_tokens"
}
if timer is not None:
timer.start()
with torch.no_grad():
hypos = self.generate(encoder_input)
if timer is not None:
timer.stop(sum(len(h[0]["tokens"]) for h in hypos))
for i, id in enumerate(s["id"].data):
# remove padding
src = utils.strip_pad(input["src_tokens"].data[i, :], self.pad)
ref = (
utils.strip_pad(s["target"].data[i, :], self.pad)
if s["target"] is not None
else None
)
yield id, src, ref, hypos[i]
@torch.no_grad()
def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs):
"""Generate translations. Match the api of other fairseq generators.
Args:
models (List[~fairseq.models.FairseqModel]): ensemble of models
sample (dict): batch
prefix_tokens (torch.LongTensor, optional): force decoder to begin
with these tokens
constraints (torch.LongTensor, optional): force decoder to include
the list of constraints
bos_token (int, optional): beginning of sentence token
(default: self.eos)
"""
return self._generate(sample, **kwargs)
def _generate(
self,
sample: Dict[str, Dict[str, Tensor]],
prefix_tokens: Optional[Tensor] = None,
constraints: Optional[Tensor] = None,
bos_token: Optional[int] = None,
):
incremental_states = torch.jit.annotate(
List[Dict[str, Dict[str, Optional[Tensor]]]],
[
torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {})
for i in range(self.model.models_size)
],
)
net_input = sample["net_input"]
command_input = net_input["command_input"]
if "src_tokens" in net_input:
src_tokens = net_input["src_tokens"]
# length of the source text being the character length except EndOfSentence and pad
src_lengths = (
(src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
)
elif "source" in net_input:
src_tokens = net_input["source"]
src_lengths = (
net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1)
if net_input["padding_mask"] is not None
else torch.tensor(src_tokens.size(-1)).to(src_tokens)
)
else:
raise Exception("expected src_tokens or source in net input")
# bsz: total number of sentences in beam
# Note that src_tokens may have more than 2 dimenions (i.e. audio features)
bsz, src_len = src_tokens.size()[:2]
beam_size = self.beam_size
if constraints is not None and not self.search.supports_constraints:
raise NotImplementedError(
"Target-side constraints were provided, but search method doesn't support them"
)
# Initialize constraints, when active
self.search.init_constraints(constraints, beam_size)
max_len: int = -1
if self.match_source_len:
max_len = src_lengths.max().item()
else:
max_len = min(
int(self.max_len_a * src_len + self.max_len_b),
# exclude the EOS marker
self.model.max_decoder_positions() - 1,
)
assert (
self.min_len <= max_len
), "min_len cannot be larger than max_len, please adjust these!"
# compute the encoder output for each beam
encoder_outs = self.model.forward_encoder(net_input)
# placeholder of indices for bsz * beam_size to hold tokens and accumulative scores
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
new_order = new_order.to(src_tokens.device).long()
encoder_outs = self.model.reorder_encoder_out(encoder_outs, new_order)
# ensure encoder_outs is a List.
assert encoder_outs is not None
# initialize buffers
scores = (
torch.zeros(bsz * beam_size, max_len + 1).to(src_tokens).float()
) # +1 for eos; pad is never chosen for scoring
tokens = (
torch.zeros(bsz * beam_size, max_len + 2)
.to(src_tokens)
.long()
.fill_(self.pad)
) # +2 for eos and pad
tokens[:, 0] = self.eos if bos_token is None else bos_token
attn: Optional[Tensor] = None
# A list that indicates candidates that should be ignored.
# For example, suppose we're sampling and have already finalized 2/5
# samples. Then cands_to_ignore would mark 2 positions as being ignored,
# so that we only finalize the remaining 3 samples.
cands_to_ignore = (
torch.zeros(bsz, beam_size).to(src_tokens).eq(-1)
) # forward and backward-compatible False mask
# list of completed sentences
finalized = torch.jit.annotate(
List[List[Dict[str, Tensor]]],
[torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)],
) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step
finished = [
False for i in range(bsz)
] # a boolean array indicating if the sentence at the index is finished or not
num_remaining_sent = bsz # number of sentences remaining
# number of candidate hypos per step
cand_size = 2 * beam_size # 2 x beam size in case half are EOS
# offset arrays for converting between different indexing schemes
bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens)
cand_offsets = torch.arange(0, cand_size).type_as(tokens)
reorder_state: Optional[Tensor] = None
batch_idxs: Optional[Tensor] = None
original_batch_idxs: Optional[Tensor] = None
if "id" in sample and isinstance(sample["id"], Tensor):
original_batch_idxs = sample["id"]
else:
original_batch_idxs = torch.arange(0, bsz).type_as(tokens)
for step in range(max_len + 1): # one extra step for EOS marker
# reorder decoder internal states based on the prev choice of beams
# print(f'step: {step}')
if reorder_state is not None:
if batch_idxs is not None:
# update beam indices to take into account removed sentences
corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(
batch_idxs
)
reorder_state.view(-1, beam_size).add_(
corr.unsqueeze(-1) * beam_size
)
original_batch_idxs = original_batch_idxs[batch_idxs]
self.model.reorder_incremental_state(incremental_states, reorder_state)
encoder_outs = self.model.reorder_encoder_out(
encoder_outs, reorder_state
)
lprobs, avg_attn_scores = self.model.forward_decoder(
tokens[:, : step + 1],
command_input,
encoder_outs,
incremental_states,
self.temperature,
)
if self.lm_model is not None:
lm_out = self.lm_model(tokens[:, : step + 1])
probs = self.lm_model.get_normalized_probs(
lm_out, log_probs=True, sample=None
)
probs = probs[:, -1, :] * self.lm_weight
lprobs += probs
lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)
lprobs[:, self.pad] = -math.inf # never select pad
lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty
# handle max length constraint
if step >= max_len:
lprobs[:, : self.eos] = -math.inf
lprobs[:, self.eos + 1 :] = -math.inf
# handle prefix tokens (possibly with different lengths)
if (
prefix_tokens is not None
and step < prefix_tokens.size(1)
and step < max_len
):
lprobs, tokens, scores = self._prefix_tokens(
step, lprobs, scores, tokens, prefix_tokens, beam_size
)
elif step < self.min_len:
# minimum length constraint (does not apply if using prefix_tokens)
lprobs[:, self.eos] = -math.inf
# Record attention scores, only support avg_attn_scores is a Tensor
if avg_attn_scores is not None:
if attn is None:
attn = torch.empty(
bsz * beam_size, avg_attn_scores.size(1), max_len + 2
).to(scores)
attn[:, :, step + 1].copy_(avg_attn_scores)
scores = scores.type_as(lprobs)
eos_bbsz_idx = torch.empty(0).to(
tokens
) # indices of hypothesis ending with eos (finished sentences)
eos_scores = torch.empty(0).to(
scores
) # scores of hypothesis ending with eos (finished sentences)
if self.should_set_src_lengths:
self.search.set_src_lengths(src_lengths)
if self.no_repeat_ngram_size > 0:
lprobs = self._no_repeat_ngram(tokens, lprobs, bsz, beam_size, step)
# Shape: (batch, cand_size)
cand_scores, cand_indices, cand_beams = self.search.step(
step,
lprobs.view(bsz, -1, self.vocab_size),
scores.view(bsz, beam_size, -1)[:, :, :step],
tokens[:, : step + 1],
original_batch_idxs,
)
# cand_bbsz_idx contains beam indices for the top candidate
# hypotheses, with a range of values: [0, bsz*beam_size),
# and dimensions: [bsz, cand_size]
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
# finalize hypotheses that end in eos
# Shape of eos_mask: (batch size, beam size)
eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf)
eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask)
# only consider eos when it's among the top beam_size indices
# Now we know what beam item(s) to finish
# Shape: 1d list of absolute-numbered
eos_bbsz_idx = torch.masked_select(
cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size]
)
finalized_sents: List[int] = []
if eos_bbsz_idx.numel() > 0:
eos_scores = torch.masked_select(
cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size]
)
finalized_sents = self.finalize_hypos(
step,
eos_bbsz_idx,
eos_scores,
tokens,
scores,
finalized,
finished,
beam_size,
attn,
src_lengths,
max_len,
)
num_remaining_sent -= len(finalized_sents)
assert num_remaining_sent >= 0
if num_remaining_sent == 0:
break
if self.search.stop_on_max_len and step >= max_len:
break
assert step < max_len
# Remove finalized sentences (ones for which {beam_size}
# finished hypotheses have been generated) from the batch.
if len(finalized_sents) > 0:
new_bsz = bsz - len(finalized_sents)
# construct batch_idxs which holds indices of batches to keep for the next pass
batch_mask = torch.ones(
bsz, dtype=torch.bool, device=cand_indices.device
)
batch_mask[finalized_sents] = False
# TODO replace `nonzero(as_tuple=False)` after TorchScript supports it
batch_idxs = torch.arange(
bsz, device=cand_indices.device
).masked_select(batch_mask)
# Choose the subset of the hypothesized constraints that will continue
self.search.prune_sentences(batch_idxs)
eos_mask = eos_mask[batch_idxs]
cand_beams = cand_beams[batch_idxs]
bbsz_offsets.resize_(new_bsz, 1)
cand_bbsz_idx = cand_beams.add(bbsz_offsets)
cand_scores = cand_scores[batch_idxs]
cand_indices = cand_indices[batch_idxs]
command_input = command_input[batch_idxs]
if prefix_tokens is not None:
prefix_tokens = prefix_tokens[batch_idxs]
src_lengths = src_lengths[batch_idxs]
cands_to_ignore = cands_to_ignore[batch_idxs]
scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
if attn is not None:
attn = attn.view(bsz, -1)[batch_idxs].view(
new_bsz * beam_size, attn.size(1), -1
)
bsz = new_bsz
else:
batch_idxs = None
# Set active_mask so that values > cand_size indicate eos hypos
# and values < cand_size indicate candidate active hypos.
# After, the min values per row are the top candidate active hypos
# Rewrite the operator since the element wise or is not supported in torchscript.
eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size]))
active_mask = torch.add(
eos_mask.type_as(cand_offsets) * cand_size,
cand_offsets[: eos_mask.size(1)],
)
# get the top beam_size active hypotheses, which are just
# the hypos with the smallest values in active_mask.
# {active_hypos} indicates which {beam_size} hypotheses
# from the list of {2 * beam_size} candidates were
# selected. Shapes: (batch size, beam size)
new_cands_to_ignore, active_hypos = torch.topk(
active_mask, k=beam_size, dim=1, largest=False
)
# update cands_to_ignore to ignore any finalized hypos.
cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size]
# Make sure there is at least one active item for each sentence in the batch.
assert (~cands_to_ignore).any(dim=1).all()
# update cands_to_ignore to ignore any finalized hypos
# {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam
# can be selected more than once).
active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos)
active_scores = torch.gather(cand_scores, dim=1, index=active_hypos)
active_bbsz_idx = active_bbsz_idx.view(-1)
active_scores = active_scores.view(-1)
# copy tokens and scores for active hypotheses
# Set the tokens for each beam (can select the same row more than once)
tokens[:, : step + 1] = torch.index_select(
tokens[:, : step + 1], dim=0, index=active_bbsz_idx
)
# Select the next token for each of them
tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather(
cand_indices, dim=1, index=active_hypos
)
if step > 0:
scores[:, :step] = torch.index_select(
scores[:, :step], dim=0, index=active_bbsz_idx
)
scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather(
cand_scores, dim=1, index=active_hypos
)
# Update constraints based on which candidates were selected for the next beam
self.search.update_constraints(active_hypos)
# copy attention for active hypotheses
if attn is not None:
attn[:, :, : step + 2] = torch.index_select(
attn[:, :, : step + 2], dim=0, index=active_bbsz_idx
)
# reorder incremental state in decoder
reorder_state = active_bbsz_idx
# sort by score descending
for sent in range(len(finalized)):
scores = torch.tensor(
[float(elem["score"].item()) for elem in finalized[sent]]
)
_, sorted_scores_indices = torch.sort(scores, descending=True)
finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices]
finalized[sent] = torch.jit.annotate(
List[Dict[str, Tensor]], finalized[sent]
)
return finalized
def _prefix_tokens(
self, step: int, lprobs, scores, tokens, prefix_tokens, beam_size: int
):
"""Handle prefix tokens"""
prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1)
prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1))
prefix_mask = prefix_toks.ne(self.pad)
lprobs[prefix_mask] = torch.tensor(-math.inf).to(lprobs)
lprobs[prefix_mask] = lprobs[prefix_mask].scatter(
-1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask]
)
# if prefix includes eos, then we should make sure tokens and
# scores are the same across all beams
eos_mask = prefix_toks.eq(self.eos)
if eos_mask.any():
# validate that the first beam matches the prefix
first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[
:, 0, 1 : step + 1
]
eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0]
target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step]
assert (first_beam == target_prefix).all()
# copy tokens, scores and lprobs from the first beam to all beams
tokens = self.replicate_first_beam(tokens, eos_mask_batch_dim, beam_size)
scores = self.replicate_first_beam(scores, eos_mask_batch_dim, beam_size)
lprobs = self.replicate_first_beam(lprobs, eos_mask_batch_dim, beam_size)
return lprobs, tokens, scores
def replicate_first_beam(self, tensor, mask, beam_size: int):
tensor = tensor.view(-1, beam_size, tensor.size(-1))
tensor[mask] = tensor[mask][:, :1, :]
return tensor.view(-1, tensor.size(-1))
def finalize_hypos(
self,
step: int,
bbsz_idx,
eos_scores,
tokens,
scores,
finalized: List[List[Dict[str, Tensor]]],
finished: List[bool],
beam_size: int,
attn: Optional[Tensor],
src_lengths,
max_len: int,
):
"""Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly.
A sentence is finalized when {beam_size} finished items have been collected for it.
Returns number of sentences (not beam items) being finalized.
These will be removed from the batch and not processed further.
Args:
bbsz_idx (Tensor):
"""
assert bbsz_idx.numel() == eos_scores.numel()
# clone relevant token and attention tensors.
# tokens is (batch * beam, max_len). So the index_select
# gets the newly EOS rows, then selects cols 1..{step + 2}
tokens_clone = tokens.index_select(0, bbsz_idx)[
:, 1 : step + 2
] # skip the first index, which is EOS
tokens_clone[:, step] = self.eos
attn_clone = (
attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2]
if attn is not None
else None
)
# compute scores per token position
pos_scores = scores.index_select(0, bbsz_idx)[:, : step + 1]
pos_scores[:, step] = eos_scores
# convert from cumulative to per-position scores
pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]
# normalize sentence-level scores
if self.normalize_scores:
eos_scores /= (step + 1) ** self.len_penalty
# cum_unfin records which sentences in the batch are finished.
# It helps match indexing between (a) the original sentences
# in the batch and (b) the current, possibly-reduced set of
# sentences.
cum_unfin: List[int] = []
prev = 0
for f in finished:
if f:
prev += 1
else:
cum_unfin.append(prev)
# set() is not supported in script export
# The keys here are of the form "{sent}_{unfin_idx}", where
# "unfin_idx" is the index in the current (possibly reduced)
# list of sentences, and "sent" is the index in the original,
# unreduced batch
sents_seen: Dict[str, Optional[Tensor]] = {}
# For every finished beam item
for i in range(bbsz_idx.size()[0]):
idx = bbsz_idx[i]
score = eos_scores[i]
# sentence index in the current (possibly reduced) batch
unfin_idx = idx // beam_size
# sentence index in the original (unreduced) batch
sent = unfin_idx + cum_unfin[unfin_idx]
# print(f"{step} FINISHED {idx} {score} {sent}={unfin_idx} {cum_unfin}")
# Cannot create dict for key type '(int, int)' in torchscript.
# The workaround is to cast int to string
seen = str(sent.item()) + "_" + str(unfin_idx.item())
if seen not in sents_seen:
sents_seen[seen] = None
if self.match_source_len and step > src_lengths[unfin_idx]:
score = torch.tensor(-math.inf).to(score)
# An input sentence (among those in a batch) is finished when
# beam_size hypotheses have been collected for it
if len(finalized[sent]) < beam_size:
if attn_clone is not None:
# remove padding tokens from attn scores
hypo_attn = attn_clone[i]
else:
hypo_attn = torch.empty(0)
finalized[sent].append(
{
"tokens": tokens_clone[i],
"score": score,
"attention": hypo_attn, # src_len x tgt_len
"alignment": torch.empty(0),
"positional_scores": pos_scores[i],
}
)
newly_finished: List[int] = []
for seen in sents_seen.keys():
# check termination conditions for this sentence
sent: int = int(float(seen.split("_")[0]))
unfin_idx: int = int(float(seen.split("_")[1]))
if not finished[sent] and self.is_finished(
step, unfin_idx, max_len, len(finalized[sent]), beam_size
):
finished[sent] = True
newly_finished.append(unfin_idx)
return newly_finished
def is_finished(
self,
step: int,
unfin_idx: int,
max_len: int,
finalized_sent_len: int,
beam_size: int,
):
"""
Check whether decoding for a sentence is finished, which
occurs when the list of finalized sentences has reached the
beam size, or when we reach the maximum length.
"""
assert finalized_sent_len <= beam_size
if finalized_sent_len == beam_size or step == max_len:
return True
return False
def calculate_banned_tokens(
self,
tokens,
step: int,
gen_ngrams: List[Dict[str, List[int]]],
no_repeat_ngram_size: int,
bbsz_idx: int,
):
tokens_list: List[int] = tokens[
bbsz_idx, step + 2 - no_repeat_ngram_size : step + 1
].tolist()
# before decoding the next token, prevent decoding of ngrams that have already appeared
ngram_index = ",".join([str(x) for x in tokens_list])
return gen_ngrams[bbsz_idx].get(ngram_index, torch.jit.annotate(List[int], []))
def transpose_list(self, l: List[List[int]]):
# GeneratorExp aren't supported in TS so ignoring the lint
min_len = min([len(x) for x in l]) # noqa
l2 = [[row[i] for row in l] for i in range(min_len)]
return l2
def _no_repeat_ngram(self, tokens, lprobs, bsz: int, beam_size: int, step: int):
# for each beam and batch sentence, generate a list of previous ngrams
gen_ngrams: List[Dict[str, List[int]]] = [
torch.jit.annotate(Dict[str, List[int]], {})
for bbsz_idx in range(bsz * beam_size)
]
cpu_tokens = tokens.cpu()
for bbsz_idx in range(bsz * beam_size):
gen_tokens: List[int] = cpu_tokens[bbsz_idx].tolist()
for ngram in self.transpose_list(
[gen_tokens[i:] for i in range(self.no_repeat_ngram_size)]
):
key = ",".join([str(x) for x in ngram[:-1]])
gen_ngrams[bbsz_idx][key] = gen_ngrams[bbsz_idx].get(
key, torch.jit.annotate(List[int], [])
) + [ngram[-1]]
if step + 2 - self.no_repeat_ngram_size >= 0:
# no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
banned_tokens = [
self.calculate_banned_tokens(
tokens, step, gen_ngrams, self.no_repeat_ngram_size, bbsz_idx
)
for bbsz_idx in range(bsz * beam_size)
]
else:
banned_tokens = [
torch.jit.annotate(List[int], []) for bbsz_idx in range(bsz * beam_size)
]
for bbsz_idx in range(bsz * beam_size):
lprobs[bbsz_idx][
torch.tensor(banned_tokens[bbsz_idx]).long()
] = torch.tensor(-math.inf).to(lprobs)
return lprobs
class EnsembleModel(nn.Module):
"""A wrapper around an ensemble of models."""
def __init__(self, models):
super().__init__()
self.models_size = len(models)
# method '__len__' is not supported in ModuleList for torch script
self.single_model = models[0]
self.models = nn.ModuleList(models)
self.has_incremental: bool = False
if all(
hasattr(m, "decoder") and isinstance(m.decoder, FairseqIncrementalDecoder)
for m in models
):
self.has_incremental = True
def forward(self):
pass
def has_encoder(self):
return hasattr(self.single_model, "encoder")
def has_incremental_states(self):
return self.has_incremental
def max_decoder_positions(self):
return min([m.max_decoder_positions() for m in self.models])
@torch.jit.export
def forward_encoder(self, net_input: Dict[str, Tensor]):
if not self.has_encoder():
return None
return [model.encoder.forward_torchscript(net_input) for model in self.models]
@torch.jit.export
def forward_decoder(
self,
tokens,
command_input,
encoder_outs: List[EncoderOut],
incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
temperature: float = 1.0,
):
log_probs = []
avg_attn: Optional[Tensor] = None
encoder_out: Optional[EncoderOut] = None
for i, model in enumerate(self.models):
if self.has_encoder():
encoder_out = encoder_outs[i]
# decode each model
if self.has_incremental_states():
decoder_out = model.decoder.forward(
tokens,
command_input,
encoder_out=encoder_out,
incremental_state=incremental_states[i],
)
else:
decoder_out = model.decoder.forward(tokens, command_input, encoder_out=encoder_out)
attn: Optional[Tensor] = None
decoder_len = len(decoder_out)
if decoder_len > 1 and decoder_out[1] is not None:
if isinstance(decoder_out[1], Tensor):
attn = decoder_out[1]
else:
attn_holder = decoder_out[1]["attn"]
if isinstance(attn_holder, Tensor):
attn = attn_holder
elif attn_holder is not None:
attn = attn_holder[0]
if attn is not None:
attn = attn[:, -1, :]
decoder_out_tuple = (
decoder_out[0][:, -1:, :].div_(temperature),
None if decoder_len <= 1 else decoder_out[1],
)
probs = model.get_normalized_probs(
decoder_out_tuple, log_probs=True, sample=None
)
probs = probs[:, -1, :]
if self.models_size == 1:
return probs, attn
log_probs.append(probs)
if attn is not None:
if avg_attn is None:
avg_attn = attn
else:
avg_attn.add_(attn)
avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log(
self.models_size
)
if avg_attn is not None:
avg_attn.div_(self.models_size)
return avg_probs, avg_attn
@torch.jit.export
def reorder_encoder_out(self, encoder_outs: Optional[List[EncoderOut]], new_order):
"""
Reorder encoder output according to *new_order*.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
*encoder_out* rearranged according to *new_order*
"""
new_outs: List[EncoderOut] = []
if not self.has_encoder():
return new_outs
for i, model in enumerate(self.models):
assert encoder_outs is not None
new_outs.append(
model.encoder.reorder_encoder_out(encoder_outs[i], new_order)
)
return new_outs
@torch.jit.export
def reorder_incremental_state(
self,
incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
new_order,
):
if not self.has_incremental_states():
return
for i, model in enumerate(self.models):
model.decoder.reorder_incremental_state_scripting(
incremental_states[i], new_order
)

Просмотреть файл

@ -0,0 +1,330 @@
from fairseq.data.base_wrapper_dataset import BaseWrapperDataset
import numpy as np
from fairseq.data import data_utils
from fairseq.tasks.language_modeling import LanguageModelingTask, LanguageModelingConfig
from fairseq.tasks import register_task
import logging
from .linear import transformer_lm
from fairseq import search
from fairseq.models.transformer_lm import (
TransformerLanguageModel,
TransformerLanguageModelConfig,
base_lm_architecture,
transformer_lm_gpt,
DEFAULT_MAX_TARGET_POSITIONS
)
from fairseq.models import (
register_model,
register_model_architecture,
)
import torch
logger = logging.getLogger(__name__)
class CommandDataset(BaseWrapperDataset):
def __init__(self, dataset, command_data, args = None):
super().__init__(dataset)
self._sizes = self.dataset.sizes.copy()
self.command_data = command_data # need change to np.mmap
self.args = args
def __getitem__(self, index):
sample = self.dataset[index]
assert len(sample["source"]) <= self.args.truncated_length + 2, f"The maximum length exceeds {self.args.truncated_length}. Please resample the dataset."
return {
"id": index,
"source": sample["source"],
"target": sample["target"],
"command": torch.from_numpy(np.array(self.command_data[index])).to(sample["source"].device)
}
@property
def sizes(self):
return self._sizes
def size(self, index):
return self._sizes[index]
def num_tokens(self, index):
return self._sizes[index]
def filter_indices_by_size(self, indices, max_sizes):
"""
Filter a list of sample indices. Remove those that are longer than
specified in *max_sizes*.
WARNING: don't update, override method in child classes
Args:
indices (np.array): original array of sample indices
max_sizes (int or list[int] or tuple[int]): max sample size,
can be defined separately for src and tgt (then list or tuple)
Returns:
np.array: filtered sample array
list: list of removed indices
"""
if isinstance(max_sizes, float) or isinstance(max_sizes, int):
if hasattr(self, "sizes") and isinstance(self.sizes, np.ndarray):
ignored = indices[self.sizes[indices] > max_sizes].tolist()
indices = indices[self.sizes[indices] <= max_sizes]
elif (
hasattr(self, "sizes")
and isinstance(self.sizes, list)
and len(self.sizes) == 1
):
ignored = indices[self.sizes[0][indices] > max_sizes].tolist()
indices = indices[self.sizes[0][indices] <= max_sizes]
else:
indices, ignored = data_utils._filter_by_size_dynamic(
indices, self.size, max_sizes
)
else:
indices, ignored = data_utils._filter_by_size_dynamic(
indices, self.size, max_sizes
)
if len(ignored) > 0:
print(self.sizes)
print(ignored)
print(max_sizes)
return indices, ignored
def collater(self, samples):
# samples: list->dict, just as __get_item__(index) rets
# print("samples type:", type(samples))
return self.collate_helper(samples, self.dataset.vocab.pad(), self.dataset.vocab.eos())
def collate_helper(self, samples, pad_idx, eos_idx):
if len(samples) == 0:
return {}
def merge(key, is_list=False):
if is_list:
res = []
for i in range(len(samples[0][key])):
res.append(
data_utils.collate_tokens(
[s[key][i] for s in samples],
pad_idx,
eos_idx,
left_pad=False,
)
)
return res
else:
return data_utils.collate_tokens(
[s[key] for s in samples],
pad_idx,
eos_idx,
left_pad=False,
)
src_tokens = merge("source")
if samples[0]["command"] is not None:
command_tokens = merge("command")
else:
command_tokens = None
if samples[0]["target"] is not None:
is_target_list = isinstance(samples[0]["target"], list)
target = merge("target", is_target_list)
else:
target = src_tokens
return {
"id": torch.LongTensor([s["id"] for s in samples]),
"nsentences": len(samples),
"ntokens": sum(len(s["source"]) for s in samples),
"net_input": {
"src_tokens": src_tokens,
"command_input": command_tokens,
"src_lengths": torch.LongTensor([s["source"].numel() for s in samples]),
},
"target": target,
}
@register_task("language_modeling_control", dataclass=LanguageModelingConfig)
class LanguageModelingTaskWithControl(LanguageModelingTask):
@classmethod
def add_args(cls, parser):
super().add_args(parser)
parser.add_argument("--command_in_dim", type=int)
parser.add_argument("--command_out_dim", type=int)
parser.add_argument("--truncated_length", type=int, default=8192)
parser.add_argument("--feature_num", type=int, default=3)
parser.add_argument("--control_mode", type=str)
parser.add_argument("--command_path", type=str)
parser.add_argument("--bucket_num", type=int)
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
super().load_dataset(split, epoch=epoch, combine=combine, **kwargs)
command_dataset = np.load(f"{self.args.command_path}/{split}_command.npy", mmap_mode="r")
assert command_dataset.shape[0] == len(self.datasets[split]), f"error command sample num for {split}!"
assert command_dataset.shape[1] == self.args.feature_num, "command feature_num isn't the same as args feature_num"
logger.info(f'Load CommandSourceTargetDataset for {split} from {self.args.command_path}, truncated length: {self.args.truncated_length}')
self.datasets[split] = CommandDataset(self.datasets[split], command_dataset, self.args)
def get_batch_iterator(
self,
dataset,
max_tokens=None,
max_sentences=None,
max_positions=None,
ignore_invalid_inputs=False,
required_batch_size_multiple=1,
seed=1,
num_shards=1,
shard_id=0,
num_workers=0,
epoch=1,
data_buffer_size=0,
disable_iterator_cache=False,
):
split = None
if 'train' in self.datasets and dataset == self.datasets['train']:
split = 'train'
elif 'valid' in self.datasets and dataset == self.datasets['valid']:
split = 'valid'
elif 'test' in self.datasets and dataset == self.datasets['test']:
split = 'test'
max_positions_split = getattr(self.args, 'max_positions_%s' % split, None)
if max_positions_split is None:
max_positions_split = getattr(self.args, 'truncate_%s' % split, None)
if max_positions_split is not None:
max_positions = max_positions_split
logger.info('Using max_positions limit (%d) for %s' % (max_positions,
split if split is not None else 'unknown'))
return super().get_batch_iterator(
dataset,
max_tokens=max_tokens,
max_sentences=max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=ignore_invalid_inputs,
required_batch_size_multiple=required_batch_size_multiple,
seed=seed,
num_shards=num_shards,
shard_id=shard_id,
num_workers=num_workers,
epoch=epoch,
data_buffer_size=data_buffer_size,
disable_iterator_cache=disable_iterator_cache
)
def build_generator(
self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None
):
if getattr(args, "score_reference", False):
from fairseq.sequence_scorer import SequenceScorer
return SequenceScorer(
self.target_dictionary,
compute_alignment=getattr(args, "print_alignment", False),
)
from .command_seq_generator import CommandSequenceGenerator
# Choose search strategy. Defaults to Beam Search.
sampling = getattr(args, "sampling", False)
sampling_topk = getattr(args, "sampling_topk", -1)
sampling_topp = getattr(args, "sampling_topp", -1.0)
diverse_beam_groups = getattr(args, "diverse_beam_groups", -1)
diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5)
match_source_len = getattr(args, "match_source_len", False)
diversity_rate = getattr(args, "diversity_rate", -1)
constrained = getattr(args, "constraints", False)
prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None)
if (
sum(
int(cond)
for cond in [
sampling,
diverse_beam_groups > 0,
match_source_len,
diversity_rate > 0,
]
)
> 1
):
raise ValueError("Provided Search parameters are mutually exclusive.")
assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling"
assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling"
if sampling:
search_strategy = search.Sampling(
self.target_dictionary, sampling_topk, sampling_topp
)
elif diverse_beam_groups > 0:
search_strategy = search.DiverseBeamSearch(
self.target_dictionary, diverse_beam_groups, diverse_beam_strength
)
elif match_source_len:
# this is useful for tagging applications where the output
# length should match the input length, so we hardcode the
# length constraints for simplicity
search_strategy = search.LengthConstrainedBeamSearch(
self.target_dictionary,
min_len_a=1,
min_len_b=0,
max_len_a=1,
max_len_b=0,
)
elif diversity_rate > -1:
search_strategy = search.DiverseSiblingsSearch(
self.target_dictionary, diversity_rate
)
elif constrained:
search_strategy = search.LexicallyConstrainedBeamSearch(
self.target_dictionary, args.constraints
)
elif prefix_allowed_tokens_fn:
search_strategy = search.PrefixConstrainedBeamSearch(
self.target_dictionary, prefix_allowed_tokens_fn
)
else:
search_strategy = search.BeamSearch(self.target_dictionary)
if seq_gen_cls is None:
if getattr(args, "print_alignment", False):
raise ImportError("SequenceGeneratorWithAlignment is not allowed!")
# seq_gen_cls = SequenceGeneratorWithAlignment
else:
seq_gen_cls = CommandSequenceGenerator
extra_gen_cls_kwargs = extra_gen_cls_kwargs or {}
return seq_gen_cls(
models,
self.target_dictionary,
beam_size=getattr(args, "beam", 5),
max_len_a=getattr(args, "max_len_a", 0),
max_len_b=getattr(args, "max_len_b", 200),
min_len=getattr(args, "min_len", 1),
normalize_scores=(not getattr(args, "unnormalized", False)),
len_penalty=getattr(args, "lenpen", 1),
unk_penalty=getattr(args, "unkpen", 0),
temperature=getattr(args, "temperature", 1.0),
match_source_len=getattr(args, "match_source_len", False),
no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
search_strategy=search_strategy,
**extra_gen_cls_kwargs,
)

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,123 @@
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
# Apoorv Vyas <avyas@idiap.ch>
#
"""The base attention layer performs all the query key value projections and
output projections leaving the implementation of the attention to the inner
attention module.
The transformer layers, however, are agnostic of the attention implementation
and any layer that implements the same interface can substitute for the
attention layer.
"""
import math
import torch.nn as nn
from torch.nn import Linear, Module
class AttentionLayer(Module):
"""Implement the attention layer. Namely project the inputs to multi-head
queries, keys and values, call the attention implementation and then
reproject the output.
It can be thought of as a decorator (see decorator design patter) of an
attention layer.
Arguments
---------
attention: Specific inner attention implementation that just computes a
weighted average of values given a similarity of queries and
keys.
d_model: The input feature dimensionality
n_heads: The number of heads for the multi head attention
d_keys: The dimensionality of the keys/queries
(default: d_model/n_heads)
d_values: The dimensionality of the values (default: d_model/n_heads)
event_dispatcher: str or EventDispatcher instance to be used by this
module for dispatching events (default: the default
global dispatcher)
"""
def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None):
super(AttentionLayer, self).__init__()
# Fill d_keys and d_values
d_keys = d_keys or (d_model//n_heads)
d_values = d_values or (d_model//n_heads)
self.inner_attention = attention
self.q_proj = Linear(d_model, d_keys * n_heads)
self.k_proj = Linear(d_model, d_keys * n_heads)
self.v_proj = Linear(d_model, d_values * n_heads)
self.out_proj = Linear(d_values * n_heads, d_model)
self.n_heads = n_heads
self.qkv_same_dim = d_keys == d_values == d_model // n_heads
def reset_parameters(self):
if self.qkv_same_dim:
# Empirically observed the convergence to be much better with
# the scaled initialization
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
else:
nn.init.xavier_uniform_(self.k_proj.weight)
nn.init.xavier_uniform_(self.v_proj.weight)
nn.init.xavier_uniform_(self.q_proj.weight)
nn.init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None:
nn.init.constant_(self.out_proj.bias, 0.0)
def forward(self, queries, keys, values, attn_mask=None, key_padding_mask=None):
"""Apply attention to the passed in queries/keys/values after
projecting them to multiple heads.
In the argument description we make use of the following sizes
- N: the batch size
- L: The maximum length of the queries
- S: The maximum length of the keys (the actual length per sequence
is given by the length mask)
- D: The input feature dimensionality passed in the constructor as
'd_model'
Arguments
---------
queries: (N, L, D) The tensor containing the queries
keys: (N, S, D) The tensor containing the keys
values: (N, S, D) The tensor containing the values
attn_mask: An implementation of BaseMask that encodes where each
query can attend to
query_lengths: An implementation of BaseMask that encodes how
many queries each sequence in the batch consists of
key_lengths: An implementation of BaseMask that encodes how
many queries each sequence in the batch consists of
Returns
-------
The new value for each query as a tensor of shape (N, L, D).
"""
# Extract the dimensions into local variables
N, L, _ = queries.shape
_, S, _ = keys.shape
H = self.n_heads
# Project the queries/keys/values
queries = self.q_proj(queries).view(N, L, H, -1)
keys = self.k_proj(keys).view(N, S, H, -1)
values = self.v_proj(values).view(N, S, H, -1)
# Compute the attention
new_values = self.inner_attention(
queries,
keys,
values,
attn_mask,
key_padding_mask
).view(N, L, -1)
# Project the output and return
return self.out_proj(new_values)

Просмотреть файл

@ -0,0 +1,76 @@
#
# Copyright (c) 2020 Idiap Research Institute, http://www.idiap.ch/
# Written by Angelos Katharopoulos <angelos.katharopoulos@idiap.ch>,
# Apoorv Vyas <avyas@idiap.ch>
#
"""Implement causally masked linear attention."""
import torch
from torch.nn import Module
from fast_transformers.causal_product import causal_dot_product
from fast_transformers.feature_maps import elu_feature_map
def causal_linear(Q, K, V):
dtype = Q.dtype
Q = Q.permute(0,2,1,3).float().contiguous() # [bs, n_head, seq_len, d_hidden]
K = K.permute(0,2,1,3).float().contiguous()
V = V.permute(0,2,1,3).float().contiguous()
V_new = causal_dot_product(Q, K, V)
return V_new.permute(0,2,1,3).type(dtype).contiguous() # [bs, seq_len, n_head, d_hidden]
class CausalLinearAttention(Module):
"""Implement causally masked attention using dot product of feature maps in
O(N D^2) complexity.
See fast_transformers.attention.linear_attention.LinearAttention for the
general concept of replacing the softmax with feature maps. In addition to
that, we also make use of the fact that causal masking is a triangular mask
which allows us to apply the masking and still compute the attention in O(N
D^2) complexity.
Arguments
---------
feature_map: callable, a callable that applies the feature map to the
last dimension of a tensor (default: elu(x)+1)
eps: float, a small number to ensure the numerical stability of the
denominator (default: 1e-6)
event_dispatcher: str or EventDispatcher instance to be used by this
module for dispatching events (default: the default
global dispatcher)
"""
def __init__(self, query_dimensions, feature_map=None, eps=1e-6):
super(CausalLinearAttention, self).__init__()
self.feature_map = (
feature_map(query_dimensions) if feature_map else
elu_feature_map(query_dimensions)
)
self.eps = eps
def forward(self, queries, keys, values, attn_mask=None, key_padding_mask=None):
# Apply the feature map to the queries and keys
self.feature_map.new_feature_map(queries.device)
Q = self.feature_map.forward_queries(queries) # [bs, seq_len, n_head, d_hidden]
K = self.feature_map.forward_keys(keys)
assert attn_mask is None, "Cannot assign attn_mask for %s" % self.__class__.__name__
if key_padding_mask is not None:
K = K * key_padding_mask.type(queries.dtype)[:, :, None, None]
# Compute the normalizers
Z = 1/(torch.einsum("nlhi,nlhi->nlh", Q, K.cumsum(1)) + self.eps)
# Compute the unnormalized result
V = causal_linear(
Q,
K,
values
)
return V * Z[:, :, :, None]

Просмотреть файл

@ -0,0 +1,477 @@
from fairseq.models.transformer import Linear
from fairseq.models import FairseqDecoder
import math, gc
from typing import Any, Dict, List, Optional
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from fairseq import utils
from fairseq.models.fairseq_encoder import EncoderOut
from fairseq.modules import (
AdaptiveSoftmax,
FairseqDropout,
LayerDropModuleList,
LayerNorm,
PositionalEmbedding,
SinusoidalPositionalEmbedding,
)
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
from torch import Tensor
from .transformer_layer import LinearTransformerDecoderLayer
import numpy as np
class LinearTransformerDecoder(FairseqDecoder):
"""
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`TransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
self.args = args
super().__init__(dictionary)
self.register_buffer("version", torch.Tensor([3]))
self._future_mask = torch.empty(0)
self.dropout_module = FairseqDropout(
args.dropout, module_name=self.__class__.__name__
)
self.decoder_layerdrop = args.decoder_layerdrop
self.share_input_output_embed = args.share_decoder_input_output_embed
input_embed_dim = embed_tokens.embedding_dim
embed_dim = args.decoder_embed_dim
self.embed_dim = embed_dim
self.output_embed_dim = args.decoder_output_dim
self.padding_idx = embed_tokens.padding_idx
self.max_target_positions = args.max_target_positions
self.embed_tokens = embed_tokens
self.embed_scale = 1.0 if args.no_scale_embedding else math.sqrt(embed_dim)
# add new layers:
if self.args.control_mode == "embedding_v2":
self.feature_embedding_layer_list = nn.ModuleList()
for i in range(self.args.feature_num):
if self.args.bucket_num > 0:
self.feature_embedding_layer_list.append(
nn.Embedding(self.args.bucket_num, self.args.command_embed_dim))
else:
self.feature_embedding_layer_list.append(
nn.Linear(1, self.args.command_embed_dim))
self.command_project = nn.Linear(self.args.command_embed_dim * self.args.feature_num, input_embed_dim)
else:
raise ValueError("unknown control mode, please chosen from [embedding_v2]")
if not args.adaptive_input and args.quant_noise_pq > 0:
self.quant_noise = apply_quant_noise_(
nn.Linear(embed_dim, embed_dim, bias=False),
args.quant_noise_pq,
args.quant_noise_pq_block_size,
)
else:
self.quant_noise = None
self.project_in_dim = (
Linear(input_embed_dim, embed_dim, bias=False)
if embed_dim != input_embed_dim
else None
)
self.embed_positions = (
PositionalEmbedding(
args.truncated_length,
embed_dim,
self.padding_idx,
learned=args.decoder_learned_pos,
)
if not args.no_token_positional_embeddings
else None
)
if getattr(args, "layernorm_embedding", False):
self.layernorm_embedding = LayerNorm(embed_dim)
else:
self.layernorm_embedding = None
self.cross_self_attention = getattr(args, "cross_self_attention", False)
if self.decoder_layerdrop > 0.0:
self.layers = LayerDropModuleList(p=self.decoder_layerdrop)
else:
self.layers = nn.ModuleList([])
self.layers.extend(
[
self.build_decoder_layer(args, no_encoder_attn)
for _ in range(args.decoder_layers)
]
)
self.num_layers = len(self.layers)
if args.decoder_normalize_before and not getattr(
args, "no_decoder_final_norm", False
):
self.layer_norm = LayerNorm(embed_dim)
else:
self.layer_norm = None
self.project_out_dim = (
Linear(embed_dim, self.output_embed_dim, bias=False)
if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights
else None
)
self.adaptive_softmax = None
self.output_projection = None
if args.adaptive_softmax_cutoff is not None:
self.adaptive_softmax = AdaptiveSoftmax(
len(dictionary),
self.output_embed_dim,
utils.eval_str_list(args.adaptive_softmax_cutoff, type=int),
dropout=args.adaptive_softmax_dropout,
adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None,
factor=args.adaptive_softmax_factor,
tie_proj=args.tie_adaptive_proj,
)
elif self.share_input_output_embed:
self.output_projection = nn.Linear(
self.embed_tokens.weight.shape[1],
self.embed_tokens.weight.shape[0],
bias=False,
)
self.output_projection.weight = self.embed_tokens.weight
else:
self.output_projection = nn.Linear(
self.output_embed_dim, len(dictionary), bias=False
)
nn.init.normal_(
self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5
)
self.gradient_checkpointing = getattr(self.args, 'gradient_checkpointing', False)
if self.gradient_checkpointing:
checkpointing_layers = getattr(self.args, 'gradient_checkpointing_layers', None)
if checkpointing_layers is None:
gradient_checkpointing_every_n_layer = getattr(self.args, 'gradient_checkpointing_every_n_layer', 1)
checkpointing_layers = tuple(range(0, self.num_layers, gradient_checkpointing_every_n_layer))
self.checkpointing_layers = checkpointing_layers
def build_decoder_layer(self, args, no_encoder_attn=False):
return LinearTransformerDecoderLayer(args, no_encoder_attn)
def forward(
self,
prev_output_tokens,
command_input,
encoder_out: Optional[EncoderOut] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
features_only: bool = False,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
src_lengths: Optional[Any] = None,
return_all_hiddens: bool = False,
):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (optional): output from the encoder, used for
encoder-side attention
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
features_only (bool, optional): only return features without
applying output layer (default: False).
full_context_alignment (bool, optional): don't apply
auto-regressive mask to self-attention (default: False).
Returns:
tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
x, extra = self.extract_features(
prev_output_tokens,
command_input,
encoder_out=encoder_out,
incremental_state=incremental_state,
full_context_alignment=full_context_alignment,
alignment_layer=alignment_layer,
alignment_heads=alignment_heads,
)
if not features_only:
x = self.output_layer(x)
return x, extra
def extract_features(
self,
prev_output_tokens,
command_input,
encoder_out: Optional[EncoderOut] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
):
return self.extract_features_scriptable(
prev_output_tokens,
command_input,
encoder_out,
incremental_state,
full_context_alignment,
alignment_layer,
alignment_heads,
)
"""
A scriptable subclass of this class has an extract_features method and calls
super().extract_features, but super() is not supported in torchscript. Aa copy of
this function is made to be used in the subclass instead.
"""
def extract_features_scriptable(
self,
prev_output_tokens,
command_input,
encoder_out: Optional[EncoderOut] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
):
"""
Similar to *forward* but only return features.
Includes several features from "Jointly Learning to Align and
Translate with Transformer Models" (Garg et al., EMNLP 2019).
Args:
full_context_alignment (bool, optional): don't apply
auto-regressive mask to self-attention (default: False).
alignment_layer (int, optional): return mean alignment over
heads at this layer (default: last layer).
alignment_heads (int, optional): only average alignment over
this many heads (default: all heads).
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
if alignment_layer is None:
alignment_layer = self.num_layers - 1
# embed positions
positions = (
self.embed_positions(
prev_output_tokens, incremental_state=incremental_state
)
if self.embed_positions is not None
else None
)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
if positions is not None:
positions = positions[:, -1:]
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
command_emb_list = []
command_input = command_input.to(torch.int64)
for i in range(self.args.feature_num):
if command_input[0, i] != -1:
command_emb_list.append(
self.feature_embedding_layer_list[i](command_input[:, i])[:, None, :]) # [bs, 1, emb_size]
else:
command_emb_list.append(torch.zeros(command_emb_list[0].shape).to(prev_output_tokens.device))
command_emb = torch.cat(command_emb_list, dim=1)[:, None, :, :] # [bs, seq_len=1, feature_num, emb_size]
seq_len = prev_output_tokens.shape[1]
gc.collect()
batch_size = command_emb.shape[0]
command_emb = torch.repeat_interleave(command_emb, repeats=seq_len,
dim=1) # [bs, seq_len, feature_num, emb_size]
bs, seq_len, feature_num, emb_size = command_emb.shape
command_emb = command_emb.reshape([bs, seq_len, feature_num * emb_size])
command_emb = self.command_project(command_emb)
x += command_emb
if self.quant_noise is not None:
x = self.quant_noise(x)
if self.project_in_dim is not None:
x = self.project_in_dim(x)
if positions is not None:
x += positions
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
self_attn_padding_mask: Optional[Tensor] = None
if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
# decoder layers
attn: Optional[Tensor] = None
inner_states: List[Optional[Tensor]] = [x]
gradient_checkpointing_every_n_layer = getattr(self.args, "gradient_checkpointing_every_n_layer", 1)
for idx, layer in enumerate(self.layers):
# if incremental_state is None and not full_context_alignment:
# self_attn_mask = self.buffered_future_mask(x)
# else:
# self_attn_mask = None
self_attn_mask = None # Casual Linear Attention does not need this
if (
getattr(self.args, "gradient_checkpointing", False) and self.training and
idx in self.checkpointing_layers
):
x, layer_attn, _ = checkpoint(
layer,
x,
encoder_out.encoder_out if encoder_out is not None else None,
encoder_out.encoder_padding_mask if encoder_out is not None else None,
incremental_state,
None,
None,
self_attn_mask,
self_attn_padding_mask,
bool((idx == alignment_layer)),
bool((idx == alignment_layer)),
)
else:
x, layer_attn, _ = layer(
x,
encoder_out.encoder_out if encoder_out is not None else None,
encoder_out.encoder_padding_mask if encoder_out is not None else None,
incremental_state,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
need_attn=bool((idx == alignment_layer)),
need_head_weights=bool((idx == alignment_layer)),
)
inner_states.append(x)
if layer_attn is not None and idx == alignment_layer:
attn = layer_attn.float().to(x)
if attn is not None:
if alignment_heads is not None:
attn = attn[:alignment_heads]
# average probabilities over heads
attn = attn.mean(dim=0)
if self.layer_norm is not None:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
if self.project_out_dim is not None:
x = self.project_out_dim(x)
return x, {"attn": [attn], "inner_states": inner_states}
def output_layer(self, features, **kwargs):
"""Project features to the vocabulary size."""
if self.adaptive_softmax is None:
# project back to size of vocabulary
return self.output_projection(features, **kwargs)
else:
return features
def max_positions(self):
"""Maximum output length supported by the decoder."""
if self.embed_positions is None:
return self.max_target_positions
return min(self.max_target_positions, self.embed_positions.max_positions)
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
# self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround.
if (
self._future_mask.size(0) == 0
or (not self._future_mask.device == tensor.device)
or self._future_mask.size(0) < dim
):
self._future_mask = torch.triu(
utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1
)
self._future_mask = self._future_mask.to(tensor)
return self._future_mask[:dim, :dim]
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
weights_key = "{}.embed_positions.weights".format(name)
if weights_key in state_dict:
del state_dict[weights_key]
state_dict[
"{}.embed_positions._float_tensor".format(name)
] = torch.FloatTensor(1)
if f"{name}.output_projection.weight" not in state_dict:
if self.share_input_output_embed:
embed_out_key = f"{name}.embed_tokens.weight"
else:
embed_out_key = f"{name}.embed_out"
if embed_out_key in state_dict:
state_dict[f"{name}.output_projection.weight"] = state_dict[
embed_out_key
]
if not self.share_input_output_embed:
del state_dict[embed_out_key]
for i in range(self.num_layers):
# update layer norms
layer_norm_map = {
"0": "self_attn_layer_norm",
"1": "encoder_attn_layer_norm",
"2": "final_layer_norm",
}
for old, new in layer_norm_map.items():
for m in ("weight", "bias"):
k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m)
if k in state_dict:
state_dict[
"{}.layers.{}.{}.{}".format(name, i, new, m)
] = state_dict[k]
del state_dict[k]
version_key = "{}.version".format(name)
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
# earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None
self.normalize = False
state_dict[version_key] = torch.Tensor([1])
return state_dict

Просмотреть файл

@ -0,0 +1,191 @@
from typing import Dict, List, Optional
import torch
from torch import Tensor
from fairseq.modules.transformer_layer import TransformerDecoderLayer, TransformerEncoderLayer
# from fast_transformers.attention.attention_layer import AttentionLayer
# from fast_transformers.attention.causal_linear_attention import CausalLinearAttention
# from fast_transformers.masking import TriangularCausalMask, LengthMask, FullMask
from .causal_linear_attention import CausalLinearAttention
from .attention_layer import AttentionLayer
# from fast_transformers.transformers import TransformerEncoderLayer
from fast_transformers.attention import LinearAttention
class LinearTransformerDecoderLayer(TransformerDecoderLayer):
def __init__(self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False):
super().__init__(args, no_encoder_attn=no_encoder_attn, add_bias_kv=add_bias_kv, add_zero_attn=add_zero_attn)
self.decoder_attention_heads = args.decoder_attention_heads
def build_self_attention(
self, embed_dim, args, add_bias_kv=False, add_zero_attn=False
):
causal_linear_attention = CausalLinearAttention(embed_dim)
linear_attention_layer = AttentionLayer(causal_linear_attention,
embed_dim, args.decoder_attention_heads)
return linear_attention_layer
def forward(
self,
x,
encoder_out: Optional[torch.Tensor] = None,
encoder_padding_mask: Optional[torch.Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
prev_self_attn_state: Optional[List[torch.Tensor]] = None,
prev_attn_state: Optional[List[torch.Tensor]] = None,
self_attn_mask: Optional[torch.Tensor] = None,
self_attn_padding_mask: Optional[torch.Tensor] = None,
need_attn: bool = False,
need_head_weights: bool = False,
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor, optional): binary
ByteTensor of shape `(batch, src_len)` where padding
elements are indicated by ``1``.
need_attn (bool, optional): return attention weights
need_head_weights (bool, optional): return attention weights
for each head (default: return average over heads).
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
if need_head_weights:
need_attn = True
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
# if prev_self_attn_state is not None:
# prev_key, prev_value = prev_self_attn_state[:2]
# saved_state: Dict[str, Optional[Tensor]] = {
# "prev_key": prev_key,
# "prev_value": prev_value,
# }
# if len(prev_self_attn_state) >= 3:
# saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
# assert incremental_state is not None
# self.self_attn._set_input_buffer(incremental_state, saved_state)
# _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
# if self.cross_self_attention and not (
# incremental_state is not None
# and _self_attn_input_buffer is not None
# and "prev_key" in _self_attn_input_buffer
# ):
# if self_attn_mask is not None:
# assert encoder_out is not None
# self_attn_mask = torch.cat(
# (x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1
# )
# if self_attn_padding_mask is not None:
# if encoder_padding_mask is None:
# assert encoder_out is not None
# encoder_padding_mask = self_attn_padding_mask.new_zeros(
# encoder_out.size(1), encoder_out.size(0)
# )
# self_attn_padding_mask = torch.cat(
# (encoder_padding_mask, self_attn_padding_mask), dim=1
# )
# assert encoder_out is not None
# y = torch.cat((encoder_out, x), dim=0)
# else:
y = x
x, attn = self.run_self_attn(
query=x,
key=y,
value=y,
key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state,
need_weights=False,
attn_mask=self_attn_mask,
)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
assert self.encoder_attn is None and encoder_out is None
# if self.encoder_attn is not None and encoder_out is not None:
# residual = x
# if self.normalize_before:
# x = self.encoder_attn_layer_norm(x)
# if prev_attn_state is not None:
# prev_key, prev_value = prev_attn_state[:2]
# saved_state: Dict[str, Optional[Tensor]] = {
# "prev_key": prev_key,
# "prev_value": prev_value,
# }
# if len(prev_attn_state) >= 3:
# saved_state["prev_key_padding_mask"] = prev_attn_state[2]
# assert incremental_state is not None
# self.encoder_attn._set_input_buffer(incremental_state, saved_state)
#
# x, attn = self.encoder_attn(
# query=x,
# key=encoder_out,
# value=encoder_out,
# key_padding_mask=encoder_padding_mask,
# incremental_state=incremental_state,
# static_kv=True,
# need_weights=need_attn or (not self.training and self.need_attn),
# need_head_weights=need_head_weights,
# )
# x = self.dropout_module(x)
# x = self.residual_connection(x, residual)
# if not self.normalize_before:
# x = self.encoder_attn_layer_norm(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.activation_dropout_module(x)
x = self.fc2(x)
x = self.dropout_module(x)
x = self.residual_connection(x, residual)
if not self.normalize_before:
x = self.final_layer_norm(x)
if self.onnx_trace and incremental_state is not None:
raise NotImplementedError
saved_state = self.self_attn._get_input_buffer(incremental_state)
assert saved_state is not None
if self_attn_padding_mask is not None:
self_attn_state = [
saved_state["prev_key"],
saved_state["prev_value"],
saved_state["prev_key_padding_mask"],
]
else:
self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]]
return x, attn, self_attn_state
return x, attn, None
def run_self_attn(self, query, key_padding_mask, incremental_state, need_weights, **kwargs):
if incremental_state is not None:
raise NotImplementedError
if need_weights:
raise NotImplementedError
# tgt_len, bsz, embed_dim = query.shape
# src_len = tgt_len
# num_heads = self.decoder_attention_heads
# head_dim = self.embed_dim // num_heads
# query = query.transpose(0, 1).reshape(bsz, tgt_len, num_heads, head_dim)
query = query.transpose(0, 1)
if key_padding_mask is not None:
key_padding_mask = ~key_padding_mask
r = self.self_attn(query, query, query, attn_mask=None, key_padding_mask=key_padding_mask)
r = r.transpose(0, 1)
return r, None
# class LinearTransformerEncoderLayer(TransformerEncoderLayer):

Просмотреть файл

@ -0,0 +1,101 @@
from fairseq.models.transformer_lm import TransformerLanguageModel, TransformerLanguageModelConfig, \
DEFAULT_MAX_TARGET_POSITIONS, transformer_lm_gpt, base_lm_architecture
from fairseq import options
from fairseq.modules import AdaptiveInput, CharacterTokenEmbedder
from fairseq.models import register_model, register_model_architecture
from .transformer import LinearTransformerDecoder
@register_model("linear_transformer_lm", dataclass=TransformerLanguageModelConfig)
class LinearTransformerLanguageModel(TransformerLanguageModel):
@classmethod
def add_args(cls, parser):
super().add_args(parser)
parser.add_argument('--gradient-checkpointing', type=lambda x: x.lower() == 'true', default=False)
parser.add_argument('--gradient-checkpointing-every-n-layer', type=int, default=1)
parser.add_argument('--gradient-checkpointing-layers',
type=lambda x: tuple([int(item) for item in x.split(',')]), default=None)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_lm_architecture(args)
if args.decoder_layers_to_keep:
args.decoder_layers = len(args.decoder_layers_to_keep.split(","))
if getattr(args, "max_target_positions", None) is None:
args.max_target_positions = getattr(
args, "tokens_per_sample", DEFAULT_MAX_TARGET_POSITIONS
)
if args.character_embeddings:
embed_tokens = CharacterTokenEmbedder(
task.source_dictionary,
eval(args.character_filters),
args.character_embedding_dim,
args.decoder_embed_dim,
args.char_embedder_highway_layers,
)
elif args.adaptive_input:
embed_tokens = AdaptiveInput(
len(task.source_dictionary),
task.source_dictionary.pad(),
args.decoder_input_dim,
args.adaptive_input_factor,
args.decoder_embed_dim,
options.eval_str_list(args.adaptive_input_cutoff, type=int),
args.quant_noise_pq,
args.quant_noise_pq_block_size,
)
else:
embed_tokens = cls.build_embedding(
args, task.source_dictionary, args.decoder_input_dim
)
if args.tie_adaptive_weights:
assert args.adaptive_input
assert args.adaptive_input_factor == args.adaptive_softmax_factor
assert (
args.adaptive_softmax_cutoff == args.adaptive_input_cutoff
), "{} != {}".format(
args.adaptive_softmax_cutoff, args.adaptive_input_cutoff
)
assert args.decoder_input_dim == args.decoder_output_dim
decoder = LinearTransformerDecoder(
args, task.target_dictionary, embed_tokens, no_encoder_attn=True
)
return cls(decoder)
@register_model_architecture("linear_transformer_lm", "linear_transformer_lm_gpt")
def linear_transformer_lm_gpt_architecture(args):
transformer_lm_gpt(args)
@register_model_architecture("linear_transformer_lm", "linear_transformer_lm_std")
def std_linear_transformer_lm_architecture(args):
args.command_embed_dim = 16
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
args.activation_fn = getattr(args, "activation_fn", "gelu")
base_lm_architecture(args)
@register_model_architecture("linear_transformer_lm", "linear_transformer_lm_debug")
def std_linear_transformer_lm_architecture(args):
args.command_embed_dim = 16
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 32)
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 32)
args.decoder_layers = getattr(args, "decoder_layers", 2)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.dropout = getattr(args, "dropout", 0.1)
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
args.activation_fn = getattr(args, "activation_fn", "gelu")
base_lm_architecture(args)

90
emogen/readMe.md Normal file
Просмотреть файл

@ -0,0 +1,90 @@
# EmoGen
EmoGen: Eliminating Subjective Bias in Emotional Music Generation, by Chenfei Kang, Peiling Lu, Botao Yu, Xu Tan, Wei Ye, Shikun Zhang, Jiang Bian, is an emotional music generation system that leverages a set of emotion-related music attributes as the bridge between emotion and music, and divides the generation into two stages: emotion-to-attribute mapping with supervised clustering, and attribute-to-music generation with self-supervised learning. Both stages are beneficial: in the first stage, the attribute values around the clustering center represent the general emotions of these samples, which help eliminate the impacts of the subjective bias of emotion labels; in the second stage, the generation is completely disentangled from emotion labels and thus free from the subjective bias. Both subjective and objective evaluations show that EMOGEN outperforms previous methods on emotion control accuracy and music quality respectively, which demonstrate our superiority in generating emotional music.
demo: [link](https://emo-gen.github.io/)
The following content includes the steps for EMOGEN training and inference.
### 1. Environment
- Hardware environment: We recommend Nvidia V100 16GB/32GB.
- Software environment:
Please make sure you have `python 3.8` installed. Run the following command to install necessary packages:
```sh
bash setup.sh
```
Also, please follow the instructions in [Installation](https://jmir.sourceforge.net/manuals/jSymbolic_manual/installation_files/installation.html) to install `Java`.
### 2. Dataset
We use three datasets: One emotion-labeled dataset namely EMOPIA([link](https://annahung31.github.io/EMOPIA/)) and two unlabeled datasets namely Pop1k7 ([link](https://github.com/YatingMusic/compound-word-transformer/blob/main/dataset/Dataset.md)) and LMD-Piano, where LMD-Piano is con-
structed by using the samples that only contain piano tracks from the Lakh MIDI (LMD) dataset ([link](https://colinraffel.com/projects/lmd/)). To evaluate EMOGENs ability to generate emotional music on the arbitrary dataset, we also train EmoGen on TopMAGD([link](http://www.ifs.tuwien.ac.at/mir/msd/download.html)), which is a multi-instrument dataset.
In principle, you can use any MIDI dataset to train the attribute-to-music generation model. Taking the 'Piano' dataset under the 'data/' folder as an example. Please follow the steps below to process the data.
1. First, put all MIDI files in 'data/Piano/midi'.
2. Run the following command to encode MIDI files.
```shell
cd data_process
python midi_encoding.py
```
3. Run the following command to extract attributes by jSymbolic.
You are required to download the package jSymbolic_2_2_user.zip from https://sourceforge.net/projects/jmir/files/jSymbolic/jSymbolic%202.2/, and extract it into ./jSymbolic_lib.
```shell
cd ../jSymbolic_lib
python jSymbolic_feature.py
```
4. Run the following script to prepare train/validation/test dataset.
```shell
cd ../data_process
python gen_data.py
```
### 3.Train
- Emotion-to-attribute mapping
In this stage, we map four emotion quadrants in Russel's 4Q model to four different attributes. We first compute attribute centers in four quadrants on the EMOPIA dataset and selected the closest attribute vector to the center in each quadrant as the mapping result. The mapped results are stored in `data/infer_input/inference_command.npy`. The emotion quadrants corresponding to these attribute vectors are shown in the following table:
| Index | Emotion |
| -------------------- | ------- |
| inference_command[0] | Q1 |
| inference_command[1] | Q2 |
| inference_command[2] | Q3 |
| inference_command[3] | Q4 |
- Attribute-to-music generation
Run the following command to train a 6-layer Linear Transformer model on the dataset `data/Piano`:
```shell
bash Piano_train.sh
```
### 4. Inference
Please put the checkpoints under the folder `checkpoints/`. To generate piano songs, run the following command:
```shell
# usage: bash Piano_gen.sh target_emotion
# 1-Q1 2-Q2 3-Q3 4-Q4
bash Piano_gen.sh 3 # generate songs with emotion "Q3"
```
Also, to generate multi-instrument songs, please run the following command:
```shell
bash TopMAGD_gen.sh 1 # generate songs with emotion "Q1"
```

11
emogen/setup.sh Normal file
Просмотреть файл

@ -0,0 +1,11 @@
pip install fairseq==0.10.2
pip install miditoolkit
pip install matplotlib
pip install pytorch-fast-transformers
pip install chorder
pip install plotext
pip install scikit-learn
pip install scipy
git clone https://github.com/btyu/MidiProcessor.git
cd MidiProcessor
pip install .

554
emogen/train.py Normal file
Просмотреть файл

@ -0,0 +1,554 @@
#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Train a new model on one or across multiple GPUs.
"""
import argparse
import logging
import math
import os
import random
import sys
import numpy as np
import torch
from fairseq import (
checkpoint_utils,
distributed_utils,
options,
quantization_utils,
tasks,
utils,
)
from fairseq.data import iterators
from fairseq.logging import meters, metrics, progress_bar
from fairseq.model_parallel.megatron_trainer import MegatronTrainer
from fairseq.trainer import Trainer
import linear_decoder.controlled_task
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO").upper(),
stream=sys.stdout,
)
logger = logging.getLogger("fairseq_cli.train")
def main(args):
utils.import_user_module(args)
assert (
args.max_tokens is not None or args.batch_size is not None
), "Must specify batch size either with --max-tokens or --batch-size"
metrics.reset()
np.random.seed(args.seed)
utils.set_torch_seed(args.seed)
if distributed_utils.is_master(args):
checkpoint_utils.verify_checkpoint_directory(args.save_dir)
# Print args
logger.info(args)
# Setup task, e.g., translation_control, language modeling, etc.
task = tasks.setup_task(args)
# Load valid dataset (we load training data below, based on the latest checkpoint)
for valid_sub_split in args.valid_subset.split(","):
task.load_dataset(valid_sub_split, combine=False, epoch=1)
model = task.build_model(args)
criterion = task.build_criterion(args)
logger.info(model)
logger.info("task: {} ({})".format(args.task, task.__class__.__name__))
logger.info("model: {} ({})".format(args.arch, model.__class__.__name__))
logger.info(
"criterion: {} ({})".format(args.criterion, criterion.__class__.__name__)
)
logger.info(
"num. model params: {} (num. trained: {})".format(
sum(p.numel() for p in model.parameters()),
sum(p.numel() for p in model.parameters() if p.requires_grad),
)
)
# (optionally) Configure quantization
if args.quantization_config_path is not None:
quantizer = quantization_utils.Quantizer(
config_path=args.quantization_config_path,
max_epoch=args.max_epoch,
max_update=args.max_update,
)
else:
quantizer = None
# Build trainer
if args.model_parallel_size == 1:
trainer = Trainer(args, task, model, criterion, quantizer)
else:
trainer = MegatronTrainer(args, task, model, criterion)
logger.info(
"training on {} devices (GPUs/TPUs)".format(args.distributed_world_size)
)
logger.info(
"max tokens per GPU = {} and max sentences per GPU = {}".format(
args.max_tokens, args.batch_size
)
)
# Load the latest checkpoint if one is available and restore the
# corresponding train iterator
extra_state, epoch_itr = checkpoint_utils.load_checkpoint(
args,
trainer,
# don't cache epoch iterators for sharded datasets
disable_iterator_cache=task.has_sharded_data("train"),
)
# Train until the learning rate gets too small
max_epoch = args.max_epoch or math.inf
lr = trainer.get_lr()
train_meter = meters.StopwatchMeter()
train_meter.start()
while lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch:
# train for one epoch
valid_losses, should_stop = train(args, trainer, task, epoch_itr)
if should_stop:
break
# only use first validation loss to update the learning rate
lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])
epoch_itr = trainer.get_train_iterator(
epoch_itr.next_epoch_idx,
# sharded data: get train iterator for next epoch
load_dataset=task.has_sharded_data("train"),
# don't cache epoch iterators for sharded datasets
disable_iterator_cache=task.has_sharded_data("train"),
)
train_meter.stop()
logger.info("done training in {:.1f} seconds".format(train_meter.sum))
def should_stop_early(args, valid_loss):
# skip check if no validation was done in the current epoch
if valid_loss is None:
return False
if args.patience <= 0:
return False
def is_better(a, b):
return a > b if args.maximize_best_checkpoint_metric else a < b
prev_best = getattr(should_stop_early, "best", None)
if prev_best is None or is_better(valid_loss, prev_best):
should_stop_early.best = valid_loss
should_stop_early.num_runs = 0
return False
else:
should_stop_early.num_runs += 1
if should_stop_early.num_runs >= args.patience:
logger.info(
"early stop since valid performance hasn't improved for last {} runs".format(
args.patience
)
)
return True
else:
return False
@metrics.aggregate("train")
def train(args, trainer, task, epoch_itr):
"""Train the model for one epoch and return validation losses."""
# Initialize data iterator
itr = epoch_itr.next_epoch_itr(
fix_batches_to_gpus=args.fix_batches_to_gpus,
shuffle=(epoch_itr.next_epoch_idx > args.curriculum),
)
update_freq = (
args.update_freq[epoch_itr.epoch - 1]
if epoch_itr.epoch <= len(args.update_freq)
else args.update_freq[-1]
)
itr = iterators.GroupedIterator(itr, update_freq)
if getattr(args, "tpu", False):
itr = utils.tpu_data_loader(itr)
progress = progress_bar.progress_bar(
itr,
log_format=args.log_format,
log_interval=args.log_interval,
epoch=epoch_itr.epoch,
tensorboard_logdir=(
args.tensorboard_logdir if distributed_utils.is_master(args) else None
),
default_log_format=("tqdm" if not args.no_progress_bar else "simple"),
)
trainer.begin_epoch(epoch_itr.epoch)
valid_losses = [None]
valid_subsets = args.valid_subset.split(",")
should_stop = False
num_updates = trainer.get_num_updates()
for i, samples in enumerate(progress):
with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function(
"train_step-%d" % i
):
log_output = trainer.train_step(samples)
if log_output is not None: # not OOM, overflow, ...
# log mid-epoch stats
num_updates = trainer.get_num_updates()
if num_updates % args.log_interval == 0:
stats = get_training_stats(metrics.get_smoothed_values("train_inner"))
progress.log(stats, tag="train_inner", step=num_updates)
# reset mid-epoch stats after each log interval
# the end-of-epoch stats will still be preserved
metrics.reset_meters("train_inner")
end_of_epoch = not itr.has_next()
valid_losses, should_stop = validate_and_save(
args, trainer, task, epoch_itr, valid_subsets, end_of_epoch
)
if should_stop:
break
# log end-of-epoch stats
logger.info("end of epoch {} (average epoch stats below)".format(epoch_itr.epoch))
stats = get_training_stats(metrics.get_smoothed_values("train"))
progress.print(stats, tag="train", step=num_updates)
# reset epoch-level meters
metrics.reset_meters("train")
return valid_losses, should_stop
def validate_and_save(args, trainer, task, epoch_itr, valid_subsets, end_of_epoch):
num_updates = trainer.get_num_updates()
max_update = args.max_update or math.inf
do_save = (
(end_of_epoch and epoch_itr.epoch % args.save_interval == 0)
or num_updates >= max_update
or (
args.save_interval_updates > 0
and num_updates > 0
and num_updates % args.save_interval_updates == 0
and num_updates >= args.validate_after_updates
)
)
do_validate = (
(not end_of_epoch and do_save) # validate during mid-epoch saves
or (end_of_epoch and epoch_itr.epoch % args.validate_interval == 0)
or num_updates >= max_update
or (
args.validate_interval_updates > 0
and num_updates > 0
and num_updates % args.validate_interval_updates == 0
)
) and not args.disable_validation
# Validate
valid_losses = [None]
if do_validate:
valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
# Stopping conditions
should_stop = (
should_stop_early(args, valid_losses[0])
or num_updates >= max_update
or (
args.stop_time_hours > 0
and trainer.cumulative_training_time() / (60 * 60) > args.stop_time_hours
)
)
# Save checkpoint
if do_save or should_stop:
logger.info("begin save checkpoint")
checkpoint_utils.save_checkpoint(args, trainer, epoch_itr, valid_losses[0])
return valid_losses, should_stop
def get_training_stats(stats):
stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0)
return stats
def validate(args, trainer, task, epoch_itr, subsets):
"""Evaluate the model on the validation set(s) and return the losses."""
if args.fixed_validation_seed is not None:
# set fixed seed for every validation
utils.set_torch_seed(args.fixed_validation_seed)
trainer.begin_valid_epoch(epoch_itr.epoch)
valid_losses = []
for subset in subsets:
logger.info('begin validation on "{}" subset'.format(subset))
# Initialize data iterator
itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False)
if getattr(args, "tpu", False):
itr = utils.tpu_data_loader(itr)
progress = progress_bar.progress_bar(
itr,
log_format=args.log_format,
log_interval=args.log_interval,
epoch=epoch_itr.epoch,
prefix=f"valid on '{subset}' subset",
tensorboard_logdir=(
args.tensorboard_logdir if distributed_utils.is_master(args) else None
),
default_log_format=("tqdm" if not args.no_progress_bar else "simple"),
)
# create a new root metrics aggregator so validation metrics
# don't pollute other aggregators (e.g., train meters)
with metrics.aggregate(new_root=True) as agg:
for sample in progress:
trainer.valid_step(sample)
# log validation stats
stats = get_valid_stats(args, trainer, agg.get_smoothed_values())
progress.print(stats, tag=subset, step=trainer.get_num_updates())
valid_losses.append(stats[args.best_checkpoint_metric])
return valid_losses
def get_valid_stats(args, trainer, stats):
stats["num_updates"] = trainer.get_num_updates()
if hasattr(checkpoint_utils.save_checkpoint, "best"):
key = "best_{0}".format(args.best_checkpoint_metric)
best_function = max if args.maximize_best_checkpoint_metric else min
stats[key] = best_function(
checkpoint_utils.save_checkpoint.best, stats[args.best_checkpoint_metric]
)
return stats
def cli_main(modify_parser=None):
parser = options.get_training_parser()
args = options.parse_args_and_arch(parser, modify_parser=modify_parser)
if args.profile:
with torch.cuda.profiler.profile():
with torch.autograd.profiler.emit_nvtx():
distributed_utils.call_main(args, main)
else:
distributed_utils.call_main(args, main)
if __name__ == "__main__":
cli_main()
'''
../StyleCtrlData/train_data/1026_EMO/data-bin
--task
language_modeling_control
--arch
style_model_std_v2
--control_mode
embedding_v2_bos
--command_path
../StyleCtrlData/train_data/1026_EMO/bucket2_command_thres_EMO/emotion_reselect_29
--feature_num
29
--bucket_num
2
--mask_prob
-1
--sample-break-mode
eos
--tokens-per-sample
10000000
--max-tokens
10000000
--batch-size
2
--batch-size-valid
1
--update-freq
1
--optimizer
adam
--adam-betas
(0.9,0.98)
--adam-eps
1e-9
--weight-decay
0.01
--lr
0.0001
--lr-scheduler
inverse_sqrt
--warmup-updates
1000
--log-format
simple
--log-interval
10
--tensorboard-logdir
tb_log/controlled-1
--num-workers
0
--max-update
600000
--validate-interval
1000000000
--save-interval-updates
5000
--save-dir
checkpoints/controlled
--no-epoch-checkpoints
--cpu
'''
# 39169531
'''
../StyleCtrlData/train_data/1024_EMO/bucket2_attributes29/data-bin
--task
language_modeling_control
--arch
style_model_debug
--control_mode
prefix_token
--command_path
None
--feature_num
29
--bucket_num
2
--mask_prob
-1
--add_emotion_token
0
--sample-break-mode
eos
--tokens-per-sample
10000000
--max-tokens
10000000
--batch-size
2
--batch-size-valid
1
--update-freq
1
--optimizer
adam
--adam-betas
(0.9,0.98)
--adam-eps
1e-9
--weight-decay
0.01
--lr
0.0001
--lr-scheduler
inverse_sqrt
--warmup-updates
1000
--log-format
simple
--log-interval
10
--tensorboard-logdir
tb_log/controlled-1
--num-workers
0
--max-update
600000
--validate-interval
1000000000
--save-interval-updates
5000
--save-dir
checkpoints/controlled
--no-epoch-checkpoints
--cpu
'''
'''
../StyleCtrlData/train_data/1026_EMO/data-bin
--task
language_modeling_control
--arch
linear_transformer_lm_debug
--control_mode
embedding_v2
--command_path
../StyleCtrlData/train_data/1026_EMO/bucket2_command_thres_EMO/emotion_rank_100
--feature_num
100
--bucket_num
2
--mask_prob
-1
--add_emotion_token
0
--sample-break-mode
eos
--tokens-per-sample
10000000
--max-tokens
10000000
--batch-size
2
--batch-size-valid
1
--update-freq
1
--optimizer
adam
--adam-betas
(0.9,0.98)
--adam-eps
1e-9
--weight-decay
0.01
--lr
0.0001
--lr-scheduler
inverse_sqrt
--warmup-updates
1000
--log-format
simple
--log-interval
10
--tensorboard-logdir
tb_log/controlled-1
--num-workers
1
--max-update
20
--validate-interval
1000000000
--save-interval-updates
10
--save-dir
checkpoints/controlled
--patience
10
--no-epoch-checkpoints
'''