[MuseCoco] Update data processing code

This commit is contained in:
btyu 2023-09-21 09:50:27 -04:00
Родитель d967a51a3f
Коммит 5b3890d2bb
51 изменённых файлов: 5220 добавлений и 11 удалений

Просмотреть файл

@ -1,2 +1 @@
midi_data_extractor_path = 'MidiDataExtractor' # path to MidiDataExtractor
attribute_list = ['I1s2', 'R1', 'R3', 'S2s1', 'S4', 'B1s1', 'TS1s1', 'K1', 'T1s1', 'P4', 'EM1', 'TM1']

Просмотреть файл

@ -1,14 +1,12 @@
import argparse
import os
import sys
from tqdm.auto import tqdm
import msgpack
import json
from file_list import generate_file_list
from config import midi_data_extractor_path, attribute_list
from config import attribute_list
sys.path.append(os.path.abspath(midi_data_extractor_path))
import midi_data_extractor as mde
# import midiprocessor as mp

Просмотреть файл

@ -0,0 +1,2 @@
from .data_extractor import DataExtractor
from . import attribute_unit

Просмотреть файл

@ -0,0 +1,45 @@
import importlib
from .unit_base import UnitBase
def load_unit_class(attribute_label):
unit_file_label = []
for letter in attribute_label:
if letter in '0123456789':
break
else:
unit_file_label.append(letter.lower())
unit_file_label = ''.join(unit_file_label)
module = importlib.import_module(
'.attribute_unit.unit_%s' % unit_file_label, package='midi_data_extractor'
)
unit_cls = getattr(module, 'Unit%s' % attribute_label)
return unit_cls
def load_raw_unit_class(raw_attribute_label):
unit_file_label = []
for letter in raw_attribute_label:
if letter in '0123456789':
break
else:
unit_file_label.append(letter.lower())
unit_file_label = ''.join(unit_file_label)
module = importlib.import_module(
'.attribute_unit.raw_unit_%s' % unit_file_label, package='midi_data_extractor'
)
unit_cls = getattr(module, 'RawUnit%s' % raw_attribute_label)
return unit_cls
def convert_value_into_unit(attribute_label, attribute_value, encoder=None):
unit_cls = load_unit_class(attribute_label)
unit = unit_cls(attribute_value, encoder=encoder)
return unit
def convert_value_dict_into_unit_dict(value_dict, encoder=None):
unit_dict = {}
for attr_label in value_dict:
unit_dict[attr_label] = convert_value_into_unit(attr_label, value_dict[attr_label], encoder=encoder)
return unit_dict

Просмотреть файл

@ -0,0 +1,18 @@
from .raw_unit_base import RawUnitBase
class RawUnitB1(RawUnitBase):
"""
抽取bar的个数
"""
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
intbar个数
"""
return bar_end - bar_begin

Просмотреть файл

@ -0,0 +1,54 @@
from abc import ABC
class RawUnitBase(ABC):
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
从函数输入的内容中获取该attribute的信息返回的信息应等于需要的信息或者是需要信息的超集
你重写的函数里面应该写清楚输出信息的格式和内容
:param encoder: mp.MidiEncoder实例
:param midi_dir: 数据集的全路径
:param midi_path: MIDI相对于数据集路径的相对路径
:param pos_info: pos_info对于每个小节开头位置都补齐了ts和tempo方便使用
:param bars_positions: dict小节在pos_info中的开始和结束位置
:param bars_chords: 小节序列的和弦信息每个小节给两个bar有可能为None此时对于此MIDI无法抽取chord信息
:param bars_insts: 每个小节所用到的instrument id列表每个item是set
:param bar_begin: 现在要抽取的信息的开始小节从0开始
:param bar_end: 现在要抽取的信息的结束小节不含
:param kwargs: 其他信息默认为空字典
:return:
"""
raise NotImplementedError
@classmethod
def repr_value(cls, value):
return value
@classmethod
def derepr_value(cls, rep_value):
return rep_value
class RawUnitForExistedValue(RawUnitBase):
@classmethod
def get_fields(cls):
raise NotImplementedError
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
fields = cls.get_fields()
if isinstance(fields, str):
fields = (fields,)
r = {}
for field in fields:
r[field] = kwargs[field] if field in kwargs else None
return r

Просмотреть файл

@ -0,0 +1,26 @@
from .raw_unit_base import RawUnitBase
class RawUnitC1(RawUnitBase):
"""
段落的chord序列每小节给两个chord
"""
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
- list段落的chord序列每小节给两个chord
当MIDI的和弦因为某些问题无法检测时返回None
"""
if bars_chords is None:
return None
num_bars = len(bars_positions)
assert num_bars * 2 == len(bars_chords)
seg_bars_chords = bars_chords[bar_begin * 2 : bar_end * 2]
return seg_bars_chords

Просмотреть файл

@ -0,0 +1,7 @@
from .raw_unit_base import RawUnitForExistedValue
class RawUnitEM1(RawUnitForExistedValue):
@classmethod
def get_fields(cls):
return 'emotion'

Просмотреть файл

@ -0,0 +1,50 @@
from .raw_unit_base import RawUnitBase
class RawUnitI1(RawUnitBase):
@classmethod
def extract(
cls, encoder, midi_dir, midi_path, pos_info, bars_positions, bars_chords, bars_insts,
bar_begin, bar_end, **kwargs
):
"""
抽取使用的乐器
:return:
- tuple使用到的乐器的ID无None
"""
insts = set()
for bar_insts in bars_insts[bar_begin: bar_end]:
for inst_id in bar_insts:
insts.add(inst_id)
insts = tuple(insts)
return insts
class RawUnitI2(RawUnitBase):
"""
- tuple, 前半段使用的乐器当bar数量为非正偶数的时候返回None
- tuple后半段使用的乐器当bar数量为非正偶数的时候返回None
"""
@classmethod
def extract(
cls, encoder, midi_dir, midi_path, pos_info, bars_positions, bars_chords, bars_insts,
bar_begin, bar_end, **kwargs
):
num_bars = bar_end - bar_begin
if num_bars <= 0 or num_bars % 2 == 1:
return None, None
left_insts = set()
right_insts = set()
for bar_insts in bars_insts[bar_begin: bar_begin + num_bars // 2]:
for inst_id in bar_insts:
left_insts.add(inst_id)
for bar_insts in bars_insts[bar_begin + num_bars // 2: bar_end]:
for inst_id in bar_insts:
right_insts.add(inst_id)
left_insts = tuple(left_insts)
right_insts = tuple(right_insts)
return left_insts, right_insts

Просмотреть файл

@ -0,0 +1,26 @@
from .raw_unit_base import RawUnitBase
class RawUnitK1(RawUnitBase):
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
大调或小调
:return:
- str: major为大调minor为小调可能为None表示不知道
"""
r = None
if 'is_major' in kwargs:
is_major = kwargs['is_major']
if is_major is True:
r = 'major'
elif is_major is False:
r = 'minor'
elif is_major is None:
r = None
else:
raise ValueError('is_major argument is set to a wrong value:', is_major)
return r

Просмотреть файл

@ -0,0 +1,22 @@
from .raw_unit_base import RawUnitBase
class RawUnitM1(RawUnitBase):
"""
各轨的SSM
"""
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
- dict: key是inst_id
value是dict, key为(i, j)表示bar i和bar j仅包含j < i的情况value为两bar之间的相似性
"""
ssm = kwargs['ssm']
r = {}
for inst_id in ssm:
r[inst_id] = ssm[bar_begin: bar_end, bar_begin: bar_end]
return r

Просмотреть файл

@ -0,0 +1,43 @@
from .raw_unit_base import RawUnitBase
from ..utils.data import convert_dict_key_to_str, convert_dict_key_to_int
class RawUnitN2(RawUnitBase):
"""
"""
@classmethod
def extract(
cls, encoder, midi_dir, midi_path, pos_info, bars_positions, bars_chords, bars_insts,
bar_begin, bar_end, **kwargs
):
"""
:return:
- dict, 各乐器的音符数量
"""
begin = bars_positions[bar_begin][0]
end = bars_positions[bar_end - 1][1]
num_note_record = {}
for idx in range(begin, end):
pos_item = pos_info[idx]
insts_notes = pos_item[4]
if insts_notes is None:
continue
for inst_id in insts_notes:
inst_notes = insts_notes[inst_id]
if inst_id not in num_note_record:
num_note_record[inst_id] = 0
num_note_record[inst_id] += len(inst_notes)
return num_note_record
@classmethod
def repr_value(cls, value):
return convert_dict_key_to_str(value)
@classmethod
def derepr_value(cls, rep_value):
return convert_dict_key_to_int(rep_value)

Просмотреть файл

@ -0,0 +1,110 @@
from .raw_unit_base import RawUnitBase
from ..utils.data import convert_dict_key_to_str, convert_dict_key_to_int
class RawUnitP1(RawUnitBase):
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
- int最低pitch不考虑鼓没有音符则返回None
"""
low = 1000
begin = bars_positions[bar_begin][0]
end = bars_positions[bar_end - 1][1]
no_notes = True
for idx in range(begin, end):
pos_item = pos_info[idx]
insts_notes = pos_item[4]
if insts_notes is None:
continue
for inst_id in insts_notes:
if inst_id >= 128:
continue
inst_notes = insts_notes[inst_id]
for pitch, _, _ in inst_notes:
low = min(low, pitch)
no_notes = False
if no_notes:
return None
return low
class RawUnitP2(RawUnitBase):
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
- int最高pitch不考虑鼓没有音符则返回None
"""
high = -1
begin = bars_positions[bar_begin][0]
end = bars_positions[bar_end - 1][1]
no_notes = True
for idx in range(begin, end):
pos_item = pos_info[idx]
insts_notes = pos_item[4]
if insts_notes is None:
continue
for inst_id in insts_notes:
if inst_id >= 128:
continue
inst_notes = insts_notes[inst_id]
for pitch, _, _ in inst_notes:
high = max(high, pitch)
no_notes = False
if no_notes:
return None
return high
class RawUnitP3(RawUnitBase):
@classmethod
def extract(
cls, encoder, midi_dir, midi_path, pos_info, bars_positions, bars_chords, bars_insts,
bar_begin, bar_end, **kwargs
):
"""
各乐器的总音高pitch之和不计算鼓
:return:
- dict: 各乐器的pitch之和, key为inst idvalue为pitch的和如果无除鼓外的其他乐器的音符则返回空dict
"""
begin = bars_positions[bar_begin][0]
end = bars_positions[bar_end - 1][1]
pitch_record = {}
for idx in range(begin, end):
pos_item = pos_info[idx]
insts_notes = pos_item[4]
if insts_notes is None:
continue
for inst_id in insts_notes:
if inst_id >= 128:
continue
inst_notes = insts_notes[inst_id]
if inst_id not in pitch_record:
pitch_record[inst_id] = 0
for pitch, _, _ in inst_notes:
pitch_record[inst_id] += pitch
return pitch_record
@classmethod
def repr_value(cls, value):
return convert_dict_key_to_str(value)
@classmethod
def derepr_value(cls, rep_value):
return convert_dict_key_to_int(rep_value)

Просмотреть файл

@ -0,0 +1,124 @@
from .raw_unit_base import RawUnitBase
from ..utils.data import convert_dict_key_to_str, convert_dict_key_with_eval
class RawUnitR1(RawUnitBase):
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
- int, 片段的onset总数无note则为0
"""
begin = bars_positions[bar_begin][0]
end = bars_positions[bar_end - 1][1]
num_onsets = 0
for pos_item in pos_info[begin: end]:
insts_notes = pos_item[-1]
if insts_notes is not None:
num_onsets += 1
return num_onsets
class RawUnitR2(RawUnitBase):
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
- float: 片段的beat总数
"""
pos_resolution = 12
assert pos_resolution == encoder.vm.pos_resolution, str(encoder.vm.pos_resolution)
begin = bars_positions[bar_begin][0]
end = bars_positions[bar_end - 1][1]
num_beats = (end - begin) / pos_resolution
return num_beats
class RawUnitR3(RawUnitBase):
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
- float: num_onsets / num_beats即note density
"""
num_onsets = RawUnitR1.extract(
encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
)
num_beats = RawUnitR2.extract(
encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
)
density = num_onsets / num_beats
return density
class RawUnitR4(RawUnitBase):
"""
鼓在各local pos上的音符数量
"""
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
- dict: key为TSvalue为listlist的长度为该TS情况下每个小节的pos数量list的每个元素为鼓在该pos上的音符数量
如果没有鼓value为None
"""
r = {}
r_has_drum = {}
for bar_idx in range(bar_begin, bar_end):
begin, end = bars_positions[bar_idx]
num_bar_pos = end - begin
ts = pos_info[begin][1]
assert ts is not None
if ts not in r:
r[ts] = [0] * num_bar_pos
r_has_drum[ts] = False
for pos_item in pos_info[begin: end]:
insts_notes = pos_item[-1]
if insts_notes is None:
continue
local_pos = pos_item[2]
for inst_id in insts_notes:
if inst_id != 128:
continue
inst_notes = insts_notes[inst_id]
num_notes = len(inst_notes)
r[ts][local_pos] += num_notes
r_has_drum[ts] = True
for ts in r_has_drum:
if not r_has_drum[ts]:
r[ts] = None
return r
@classmethod
def repr_value(cls, value):
return convert_dict_key_to_str(value)
@classmethod
def derepr_value(cls, rep_value):
return convert_dict_key_with_eval(rep_value)

Просмотреть файл

@ -0,0 +1,13 @@
from .raw_unit_base import RawUnitForExistedValue
class RawUnitS1(RawUnitForExistedValue):
@classmethod
def get_fields(cls):
return 'artist'
class RawUnitS2(RawUnitForExistedValue):
@classmethod
def get_fields(cls):
return 'genre'

Просмотреть файл

@ -0,0 +1,7 @@
from .raw_unit_base import RawUnitForExistedValue
class RawUnitST1(RawUnitForExistedValue):
@classmethod
def get_fields(cls):
return 'piece_structure'

Просмотреть файл

@ -0,0 +1,24 @@
from .raw_unit_base import RawUnitBase
class RawUnitT1(RawUnitBase):
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
- tuple of float: 所使用的所有tempo已去重
"""
tempo_set = set()
begin = bars_positions[bar_begin][0]
end = bars_positions[bar_end - 1][1]
assert pos_info[begin][3] is not None
for idx in range(begin, end):
tempo = pos_info[idx][3]
if tempo is None:
continue
tempo_set.add(tempo)
tempo_set = tuple(tempo_set)
return tempo_set

Просмотреть файл

@ -0,0 +1,40 @@
from .raw_unit_base import RawUnitBase
class RawUnitTM1(RawUnitBase):
"""
片段时长
"""
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
- float, 时长单位为分无None情况
"""
pos_resolution = 12
assert pos_resolution == encoder.vm.pos_resolution, str(encoder.vm.pos_resolution)
pos_dict = {}
begin = bars_positions[bar_begin][0]
end = bars_positions[bar_end - 1][1]
assert pos_info[begin][3] is not None
last_tempo = None
for idx in range(begin, end):
tempo = pos_info[idx][3]
if tempo is not None:
last_tempo = tempo
if last_tempo not in pos_dict:
pos_dict[last_tempo] = 0
pos_dict[last_tempo] += 1
time_second = 0
for tempo in pos_dict:
n = pos_dict[tempo] * 60 / pos_resolution / tempo
time_second += n
return time_second

Просмотреть файл

@ -0,0 +1,23 @@
from .raw_unit_base import RawUnitBase
class RawUnitTS1(RawUnitBase):
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
- list, 所用过的所有ts每个元素是一个元组例如(3, 4)
"""
ts_set = set()
begin = bars_positions[bar_begin][0]
end = bars_positions[bar_end - 1][1]
assert pos_info[begin][1] is not None
for idx in range(begin, end):
ts = pos_info[idx][1]
if ts is None:
continue
ts_set.add(ts)
return list(ts_set)

Просмотреть файл

@ -0,0 +1,81 @@
from .unit_base import UnitBase
from .raw_unit_b import RawUnitB1
class UnitB1(UnitBase):
"""
抽取bar的个数
"""
@property
def version(self) -> str:
return 'v1.0'
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
intbar个数
"""
return bar_end - bar_begin
def get_vector(self, use=True, use_info=None):
value = self.value
vector = [0] * 14
if not use:
vector[-1] = 1
return vector
vector[value - 4] = 1
return vector
@property
def vector_dim(self) -> int:
return 14
class UnitB1s1(UnitBase):
"""
抽取bar的个数
"""
@classmethod
def get_raw_unit_class(cls):
return RawUnitB1
@classmethod
def convert_raw_to_value(cls, raw_data, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs):
"""
:return:
- int: bar的个数
- int: bar个数区间的id01-415-829-12313-16
无None
"""
num_bars = raw_data['B1']
if not (0 < num_bars <= 16):
# raise NotImplementedError("The current implementation only supports 1~16 bars.")
return num_bars, -1
bar_id = cls.convert_num_bars_to_id(num_bars)
return num_bars, bar_id
@classmethod
def convert_num_bars_to_id(cls, num_bars):
return int(max(num_bars - 1, 0) / 4)
def get_vector(self, use=True, use_info=None):
# 顺序0 1 2 3 NA
_, bar_id = self.value
vector = [0] * self.vector_dim
if not use:
vector[-1] = 1
return vector
vector[bar_id] = 1
return vector
@property
def vector_dim(self) -> int:
return 5

Просмотреть файл

@ -0,0 +1,103 @@
from abc import ABC
from typing import Tuple, Union
from .raw_unit_base import RawUnitBase
class UnitBase(ABC):
def __init__(self, value, encoder=None):
self.value = value
self.encoder = encoder
@classmethod
def get_raw_unit_class(cls):
raise NotImplementedError
@classmethod
def new(cls, encoder, *args, **kwargs):
value = cls.extract(encoder, *args, **kwargs)
unit = cls(value, encoder=encoder)
return unit
@classmethod
def extract_raw(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
raw_units = cls.get_raw_unit_class()
if not isinstance(raw_units, tuple):
raw_units = (raw_units,)
raw_unit_class_dict = {}
for raw_unit_class in raw_units:
assert issubclass(raw_unit_class, RawUnitBase)
class_name = raw_unit_class.__name__
assert class_name.startswith('RawUnit')
label = class_name[7:]
raw_unit_class_dict[label] = raw_unit_class
raw_value_dict = {}
for label in raw_unit_class_dict:
raw_unit = raw_unit_class_dict[label]
raw_value_dict[label] = raw_unit.extract(
encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
)
return raw_value_dict
@classmethod
def convert_raw_to_value(
cls, raw_data, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
raise NotImplementedError
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
从函数输入的内容中获取该attribute的信息返回的信息应等于需要的信息或者是需要信息的超集
你重写的函数里面应该写清楚输出信息的格式和内容
:param encoder: mp.MidiEncoder实例
:param midi_dir: 数据集的全路径
:param midi_path: MIDI相对于数据集路径的相对路径
:param pos_info: pos_info对于每个小节开头位置都补齐了ts和tempo方便使用
:param bars_positions: dict小节在pos_info中的开始和结束位置
:param bars_chords: 小节序列的和弦信息每个小节给两个bar有可能为None此时对于此MIDI无法抽取chord信息
:param bars_insts: 每个小节所用到的instrument id列表每个item是set
:param bar_begin: 现在要抽取的信息的开始小节从0开始
:param bar_end: 现在要抽取的信息的结束小节不含
:param kwargs: 其他信息默认为空字典
:return:
"""
return cls.convert_raw_to_value(
cls.extract_raw(
encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
),
encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
)
@classmethod
def repr_value(cls, value):
return value
@classmethod
def derepr_value(cls, rep_value):
return rep_value
def get_vector(self, use=True, use_info=None) -> list:
"""
返回attribute list元素为int或float值长度为vector_dim
:return:
"""
raise NotImplementedError
@property
def vector_dim(self) -> Union[int, Tuple[int, int]]:
"""
vector的维度
"""
raise NotImplementedError

