зеркало из https://github.com/microsoft/muzic.git
163 строки
5.0 KiB
Python
163 строки
5.0 KiB
Python
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT License.
|
|
#
|
|
|
|
import numpy as np
|
|
import os
|
|
import sys
|
|
import miditoolkit
|
|
|
|
ALIGN = '[align]'
|
|
SEP = '[sep]'
|
|
Duration_vocab = dict([(129+i, x/100)
|
|
for i, x in enumerate(list(range(25, 3325, 25)))])
|
|
Duration_vocab_re = dict([(x/100, 129+i)
|
|
for i, x in enumerate(list(range(25, 3325, 25)))])
|
|
MAX_DUR = int(max(Duration_vocab.keys()))
|
|
REST = 128
|
|
min_pitch = 60
|
|
max_pitch = 72
|
|
|
|
|
|
def get_pitch_count(x):
|
|
cnt = [0] * (128 + 1)
|
|
for pitch in x:
|
|
if pitch != 128:
|
|
if pitch <= min_pitch:
|
|
pitch = min_pitch + pitch % 12
|
|
elif pitch >= max_pitch + 12:
|
|
pitch = max_pitch + pitch % 12
|
|
cnt[pitch] += 1
|
|
return np.array(cnt)
|
|
|
|
|
|
def get_dur_count(x):
|
|
cnt = [0 for i, x in enumerate(list(range(25, 3325, 25)))]
|
|
min_dur = Duration_vocab[min(x)]
|
|
for dur in x:
|
|
cur_dur = Duration_vocab[dur]
|
|
cur_idx = int(cur_dur / min_dur)
|
|
if cur_idx < len(cnt):
|
|
cnt[cur_idx] += 1
|
|
else:
|
|
cnt[-1] += 1
|
|
return np.array(cnt)
|
|
|
|
|
|
def separate(string, use_word=True):
|
|
if not use_word:
|
|
tmp = [i for i in string.strip().split() if i not in [SEP, ALIGN]]
|
|
pitch = []
|
|
dur = []
|
|
for i in range(0, len(tmp), 2):
|
|
pitch.append(int(tmp[i]))
|
|
dur.append(int(tmp[i+1]))
|
|
return pitch, dur
|
|
else:
|
|
tmp = [i for i in string.strip().split() if i not in [SEP]]
|
|
p = []
|
|
d = []
|
|
is_pitch = True
|
|
cur_dur = 0
|
|
for idx in range(len(tmp)):
|
|
if tmp[idx] != ALIGN:
|
|
if is_pitch:
|
|
p.append(int(tmp[idx]))
|
|
else:
|
|
cur_dur += Duration_vocab[int(tmp[idx])]
|
|
is_pitch = not is_pitch
|
|
else:
|
|
if cur_dur in Duration_vocab_re:
|
|
d.append(Duration_vocab_re[cur_dur])
|
|
cur_dur = 0
|
|
if cur_dur > 0 and cur_dur in Duration_vocab_re:
|
|
d.append(Duration_vocab_re[cur_dur])
|
|
return p, d
|
|
|
|
|
|
def get_melody(mid_obj):
|
|
def get_vocab(dur):
|
|
beat = dur // (ticks // 4)
|
|
if beat == 0:
|
|
beat = 1
|
|
beat = beat / 4
|
|
return str(Duration_vocab_re.get(beat, MAX_DUR))
|
|
|
|
lyrics = mid_obj.lyrics
|
|
notes = mid_obj.instruments[0].notes
|
|
last_end = notes[0].start
|
|
cur_melody = []
|
|
ticks = mid_obj.ticks_per_beat
|
|
max_lyric = len(lyrics)
|
|
lyric_id = 1
|
|
for note in notes:
|
|
if lyric_id < max_lyric and lyrics[lyric_id].time == note.start:
|
|
cur_melody.append(ALIGN)
|
|
lyric_id += 1
|
|
if note.start != last_end:
|
|
cur_melody.append(str(128))
|
|
cur_melody.append(get_vocab(note.start - last_end))
|
|
cur_melody.append(str(note.pitch))
|
|
cur_melody.append(get_vocab(note.end - note.start))
|
|
last_end = note.end
|
|
|
|
cur_melody.append(ALIGN)
|
|
cur_melody.append(SEP)
|
|
new_meldoy = []
|
|
cur_pitch = []
|
|
for i in cur_melody:
|
|
if i == SEP:
|
|
new_meldoy.append(i)
|
|
elif i != ALIGN:
|
|
new_meldoy.append(i)
|
|
if int(i) < 128:
|
|
cur_pitch.append(i)
|
|
elif i == ALIGN:
|
|
if len(cur_pitch):
|
|
new_meldoy.append(i)
|
|
cur_pitch = []
|
|
cur_melody = new_meldoy
|
|
|
|
return ' '.join(cur_melody)
|
|
|
|
|
|
def cal_overlap(gt_d, hyp_d):
|
|
sum_gt = np.sum(gt_d) if np.sum(gt_d) > 0 else 1
|
|
sum_hyp = np.sum(hyp_d) if np.sum(hyp_d) > 0 else 1
|
|
gt_d = gt_d.astype(np.float32) / sum_gt
|
|
hyp_d = hyp_d.astype(np.float32) / sum_hyp
|
|
diff = np.abs(gt_d - hyp_d)
|
|
overlap = (gt_d + hyp_d - diff) / 2
|
|
return np.sum(overlap)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
assert len(sys.argv) == 1 + 2
|
|
gt_prefix = sys.argv[1]
|
|
|
|
hyp_prefix = sys.argv[2]
|
|
|
|
pitch_overlap = []
|
|
dur_overlap = []
|
|
print(f'hyp: {hyp_prefix} gt: {gt_prefix}')
|
|
print(os.listdir(f'{hyp_prefix}/'))
|
|
for filename in os.listdir(f'{hyp_prefix}/'):
|
|
try:
|
|
hyp_midi = miditoolkit.MidiFile(f'{hyp_prefix}/{filename}')
|
|
hyp = get_melody(hyp_midi)
|
|
gt_midi = miditoolkit.MidiFile(f'{gt_prefix}/{filename}')
|
|
gt = get_melody(gt_midi)
|
|
hyp_pitch, hyp_dur = separate(hyp)
|
|
gt_pitch, gt_dur = separate(gt)
|
|
pitch_overlap.append(cal_overlap(
|
|
get_pitch_count(gt_pitch), get_pitch_count(hyp_pitch)))
|
|
dur_overlap.append(cal_overlap(
|
|
get_dur_count(gt_dur), get_dur_count(hyp_dur)))
|
|
except BaseException as e:
|
|
continue
|
|
pitch_overlap = np.array(pitch_overlap)
|
|
dur_overlap = np.array(dur_overlap)
|
|
print('PITCH', np.mean(pitch_overlap), np.std(pitch_overlap))
|
|
print('DUR', np.mean(dur_overlap), np.std(dur_overlap))
|
|
print()
|