Просмотреть файл

@ -0,0 +1,93 @@
import math
from .unit_base import UnitBase
from .raw_unit_c import RawUnitC1
class UnitC1(UnitBase):
"""
chord的变化
"""
@classmethod
def get_raw_unit_class(cls):
return RawUnitC1
@classmethod
def is_bright(cls, chord):
if chord == 'N.C.':
return None
chord = chord.split(':')
assert len(chord) == 2
if chord[1] in ('', 'maj7', '7'):
return True
return False
@classmethod
def convert_raw_to_value(
cls, raw_data, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
- int取值为0-3表示一直明亮一直阴暗明亮变阴暗阴暗变明亮
取值可能为None表示此条没有chord相关信息或者变化的情况比较复杂不考虑了
"""
seg_bars_chords = raw_data['C1']
if seg_bars_chords is None:
return None
seg_brights = [cls.is_bright(item) for item in seg_bars_chords]
num_seg_bars = bar_end - bar_begin
num_all = len(seg_brights)
num_brights = 0
num_not_brights = 0
for item in seg_brights:
if item is True:
num_brights += 1
elif item is False:
num_not_brights += 1
if num_brights / num_all >= 0.875:
return 0
if num_not_brights / num_all >= 0.875:
return 1
break_points = set()
for idx in range(max(1, math.floor(num_seg_bars /4)), min(num_seg_bars - 1, math.ceil(num_seg_bars * 3 / 4))):
break_points.add(idx)
for bp in break_points:
num_left = bp * 2
num_right = num_all - num_left
num_left_bright, num_left_not_bright = 0, 0
num_right_bright, num_right_not_bright = 0, 0
for idx in range(num_left):
item = seg_brights[idx]
if item is True:
num_left_bright += 1
elif item is False:
num_left_not_bright += 1
for idx in range(bp * 2, num_all):
item = seg_brights[idx]
if item is True:
num_right_bright += 1
elif item is False:
num_right_not_bright += 1
if num_left_bright / num_left >= 0.875 and num_right_not_bright / num_right >= 0.875:
return 2
elif num_left_not_bright / num_left >= 0.875 and num_right_bright / num_right >= 0.875:
return 3
return None
def get_vector(self, use=True, use_info=None):
value = self.value
vector = [0] * self.vector_dim
if value is None or not use:
vector[-1] = 1
return vector
assert 0 <= value < 4
vector[value] = 1
return vector
@property
def vector_dim(self) -> int:
return 5

Просмотреть файл

@ -0,0 +1,49 @@
from .unit_base import UnitBase
import os
from .raw_unit_em import RawUnitEM1
def get_emotion_by_file_name_1(file_name):
assert file_name.endswith('.mid')
file_name = file_name[:-4]
r = file_name.split('_')[1]
return r
em1_funcs = {
'file_name_1': get_emotion_by_file_name_1,
}
class UnitEM1(UnitBase):
"""
所有的ts种类
"""
@classmethod
def get_raw_unit_class(cls):
return RawUnitEM1
@classmethod
def convert_raw_to_value(
cls, raw_data, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
emo_label = raw_data['EM1']
emo_label = emo_label['emotion']
return emo_label
def get_vector(self, use=True, use_info=None):
value = self.value
vector = [0] * 5
if value is None or not use:
vector[-1] = 1
return vector
emo_id = int(value[1]) - 1
assert 0 <= emo_id < 4
vector[emo_id] = 1
return vector
@property
def vector_dim(self) -> int:
return 5

Просмотреть файл

@ -0,0 +1,701 @@
from typing import Tuple, Union
from .unit_base import UnitBase
from ..const import inst_id_to_inst_class_id, inst_id_to_inst_class_id_2
from .raw_unit_i import RawUnitI1
from .raw_unit_p import RawUnitP3
from .raw_unit_n import RawUnitN2
class UnitI1(UnitBase):
"""
所用的乐器大类
"""
@property
def version(self) -> str:
return 'v1.0'
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
tuple包含所有乐器的大类id已去重若空则返回None
"""
all_insts = set()
for bar_insts in bars_insts[bar_begin: bar_end]:
all_insts = all_insts | bar_insts
all_inst_classes = []
# print(all_insts)
# print(bars_insts[bar_begin: bar_end])
for inst_id in all_insts:
all_inst_classes.append(inst_id_to_inst_class_id[inst_id])
if len(all_inst_classes) == 0:
return None
return tuple(set(all_inst_classes))
def get_vector(self, use=True, use_info=None):
vector = [0] * 17
if use_info is None:
value = self.value
else:
value = use_info
if not use or value is None:
return vector
for inst_class_id in value:
vector[inst_class_id] = 1
return vector
@property
def vector_dim(self) -> int:
return 17
class UnitI1s1(UnitBase):
"""
所用的乐器大类v2
"""
@property
def version(self) -> str:
return 'v1.0'
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
tuple包含所有乐器的大类id已去重若空则返回None
"""
all_insts = set()
for bar_insts in bars_insts[bar_begin: bar_end]:
all_insts = all_insts | bar_insts
all_inst_classes = []
# print(all_insts)
# print(bars_insts[bar_begin: bar_end])
for inst_id in all_insts:
all_inst_classes.append(inst_id_to_inst_class_id_2[inst_id])
if len(all_inst_classes) == 0:
return None
return tuple(set(all_inst_classes))
def get_vector(self, use=True, use_info=None):
vector = [0] * self.vector_dim
if use_info is None:
value = self.value
else:
value = use_info
if not use or value is None:
return vector
for inst_class_id in value:
vector[inst_class_id] = 1
return vector
@property
def vector_dim(self) -> int:
return 14
class UnitI1s2(UnitBase):
"""
所用的乐器大类v3
"""
inst_class_version = 'v3'
inst_id_to_inst_class_id = {
# piano 0:
0: 0,
1: 0,
2: 0,
3: 0,
4: 0,
5: 0,
# keyboard 1:
6: 1,
7: 1,
8: 1,
9: 1,
# percussion 2:
11: 2,
12: 2,
13: 2,
14: 2,
15: 2,
47: 2,
55: 2,
112: 2,
113: 2,
115: 2,
117: 2,
119: 2,
# organ 3:
16: 3,
17: 3,
18: 3,
19: 3,
20: 3,
21: 3,
22: 3,
23: 3,
# guitar 4:
24: 4,
25: 4,
26: 4,
27: 4,
28: 4,
29: 4,
30: 4,
31: 4,
# bass 5:
32: 5,
33: 5,
34: 5,
35: 5,
36: 5,
37: 5,
38: 5,
39: 5,
43: 5,
# violin 6:
40: 6,
# viola 7:
41: 7,
# cello 8:
42: 8,
# harp 9:
46: 9,
# strings 10:
44: 10,
45: 10,
48: 10,
49: 10,
50: 10,
51: 10,
# voice 11:
52: 11,
53: 11,
54: 11,
# trumpet 12:
56: 12,
59: 12,
# trombone 13:
57: 13,
# tuba 14:
58: 14,
# horn 15:
60: 15,
69: 15,
# brass 16:
61: 16,
62: 16,
63: 16,
# sax 17:
64: 17,
65: 17,
66: 17,
67: 17,
# oboe 18:
68: 18,
# bassoon 19:
70: 19,
# clarinet 20:
71: 20,
# piccolo 21:
72: 21,
# flute 22:
73: 22,
75: 22,
# pipe 23:
74: 23,
76: 23,
77: 23,
78: 23,
79: 23,
# synthesizer 24:
80: 24,
81: 24,
82: 24,
83: 24,
84: 24,
85: 24,
86: 24,
87: 24,
88: 24,
89: 24,
90: 24,
91: 24,
92: 24,
93: 24,
94: 24,
95: 24,
# ethnic instrument 25:
104: 25,
105: 25,
106: 25,
107: 25,
108: 25,
109: 25,
110: 25,
111: 25,
# sound effect 26:
10: 26,
120: 26,
121: 26,
122: 26,
123: 26,
124: 26,
125: 26,
126: 26,
127: 26,
96: 26,
97: 26,
98: 26,
99: 26,
100: 26,
101: 26,
102: 26,
103: 26,
# drum 27:
128: 27,
118: 27,
114: 27,
116: 27,
}
inst_class_id_to_inst_class_name = {
# piano 0:
0: 'piano',
# keyboard 1:
1: 'keyboard',
# percussion 2:
2: 'percussion',
# organ 3:
3: 'organ',
# guitar 4:
4: 'guitar',
# bass 5:
5: 'bass',
# violin 6:
6: 'violin',
# viola 7:
7: 'viola',
# cello 8:
8: 'cello',
# harp 9:
9: 'harp',
# strings 10:
10: 'strings',
# voice 11:
11: 'voice',
# trumpet 12:
12: 'trumpet',
# trombone 13:
13: 'trombone',
# tuba 14:
14: 'tuba',
# horn 15:
15: 'horn',
# brass 16:
16: 'brass',
# sax 17:
17: 'sax',
# oboe 18:
18: 'oboe',
# bassoon 19:
19: 'bassoon',
# clarinet 20:
20: 'clarinet',
# piccolo 21:
21: 'piccolo',
# flute 22:
22: 'flute',
# pipe 23:
23: 'pipe',
# synthesizer 24:
24: 'synthesizer',
# ethnic instrument 25:
25: 'ethnic instrument',
# sound effect 26:
26: 'sound effect',
# drum 27:
27: 'drum',
}
inst_class_name_to_inst_class_id = {}
for inst_class_id in inst_class_id_to_inst_class_name:
inst_class_name = inst_class_id_to_inst_class_name[inst_class_id]
inst_class_name_to_inst_class_id[inst_class_name] = inst_class_id
num_classes = len(inst_class_id_to_inst_class_name)
@classmethod
def convert_inst_id_to_inst_class_id(cls, inst_id):
return cls.inst_id_to_inst_class_id[inst_id]
@classmethod
def convert_inst_class_id_to_inst_class_name(cls, inst_class_id):
return cls.inst_class_id_to_inst_class_name[inst_class_id]
@classmethod
def convert_inst_class_name_to_inst_class_id(cls, inst_class_name):
return cls.inst_class_name_to_inst_class_id[inst_class_name]
@classmethod
def get_raw_unit_class(cls):
return RawUnitI1
@classmethod
def convert_raw_to_value(cls, raw_data, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs):
"""
:return:
- tuple包含所有乐器的大类id已去重若空则返回None
"""
r = raw_data['I1']
if len(r) == 0:
return None
nr = set()
for inst_id in r:
nr.add(cls.convert_inst_id_to_inst_class_id(inst_id))
nr = tuple(nr)
return nr
def get_vector(self, use=True, use_info=None) -> list:
"乐器个列表每个列表长度为3依次为是、否、NA"
value = self.value # tuple
vector = [[0, 0, 0] for _ in range(len(self.inst_class_id_to_inst_class_name))]
if not use:
for item in vector:
item[2] = 1
return vector
if use_info is not None:
used_insts, unused_insts = use_info
usedNone = True
unusedNone = True
if used_insts != None:
used_insts = set(used_insts)
usedNone = False
else:
used_insts = set()
if unused_insts != None:
unused_insts = set(unused_insts)
unusedNone = False
else:
unused_insts = set()
if unusedNone == False and usedNone == False:
assert len(used_insts & unused_insts) == 0
if usedNone==False:
for inst_class_id in used_insts:
vector[inst_class_id][0] = 1
if unusedNone == False:
for inst_class_id in unused_insts:
vector[inst_class_id][1] = 1
na_insts = set(range(len(self.inst_class_id_to_inst_class_name))) - used_insts - unused_insts
for inst_class_id in na_insts:
vector[inst_class_id][2] = 1
else:
if value is None:
value = tuple()
for inst_class_id in value:
vector[inst_class_id][0] = 1
na_insts = set(range(len(self.inst_class_id_to_inst_class_name))) - set(value)
for inst_class_id in na_insts:
vector[inst_class_id][2] = 1
return vector
@property
def vector_dim(self) -> Tuple[int, int]:
return len(self.inst_class_id_to_inst_class_name), 3
class UnitI2(UnitBase):
"""
乐器大类的增加或减少
"""
@property
def version(self) -> str:
return 'v1.0'
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
若返回一个东西则必为None表示没有可用的乐器增减情况此条样本可忽略
若返回四个东西
第一个值为'inc''dec'表示有乐器增或减
第二个值为set包含增的乐器class id
第三个值为set包含减的乐器class id
第四个值为int表示增减的bar的索引
"""
seg_bars_insts = bars_insts[bar_begin: bar_end]
temp = []
for bar_insts in seg_bars_insts:
temp_set = set()
for inst_id in bar_insts:
temp_set.add(inst_id_to_inst_class_id[inst_id])
temp.append(temp_set)
seg_bars_insts = temp
last_insts = []
change_point = None
for idx, bar_insts in enumerate(seg_bars_insts):
if len(last_insts) == 0:
last_insts.append(bar_insts)
else:
if last_insts[-1] != bar_insts:
last_insts.append(bar_insts)
change_point = idx
if len(last_insts) > 2:
return None
if len(last_insts) != 2:
return None
increased_insts = tuple(last_insts[1] - last_insts[0])
decreased_insts = tuple(last_insts[0] - last_insts[1])
if len(increased_insts) > 0 and len(decreased_insts) == 0:
return 'inc', increased_insts, None, change_point
elif len(increased_insts) == 0 and len(decreased_insts) > 0:
return 'dec', None, decreased_insts, change_point
else:
return None
def get_vector(self, use=True, use_info=None):
value = self.value
vector = [0] * 34
if value is None or not use:
return vector
change_type, inc_insts, dec_insts, change_point = value
offset = 0 if change_type == 'inc' else 17
change_insts = inc_insts if change_type == 'inc' else dec_insts
for inst_class_id in change_insts:
vector[inst_class_id + offset] = 1
return vector
@property
def vector_dim(self) -> int:
return 34
class UnitI2s1(UnitBase):
"""
乐器大类v2的增加或减少
"""
@property
def version(self) -> str:
return 'v1.0'
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
若返回一个东西则必为None表示没有可用的乐器增减情况此条样本可忽略
若返回四个东西
第一个值为'inc''dec'表示有乐器增或减
第二个值为set包含增的乐器class id
第三个值为set包含减的乐器class id
第四个值为int表示增减的bar的索引
"""
seg_bars_insts = bars_insts[bar_begin: bar_end]
temp = []
for bar_insts in seg_bars_insts:
temp_set = set()
for inst_id in bar_insts:
temp_set.add(inst_id_to_inst_class_id_2[inst_id])
temp.append(temp_set)
seg_bars_insts = temp
last_insts = []
change_point = None
for idx, bar_insts in enumerate(seg_bars_insts):
if len(last_insts) == 0:
last_insts.append(bar_insts)
else:
if last_insts[-1] != bar_insts:
last_insts.append(bar_insts)
change_point = idx
if len(last_insts) > 2:
return None
if len(last_insts) != 2:
return None
increased_insts = tuple(last_insts[1] - last_insts[0])
decreased_insts = tuple(last_insts[0] - last_insts[1])
if len(increased_insts) > 0 and len(decreased_insts) == 0:
return 'inc', increased_insts, None, change_point
elif len(increased_insts) == 0 and len(decreased_insts) > 0:
return 'dec', None, decreased_insts, change_point
else:
return None
def get_vector(self, use=True, use_info=None):
value = self.value
vector = [0] * self.vector_dim
if value is None or not use:
return vector
change_type, inc_insts, dec_insts, change_point = value
offset = 0 if change_type == 'inc' else 14
change_insts = inc_insts if change_type == 'inc' else dec_insts
for inst_class_id in change_insts:
vector[inst_class_id + offset] = 1
return vector
@property
def vector_dim(self) -> int:
return 28
# class UnitI3(UnitBase):
# """
# 前半段和后半段的乐器大类v3的变化
# """
#
# @classmethod
# def convert_raw_to_value(cls, raw_data):
# """
#
# :return:
# - tuple后半段相对于前半段增加的乐器的大类id已去重若空则返回None
# - tuple后半段相对于前半段减少的乐器的大类id已去重若空则返回None
# """
# pass
class UnitI4(UnitBase):
"""
演奏旋律的乐器大类v3
"""
inst_class_version = 'v3'
inst_id_to_inst_class_id = UnitI1s2.inst_id_to_inst_class_id
inst_class_id_to_inst_class_name = UnitI1s2.inst_class_id_to_inst_class_name
inst_class_name_to_inst_class_id = UnitI1s2.inst_class_name_to_inst_class_id
num_classes = UnitI1s2.num_classes
@classmethod
def convert_inst_id_to_inst_class_id(cls, inst_id):
return cls.inst_id_to_inst_class_id[inst_id]
@classmethod
def convert_inst_class_id_to_inst_class_name(cls, inst_class_id):
return cls.inst_class_id_to_inst_class_name[inst_class_id]
@classmethod
def convert_inst_class_name_to_inst_class_id(cls, inst_class_name):
return cls.inst_class_name_to_inst_class_id[inst_class_name]
@classmethod
def get_raw_unit_class(cls):
return RawUnitP3, RawUnitN2
@classmethod
def convert_raw_to_value(cls, raw_data, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs):
"""
:return:
- int: 演奏旋律的乐器大类v3id若无法检测到旋律乐器则返回None
- bool: 若只有一个非鼓原始类128乐器返回True否则返回False
"""
raw_p3 = raw_data['P3']
raw_n2 = raw_data['N2']
if len(raw_p3) == 0:
r = None
else:
avg_pitch_dict = {}
for inst_id in raw_p3:
avg_pitch_dict[inst_id] = raw_p3[inst_id] / raw_n2[inst_id]
sorted_list = sorted(avg_pitch_dict.items(), key=lambda x: x[0], reverse=True)
candidate_inst_id = sorted_list[0][0]
if raw_n2[candidate_inst_id] > 20:
r = candidate_inst_id
r = cls.convert_inst_id_to_inst_class_id(r)
else:
r = None
if len(raw_p3) == 1:
sin = True
else:
sin = False
return r, sin
def get_vector(self, use=True, use_info=None) -> list:
value = self.value
r, sin = value
vector = [0] * (self.num_classes + 1)
if not use or r is None:
vector[-1] = 1
else:
vector[r] = 1
return vector
def vector_dim(self) -> Union[int, Tuple[int, int]]:
return self.num_classes + 1

Просмотреть файл

@ -0,0 +1,56 @@
from .unit_base import UnitBase
from .raw_unit_k import RawUnitK1
class UnitK1(UnitBase):
"""
大调或小调
"""
@classmethod
def get_raw_unit_class(cls):
return RawUnitK1
@classmethod
def convert_raw_to_value(
cls, raw_data, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
return raw_data['K1']
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
- strmajor为大调minor为小调可能为None表示不知道
"""
is_major = None
if 'is_major' in kwargs:
is_major = kwargs['is_major']
if is_major is True:
return 'major'
elif is_major is False:
return 'minor'
else:
return None
def get_vector(self, use=True, use_info=None):
# 顺序major, minor, NA
value = self.value
vector = [0] * self.vector_dim
if not use or value is None:
vector[-1] = 1
return vector
if value == 'major':
vector[0] = 1
elif value == 'minor':
vector[1] = 1
else:
raise ValueError("The K1 value is \"%s\", which is abnormal." % str(value))
return vector
@property
def vector_dim(self) -> int:
return 3

Просмотреть файл

@ -0,0 +1,51 @@
from .unit_base import UnitBase
class UnitM1(UnitBase):
"""
track repetition
"""
@classmethod
def get_raw_unit_class(cls):
raise NotImplementedError
@classmethod
def convert_raw_to_value(cls, raw_data):
"""
:return:
- dict: key为inst_id原始id0-128value为bool表示是否有重复若value为None则表示无法判断
"""
pass
@property
def vector_dim(self) -> int:
raise NotImplementedError
def get_vector(self, use=True, use_info=None) -> list:
raise NotImplementedError
class UnitM2(UnitBase):
"""
melody pattern
"""
@classmethod
def get_raw_unit_class(cls):
raise NotImplementedError
@classmethod
def convert_raw_to_value(cls, raw_data):
"""
:return:
- int: 0表示上升1表示下降2表示上升后下降3表示下降后上升4表示平None表示不予考虑的其他情况
"""
pass
@property
def vector_dim(self) -> int:
raise NotImplementedError
def get_vector(self, use=True, use_info=None) -> list:
raise NotImplementedError

Просмотреть файл

@ -0,0 +1,211 @@
import math
from .unit_base import UnitBase
from .raw_unit_p import RawUnitP1, RawUnitP2
class UnitP1(UnitBase):
"""
low pitch
"""
@property
def version(self) -> str:
return 'v1.0'
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
- int最低pitch
没有音符则返回None
"""
low = 1000
begin = bars_positions[bar_begin][0]
end = bars_positions[bar_end - 1][1]
no_notes = True
for idx in range(begin, end):
pos_item = pos_info[idx]
insts_notes = pos_item[4]
if insts_notes is None:
continue
for inst_id in insts_notes:
if inst_id >= 128:
continue
inst_notes = insts_notes[inst_id]
for pitch, _, _ in inst_notes:
low = min(low, pitch)
no_notes = False
if no_notes:
return None
return low
def get_vector(self, use=True, use_info=None):
pitch = self.value
vec = [0] * 129
if pitch is None or not use:
vec[-1] = 1
return vec
vec[pitch] = 1
return vec
@property
def vector_dim(self) -> int:
return 129
class UnitP2(UnitBase):
"""
pitch range
"""
@property
def version(self) -> str:
return 'v1.0'
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
- int最高pitch
没有音符则返回None
"""
high = -1
begin = bars_positions[bar_begin][0]
end = bars_positions[bar_end - 1][1]
no_notes = True
for idx in range(begin, end):
pos_item = pos_info[idx]
insts_notes = pos_item[4]
if insts_notes is None:
continue
for inst_id in insts_notes:
if inst_id >= 128:
continue
inst_notes = insts_notes[inst_id]
for pitch, _, _ in inst_notes:
high = max(high, pitch)
no_notes = False
if no_notes:
return None
return high
def get_vector(self, use=True, use_info=None):
pitch = self.value
vec = [0] * 129
if pitch is None or not use:
vec[-1] = 1
return vec
vec[pitch] = 1
return vec
@property
def vector_dim(self) -> int:
return 129
class UnitP3(UnitBase):
"""
pitch class
"""
@property
def version(self) -> str:
return 'v1.0'
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
tuple, 包含所有pitch class
"""
begin = bars_positions[bar_begin][0]
end = bars_positions[bar_end - 1][1]
no_notes = True
pitch_class_set = set()
for idx in range(begin, end):
pos_item = pos_info[idx]
insts_notes = pos_item[4]
if insts_notes is None:
continue
for inst_id in insts_notes:
if inst_id >= 128:
continue
inst_notes = insts_notes[inst_id]
for pitch, _, _ in inst_notes:
pitch = int(pitch)
pitch_class_set.add(pitch % 12)
no_notes = False
if no_notes:
return None
return tuple(pitch_class_set)
def get_vector(self, use=True, use_info=None):
value = self.value
vec = [0] * 12
if not use or value is None:
return vec
for pitch_class in value:
vec[pitch_class] = 1
return vec
@property
def vector_dim(self) -> int:
return 12
class UnitP4(UnitBase):
"""
pitch range
"""
@classmethod
def get_raw_unit_class(cls):
return RawUnitP1, RawUnitP2
@classmethod
def convert_raw_to_value(
cls, raw_data, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
- int跨越的八度个数没有音符则返回None
"""
low = raw_data['P1']
high = raw_data['P2']
if low is None or high is None:
return None
return math.floor((high - low) / 12)
def get_vector(self, use=True, use_info=None):
# 顺序0个8度1个8度...11个8度NA
value = self.value
vec = [0] * self.vector_dim
if value is None or not use:
vec[-1] = 1
return vec
vec[value] = 1
return vec
@property
def vector_dim(self) -> int:
return 13

Просмотреть файл

@ -0,0 +1,225 @@
from .unit_base import UnitBase
from .raw_unit_r import RawUnitR4, RawUnitR3
class UnitR1(UnitBase):
"""
是否danceable
"""
@classmethod
def get_raw_unit_class(cls):
return RawUnitR4
@classmethod
def count_on_and_off_beat_notes(cls, ts, pos_notes):
if pos_notes is None:
return None
num_on = 0
num_off = 0
def get_on_positions(num_beat_pos, on_beat_list):
on_position_set = list()
for on_beat in on_beat_list:
# assert isinstance(on_beat, int)
# assert isinstance(num_beat_pos, int)
for pos in range((on_beat - 1) * num_beat_pos, on_beat * num_beat_pos):
on_position_set.append(pos)
on_position_set = set(on_position_set)
return on_position_set
ts = tuple(ts)
if ts == (4, 4):
beat_pos = len(pos_notes) // 4
assert len(pos_notes) % 4 == 0
on_beats = (1, 3)
on_positions = get_on_positions(beat_pos, on_beats)
elif ts == (3, 4):
beat_pos = len(pos_notes) // 3
assert len(pos_notes) % 3 == 0
on_beats = (1,)
on_positions = get_on_positions(beat_pos, on_beats)
elif ts == (2, 4):
beat_pos = len(pos_notes) // 2
assert len(pos_notes) % 2 == 0
on_beats = (1,)
on_positions = get_on_positions(beat_pos, on_beats)
else:
return None
for idx, num in enumerate(pos_notes):
if idx in on_positions:
num_on += 1
else:
num_off += 1
return num_on, num_off
@classmethod
def convert_raw_to_value(cls, raw_data, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs):
"""
:return:
- bool: 是否danceable若无法判断则为None
"""
raw_r4 = raw_data['R4']
num_on, num_off = 0, 0
for ts in raw_r4:
pos_count = raw_r4[ts]
value = cls.count_on_and_off_beat_notes(ts, pos_count)
if value is None:
continue
num_on += value[0]
num_off += value[1]
if num_on == 0 and num_off == 0:
return None
if num_on > num_off:
return True
elif num_on < num_off:
return False
else:
return None
def get_vector(self, use=True, use_info=None) -> list:
# 顺序是、否、NA
value = self.value
vector = [0] * 3
if value is None or not use:
vector[2] = 1
elif value is True:
vector[0] = 1
else:
vector[1] = 1
return vector
@property
def vector_dim(self) -> int:
return 3
class UnitR2(UnitBase):
"""
是否活泼存在某轨的跳音比例超过50%
"""
@property
def version(self) -> str:
return 'v1.0'
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
- bool, 是否活泼没有音符则返回None
"""
begin = bars_positions[bar_begin][0]
end = bars_positions[bar_end - 1][1]
pos_resolution = encoder.vm.pos_resolution
insts_num_notes = {}
insts_small_dur_notes = {}
no_notes = True
last_tempo = None
has_staccato = False
for idx in range(begin, end):
pos_item = pos_info[idx]
tempo = pos_item[3]
if tempo is not None:
last_tempo = tempo
insts_notes = pos_item[4]
if insts_notes is None:
continue
for inst_id in insts_notes:
if inst_id >= 128:
continue
inst_notes = insts_notes[inst_id]
for _, dur, _ in inst_notes:
no_notes = False
if inst_id not in insts_num_notes:
insts_num_notes[inst_id] = 0
insts_num_notes[inst_id] += 1
if inst_id not in insts_small_dur_notes:
insts_small_dur_notes[inst_id] = 0
num_seconds = dur * last_tempo * 60 / pos_resolution
if num_seconds <= 0.1:
insts_small_dur_notes[inst_id] += 1
has_staccato = True
if no_notes:
return None
if not has_staccato:
return False
for inst_id in insts_small_dur_notes:
num_small_notes = insts_small_dur_notes[inst_id]
num_notes = insts_num_notes[inst_id]
if num_small_notes / num_notes >= 0.5:
return True
return False
def get_vector(self, use=True, use_info=None):
value = self.value
vec = [0] * self.vector_dim
if value is None or not use:
vec[-1] = 1
return vec
if value:
vec[0] = 1
else:
vec[1] = 1
return vec
@property
def vector_dim(self) -> int:
return 3
class UnitR3(UnitBase):
"""
节奏是否激烈note density
"""
@classmethod
def get_raw_unit_class(cls):
return RawUnitR3
@classmethod
def convert_raw_to_value(cls, raw_data, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs):
"""
:return:
- int: 是否活泼0为否1为适中2为活泼若无法判断则为None
"""
if 'R3' not in raw_data:
return None
raw_r3 = raw_data['R3']
if raw_r3 is None:
return None
if raw_r3 <= 1:
return 0
elif 1 < raw_r3 < 2:
return 1
else:
return 2
def get_vector(self, use=True, use_info=None) -> list:
vector = [0] * 4
value = self.value
if value is None or not use:
vector[-1] = 1
else:
vector[value] = 1
return vector
@property
def vector_dim(self) -> int:
raise 4

Просмотреть файл

@ -0,0 +1,368 @@
from .unit_base import UnitBase
import os
from .raw_unit_s import RawUnitS1, RawUnitS2
def s1_func_by_is_symphony(file_path):
return True
def s1_func_by_has_symphony_1(file_path):
file_path = file_path.replace('\\', '/')
if 'symphony' in file_path:
return True
return None
s1_funcs = {
'is_symphony': s1_func_by_is_symphony,
'has_symphony_1': s1_func_by_has_symphony_1
}
class UnitS1(UnitBase):
"""
是否是交响乐
"""
@property
def version(self) -> str:
return 'v1.0'
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
- bool表示是否是交响乐
可能为None表示不知道是否为交响乐
"""
if 's1_func' not in kwargs:
return None
judge_func = kwargs['s1_func']
if judge_func is None:
return None
judge_func = s1_funcs[kwargs['s1_func']]
file_name = os.path.basename(midi_path)
is_symphony = judge_func(file_name)
return is_symphony
def get_vector(self, use=True, use_info=None):
value = self.value
vector = [0] * self.vector_dim
if value is None or not use:
vector[-1] = 1
return vector
if value:
vector[0] = 1
else:
vector[1] = 1
return vector
@property
def vector_dim(self) -> int:
return 3
dir_name_to_artist_name = {
'beethoven': 'Beethoven',
'mozart': 'Mozart',
'chopin': 'Chopin',
'schubert': 'Schubert',
'schumann': 'Schumann',
}
artist_name_to_id = {
'Beethoven': 0,
'Mozart': 1,
'Chopin': 2,
'Schubert': 3,
'Schumann': 4,
}
def s2_func_by_file_path_1(file_path):
file_path = file_path.replace('\\', '/')
file_path_split = file_path.split('/')
first_dir = file_path_split[0]
if first_dir in dir_name_to_artist_name:
return dir_name_to_artist_name[first_dir]
return None
s2_funcs = {
'file_path_1': s2_func_by_file_path_1,
}
class UnitS2(UnitBase):
"""
是否是某艺术家的作品
"""
@property
def version(self) -> str:
return 'v1.0'
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
- str艺术家名字
可能为None表示不知道艺术家是谁
"""
if 's2_func' not in kwargs:
return None
judge_func = kwargs['s2_func']
if judge_func is None:
return None
judge_func = s2_funcs[kwargs['s2_func']]
artist_name = judge_func(midi_path)
return artist_name
def get_vector(self, use=True, use_info=None):
value = self.value
vector = [0] * self.vector_dim
if value is None or not use:
vector[-1] = 1
return vector
value_id = artist_name_to_id[value]
assert 0 <= value_id < self.vector_dim - 1
vector[value_id] = 1
return vector
@property
def vector_dim(self) -> int:
return len(artist_name_to_id) + 1
class UnitS2s1(UnitBase):
"""
艺术家
"""
artist_label_to_artist_id = {
'beethoven': 0,
'mozart': 1,
'chopin': 2,
'schubert': 3,
'schumann': 4,
'bach-js': 5,
'haydn': 6,
'brahms': 7,
'Handel': 8,
'tchaikovsky': 9,
'mendelssohn': 10,
'dvorak': 11,
'liszt': 12,
'stravinsky': 13,
'mahler': 14,
'prokofiev': 15,
'shostakovich': 16,
}
@classmethod
def get_raw_unit_class(cls):
return RawUnitS1
@classmethod
def convert_raw_to_value(cls, raw_data, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs):
"""
:return:
- str艺术家label可能为None表示不知道艺术家是谁
"""
raw_s1 = raw_data['S1']
raw_s1 = raw_s1['artist']
return raw_s1
@classmethod
def convert_label_to_id(cls, label):
return cls.artist_label_to_artist_id[label]
def get_vector(self, use=True, use_info=None) -> list:
# 顺序artist 0, artist 1, ..., NA
vector = [0] * (len(self.artist_label_to_artist_id) + 1)
if not use or self.value is None:
vector[-1] = 1
else:
label_id = self.convert_label_to_id(self.value)
vector[label_id] = 1
return vector
@property
def vector_dim(self):
return len(self.artist_label_to_artist_id) + 1
def s3_func_by_is_classical(file_name):
return True
def s3_func_by_has_classical_1(file_path):
file_path = file_path.replace('\\', '/')
if 'classical' in file_path:
return True
return None
s3_funcs = {
'is_classical': s3_func_by_is_classical,
'has_classical_1': s3_func_by_has_classical_1,
}
class UnitS3(UnitBase):
"""
是否是古典乐
"""
@property
def version(self) -> str:
return 'v1.0'
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
- bool表示是否是古典乐
可能为None表示不知道是否为古典乐
"""
if 's3_func' not in kwargs:
return None
judge_func = kwargs['s3_func']
if judge_func is None:
return None
judge_func = s3_funcs[kwargs['s3_func']]
is_classical = judge_func(midi_path)
return is_classical
def get_vector(self, use=True, use_info=None):
value = self.value
vector = [0] * self.vector_dim
if value is None or not use:
vector[-1] = 1
return vector
if value:
vector[0] = 1
else:
vector[1] = 1
return vector
@property
def vector_dim(self) -> int:
return 3
class UnitS4(UnitBase):
"""
Genre
"""
genre_label_to_genre_id = {
'New Age': 0,
'Electronic': 1,
'Rap': 2,
'Religious': 3,
'International': 4,
'Easy_Listening': 5,
'Avant_Garde': 6,
'RnB': 7,
'Latin': 8,
'Children': 9,
'Jazz': 10,
'Classical': 11,
'Comedy_Spoken': 12,
'Pop_Rock': 13,
'Reggae': 14,
'Stage': 15,
'Folk': 16,
'Blues': 17,
'Vocal': 18,
'Holiday': 19,
'Country': 20,
'Symphony': 21,
}
@classmethod
def convert_label_to_id(cls, label):
return cls.genre_label_to_genre_id[label]
@classmethod
def get_raw_unit_class(cls):
return RawUnitS2
@classmethod
def convert_raw_to_value(cls, raw_data, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs):
"""
:return:
- tuple of str: 所有适用的genre label已去重若不知道则为None
"""
raw_s2 = raw_data['S2']
raw_s2 = raw_s2['genre']
raw_s2 = tuple(set(raw_s2)) if raw_s2 is not None else None
return raw_s2
def get_vector(self, use=True, use_info=None):
# 返回genre种数个列表每个列表顺序是, 否, NA
vector = [[0, 0, 0] for _ in range(len(self.genre_label_to_genre_id))]
if not use:
for item in vector:
item[2] = 1
return vector
if use_info is not None:
used_genres, unused_genres = use_info
usedNone = True
unusedNone = True
if used_genres != None:
used_genres = set(used_genres)
usedNone = False
else:
used_genres = set()
if unused_genres != None:
unused_genres = set(unused_genres)
unusedNone = False
else:
unused_genres = set()
if usedNone == False and unusedNone == False:
assert len(used_genres & unused_genres) == 0
if usedNone == False:
for genre in used_genres:
genre_id = self.convert_label_to_id(genre)
vector[genre_id][0] = 1
if unusedNone == False:
for genre in unused_genres:
genre_id = self.convert_label_to_id(genre)
vector[genre_id][1] = 1
na_insts = set(self.genre_label_to_genre_id.keys()) - used_genres - unused_genres
for genre in na_insts:
genre_id = self.convert_label_to_id(genre)
vector[genre_id][2] = 1
else:
value = self.value
if value is None:
value = tuple()
for genre in value:
genre_id = self.convert_label_to_id(genre)
vector[genre_id][0] = 1
na_insts = set(self.genre_label_to_genre_id.keys()) - set(value)
for genre in na_insts:
genre_id = self.convert_label_to_id(genre)
vector[genre_id][2] = 1
return vector
@property
def vector_dim(self):
return len(self.genre_label_to_genre_id), 3

Просмотреть файл

@ -0,0 +1,76 @@
from .unit_base import UnitBase
import os
from .raw_unit_st import RawUnitST1
structure_id_to_structure_label = ["A", "AB", "AA", "ABA", "AAB", "ABB", "AAAA", "AAAB", "AABB", "AABA", "ABAA", "ABAB", "ABBA", "ABBB"]
structure_label_to_structure_id = {}
for idx, item in enumerate(structure_id_to_structure_label):
structure_label_to_structure_id[item] = idx
def remove_digit(t):
r = []
for letter in t:
if letter not in '0123456789':
r.append(letter)
r = ''.join(r)
return r
def get_structure_by_file_name_1(file_name):
assert file_name.endswith('.mid')
file_name = file_name[:-4]
r = file_name.split('_')[1]
return r
st1_funcs = {
'file_name_1': get_structure_by_file_name_1,
}
class UnitST1(UnitBase):
"""
所有的ts种类
"""
structure_id_to_structure_label = [
"A", "AB", "AA", "ABA", "AAB", "ABB", "AAAA", "AAAB", "AABB", "AABA", "ABAA", "ABAB", "ABBA", "ABBB"
]
structure_label_to_structure_id = {}
for idx, item in enumerate(structure_id_to_structure_label):
structure_label_to_structure_id[item] = idx
@classmethod
def get_raw_unit_class(cls):
return RawUnitST1
@classmethod
def convert_raw_to_value(
cls, raw_data, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
str, 表示structure的字符串包含数字
"""
structure_label = raw_data['ST1']
return structure_label
def get_vector(self, use=True, use_info=None):
value = self.value
vector = [0] * (len(self.structure_id_to_structure_label) + 1)
if value is None or not use:
vector[-1] = 1
return vector
structure = remove_digit(value)
structure_id = self.structure_label_to_structure_id[structure]
vector[structure_id] = 1
return vector
@property
def vector_dim(self) -> int:
return len(self.structure_id_to_structure_label) + 1

Просмотреть файл

@ -0,0 +1,115 @@
from .unit_base import UnitBase
from .raw_unit_t import RawUnitT1
def convert_tempo_value_to_type_name_and_type_id(value):
if value >= 200:
return 'Prestissimo', 0
elif value >= 168:
return 'Presto', 1
elif value >= 120:
return 'Allegro', 2
elif value >= 108:
return 'Moderato', 3
elif value >= 76:
return 'Andante', 4
elif value >= 66:
return 'Adagio', 5
elif value >= 60:
return 'Larghetto', 6
elif value >= 40:
return 'Largo', 7
else:
return 'Grave', 8
class UnitT1(UnitBase):
"""
所使用的唯一tempo
"""
@property
def version(self) -> str:
return 'v1.0'
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
- float所使用的唯一tempo
- str, 类别名称
若不唯一则两个返回值均为None
"""
tempo_set = set()
begin = bars_positions[bar_begin][0]
end = bars_positions[bar_end - 1][1]
assert pos_info[begin][3] is not None
for idx in range(begin, end):
tempo = pos_info[idx][3]
if tempo is None:
continue
tempo_set.add(tempo)
if len(tempo_set) > 1:
return None, None
tempo = list(tempo_set)[0]
return tempo, convert_tempo_value_to_type_name_and_type_id(tempo)[0]
def get_vector(self, use=True, use_info=None):
value = self.value
tempo = value[0]
vector = [0] * 10
if not use or tempo is None:
vector[-1] = 1
return vector
tempo_id = convert_tempo_value_to_type_name_and_type_id(tempo)[1]
vector[tempo_id] = 1
return vector
@property
def vector_dim(self) -> int:
return 10
class UnitT1s1(UnitBase):
"""
演奏速度
"""
@classmethod
def get_raw_unit_class(cls):
return RawUnitT1
@classmethod
def convert_raw_to_value(cls, raw_data, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs):
"""
:return:
- float: 所使用的唯一tempo若有多个tempo则返回值为None
- int: 0表示慢1表示适中2表示快若有多个tempo则返回值为None
"""
tempo_list = raw_data['T1']
if len(tempo_list) > 1:
return None, None
tempo = tempo_list[0]
if tempo >= 120:
return tempo, 2
elif tempo <= 76:
return tempo, 0
else:
return tempo, 1
def get_vector(self, use=True, use_info=None) -> list:
# 顺序适中NA
_, label_id = self.value
vector = [0] * 4
if label_id is None or not use:
vector[-1] = 1
else:
vector[label_id] = 1
return vector
@property
def vector_dim(self):
return 4

Просмотреть файл

@ -0,0 +1,47 @@
from .unit_base import UnitBase
from .raw_unit_tm import RawUnitTM1
class UnitTM1(UnitBase):
"""
片段时长
"""
@classmethod
def get_raw_unit_class(cls):
return RawUnitTM1
@classmethod
def convert_raw_to_value(cls, raw_data, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs):
"""
:return:
- float: 片段时长单位为秒
- int: 分段id0表示0-151表示15-302表示30-453表示45-604表示60秒以上均为左开右闭区间
"""
time_second = raw_data['TM1']
if 0 < time_second <= 15:
return time_second, 0
elif 15 < time_second <= 30:
return time_second, 1
elif 30 < time_second <= 45:
return time_second, 2
elif 45 < time_second <= 60:
return time_second, 3
else:
return time_second, 4
def get_vector(self, use=True, use_info=None) -> list:
# 顺序区间0、1、2、3、4、NA
_, label_id = self.value
vector = [0] * 6
if not use:
vector[-1] = 1
else:
vector[label_id] = 1
return vector
@property
def vector_dim(self) -> int:
return 6

Просмотреть файл

@ -0,0 +1,146 @@
from .unit_base import UnitBase
from .raw_unit_ts import RawUnitTS1
class UnitTS1(UnitBase):
"""
所用过的唯一ts
"""
@property
def version(self) -> str:
return 'v1.0'
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
所用过的唯一ts元组例如(3, 4)代表3/4如果不止一个ts则return None
"""
ts_set = set()
begin = bars_positions[bar_begin][0]
end = bars_positions[bar_end - 1][1]
assert pos_info[begin][1] is not None
for idx in range(begin, end):
ts = pos_info[idx][1]
if ts is None:
continue
ts_set.add(ts)
if len(ts_set) > 1:
return None
return list(ts_set)[0]
def get_vector(self, use=True, use_info=None):
ts = self.value
vector = [0] * (len(self.encoder.vm.ts_list) + 1)
if not use or ts is None:
vector[-1] = 1
return vector
ts_id = self.encoder.vm.convert_ts_to_id(ts)
vector[ts_id] = 1
return vector
@property
def vector_dim(self) -> int:
return len(self.encoder.vm.ts_list) + 1
class UnitTS1s1(UnitBase):
"""
所用过的唯一ts常见类型+其他
"""
@classmethod
def get_raw_unit_class(cls):
return RawUnitTS1
@classmethod
def convert_raw_to_value(
cls, raw_data, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
ts_set = raw_data['TS1']
if len(ts_set) > 1:
return None
ts = tuple(ts_set)[0]
ts_id = cls.convert_ts_to_id(ts)
if ts_id == -1:
return 'other'
return ts
@classmethod
def convert_ts_to_id(cls, ts):
ts_list = [(4, 4), (2, 4), (3, 4), (1, 4), (6, 8), (3, 8)]
try:
idx = ts_list.index(ts)
except ValueError:
idx = -1
return idx
def get_vector(self, use=True, use_info=None):
# 顺序:(4, 4), (2, 4), (3, 4), (1, 4), (6, 8), (3, 8), other, NA
value = self.value
vector = [0] * self.vector_dim
if not use or value is None:
vector[-1] = 1
return vector
if value == 'other':
vector[-2] = 1
return vector
ts_id = self.convert_ts_to_id(value)
assert ts_id != -1
vector[ts_id] = 1
return vector
@property
def vector_dim(self) -> int:
return 8
class UnitTS2(UnitBase):
"""
TS是否变化
"""
def __init__(self, value, encoder=None):
super().__init__(value, encoder=encoder)
raise NotImplementedError("需要refine")
@property
def version(self) -> str:
return 'v1.0'
@classmethod
def extract(
cls, encoder, midi_dir, midi_path,
pos_info, bars_positions, bars_chords, bars_insts, bar_begin, bar_end, **kwargs
):
"""
:return:
Bool, True表示发生了变化False表示没有
"""
ts_set = set()
begin = bars_positions[bar_begin][0]
end = bars_positions[bar_end - 1][1]
assert pos_info[begin][1] is not None
for idx in range(begin, end):
ts = pos_info[idx][1]
if ts is None:
continue
ts_set.add(ts)
if len(ts_set) > 1:
return True
return False
def get_vector(self, use=True, use_info=None):
value = self.value
if value is True:
return [1]
else:
return [0]
@property
def vector_dim(self) -> int:
return 1

Просмотреть файл

@ -0,0 +1,103 @@
import math
import numpy as np
from .utils.magenta_chord_recognition import infer_chords_for_sequence, _key_chord_distribution, \
_key_chord_transition_distribution
class Item(object):
def __init__(self, name, start, end, vel=0, pitch=0, track=0, value=''):
self.name = name
self.start = start # start step
self.end = end # end step
self.vel = vel
self.pitch = pitch
self.track = track
self.value = value
def __repr__(self):
return f'Item(name={self.name:>10s}, start={self.start:>4d}, end={self.end:>4d}, ' \
f'vel={self.vel:>3d}, pitch={self.pitch:>3d}, track={self.track:>2d}, ' \
f'value={self.value:>10s})\n'
def __eq__(self, other):
return self.name == other.name and self.start == other.start and \
self.pitch == other.pitch and self.track == other.track
class ChordDetector(object):
def __init__(self, encoder):
self.encoder = encoder
self.pos_resolution = self.encoder.vm.pos_resolution
self.key_chord_loglik, self.key_chord_transition_loglik = self.init_for_chord_detection()
@staticmethod
def init_for_chord_detection():
chord_pitch_out_of_key_prob = 0.01
key_change_prob = 0.001
chord_change_prob = 0.5
key_chord_distribution = _key_chord_distribution(
chord_pitch_out_of_key_prob=chord_pitch_out_of_key_prob)
key_chord_loglik = np.log(key_chord_distribution)
key_chord_transition_distribution = _key_chord_transition_distribution(
key_chord_distribution,
key_change_prob=key_change_prob,
chord_change_prob=chord_change_prob)
key_chord_transition_loglik = np.log(key_chord_transition_distribution)
return key_chord_loglik, key_chord_transition_loglik
def infer_chord_for_pos_info(self, pos_info):
# 此函数只针对4/4的曲子
# input: pos_info, 经过normalize到大小调、去除melody轨道重叠音符并且decode(encode(pos_info))之后的pos_info
# output: 这个pos_info的和弦粒度是24个位置一个和弦一个bar有两个和弦
# magenta算法已经修复了多一个和弦的bug
key_chord_loglik, key_chord_transition_loglik = self.key_chord_loglik, self.key_chord_transition_loglik
pos_resolution = self.pos_resolution
max_pos = 0
note_items = []
for bar, ts, pos, tempo, insts_notes in pos_info:
if ts is not None and tuple(ts) != (4, 4):
raise NotImplementedError("This implementation only supports time signature 4/4.")
if insts_notes is None:
continue
for inst_id in insts_notes:
if inst_id >= 128:
continue
inst_notes = insts_notes[inst_id] # 浅复制,修改会影响值
for note_idx, (pitch, duration, velocity) in enumerate(inst_notes):
max_pos = max(max_pos, bar * pos_resolution * 4 + pos + duration)
if 0 <= pitch < 128:
# squeeze pitch ranges to facilitate chord detection
while pitch > 72:
pitch -= 12
while pitch < 48:
pitch += 12
note_items.append(
Item(
name='On',
start=bar * pos_resolution * 4 + pos, # 这里pos_resolution*4代表一个bar有4拍不是4/4的曲子不适用
end=bar * pos_resolution * 4 + pos + duration,
vel=velocity,
pitch=pitch,
track=0
)
)
note_items.sort(key=lambda x: (x.start, -x.end))
pos_per_chord = pos_resolution * 2 # 24
# max_chords = round(max_pos // pos_per_chord + 0.5)
max_chords = math.ceil(max_pos / pos_per_chord)
chords = infer_chords_for_sequence(
note_items,
pos_per_chord=pos_per_chord,
max_chords=max_chords,
key_chord_loglik=key_chord_loglik,
key_chord_transition_loglik=key_chord_transition_loglik
)
num_bars = pos_info[-1][0] + 1
while len(chords) < num_bars * 2:
chords.append('N.C.')
if len(chords) > num_bars * 2: # with a very long note in th end, the chords num will be larger than num_bars*2
chords = chords[:num_bars * 2]
assert len(chords) == num_bars * 2, 'chord length: %d, number of bars: %d' % (len(chords), num_bars)
return chords

Просмотреть файл

@ -0,0 +1,5 @@
attribute_versions_list = {
'v1': ['I1', 'I2', 'B1', 'TS1', 'T1', 'P1', 'P2', 'P3', 'ST1', 'EM1'],
'v2': ['I1s1', 'I2s1', 'C1', 'R2', 'S1', 'S2', 'S3', 'B1s1', 'TS1s1', 'K1', 'T1', 'P3', 'P4', 'ST1', 'EM1'],
'v3': ['I1s2', 'I4', 'C1', 'R1', 'R3', 'S2s1', 'S4', 'B1s1', 'TS1s1', 'K1', 'T1s1', 'P4', 'ST1', 'EM1', 'TM1'],
}

Просмотреть файл

@ -0,0 +1,490 @@
inst_id_to_names = {
# piano:
0: ['acoustic grand piano', 'piano'],
1: ['bright acoustic piano', 'piano'],
2: ['electric grand piano', 'electric piano'],
3: ['honky-tonk piano', 'honky-tonk piano'],
4: ['electric piano 1', 'electric piano'],
5: ['electric piano 2', 'electric piano'],
6: ['harpsichord', 'harpsichord'],
7: ['clavinet', 'clavinet'],
# chromatic percussion:
8: ['celesta', 'celesta'],
9: ['glockenspiel', 'glockenspiel'],
10: ['music box', 'music box'],
11: ['vibraphone', 'vibraphone'],
12: ['marimba', 'marimba'],
13: ['xylophone', 'xylophone'],
14: ['tubular bells', 'tubular bells'],
15: ['dulcimer', 'dulcimer'],
# organ:
16: ['drawbar organ', 'drawbar organ'],
17: ['percussive organ', 'percussive organ'],
18: ['rock organ', 'rock organ'],
19: ['church organ', 'church organ'],
20: ['reed organ', 'reed organ'],
21: ['accordion', 'accordion'],
22: ['harmonica', 'harmonica'],
23: ['tango accordion', 'tango accordion'],
# guitar:
24: ['acoustic guitar (nylon)', 'guitar'],
25: ['acoustic guitar (steel)', 'guitar'],
26: ['electric guitar (jazz)', 'electric guitar'],
27: ['electric guitar (clean)', 'electric guitar'],
28: ['electric guitar (muted)', 'electric guitar'],
29: ['overdriven guitar', 'electric guitar'],
30: ['distortion guitar', 'electric guitar'],
31: ['guitar harmonics', 'electric guitar'],
# bass:
32: ['acoustic bass', 'bass'],
33: ['electric bass (finger)', 'electric bass'],
34: ['electric bass (pick)', 'electric bass'],
35: ['fretless bass', 'fretless bass'],
36: ['slap bass 1', 'slap bass'],
37: ['slap bass 2', 'slap bass'],
38: ['synth bass 1', 'synth bass'],
39: ['synth bass 2', 'synth bass'],
# strings:
40: ['violin', 'violin'],
41: ['viola', 'viola'],
42: ['cello', 'cello'],
43: ['contrabass', 'contrabass'],
44: ['tremolo strings', 'tremolo strings'],
45: ['pizzicato strings', 'pizzicato strings'],
46: ['orchestral harp', 'orchestral harp'],
47: ['timpani', 'timpani'],
# strings (continued):
48: ['string ensemble 1', 'string ensemble'],
49: ['string ensemble 2', 'string ensemble'],
50: ['synth strings 1', 'synth strings'],
51: ['synth strings 2', 'synth strings'],
52: ['choir aahs', 'choir singing'],
53: ['voice oohs', 'singing voice'],
54: ['synth voice', 'synth voice'],
55: ['orchestra hit', 'orchestra hit'],
# brass:
56: ['trumpet', 'trumpet'],
57: ['trombone', 'trombone'],
58: ['tuba', 'tuba'],
59: ['muted trumpet', 'muted trumpet'],
60: ['french horn', 'french horn'],
61: ['brass section', 'brass section'],
62: ['synth brass 1', 'synth brass'],
63: ['synth brass 2', 'synth brass'],
# reed:
64: ['soprano sax', 'soprano sax'],
65: ['alto sax', 'alto sax'],
66: ['tenor sax', 'tenor sax'],
67: ['baritone sax', 'baritone sax'],
68: ['oboe', 'oboe'],
69: ['english horn', 'english horn'],
70: ['bassoon', 'bassoon'],
# pipe:
71: ['clarinet', 'clarinet'],
72: ['piccolo', 'piccolo'],
73: ['flute', 'flute'],
74: ['recorder', 'recorder'],
75: ['pan flute', 'pan flute'],
76: ['blown bottle', 'blown bottle'],
77: ['shakuhachi', 'shakuhachi'],
78: ['whistle', 'whistle'],
79: ['ocarina', 'ocarina'],
# synth lead:
80: ['lead 1 (square)', 'synth lead'],
81: ['lead 2 (sawtooth)', 'synth lead'],
82: ['lead 3 (calliope)', 'synth lead'],
83: ['lead 4 (chiff)', 'synth lead'],
84: ['lead 5 (charang)', 'synth lead'],
85: ['lead 6 (voice)', 'synth lead'],
86: ['lead 7 (fifths)', 'synth lead'],
87: ['lead 8 (bass + lead)', 'synth lead'],
# synth pad:
88: ['pad 1 (new age)', 'synth pad'],
89: ['pad 2 (warm)', 'synth pad'],
90: ['pad 3 (polysynth)', 'synth pad'],
91: ['pad 4 (choir)', 'synth pad'],
92: ['pad 5 (bowed)', 'synth pad'],
93: ['pad 6 (metallic)', 'synth pad'],
94: ['pad 7 (halo)', 'synth pad'],
95: ['pad 8 (sweep)', 'synth pad'],
# synth effects:
96: ['fx 1 (rain)', 'rain sound'],
97: ['fx 2 (soundtrack)', 'soundtrack'],
98: ['fx 3 (crystal)', 'crystal'],
99: ['fx 4 (atmosphere)', 'atmosphere'],
100: ['fx 5 (brightness)', 'brightness'],
101: ['fx 6 (goblins)', 'goblins'],
102: ['fx 7 (echoes)', 'echoes'],
103: ['fx 8 (sci-fi)', 'sci-fi'],
# ethnic:
104: ['sitar', 'sitar'],
105: ['banjo', 'banjo'],
106: ['shamisen', 'shamisen'],
107: ['koto', 'koto'],
108: ['kalimba', 'kalimba'],
109: ['bag pipe', 'bag pipe'],
110: ['fiddle', 'fiddle'],
111: ['shanai', 'shanai'],
# percussive:
112: ['tinkle bell', 'tinkle bell'],
113: ['agogo', 'agogo'],
114: ['steel drums', 'steel drums'],
115: ['woodblock', 'woodblock'],
116: ['taiko drum', 'taiko drum'],
117: ['melodic tom', 'melodic tom'],
118: ['synth drum', 'synth drum'],
119: ['reverse cymbal', 'reverse cymbal'],
# sound effects:
120: ['guitar fret noise', 'guitar fret noise'],
121: ['breath noise', 'breath noise'],
122: ['seashore', 'seashore'],
123: ['bird tweet', 'bird tweet'],
124: ['telephone ring', 'telephone ring'],
125: ['helicopter', 'helicopter'],
126: ['applause', 'applause'],
127: ['gunshot', 'gunshot'],
# drum:
128: ['drum', 'drum'],
}
inst_id_to_inst_class_id = {
# piano:
0: 0,
1: 0,
2: 0,
3: 0,
4: 0,
5: 0,
6: 0,
7: 0,
# chromatic percussion:
8: 1,
9: 1,
10: 1,
11: 1,
12: 1,
13: 1,
14: 1,
15: 1,
# organ:
16: 2,
17: 2,
18: 2,
19: 2,
20: 2,
21: 2,
22: 2,
23: 2,
# guitar:
24: 3,
25: 3,
26: 3,
27: 3,
28: 3,
29: 3,
30: 3,
31: 3,
# bass:
32: 4,
33: 4,
34: 4,
35: 4,
36: 4,
37: 4,
38: 4,
39: 4,
# strings:
40: 5,
41: 5,
42: 5,
43: 5,
44: 5,
45: 5,
46: 5,
47: 5,
# strings (continued):
48: 6,
49: 6,
50: 6,
51: 6,
52: 6,
53: 6,
54: 6,
55: 6,
# brass:
56: 7,
57: 7,
58: 7,
59: 7,
60: 7,
61: 7,
62: 7,
63: 7,
# reed:
64: 8,
65: 8,
66: 8,
67: 8,
68: 8,
69: 8,
70: 8,
71: 8,
# pipe
72: 9,
73: 9,
74: 9,
75: 9,
76: 9,
77: 9,
78: 9,
79: 9,
# synth lead:
80: 10,
81: 10,
82: 10,
83: 10,
84: 10,
85: 10,
86: 10,
87: 10,
# synth pad:
88: 11,
89: 11,
90: 11,
91: 11,
92: 11,
93: 11,
94: 11,
95: 11,
# synth effects:
96: 12,
97: 12,
98: 12,
99: 12,
100: 12,
101: 12,
102: 12,
103: 12,
# ethnic:
104: 13,
105: 13,
106: 13,
107: 13,
108: 13,
109: 13,
110: 13,
111: 13,
# percussive:
112: 14,
113: 14,
114: 14,
115: 14,
116: 14,
117: 14,
118: 14,
119: 14,
# sound effects:
120: 15,
121: 15,
122: 15,
123: 15,
124: 15,
125: 15,
126: 15,
127: 15,
# drum:
128: 16,
}
inst_id_to_inst_class_id_2 = {
# piano:
0: 0,
1: 0,
2: 0,
3: 0,
4: 0,
5: 0,
6: 0,
7: 0,
# pitched percussion:
8: 1,
9: 1,
10: 1,
11: 1,
12: 1,
13: 1,
14: 1,
15: 1,
# organ:
16: 2,
17: 2,
18: 2,
19: 2,
20: 2,
21: 2,
22: 2,
23: 2,
# guitar:
24: 3,
25: 3,
26: 3,
27: 3,
28: 3,
29: 3,
30: 3,
31: 3,
# bass:
32: 4,
33: 4,
34: 4,
35: 4,
36: 4,
37: 4,
38: 4,
39: 4,
# strings:
40: 5,
41: 5,
42: 5,
43: 5,
44: 5,
45: 5,
46: 5,
47: 5,
# strings (continued):
48: 6,
49: 6,
50: 6,
51: 6,
52: 6,
53: 6,
54: 6,
55: 6,
# brass:
56: 7,
57: 7,
58: 7,
59: 7,
60: 7,
61: 7,
62: 7,
63: 7,
# reed:
64: 8,
65: 8,
66: 8,
67: 8,
68: 8,
69: 8,
70: 8,
71: 8,
# pipe
72: 9,
73: 9,
74: 9,
75: 9,
76: 9,
77: 9,
78: 9,
79: 9,
# synth lead:
80: 10,
81: 10,
82: 10,
83: 10,
84: 10,
85: 10,
86: 10,
87: 10,
# synth pad:
88: 10,
89: 10,
90: 10,
91: 10,
92: 10,
93: 10,
94: 10,
95: 10,
# synth effects:
96: 10,
97: 10,
98: 10,
99: 10,
100: 10,
101: 10,
102: 10,
103: 10,
# ethnic:
104: 11,
105: 11,
106: 11,
107: 11,
108: 11,
109: 11,
110: 11,
111: 11,
# pitched percussion:
112: 1,
113: 1,
114: 1,
115: 1,
116: 1,
117: 1,
118: 1,
119: 1,
# sound effects:
120: 12,
121: 12,
122: 12,
123: 12,
124: 12,
125: 12,
126: 12,
127: 12,
# drum:
128: 13,
}
inst_class_id_to_name = {
0: 'piano',
1: 'chromatic percussion',
2: 'organ',
3: 'guitar',
4: 'bass',
5: 'strings',
6: 'ensemble',
7: 'brass',
8: 'reed',
9: 'pipe',
10: 'synth lead',
11: 'synth pad',
12: 'synth effects',
13: 'ethnic instrument',
14: 'purcussion',
15: 'sound effects',
16: 'drum'
}
inst_class_id_to_name_2 = {
0: 'piano',
1: 'pitched percussion',
2: 'organ',
3: 'guitar',
4: 'bass',
5: 'strings',
6: 'ensemble',
7: 'brass',
8: 'reed',
9: 'pipe',
10: 'synthesizer',
11: 'ethnic instrument',
12: 'sound effect',
13: 'drum'
}
ordinal_digit_numbers = ['1st', '2nd', '3rt', '4th', '5th', '6th', '7th', '8th', '9th', '10th', '11th', '12th', '13th', '14th', '15th', '16th']
ordinal_letter_numbers = ['first', 'second', 'third', 'fourth', 'fifth', 'sixth', 'seventh', 'eighth', 'ninth', 'tenth', 'eleventh', 'twelfth', 'thirteenth', 'fourteenth', 'fifteenth', 'sixteenth']

Просмотреть файл

@ -0,0 +1,299 @@
import os
import pickle
import random
import midiprocessor as mp
from copy import deepcopy
from .midi_processing import get_midi_pos_info, convert_pos_info_to_tokens
from .chord_detection import ChordDetector
from .config import attribute_versions_list
from .verbalizer import Verbalizer
from .attribute_unit import load_unit_class
from .utils.pos_process import fill_pos_ts_and_tempo_
def cut_by_none(num_bars, k, min_bar, max_bar):
return [(0, num_bars)]
def cut_by_random_1(num_bars, k, min_bar, max_bar, auto_k=True):
if num_bars < min_bar:
return None
r = set()
for begin in range(num_bars - min_bar + 1):
for end in range(begin + 1, min(begin + max_bar, num_bars)):
r.add((begin, end))
if auto_k:
k = min(len(r), k)
r = random.sample(r, k)
return r
def cut_by_random_2(num_bars, k, min_bar, max_bar, auto_k=True):
if num_bars < min_bar:
return None
r = set()
if num_bars >= max_bar:
for begin in range(num_bars - max_bar + 1):
r.add((begin, begin + max_bar))
else:
r.add((0, num_bars))
if auto_k:
k = min(len(r), k)
r = random.sample(r, k)
return r
cut_methods = {
'none': cut_by_none,
'random_1': cut_by_random_1,
'random_2': cut_by_random_2,
}
def get_bar_positions(pos_info):
r = {}
for idx, pos_item in enumerate(pos_info):
bar_id = pos_item[0]
if bar_id not in r:
r[bar_id] = [idx, idx]
r[bar_id][1] = idx + 1
nr = []
for idx in range(len(r)):
nr.append(r[idx])
r = nr
return r
def get_bars_insts(pos_info, bars_positions):
r = []
num_bars = len(bars_positions)
for idx in range(num_bars):
begin, end = bars_positions[idx]
cur_insts = set()
for t_idx in range(begin, end):
notes = pos_info[t_idx][-1]
if notes is not None:
for inst_id in notes:
cur_insts.add(inst_id)
cur_insts = tuple(cur_insts)
r.append(cur_insts)
return r
class DataExtractor(object):
def __init__(self, attribute_list_version, encoding_method='REMIGEN', attribute_list=None):
if encoding_method not in ('REMIGEN', 'REMIGEN2'):
raise NotImplementedError("Other encoding method such as %s is not supported yet." % encoding_method)
self.encoder = mp.MidiEncoder(encoding_method)
self.chord_detector = ChordDetector(self.encoder)
if attribute_list is not None:
self.attribute_list = tuple(set(attribute_list))
else:
self.attribute_list = attribute_versions_list[attribute_list_version]
self.unit_cls_dict = self.init_units(self.attribute_list)
self.verbalizer = Verbalizer()
@staticmethod
def init_units(attribute_list):
unit_cls_dict = {}
for attribute_label in attribute_list:
unit_cls_dict[attribute_label] = load_unit_class(attribute_label)
return unit_cls_dict
def extract(
self, midi_dir, midi_path,
cut_method='random_1',
normalize_pitch_value=True,
pos_info_path=None,
ignore_chord=False,
chord_path=None,
**kwargs,
):
pos_info = None
loaded_pos_info = False
if pos_info_path is not None:
try:
with open(pos_info_path, 'rb') as f:
pos_info = pickle.load(f)
except FileNotFoundError:
pos_info = None
if pos_info is None:
midi_obj = mp.midi_utils.load_midi(os.path.join(midi_dir, midi_path))
pos_info = get_midi_pos_info(self.encoder, midi_path=None, midi_obj=midi_obj)
else:
loaded_pos_info = True
pos_info = fill_pos_ts_and_tempo_(pos_info)
is_major = None
if normalize_pitch_value:
try:
pos_info, is_major, _ = self.encoder.normalize_pitch(pos_info)
except KeyboardInterrupt:
raise
except:
is_major = None
# load chord for the whole midi from file, or detect it from the sequence.
# currently only support midi using only 4/4 measure. else, bars_chords is None.
bars_chords = None
loaded_chords = False
if not ignore_chord:
if chord_path is not None:
try:
with open(chord_path, 'rb') as f:
bars_chords = pickle.load(f)
except FileNotFoundError:
bars_chords = None
if bars_chords is None:
try:
bars_chords = self.chord_detector.infer_chord_for_pos_info(pos_info)
except KeyboardInterrupt:
raise
except:
bars_chords = None
else:
loaded_chords = True
bars_positions = get_bar_positions(pos_info)
bars_instruments = get_bars_insts(pos_info, bars_positions)
num_bars = len(bars_positions)
assert num_bars == len(bars_instruments)
attribute_list = self.attribute_list
unit_cls_dict = self.unit_cls_dict
length = min(16, num_bars)
# if length < 4:
# raise ValueError("The number of bars is less than 4.")
pieces_pos = cut_methods[cut_method](num_bars, 3, length, length) # Todo: allow settings
if pieces_pos is None:
print('pieces_pos is None', num_bars, midi_path)
# assert False
raise ValueError("No valid pieces are for this MIDI.")
# Todo: move to better place
tokens = convert_pos_info_to_tokens(self.encoder, pos_info)
assert tokens[-1] == 'b-1'
last_begin = 0
last_idx = 0
bars_token_positions = {}
for idx, token in enumerate(tokens):
if token == 'b-1':
bars_token_positions[last_idx] = (last_begin, idx + 1)
last_begin = idx + 1
last_idx = last_idx + 1
pieces = []
for bar_begin, bar_end in pieces_pos:
# skip the piece if no any instrument is played.
seg_insts = bars_instruments[bar_begin: bar_end]
has_notes = False
for item in seg_insts:
if len(item) > 0:
has_notes = True
break
if not has_notes:
continue
value_dict = {}
for attribute_label in attribute_list:
unit_cls = unit_cls_dict[attribute_label]
unit = unit_cls.new(
self.encoder, midi_dir, midi_path, pos_info, bars_positions, bars_chords, bars_instruments,
bar_begin, bar_end, # Todo
is_major=is_major,
**kwargs,
)
value = unit.value
value_dict[attribute_label] = value
piece_sample = {
'bar_begin': bar_begin,
'bar_end': bar_end,
'values': value_dict,
'token_begin': bars_token_positions[bar_begin][0],
'token_end': bars_token_positions[bar_end - 1][1],
}
pieces.append(piece_sample)
if len(pieces) == 0:
# assert False
raise ValueError("No valid results for all the pieces.")
info_dict = {
'midi_dir': midi_dir,
'midi_path': midi_path,
'pieces': pieces
}
loaded_record = {
'pos_info': loaded_pos_info,
'chord': loaded_chords,
}
return tokens, pos_info, bars_chords, info_dict, loaded_record
def represent(self, info_dict, remove_raw=False):
info_dict = deepcopy(info_dict)
return self.represent_(info_dict, remove_raw=remove_raw)
def represent_(self, info_dict, remove_raw=False):
# midi_dir = info_dict['midi_dir']
midi_path = info_dict['midi_path']
pieces = info_dict['pieces']
for piece in pieces:
value_dict = piece['values']
unit_dict = piece['units']
bar_begin = piece['bar_begin']
bar_end = piece['bar_end']
reps = []
try:
text_list, used_attributes = self.verbalizer.get_text(value_dict)
except:
print(midi_path)
print(bar_begin, bar_end)
print(value_dict)
raise
if len(text_list) == 0:
continue
assert len(text_list) == len(used_attributes)
for text, u_attributes in zip(text_list, used_attributes):
vectors = {}
for attribute_label in self.attribute_list:
try:
unit = unit_dict[attribute_label]
if attribute_label in u_attributes:
vector = unit.get_vector(use=True, use_info=u_attributes[attribute_label])
else:
vector = unit.get_vector(use=False, use_info=None)
vector = tuple(vector)
vectors[attribute_label] = vector
except:
print('Error while vectorizing "%s".' % attribute_label)
raise
rep_sample = {
'text': text,
'vectors': vectors
}
reps.append(rep_sample)
piece['reps'] = reps
if remove_raw:
piece.pop('values')
piece.pop('units')
return info_dict

Просмотреть файл

@ -0,0 +1,63 @@
from midiprocessor import midi_utils, MidiEncoder, enc_remigen_utils, enc_remigen2_utils
def get_midi_pos_info(
encoder,
midi_path=None,
midi_obj=None,
remove_empty_bars=True,
):
if midi_obj is None:
midi_obj = midi_utils.load_midi(midi_path)
pos_info = encoder.collect_pos_info(midi_obj, trunc_pos=None, tracks=None, remove_same_notes=False, end_offset=0)
del midi_obj
# encode and decode to ensure the info consistency, i.e., let the following chord and cadence detection
# happen on exactly the same info as the resulting token sequences
pos_info = encoder.convert_pos_info_to_pos_info_id(pos_info)
pos_info = encoder.convert_pos_info_id_to_pos_info(pos_info)
# remove the beginning and ending empty bars
if remove_empty_bars:
pos_info = encoder.remove_empty_bars_for_pos_info(pos_info)
return pos_info
def convert_pos_info_to_tokens(encoder, pos_info, **kwargs):
pos_info_id = encoder.convert_pos_info_to_pos_info_id(pos_info)
if encoder.encoding_method == 'REMIGEN':
enc_utils = enc_remigen_utils
elif encoder.encoding_method == 'REMIGEN2':
enc_utils = enc_remigen2_utils
else:
raise ValueError(encoder.encoding_method)
tokens = enc_utils.convert_pos_info_to_token_lists(
pos_info_id, ignore_ts=False, sort_insts='id', sort_notes=None, **kwargs
)[0]
tokens = enc_utils.convert_remigen_token_list_to_token_str_list(tokens)
return tokens
if __name__ == '__main__':
midi_path = 'test.mid'
enc = MidiEncoder("REMIGEN")
pi = get_midi_pos_info(enc, midi_path)
# 这是一个包含MIDI全部信息的list【你所需要的信息理论上从这个里面获取最方便】
# 列表的长度为该MIDI的最大position个数
# 例如某个MIDI只有1个bar这个bar有4拍程序设定每拍分为12个position那么pos_info的大小为1*4*12=48
# pos_info中的每个元素也是一个列表长度为5信息依次为
# bar index: bar的索引从0开始
# ts: Time signature只有在有变化的时候才会有否则None
# local_pos: bar内的onset位置例如在一个4拍的bar中该数字会从0一直到47
# tempo: Tempo的值只有在有变化的时候才会有否则None
# insts_notes: 是一个字典key为inst的id鼓为128value为该位置所有音符的集合信息包含pitch, duration, velocity
# 可以自己弄一个MIDI来试试
# 温馨提示如果下一次还需要这个MIDI的信息可以保存pos_info下次直接加载就能更快一些
# 转化为token序列
tokens = convert_pos_info_to_tokens(enc, pi)
print(tokens[:100])

Просмотреть файл

@ -0,0 +1,32 @@
{
"I1": "[INSTRUMENTS] should be included in the music.; The music should feature [INSTRUMENTS].; [INSTRUMENTS] are utilized in the musical performance.; The musical performance employs [INSTRUMENTS].; The music is brought to life through the use of [INSTRUMENTS].; [INSTRUMENTS] play an important role in the music.; The music is given its sound through [INSTRUMENTS].; The [INSTRUMENTS] add to the musical composition.; The music is enriched by [INSTRUMENTS].; The use of [INSTRUMENTS] is vital to the music.",
"I1s1": "[INSTRUMENTS] should be included in the music.; The music should feature [INSTRUMENTS].; [INSTRUMENTS] are utilized in the musical performance.; The musical performance employs [INSTRUMENTS].; The music is brought to life through the use of [INSTRUMENTS].; [INSTRUMENTS] play an important role in the music.; The music is given its sound through [INSTRUMENTS].; The [INSTRUMENTS] add to the musical composition.; The music is enriched by [INSTRUMENTS].; The use of [INSTRUMENTS] is vital to the music.",
"B1":"The music is comprised of [NUM_BARS] bars.; The music consists of [NUM_BARS] bars.;The music spans [NUM_BARS] bars.; The music covers [NUM_BARS] bars.; The music has [NUM_BARS] bars in total." ,
"B1s1":"The music is comprised of [NUM_BARS] bars.; The music consists of [NUM_BARS] bars.;The music spans [NUM_BARS] bars.; The music covers [NUM_BARS] bars.; The music has [NUM_BARS] bars in total." ,
"TS1": "The music is in [TIME_SIGNATURE].; The time signature of the music is [TIME_SIGNATURE].;The music has a time signature of [TIME_SIGNATURE].; The meter of the music is [TIME_SIGNATURE].;The music follows a [TIME_SIGNATURE] meter.; The music is based on a [TIME_SIGNATURE] time signature.;The [TIME_SIGNATURE] time signature is used in the music.; The music features a [TIME_SIGNATURE] meter.;[TIME_SIGNATURE] is the time signature of the music.; [TIME_SIGNATURE] is the meter of the music.",
"TS1s1": "The music is in [TIME_SIGNATURE].; The time signature of the music is [TIME_SIGNATURE].;The music has a time signature of [TIME_SIGNATURE].; The meter of the music is [TIME_SIGNATURE].;The music follows a [TIME_SIGNATURE] meter.; The music is based on a [TIME_SIGNATURE] time signature.;The [TIME_SIGNATURE] time signature is used in the music.; The music features a [TIME_SIGNATURE] meter.;[TIME_SIGNATURE] is the time signature of the music.; [TIME_SIGNATURE] is the meter of the music.",
"T1": "This song is marked at [TEMPO] in musical notation.;The musical notation for the tempo of this tune is [TEMPO].;[TEMPO] is the tempo indication for this song.;The tempo indication in musical notation for this melody is [TEMPO].;The tempo of this piece is notated as [TEMPO].",
"P1": "This song features [LOW_PITCH] as its minimum pitch.;The bottom pitch in this song is [LOW_PITCH].;The song has a minimum note pitch of [LOW_PITCH].;The lowest note pitch in this song is [LOW_PITCH].;This song contains a low pitch of [LOW_PITCH].;The minimum note in this song is [LOW_PITCH].;In this song, the lowest note is [LOW_PITCH].;[LOW_PITCH] represents the lowest note in the song.;[LOW_PITCH] is the lowest note in the piece.;The lowest sound in the song is [LOW_PITCH].",
"P2" : "The top pitch in this song is [HIGH_PITCH].;This song has a maximum note pitch of [HIGH_PITCH].;The highest note pitch in this song is [HIGH_PITCH].;This song features [HIGH_PITCH] as its maximum pitch.;The song contains a high pitch of [HIGH_PITCH].;The peak pitch in this song is [HIGH_PITCH].;This song reaches its highest pitch at [HIGH_PITCH].;The song's highest note is [HIGH_PITCH].;[HIGH_PITCH] is the song's peak note.;[HIGH_PITCH] represents the highest note in this song." ,
"P3": "The music features [NUM_PITCH_CLASS] distinct pitches.;[NUM_PITCH_CLASS] different pitches are present in the music.;The music consists of [NUM_PITCH_CLASS] unique pitches.;[NUM_PITCH_CLASS] types of pitches are included in the music.;The music has [NUM_PITCH_CLASS] separate pitches.",
"ST1": "The song is arranged in [STRUCTURE] form.; The [STRUCTURE] structure helps the listener follow the progression of the song.;The [STRUCTURE] form gives a clear and repetitive structure to the music.; This makes it easier for the listener to understand the music.;[STRUCTURE] is the chosen form for the music.; This provides a clear and repetitive structure to the song.;The music follows the [STRUCTURE] form.; This makes the song easy to follow for the listener.;The [STRUCTURE] structure is utilized in the music to make the progression of the song easier to understand for the listener.",
"EM1": "The music conveys [EMOTION].; The music has a [EMOTION] feeling.;The music expresses [EMOTION].; The music is filled with [EMOTION].;The music is [EMOTION] in nature.; The music is imbued with [EMOTION].;The music is characterized by [EMOTION].; The music is defined by [EMOTION].;The music radiates [EMOTION].; The music projects [EMOTION].",
"I1_ALL_EM1_Q2": "The [INSTRUMENT]'s jarring sounds create a feeling of tension and unease.;The fast-paced, agitated notes on the [INSTRUMENT] evoke a sense of stress and anxiety.;The [INSTRUMENT]'s discordant harmonies convey a feeling of uneasiness and turmoil.;The brooding, minor-key melody played on the [INSTRUMENT] conveys a sense of distress and frustration.;The [INSTRUMENT]'s dissonant chords amplify the sense of tension and upset.;The slow, deliberate pace of the [INSTRUMENT]'s performance intensifies the feeling of stress and agitation.;The [INSTRUMENT]'s haunting sounds contribute to the overall feeling of tension and unease.",
"I1_ALL_EM1_Q1": "The [INSTRUMENT]'s lively melody evokes feelings of happiness and excitement.;The upbeat, energetic playing on the [INSTRUMENT] creates a sense of elation.;The [INSTRUMENT]'s cheerful harmonies contribute to a feeling of happiness and joy.;The [INSTRUMENT]'s lighthearted melody conveys a sense of excitement and exhilaration.;The [INSTRUMENT]'s cheerful chords amplify the emotion of happiness and elation.;The fast-paced playing on the [INSTRUMENT] intensifies feelings of excitement and joy.;The [INSTRUMENT]'s upbeat sounds produce an overall sense of happiness and elation.",
"I1_ALL_EM1_Q3":"The [INSTRUMENT]'s mournful melody evokes feelings of sadness and depression.;The slow, melancholy playing on the [INSTRUMENT] creates a sense of fatigue and dejection.;The [INSTRUMENT]'s somber harmonies contribute to a feeling of sorrow and depression.;The [INSTRUMENT]'s melancholy melody conveys a sense of sadness and discouragement.;The [INSTRUMENT]'s poignant chords amplify the emotion of sadness and exhaustion.;The slow, deliberate playing on the [INSTRUMENT] intensifies feelings of fatigue and depression.;The [INSTRUMENT]'s sorrowful sounds produce an overall sense of sadness and dejection.",
"I1_ALL_EM1_Q4":"The [INSTRUMENT]'s soothing melody evokes feelings of calmness and relaxation.;The slow, smooth playing on the [INSTRUMENT] creates a sense of serenity and tranquility.;The [INSTRUMENT]'s peaceful harmonies contribute to a feeling of relaxation and calmness.;The [INSTRUMENT]'s relaxed melody conveys a sense of serenity and peace.;The [INSTRUMENT]'s gentle chords amplify the emotion of calmness and tranquility.;The slow, melodic playing on the [INSTRUMENT] intensifies feelings of relaxation and serenity.;The [INSTRUMENT]'s peaceful sounds produce an overall sense of calmness and tranquility.",
"I1_ST1": "The musical feature of [INSTRUMENT] adds a rich and dynamic layer to the [STRUCTURE] structure.;The [STRUCTURE] structure is emphasized by the prominent use of [INSTRUMENT] throughout the piece.;The [STRUCTURE] structure is elevated by the intricate and melodic playing of [INSTRUMENT].;The inclusion of [INSTRUMENT] adds a unique timbre to the [STRUCTURE] structure, enhancing its overall musicality.;[INSTRUMENT] serves as the perfect complement to the simple yet effective [STRUCTURE] structure.;The [STRUCTURE] structure is elevated to new heights through the masterful use of [INSTRUMENT].",
"I1_ST1_A":"The intricate [INSTRUMENT] melodies set a serene tone for this piece, which follows the [STRUCTURE] form of two repeating sections.;The [STRUCTURE] form provides structure to the piece and allows the [INSTRUMENT] to shine.;The fluidity and grace of the [INSTRUMENT] playing are highlighted in this piece, which follows the [STRUCTURE] form. ;The music utlizes repetitive [STRUCTURE] structure and feature [INSTRUMENT] to form music ideas;The music should have repetitive [STRUCTURE] and feature [INSTRUMENT].",
"C2" : "The music has a Jazz-like sound.; The music evokes a Jazz-like feeling.;It has a Jazz flavor.; It gives off a Jazz vibe.;The music carries a Jazz style.; The music resembles Jazz.;The music has a Jazz atmosphere.; The music has a Jazz-infused feeling.;The music has a Jazz touch.; The music feels Jazz-inspired.",
"R1" : "The songs have a heavy feeling.;The songs carry a weighty atmosphere.;The songs are characterized by their heaviness.;The songs have a dense feeling.;The songs convey a sense of weightiness.;The songs have a significant feeling.;The songs have a ponderous quality.;The songs have a heavy-hearted feeling.;The songs have a somber feeling.;The songs have a heavy tone.",
"R2" : "The song is full of life.; The song is energetic.;The song has a lively spirit.; The song is upbeat.;The song is vivacious.; The song is bouncy.;The song is effervescent.; The song is peppy.;The song is sparkling.; The song is high-spirited.",
"S1" : "The music is a symphonic composition.;The music is a classical orchestral work.;The music is a musical symphony.;The music is an extended musical piece.;The music is a full-scale orchestral work.;The music is a large-scale musical composition.;The music is a complex musical composition.;The music is a rich musical piece.;The music is a multi-movement musical composition.;The music is a multi-sectional musical piece.",
"K1_S2" : "This is a song in a [KEY] key in the style of [ARTIST].;This music is composed in the [KEY] key and showcases the distinct style of [ARTIST], creating a unique and memorable listening experience.;With its characteristic rhythms and harmonies, this music captures the essence of [ARTIST]'s signature style within the framework of the [KEY] key.;The use of the [KEY] key imbues this music with a sense of energy and vibrancy that perfectly complements [ARTIST]'s musical style.;This music's fusion of the [KEY] key and [ARTIST]'s unique style results in a fresh and exciting sound that is sure to captivate listeners.;The [KEY] key serves as the perfect foundation for the distinct rhythms and melodies that define [ARTIST]'s musical style in this piece.;This music's skillful incorporation of the [KEY] key and [ARTIST]'s style showcases a seamless fusion of classical and modern elements.;With its evocative harmonies and rhythms, this music captures the essence of [ARTIST]'s musical style in the context of the [KEY] key, creating a truly immersive listening experience.",
"S2" : "The music is in the vein of [ARTIST].;The music is similar to [ARTIST]'s style.;The music mimics [ARTIST]'s style.;The music echoes [ARTIST]'s compositions.;The music is influenced by [ARTIST].;The music follows in [ARTIST]'s footsteps.;The music reflects [ARTIST]'s style.;The music embodies [ARTIST]'s sound.;The music pays homage to [ARTIST].;The music honors [ARTIST]'s style.",
"C1" : "The music begins with a [FEELING_A] mood and then transitions to [FEELING_B].; The music commences with an atmosphere of [FEELING_A], which then shifts to [FEELING_B].; The music sets the tone with a [FEELING_A] vibe, which then evolves into [FEELING_B].; The music opens with a [FEELING_A] emotion and progresses to [FEELING_B].; The music starts off with a [FEELING_A] tone and then moves on to [FEELING_B].;;The music initiates with a [FEELING_A] sensation and later becomes [FEELING_B].; The music launches with a [FEELING_A] mood and later transforms into [FEELING_B].; The music begins with a [FEELING_A] aura and then changes to [FEELING_B].; The music starts with a [FEELING_A] atmosphere and then shifts to [FEELING_B].",
"C1_0" : "This is a song that has a bright feeling from the beginning to the end.;This music's uplifting and bright melodies create a sense of positivity and joy that lasts from beginning to end.;With its lively rhythms and cheerful harmonies, this music captures a sense of brightness and optimism that is both infectious and uplifting.;From the opening notes, this music's bright and buoyant tone sets the stage for an upbeat and energetic listening experience.;This music's optimistic and positive mood creates a sense of enthusiasm and hope that carries throughout the entire piece.;The bright and colorful melodies of this music evoke a sense of joy and vitality that is sure to lift the listener's spirits.;This music's use of vibrant harmonies and energetic rhythms creates a sense of excitement and dynamism that is both infectious and inspiring.;From its sunny beginning to its jubilant conclusion, this music offers a deeply uplifting and joyful listening experience.",
"C1_1" : "This is a song that has a very gloomy feeling from the beginning to the end.;This music maintains a consistently gloomy atmosphere from beginning to end, creating a haunting and emotional listening experience.;With its melancholic melodies and mournful harmonies, this music evokes a sense of sadness and despair throughout.;The dark and brooding tone of this music pervades every note, leaving a lasting impression of gloominess and solemnity.;From the opening chords to the final notes, this music immerses the listener in a world of gloom and despair.;This music's somber and foreboding mood creates a sense of unease and sadness that lasts throughout the entire piece.;The persistent feeling of melancholy and sorrow in this music lends it a powerful emotional impact.;From its bleak beginning to its mournful conclusion, this music offers a deeply introspective and hauntingly beautiful listening experience.",
"S3" : "This is a classical music.; The use of dynamics in classical music creates a dramatic effect.;The structure of classical music often features movements with distinct themes.;I would like to listen to a piece of classical music.",
"P4" : "Its pitch range is within [RANGE] octaves.;The musical piece showcases a pitch range within [RANGE] octaves.;With a pitch range spanning [RANGE] octaves, this music offers a diverse and dynamic listening experience.;The compact pitch range of [RANGE] octaves results in a focused and impactful musical performance.;The music's limited pitch range of [RANGE] octaves allows for a greater emphasis on the nuances of tone and phrasing.;The pitch range of [RANGE] octaves adds a distinctive character to the music, emphasizing its emotional depth.;This music's pitch range of [RANGE] octaves offers a unique and memorable listening experience.;The use of a specific pitch range of [RANGE] octaves creates a cohesive and unified sound throughout the musical piece.",
"K1" : "This music is composed in the [KEY] key.; [KEY] key adds a unique flavor to this music.;This music's use of [KEY] key creates a distinct atmosphere.;[KEY] key gives this music a special emotional quality.;The [KEY] key in this music provides a powerful and memorable sound.;This music's choice of [KEY] key results in a captivating and memorable experience.;With its use of [KEY] key, this music conveys a unique and resonant sound.;This music's use of [KEY] key creates a rich and dynamic sonic palette."
}

Просмотреть файл

@ -0,0 +1,21 @@
from . import pos_process
from . import data
from . import similarity
from .pos_process import get_bar_num_from_sample_tgt_pos
# get_bar_num_from_sample_tgt_pos: Get the number of bars for a sample.
from .remigen_process import (
remove_instrument, count_token_num, count_bar_num, get_bar_ranges,
get_instrument_played, get_instrument_seq,
sample_bars
)
# remove_instrument: Remove an instrument from the remigen sequence.
# count_token_num: Count the number of a specific token in a remigen sequence.
# count_bar_num: Count the number of bars, including the complete bars and a possible incomplete bar.
# get_instrument_played: Get the instrument tokens played in a remigen sequence.
# get_instrument_seq: Get sub-sequence of an instrument.
# sample_bars: Get a certain number of bars.
from .random import seed_everything
from .file_list import generate_file_list, read_file_list

Просмотреть файл

@ -0,0 +1,39 @@
from .magenta_chord_recognition import _PITCH_CLASS_NAMES, _CHORD_KIND_PITCHES, NO_CHORD
_PITCH_CLASS_NAMES_TO_INDEX = dict()
for idx, pitch_class_name in enumerate(_PITCH_CLASS_NAMES):
_PITCH_CLASS_NAMES_TO_INDEX[pitch_class_name] = idx
def get_chord_pitch_indices(chord_label):
if chord_label == NO_CHORD:
return []
pitch_class, chord_kind = chord_label.split(':')
root_index = _PITCH_CLASS_NAMES_TO_INDEX[pitch_class]
return [root_index + offset for offset in _CHORD_KIND_PITCHES[chord_kind]]
def convert_pitch_index_to_token(pitch_index, offset=0, min_pitch=None, max_pitch=None, tag='p'):
p = pitch_index + offset
if min_pitch is not None:
while p < min_pitch:
p += 12
if max_pitch is not None:
while p > max_pitch:
p -= 12
assert p >= 0, (pitch_index, offset, min_pitch, max_pitch)
if min_pitch is not None:
assert p >= min_pitch, (pitch_index, offset, min_pitch, max_pitch)
if max_pitch is not None:
assert p <= max_pitch, (pitch_index, offset, min_pitch, max_pitch)
return '%s-%d' % (tag, p)
def get_chord_pitch_tokens(chord_label, offset=0, min_pitch=None, max_pitch=None, tag='p'):
pitch_indices = get_chord_pitch_indices(chord_label)
return [
convert_pitch_index_to_token(
item, offset=offset, min_pitch=min_pitch, max_pitch=max_pitch, tag=tag
) for item in pitch_indices
]

Просмотреть файл

@ -0,0 +1,37 @@
from copy import deepcopy
def convert_dict_key_to_str_(dict_data):
keys = tuple(dict_data.keys())
for inst_id in keys:
dict_data[str(inst_id)] = dict_data.pop(inst_id)
return dict_data
def convert_dict_key_to_str(dict_data):
dict_data = deepcopy(dict_data)
return convert_dict_key_to_str_(dict_data)
def convert_dict_key_to_int_(dict_data):
keys = tuple(dict_data.keys())
for inst_id in keys:
dict_data[int(inst_id)] = dict_data.pop(inst_id)
return dict_data
def convert_dict_key_to_int(dict_data):
dict_data = deepcopy(dict_data)
return convert_dict_key_to_int_(dict_data)
def convert_dict_key_with_eval_(dict_data):
keys = tuple(dict_data.keys())
for inst_id in keys:
dict_data[eval(inst_id)] = dict_data.pop(inst_id)
return dict_data
def convert_dict_key_with_eval(dict_data):
dict_data = deepcopy(dict_data)
return convert_dict_key_with_eval_(dict_data)

Просмотреть файл

@ -0,0 +1,47 @@
import os
def dump_file_list(file_list, save_path):
dirname = os.path.dirname(save_path)
if dirname != '':
os.makedirs(dirname, exist_ok=True)
with open(save_path, 'w') as f:
for item in file_list:
f.write(item + '\n')
def generate_file_list(dir, suffixes=None, ignore_suffix_case=True, save_path=None):
file_list = []
for root_dir, _, files in os.walk(dir):
for file_name in files:
if suffixes is not None:
skip = True
for sf in suffixes:
if ignore_suffix_case:
sf = sf.lower()
fn = file_name.lower()
else:
fn = file_name
if fn.endswith(sf):
skip = False
break
if skip:
continue
file_path = os.path.join(root_dir, file_name).replace('\\', '/')
file_path = os.path.relpath(file_path, dir)
file_list.append(file_path)
if save_path is not None:
dump_file_list(file_list, save_path)
return file_list
def read_file_list(path):
file_list = []
with open(path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line == '':
continue
file_list.append(line)
return file_list

Просмотреть файл

@ -0,0 +1,367 @@
# Copyright 2021 The Magenta Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Chord inference for NoteSequences."""
import bisect
import itertools
import math
import numbers
from absl import logging
import numpy as np
# Names of pitch classes to use (mostly ignoring spelling).
_PITCH_CLASS_NAMES = [
'C', 'C#', 'D', 'Eb', 'E', 'F', 'F#', 'G', 'Ab', 'A', 'Bb', 'B']
# Pitch classes in a key (rooted at zero).
_KEY_PITCHES = [0, 2, 4, 5, 7, 9, 11]
# Pitch classes in each chord kind (rooted at zero).
_CHORD_KIND_PITCHES = {
'': [0, 4, 7],
'm': [0, 3, 7],
'+': [0, 4, 8],
'dim': [0, 3, 6],
'7': [0, 4, 7, 10],
'maj7': [0, 4, 7, 11],
'm7': [0, 3, 7, 10],
'm7b5': [0, 3, 6, 10],
}
_CHORD_KINDS = _CHORD_KIND_PITCHES.keys()
NO_CHORD = 'N.C.'
# All usable chords, including no-chord.
_CHORDS = [NO_CHORD] + list(
itertools.product(range(12), _CHORD_KINDS))
# All key-chord pairs.
_KEY_CHORDS = list(itertools.product(range(12), _CHORDS))
# Maximum length of chord sequence to infer.
_MAX_NUM_CHORDS = 1000
# MIDI programs that typically sound unpitched.
UNPITCHED_PROGRAMS = (
list(range(96, 104)) + list(range(112, 120)) + list(range(120, 128)))
# Mapping from time signature to number of chords to infer per bar.
_DEFAULT_TIME_SIGNATURE_CHORDS_PER_BAR = {
(2, 2): 1,
(2, 4): 1,
(3, 4): 1,
(4, 4): 2,
(6, 8): 2,
}
def _key_chord_distribution(chord_pitch_out_of_key_prob):
"""Probability distribution over chords for each key."""
num_pitches_in_key = np.zeros([12, len(_CHORDS)], dtype=np.int32)
num_pitches_out_of_key = np.zeros([12, len(_CHORDS)], dtype=np.int32)
# For each key and chord, compute the number of chord notes in the key and the
# number of chord notes outside the key.
for key in range(12):
key_pitches = set((key + offset) % 12 for offset in _KEY_PITCHES)
for i, chord in enumerate(_CHORDS[1:]):
root, kind = chord
chord_pitches = set((root + offset) % 12
for offset in _CHORD_KIND_PITCHES[kind])
num_pitches_in_key[key, i + 1] = len(chord_pitches & key_pitches)
num_pitches_out_of_key[key, i +
1] = len(chord_pitches - key_pitches)
# Compute the probability of each chord under each key, normalizing to sum to
# one for each key.
mat = ((1 - chord_pitch_out_of_key_prob) ** num_pitches_in_key *
chord_pitch_out_of_key_prob ** num_pitches_out_of_key)
mat /= mat.sum(axis=1)[:, np.newaxis]
return mat
def _key_chord_transition_distribution(
key_chord_distribution, key_change_prob, chord_change_prob):
"""Transition distribution between key-chord pairs."""
mat = np.zeros([len(_KEY_CHORDS), len(_KEY_CHORDS)])
for i, key_chord_1 in enumerate(_KEY_CHORDS):
key_1, chord_1 = key_chord_1
chord_index_1 = i % len(_CHORDS)
for j, key_chord_2 in enumerate(_KEY_CHORDS):
key_2, chord_2 = key_chord_2
chord_index_2 = j % len(_CHORDS)
if key_1 != key_2:
# Key change. Chord probability depends only on key and not previous
# chord.
mat[i, j] = (key_change_prob / 11)
mat[i, j] *= key_chord_distribution[key_2, chord_index_2]
else:
# No key change.
mat[i, j] = 1 - key_change_prob
if chord_1 != chord_2:
# Chord probability depends on key, but we have to redistribute the
# probability mass on the previous chord since we know the chord
# changed.
mat[i, j] *= (
chord_change_prob * (
key_chord_distribution[key_2, chord_index_2] +
key_chord_distribution[key_2, chord_index_1] / (len(_CHORDS) -
1)))
else:
# No chord change.
mat[i, j] *= 1 - chord_change_prob
return mat
def _chord_pitch_vectors():
"""Unit vectors over pitch classes for all chords."""
x = np.zeros([len(_CHORDS), 12])
for i, chord in enumerate(_CHORDS[1:]):
root, kind = chord
for offset in _CHORD_KIND_PITCHES[kind]:
x[i + 1, (root + offset) % 12] = 1
x[1:, :] /= np.linalg.norm(x[1:, :], axis=1)[:, np.newaxis]
return x
def sequence_note_pitch_vectors(sequence, seconds_per_frame):
"""Compute pitch class vectors for temporal frames across a sequence.
Args:
sequence: The NoteSequence for which to compute pitch class vectors.
seconds_per_frame: The size of the frame corresponding to each pitch class
vector, in seconds. Alternatively, a list of frame boundary times in
seconds (not including initial start time and final end time).
Returns:
A numpy array with shape `[num_frames, 12]` where each row is a unit-
normalized pitch class vector for the corresponding frame in `sequence`.
"""
frame_boundaries = sorted(seconds_per_frame)
num_frames = len(frame_boundaries) + 1
x = np.zeros([num_frames, 12])
for note in sequence:
# if note.is_drum:
# continue
# if note.program in UNPITCHED_PROGRAMS:
# continue
start_frame = bisect.bisect_right(frame_boundaries, note.start)
end_frame = bisect.bisect_left(frame_boundaries, note.end)
pitch_class = note.pitch % 12
if start_frame >= end_frame:
x[start_frame, pitch_class] += note.end - note.start
else:
x[start_frame, pitch_class] += (
frame_boundaries[start_frame] - note.start)
for frame in range(start_frame + 1, end_frame):
x[frame, pitch_class] += (
frame_boundaries[frame] - frame_boundaries[frame - 1])
x[end_frame, pitch_class] += (
note.end - frame_boundaries[end_frame - 1])
x_norm = np.linalg.norm(x, axis=1)
nonzero_frames = x_norm > 0
x[nonzero_frames, :] /= x_norm[nonzero_frames, np.newaxis]
return x
def _chord_frame_log_likelihood(note_pitch_vectors, chord_note_concentration):
"""Log-likelihood of observing each frame of note pitches under each chord."""
return chord_note_concentration * np.dot(note_pitch_vectors,
_chord_pitch_vectors().T)
def _key_chord_viterbi(chord_frame_loglik,
key_chord_loglik,
key_chord_transition_loglik):
"""Use the Viterbi algorithm to infer a sequence of key-chord pairs."""
num_frames, num_chords = chord_frame_loglik.shape
num_key_chords = len(key_chord_transition_loglik)
loglik_matrix = np.zeros([num_frames, num_key_chords])
path_matrix = np.zeros([num_frames, num_key_chords], dtype=np.int32)
# Initialize with a uniform distribution over keys.
for i, key_chord in enumerate(_KEY_CHORDS):
key, unused_chord = key_chord
chord_index = i % len(_CHORDS)
loglik_matrix[0, i] = (
-np.log(12) + key_chord_loglik[key, chord_index] +
chord_frame_loglik[0, chord_index])
for frame in range(1, num_frames):
# At each frame, store the log-likelihood of the best sequence ending in
# each key-chord pair, along with the index of the parent key-chord pair
# from the previous frame.
mat = (np.tile(loglik_matrix[frame - 1][:, np.newaxis],
[1, num_key_chords]) +
key_chord_transition_loglik)
path_matrix[frame, :] = mat.argmax(axis=0)
loglik_matrix[frame, :] = (
mat[path_matrix[frame, :], range(num_key_chords)] +
np.tile(chord_frame_loglik[frame], 12))
# Reconstruct the most likely sequence of key-chord pairs.
path = [np.argmax(loglik_matrix[-1])]
for frame in range(num_frames, 1, -1):
path.append(path_matrix[frame - 1, path[-1]])
return [(index // num_chords, _CHORDS[index % num_chords])
for index in path[::-1]]
class ChordInferenceError(Exception): # pylint:disable=g-bad-exception-name
pass
class SequenceAlreadyHasChordsError(ChordInferenceError):
pass
class UncommonTimeSignatureError(ChordInferenceError):
pass
class NonIntegerStepsPerChordError(ChordInferenceError):
pass
class EmptySequenceError(ChordInferenceError):
pass
class SequenceTooLongError(ChordInferenceError):
pass
def infer_chords_for_sequence(sequence,
pos_per_chord,
max_chords,
key_chord_loglik=None,
key_chord_transition_loglik=None,
key_change_prob=0.001,
chord_change_prob=0.5,
chord_pitch_out_of_key_prob=0.01,
chord_note_concentration=100.0,
add_key_signatures=False):
"""Infer chords for a NoteSequence using the Viterbi algorithm.
This uses some heuristics to infer chords for a quantized NoteSequence. At
each chord position a key and chord will be inferred, and the chords will be
added (as text annotations) to the sequence.
If the sequence is quantized relative to meter, a fixed number of chords per
bar will be inferred. Otherwise, the sequence is expected to have beat
annotations and one chord will be inferred per beat.
Args:
sequence: The NoteSequence for which to infer chords. This NoteSequence will
be modified in place.
key_change_prob: Probability of a key change between two adjacent frames.
chord_change_prob: Probability of a chord change between two adjacent
frames.
chord_pitch_out_of_key_prob: Probability of a pitch in a chord not belonging
to the current key.
chord_note_concentration: Concentration parameter for the distribution of
observed pitches played over a chord. At zero, all pitches are equally
likely. As concentration increases, observed pitches must match the
chord pitches more closely.
add_key_signatures: If True, also add inferred key signatures to
`quantized_sequence` (and remove any existing key signatures).
Raises:
SequenceAlreadyHasChordsError: If `sequence` already has chords.
QuantizationStatusError: If `sequence` is not quantized relative to
meter but `chords_per_bar` is specified or no beat annotations are
present.
UncommonTimeSignatureError: If `chords_per_bar` is not specified and
`sequence` is quantized and has an uncommon time signature.
NonIntegerStepsPerChordError: If the number of quantized steps per chord
is not an integer.
EmptySequenceError: If `sequence` is empty.
SequenceTooLongError: If the number of chords to be inferred is too
large.
"""
beats = [pos_per_chord * i for i in range(1, max_chords)]
if len(beats) == 0:
raise Exception('max chords should > 0')
num_chords = len(beats)
if num_chords > _MAX_NUM_CHORDS:
raise Exception(
'NoteSequence too long for chord inference: %d frames' % num_chords)
# Compute pitch vectors for each chord frame, then compute log-likelihood of
# observing those pitch vectors under each possible chord.
note_pitch_vectors = sequence_note_pitch_vectors(
sequence,
beats)
chord_frame_loglik = _chord_frame_log_likelihood(
note_pitch_vectors, chord_note_concentration)
# Compute distribution over chords for each key, and transition distribution
# between key-chord pairs.
if key_chord_loglik is None:
key_chord_distribution = _key_chord_distribution(
chord_pitch_out_of_key_prob=chord_pitch_out_of_key_prob)
key_chord_loglik = np.log(key_chord_distribution)
if key_chord_transition_loglik is None:
key_chord_transition_distribution = _key_chord_transition_distribution(
key_chord_distribution,
key_change_prob=key_change_prob,
chord_change_prob=chord_change_prob)
key_chord_transition_loglik = np.log(key_chord_transition_distribution)
key_chords = _key_chord_viterbi(
chord_frame_loglik, key_chord_loglik, key_chord_transition_loglik)
# if add_key_signatures:
# del sequence.key_signatures[:]
# Add the inferred chord changes to the sequence, optionally adding key
# signature(s) as well.
# current_key_name = None
# current_chord_name = None
chords = []
for frame, (key, chord) in enumerate(key_chords):
# time = beats[frame]
# if _PITCH_CLASS_NAMES[key] != current_key_name:
# # A key change was inferred.
# if add_key_signatures:
# ks = sequence.key_signatures.add()
# ks.time = time
# ks.key = key
# else:
# if current_key_name is not None:
# logging.info(
# 'Sequence has key change from %s to %s at %f seconds.',
# current_key_name, _PITCH_CLASS_NAMES[key], time)
#
# current_key_name = _PITCH_CLASS_NAMES[key]
if chord == NO_CHORD:
figure = NO_CHORD
else:
root, kind = chord
figure = '%s:%s' % (_PITCH_CLASS_NAMES[root], kind)
chords.append(figure)
return chords
# if figure != current_chord_name:
# ta = sequence.text_annotations.add()
# ta.time = time
# ta.quantized_step = 0 if frame == 0 else sorted_beat_steps[frame - 1]
# ta.text = figure
# current_chord_name = figure

Просмотреть файл

@ -0,0 +1,24 @@
def judge_melody_track(pos_info):
pitch_record = {}
for pos_item in pos_info:
insts_notes = pos_item[-1]
if insts_notes is None:
continue
for inst_id in insts_notes:
if inst_id not in pitch_record:
pitch_record[inst_id] = (0, 0)
for pitch, dur, vel in insts_notes[inst_id]:
sum_pitch, num_notes = pitch_record[inst_id]
pitch_record[inst_id] = (sum_pitch + pitch, num_notes + 1)
if 128 in pitch_record:
pitch_record.pop(128)
items = sorted(pitch_record.items(), key=lambda x: x[1][0] / x[1][1], reverse=True)
num_beats = len(pos_info) / 12
if items[0][1][1] / num_beats >= 0.5:
return items[0][0], items
return None, items

Просмотреть файл

@ -0,0 +1,88 @@
from typing import List, Tuple
from copy import deepcopy
import msgpack
def get_bar_num_from_sample_tgt_pos(sample_tgt_pos: List[List[Tuple[List[int], List[int]]]]) -> int:
"""
Get the number of bars for a sample.
Args:
sample_tgt_pos: The target pos information for a sample.
If the output of the above pos_preprocess func is denoted as pos,
then this argument is pos[sample_idx][1]
Returns:
"""
num = 0
for _, seg_bars in sample_tgt_pos:
num += len(seg_bars)
return num
def fill_pos_ts_and_tempo_(pos_info):
cur_ts = pos_info[0][1]
cur_tempo = pos_info[0][3]
assert cur_ts is not None
assert cur_tempo is not None
for idx in range(len(pos_info)):
pos_item = pos_info[idx]
if pos_item[1] is not None:
cur_ts = pos_item[1]
if pos_item[3] is not None:
cur_tempo = pos_item[3]
if pos_item[2] == 0:
if pos_item[1] is None:
pos_item[1] = cur_ts
if pos_item[3] is None:
pos_item[3] = cur_tempo
return pos_info
def string_pos_info_inst_id_(pos_info):
for pos_item in pos_info:
insts_notes = pos_item[-1]
if insts_notes is None:
continue
inst_ids = tuple(insts_notes.keys())
for inst_id in inst_ids:
insts_notes[str(inst_id)] = insts_notes.pop(inst_id)
return pos_info
def string_pos_info_inst_id(pos_info):
pos_info = deepcopy(pos_info)
return string_pos_info_inst_id_(pos_info)
def destring_pos_info_inst_id_(pos_info):
for pos_item in pos_info:
insts_notes = pos_item[-1]
if insts_notes is None:
continue
inst_ids = tuple(insts_notes.keys())
for inst_id in inst_ids:
insts_notes[int(inst_id)] = insts_notes.pop(inst_id)
return pos_info
def destring_pos_info_inst_id(pos_info):
pos_info = deepcopy(pos_info)
return destring_pos_info_inst_id_(pos_info)
def serialize_pos_info(pos_info, need_string_inst_id=True):
if need_string_inst_id:
pos_info = string_pos_info_inst_id(pos_info)
return msgpack.dumps(pos_info)
def deserialize_pos_info(pos_info, need_destring_inst_id=True):
pos_info = msgpack.loads(pos_info)
for pos_item in pos_info:
ts = pos_item[1]
if ts is not None:
pos_item[1] = tuple(ts)
if need_destring_inst_id:
pos_info = destring_pos_info_inst_id_(pos_info)
return pos_info

Просмотреть файл

@ -0,0 +1,7 @@
import numpy as np
import random
def seed_everything(seed=42):
np.random.seed(seed)
random.seed(seed)

Просмотреть файл

@ -0,0 +1,156 @@
import random
from typing import List, Tuple, Union
def remove_instrument(remigen_tokens: List[str], instrument_token: str) -> List[str]:
"""
Remove an instrument from the remigen sequence.
Note that the empty onset is not deleted since it does not matter in most cases,
and the tempo token may still carry valuable information.
Args:
remigen_tokens: remigen sequence
instrument_token: the instrument token, e.g. "i-128"
Returns:
remigen sequence without the designated instrument
"""
tokens = []
# delete relevant part
ignore_mode = False
for token in remigen_tokens:
if token.startswith('s'):
raise NotImplementedError("This function is not implemented to support time signature token.")
if ignore_mode:
if any([token.startswith(pre) for pre in ('o', 't', 'b')]) or \
(token.startswith('i') and token != instrument_token):
ignore_mode = False
else:
if token == instrument_token:
ignore_mode = True
if not ignore_mode:
tokens.append(token)
return tokens
def count_token_num(remigen_seq: List[str], token: str, return_indices=False) -> Union[int, Tuple[int, list]]:
"""
Count the number of a specific token in a remigen sequence.
Args:
remigen_seq: remigen sequence
token: token str
return_indices: list, containing the indices of the tokens
Returns:
the number of the appearance of the token.
"""
num = 0
indices = []
for idx, t in enumerate(remigen_seq):
if token == t:
num += 1
indices.append(idx)
if return_indices:
return num, indices
return num
def count_bar_num(
remigen_seq: List[str], bar_token: str = 'b-1', return_bar_token_indices=False
) -> Union[Tuple[int, int], Tuple[int, List, int]]:
"""
Count the number of bars, including the complete bars and a possible incomplete bar.
Args:
remigen_seq: remigen sequence
bar_token: bar token string
return_bar_token_indices: bool
Returns:
num_of_complete_bars: the number of complete bars.
num_of_incomplete_bars: the number of incomplete bar (0 or 1).
If the sequence does not end with 'b-1', it is regarded as an incomplete bar.
"""
result = count_token_num(remigen_seq, bar_token, return_indices=return_bar_token_indices)
if remigen_seq[-1] != bar_token:
num_incomplete_bar = 1
else:
num_incomplete_bar = 0
if return_bar_token_indices:
return result + (num_incomplete_bar,)
return result, num_incomplete_bar
def get_bar_ranges(remigen_seq: List[str], bar_token: str = 'b-1'):
_, bar_token_indices, num_incomplete_bar = count_bar_num(
remigen_seq, bar_token=bar_token, return_bar_token_indices=True
)
complete_bar_result = []
in_complete_bar_result = []
begin = 0
for end_index in bar_token_indices:
complete_bar_result.append((begin, end_index + 1))
begin = end_index + 1
if num_incomplete_bar > 0:
in_complete_bar_result.append((begin, len(remigen_seq)))
return complete_bar_result, in_complete_bar_result
def get_instrument_played(remi_seq:List[str], max_num=None) -> List[str]:
ret = set()
for ev in remi_seq:
if ev[0] == "i":
ret.add(ev)
if len(ret) == max_num:
break
ret = list(ret)
return ret
def get_instrument_seq(remi_seq: List[str], instru_id: int) -> List[str]:
chosen_events = []
cur_pos = ""
pos_tempo = []
cur_instru = -1
has_pushed = 0
for i, note in enumerate(remi_seq):
if note[0] == "o":
cur_pos = note
has_pushed = 0
elif note[0] == 'b':
chosen_events.append(note)
elif note[0] == "t": # 在REMIGEN2里有bug因为t只在bar的开头出现
pos_tempo = [cur_pos, note]
has_pushed = 0
elif note[0] == "i":
cur_instru = eval(note[2:])
elif note[0] == "p":
if cur_instru == instru_id:
if not has_pushed:
chosen_events.extend(pos_tempo)
chosen_events.append(f"i-{instru_id}")
has_pushed = 1
chosen_events.extend(remi_seq[i:i+3])
return chosen_events
def sample_bars(input_seq: List[str], num_sampled_bars) -> Tuple[List[str], int]:
assert num_sampled_bars > 0
bar_ranges, incomplete_bar_range = get_bar_ranges(input_seq, bar_token='b-1')
assert len(incomplete_bar_range) == 0
num_bars = len(bar_ranges)
assert num_bars > 0
num_sampled_bars = min(num_bars, num_sampled_bars)
sampled_begin = random.randint(0, num_bars - num_sampled_bars)
start = bar_ranges[sampled_begin][0]
end = bar_ranges[sampled_begin + num_sampled_bars - 1][1]
return input_seq[start: end], num_sampled_bars

Просмотреть файл

@ -0,0 +1,188 @@
import os
from multiprocessing import Pool
from functools import partial
import numpy as np
from .data import convert_dict_key_to_str, convert_dict_key_to_int
def cal_bar_similarity_basic(bar1_insts_poses_notes, bar2_insts_poses_notes, bar1_insts, bar2_insts,
ignore_pitch=False, ignore_duration_for_drum=True):
o_sim = {}
for inst in ((bar1_insts - bar2_insts) | (bar2_insts - bar1_insts)):
o_sim[inst] = 'O'
for inst in (bar1_insts & bar2_insts):
inst_bar1 = bar1_insts_poses_notes[inst]
inst_bar2 = bar2_insts_poses_notes[inst]
# inst_bar1_pos = set(inst_bar1.keys())
# inst_bar2_pos = set(inst_bar2.keys())
# num_union_pos = len(inst_bar1_pos | inst_bar2_pos)
# inter_pos = inst_bar1_pos & inst_bar2_pos
if inst == 128:
if ignore_pitch and ignore_duration_for_drum:
raise ValueError("only_duration and ignore_duration_for_drum cannot be True at the same time for drum.")
inst_bar1_note = set()
for pos in inst_bar1:
temp_pos_notes = inst_bar1[pos]
for note in temp_pos_notes:
info_tuple = [pos]
if not ignore_pitch:
info_tuple.append(note[0])
if not ignore_duration_for_drum or inst != 128:
info_tuple.append(note[1])
info_tuple = tuple(info_tuple)
inst_bar1_note.add(info_tuple)
inst_bar2_note = set()
for pos in inst_bar2:
temp_pos_notes = inst_bar2[pos]
for note in temp_pos_notes:
info_tuple = [pos]
if not ignore_pitch:
info_tuple.append(note[0])
if not ignore_duration_for_drum or inst != 128:
info_tuple.append(note[1])
info_tuple = tuple(info_tuple)
inst_bar2_note.add(info_tuple)
num_common_notes = len(inst_bar1_note & inst_bar2_note)
s = num_common_notes / len(inst_bar1_note | inst_bar2_note) if num_common_notes > 0 else 'O'
o_sim[inst] = s
return o_sim
def generate_bar_insts_pos_index(bar):
r = {}
for idx, item in enumerate(bar):
notes = item[-1]
if notes is not None:
for inst in notes:
if inst not in r:
r[inst] = {}
r[inst][idx] = notes[inst]
return r
def construct_bars_info(pos_info, bars_positions):
bars_note_info = []
bars_ts_info = []
num_bars = len(bars_positions)
for bar_idx in range(num_bars):
begin, end = bars_positions[bar_idx]
ts = pos_info[begin][1]
assert ts is not None
bars_ts_info.append(ts)
r = generate_bar_insts_pos_index(pos_info[begin: end])
bars_note_info.append(r)
return bars_note_info, bars_ts_info
def cal_for_bar_i_and_j(bar_indices, bars_insts, bars_note_info, bars_ts_info):
i, j = bar_indices
bar_i_ts = bars_ts_info[i]
bar_j_ts = bars_ts_info[j]
if bar_i_ts != bar_j_ts:
return None
r = cal_bar_similarity_basic(
bars_note_info[i], bars_note_info[j], set(bars_insts[i]), set(bars_insts[j])
)
return r
def cal_song_similarity(pos_info, bars_positions, bars_insts, use_multiprocess=True, use_sparse_format=True):
num_bars = len(bars_positions)
bars_note_info, bars_ts_info = construct_bars_info(pos_info, bars_positions)
all_insts = set()
for bar_insts in bars_insts:
for inst_id in bar_insts:
all_insts.add(inst_id)
inputs = []
for i in range(num_bars):
for j in range(i):
inputs.append((i, j))
r = {}
for inst_id in all_insts:
r[inst_id] = {}
for i in range(num_bars):
for j in range(i):
r[inst_id][(i, j)] = None
# r: inst_id: {(0, 1): value / None}
if use_multiprocess:
with Pool(min(os.cpu_count(), len(inputs))) as pool:
iterator = iter(
pool.imap(
partial(
cal_for_bar_i_and_j,
bars_insts=bars_insts,
bars_note_info=bars_note_info,
bars_ts_info=bars_ts_info,
),
inputs
)
)
for i, j in inputs:
ij_r = next(iterator)
if ij_r is not None:
for inst_id in ij_r:
r[inst_id][(i, j)] = ij_r[inst_id]
else:
for i, j in inputs:
ij_r = cal_for_bar_i_and_j(
(i, j), bars_insts=bars_insts, bars_note_info=bars_note_info, bars_ts_info=bars_ts_info
)
if ij_r is not None:
for inst_id in ij_r:
r[inst_id][(i, j)] = ij_r[inst_id]
if use_sparse_format:
r = compress_value(r, num_bars)
return r
def compress_value(data, num_bars):
for inst_id in data:
record = data[inst_id]
new_record = []
for i in range(num_bars):
for j in range(i):
v = record[(i, j)]
if v == 'O':
continue
new_record.append((i, j, v))
data[inst_id] = new_record
return data
def convert_sparse_to_numpy(value, num_bars, ignore_none=True):
r = {}
for inst_id in value:
record = value[inst_id]
tensor = np.zeros((num_bars, num_bars))
for i, j, s in record:
if s is None and ignore_none:
continue
tensor[i, j] = s
tensor[j, i] = s
for i in range(num_bars):
tensor[i, i] = 1.0
r[inst_id] = tensor
return r
def repr_value(value):
return convert_dict_key_to_str(value)
def derepr_value(value):
return convert_dict_key_to_int(value)

Просмотреть файл

@ -0,0 +1,296 @@
import json
import random
import os
from music21 import pitch
from .const import inst_class_id_to_name_2
# from const import inst_class_id_to_name
_inst_class_id_to_name = inst_class_id_to_name_2
inst_label = "I1s1"
inst_c_label = "I2s1"
metas = ["B1", "TS1", "T1", "P1", "P2", "P3", "C2", "R1", "R2", "S1", "S3", "P4", "TS1s1","B1s1","K1"]
label_to_template = {
inst_label:"[INSTRUMENTS]", inst_c_label : "[INSTRUMENT]",
"B1": "[NUM_BARS]", "B1s1" : "[NUM_BARS]", "TS1": "[TIME_SIGNATURE]", "TS1s1" :"[TIME_SIGNATURE]",
"T1": "[TEMPO]", "P1": "[LOW_PITCH]", "P2": "[HIGH_PITCH]", "P3": "[NUM_PITCH_CLASS]",
"ST1": "[STRUCTURE]", "EM1": "[EMOTION]","C2":"[]", "R1":"[]", "R2":"[]", "S1":"[]","S3":"[]",
"K1": "[KEY]", "S2":"[ARTIST]", "P4":"[RANGE]"
}
def remove_digit(t):
r = []
for letter in t:
if letter not in '0123456789':
r.append(letter)
r = ''.join(r)
return r
class Verbalizer(object):
def __init__(self):
pass
def instr_to_str(self, instr):
res = []
for _instr in instr:
res.append(_inst_class_id_to_name[_instr])
return self.concat_str(res)
def attribute_to_str(self, attribute_values):
_attribute_values = {}
for att in attribute_values:
v = attribute_values[att]
if v is None or v == (None, None):
continue
if (att == "C2" or att == "R1" or att == "R2" or att == "S1" or att == "S3") and v == 1:
_attribute_values[att] = "True"
if att == "B1" or att == "P4":
_attribute_values[att] = str(v)
if att == "B1s1":
bar_range = ["1-4", "5-8", "9-12", "13-16"]
_attribute_values[att] = bar_range[v[1]]
if att == "P1" or att == "P2":
_attribute_values[att] = pitch.Pitch(midi=v).nameWithOctave
if att == "TS1":
_attribute_values[att] = f"{v[0]}/{v[1]}"
if att == "T1":
_attribute_values[att] = v[1]
if att == "P3":
# attribute_values[att] = str(v)[1:-1]
_attribute_values[att] = str(len(v))
if att == inst_label:
_attribute_values[att] = self.instr_to_str(v)
if att == inst_c_label:
if v[1] is not None:
_attribute_values[att] = (v[0], self.instr_to_str(v[1]))
if v[2] is not None:
_attribute_values[att] = (v[0], self.instr_to_str(v[2]))
if att == "ST1":
_attribute_values[att] = remove_digit(v)
if att == "EM1" or att == "K1" or att == "S2" or att == "C1":
_attribute_values[att] = v
if att == "TS1s1":
if len(v) == 2:
_attribute_values[att] = f"{v[0]}/{v[1]}"
return _attribute_values
def concat_str(self, str_list):
str_list = [s for s in str_list if s != ""]
if len(str_list) == 0:
return ""
if len(str_list) == 1:
return str_list[0]
res = str_list[0]
if res[-1] == '.':
res = res[:-1]
for i in range(1, len(str_list) - 1):
if str_list[i][-1] == ".":
mid = str_list[i][1:-1]
else:
mid = str_list[i][1:]
res += f", {str_list[i][0].lower()}{mid}"
res += f" and {str_list[-1][0].lower()}{str_list[-1][1:]}"
return res
def emotion_to_str(self, emo):
_emo = {
"Q1": ["happiness", "excitement"],
"Q2": ["tension ", " unease"],
"Q3": ["sadness ", " depression"],
"Q4": ["calmness ", " relaxation"]
}
_ = _emo[emo]
return random.choice(_)
def feeling_to_str(self, feel):
_feel = {
"F1": ["Bright"],
"F2": ["Gloomy"]
}
_ = _feel[feel]
return random.choice(_)
def get_text(self, attribute_values):
"""
返回所有文本的list
:param attribute_values: dict, key是attribute标号value是attribute_unit的extract函数返回的信息
:return:
"""
# TODO 单复数模板
# TODO 专有名词大小写
template = json.load(open(os.path.join(os.path.dirname(__file__), "template.txt"), "r"))
for label in template:
template[label] = template[label].split(";")
for i, _ in enumerate(template[label]):
template[label][i] = _.strip()
instr_all = attribute_values[inst_label]
attribute_values = self.attribute_to_str(attribute_values)
meta_info = [meta for meta in metas if meta in attribute_values]
meta_str_with_info = []
random_meta = 10
max_meta = 4
for i in range(random_meta):
_meta_info = random.sample(meta_info, random.randint(1, min(max_meta, len(meta_info))))
meta_strs = []
for meta in _meta_info:
meta_temp = random.choice(template[meta])
try:
meta_strs.append(meta_temp.replace(label_to_template[meta], attribute_values[meta]))
except:
print(meta)
raise
concat_meta_strs = self.concat_str(meta_strs)
meta_str_with_info.append((_meta_info, concat_meta_strs))
core_strs = []
if "EM1" in attribute_values:
v = attribute_values["EM1"]
for temp in template["I1_ALL_EM1_" + v]:
core_strs.append((["EM1", "I1_ALL"], temp.replace("[INSTRUMENT]", attribute_values[inst_label])))
if len(instr_all) > 1:
instr = random.sample(instr_all, 1)
instr_name = self.instr_to_str(instr)
for temp_em in template["EM1"]:
for temp_i in template[inst_label]:
temp_em = temp_em.replace("[EMOTION]", self.emotion_to_str(v))
temp_i = temp_i.replace("[INSTRUMENTS]", instr_name)
core_strs.append((["EM1", inst_label, instr], self.concat_str([temp_em, temp_i])))
if "ST1" in attribute_values:
v = attribute_values["ST1"]
if v == "AA" or v == "AAAA":
_id = "I1_ST1_A"
else:
_id = "I1_ST1"
for temp in template[_id]:
_temp = temp.replace("[INSTRUMENT]", attribute_values[inst_label])
core_strs.append((["ST1", "I1_ALL"], _temp.replace("[STRUCTURE]", v)))
if "I1s1" in attribute_values:
for temp in template[inst_label]:
_temp = temp.replace("[INSTRUMENTS]", attribute_values[inst_label])
core_strs.append((["I1_ALL"], _temp))
mid_strs = []
# instr_var = "
if "K1" in attribute_values and "S2" in attribute_values:
K1_S2_temp = template["K1_S2"]
K1_S2_temp = random.choice(K1_S2_temp)
K1_S2_temp = K1_S2_temp.replace(label_to_template["K1"], attribute_values["K1"])
K1_S2_temp = K1_S2_temp.replace(label_to_template["S2"], attribute_values["S2"])
mid_strs.append(["K1_S2",K1_S2_temp])
if "C1" in attribute_values:
v = attribute_values["C1"]
if v == 0 or v == 1:
temp = random.choice(template[f"C1_{v}"])
elif v == 2:
temp = random.choice(template["C1"])
temp = temp.replace("[FEELING_A]", self.feeling_to_str("F1"))
temp = temp.replace("[FEELING_B]", self.feeling_to_str("F2"))
elif v == 3:
temp = random.choice(template["C1"])
temp = temp.replace("[FEELING_A]", self.feeling_to_str("F2"))
temp = temp.replace("[FEELING_B]", self.feeling_to_str("F1"))
mid_strs.append(["C1", temp])
if "S2" in attribute_values and "K1" not in attribute_values:
temp = random.choice(template["S2"])
S2_temp = temp.replace(label_to_template["S2"], attribute_values["S2"])
mid_strs.append(["S2", S2_temp])
if inst_c_label in attribute_values:
if attribute_values[inst_c_label][0] == "inc":
instr_var = "[INSTRUMENT] is added in the middle."
else:
instr_var = "[INSTRUMENT] is removed in the middle."
instr_var = instr_var.replace("[INSTRUMENT]", attribute_values[inst_c_label][1])
instr_var = instr_var[0].upper() + instr_var[1:]
mid_strs.append([inst_c_label, instr_var])
# mid_strs = random.choices(mid_strs, k=min(3, len(mid_strs)))
mid_attr = [att[0] for att in mid_strs]
mid_strs = [s[1] for s in mid_strs]
mid_str = self.concat_str(mid_strs)
core_strs = random.sample(core_strs, 4)
meta_str_with_info = random.sample(meta_str_with_info, 4)
res_str = []
used_attr = []
for i in range(len(core_strs)):
for j in range(len(meta_str_with_info)):
attr = {}
for _meta in meta_str_with_info[j][0]:
attr[_meta] = 1
for _i, _meta in enumerate(core_strs[i][0]):
if _meta == "I1_ALL":
attr[inst_label] = instr_all
break
if _meta == inst_label:
attr[inst_label] = tuple(core_strs[i][0][_i + 1])
break
attr[_meta] = 1
for att in mid_attr:
if att == "K1_S2":
attr["K1"] = 1; attr["S2"] = 1
else:
attr[att] = 1
if mid_str:
s = " ".join([core_strs[i][1], mid_str, meta_str_with_info[j][1]])
else:
s = " ".join([core_strs[i][1], meta_str_with_info[j][1]])
res_str.append(s)
used_attr.append(attr)
return res_str, used_attr
if __name__ == "__main__":
verbalizer = Verbalizer()
test_attr = {inst_label: set([0, 1]), inst_c_label: ("inc", None, set([3, 4]), 4), "ST1": None, "B1": 16, "TS1": (3, 4),
"T1": (108.0, "Moderato"), "P1": 50, "P2": 70, "P3": set([1, 2, 3, 4]), "EM1": "Q1"}
test_none_attr = {inst_label: set([0, 1]), inst_c_label: None, "ST1": None, "B1": 16, "TS1": None,
"T1": (None, None), "P1": 50, "P2": 70, "P3": set([1, 2, 3, 4]), "EM1": "Q1"}
no_emo_and_st_attr = {inst_label: set([0, 1]), inst_c_label: None, "ST1": None, "B1": 16, "TS1": None,
"T1": (None, None), "P1": 50, "P2": 70, "P3": set([1, 2, 3, 4]), "EM1": None}
test_attr_2 = {inst_label: {16}, inst_c_label: ('inc', {16}, None, 1), 'B1': 16, 'TS1': (4, 4), 'T1': (None, None), 'P1': None, 'P2': None, 'P3': None, 'ST1': None, 'EM1': 'Q3'}
v2_test = {inst_label: set([0, 1]), inst_c_label: ("inc", None, set([3, 4]), 4), "ST1": None, "B1": 16, "TS1": (3, 4),
"T1": (108.0, "Moderato"), "P1": 50, "P2": 70, "P3": set([1, 2, 3, 4]), "EM1": "Q1",
"C1": 0, "C2": 1, "R1": 1, "R2": 1, "S1":1, "S2":None, "K1": "major",
"P4": 3, "S3":1 ,"I1s1":{0,1}, "I2s1": ("inc", None, set([3, 4]), 4) }
v2_test_P4 = {inst_label: set([0, 1]), inst_c_label: ("inc", None, set([3, 4]), 4), "ST1": None, "B1": None, "TS1": None,
"T1": (None, None), "P3": None, "EM1": "Q1",
"P4": 3, "S3":1 ,"K1": "major"}
v2_test_B1s1_TS1s1 = {inst_label: set([0, 1]), inst_c_label: None, "ST1": None,
"T1": (None, None), "P3": None, "EM1": "Q1",
"C1": None, "C2": None, "R1": None, "R2": None, "S1":None, "S2": None, "K1": None, "B1s1":(16, 3), "TS1s1":(3, 4)}
test_none_attr = {inst_label: set([0, 1]), inst_c_label: None, "ST1": None, "B1": 16, "TS1": None,
"T1": (None, None), "P1": 50, "P2": 70, "P3": None, "EM1": "Q1",
"C1": None, "C2": None, "R1": None, "R2": None, "S1":None, "S2": None, "K1": None}
v2_test_2 = {'I1s1': {0, 3, 4, 13}, 'I2s1': ('inc', {3, 4, 13}, None, 14), 'C1': 2, 'R2': False, 'S1': None, 'S2': None, 'S3': None, 'B1s1': (16, 3), 'TS1s1': (4, 4), 'K1': 'minor', 'T1': (120.81591202325676, 'Allegro'), 'P3': {0, 2, 4, 5, 7, 9, 11}, 'P4': 5, 'ST1': None, 'EM1': 'Q2'}
v2_test_3 = {'I1s1': {0}, 'I2s1': ('inc', {0}, None, 2), 'C1': 3, 'R2': False, 'S1': None, 'S2': None, 'S3': None, 'B1s1': (16, 3), 'TS1s1': (4, 4), 'K1': 'minor', 'T1': (120.81591202325676, 'Allegro'), 'P3': {0, 2, 4, 5, 7, 9, 11}, 'P4': 3, 'ST1': None, 'EM1': 'Q4'}
print("=====Start Test====")
verbalizer.get_text(test_attr)
verbalizer.get_text(test_attr_2)
verbalizer.get_text(test_none_attr)
verbalizer.get_text(no_emo_and_st_attr)
verbalizer.get_text(v2_test)
verbalizer.get_text(v2_test_P4)
verbalizer.get_text(v2_test_B1s1_TS1s1)
verbalizer.get_text(v2_test_2)
verbalizer.get_text(v2_test_3)
verbalizer.get_text(test_none_attr)
print("=====Passed====")

Просмотреть файл

@ -73,23 +73,17 @@ The checkpoint of the fine-tuned model and `num_labels.json` are obtained.
## II. Attribute-to-Music Generation
### 1. Data processing
Switch to the `2-attribute2music_dataprepare` folder, and set `midi_data_extractor_path` in `config.py` to the path that contains `midi_data_extractor`.
Then, in `data_tool` folder, run the following command to obtain the packed data.
Switch to the `2-attribute2music_dataprepare` folder. Then, run the following command to obtain the packed data.
```bash
python extract_data.py path/to/the/folder/containing/midi/files path/to/save/the/dataset
```
**Note:** The tool can only automatically extract the objective attributes' values from MIDI files. If you want to insert values for the subjective attributes' values, please input it manually at L40-L42 in `extract_data.py`.
Prepare `Token.bin, Token_index.json, RID.bin, RID_index.json` in folder `data/`. Then run the following command to process the data into `train, validation, test`.
```shell