This commit is contained in:
UranusYu 2023-08-06 21:22:29 +08:00
Родитель c36f750899
Коммит 8c81749b0d
18 изменённых файлов: 2619 добавлений и 0 удалений

5
copilot/.env Normal file
Просмотреть файл

@ -0,0 +1,5 @@
OPENAI_API_KEY=""
OPENAI_ORG_ID=""
AZURE_OPENAI_DEPLOYMENT_NAME=""
AZURE_OPENAI_ENDPOINT=""
AZURE_OPENAI_API_KEY=""

163
copilot/README.md Normal file
Просмотреть файл

@ -0,0 +1,163 @@
<!-- <p align="center"> <b> Music Copilot </b> </p> -->
<div align="center">
[![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)]()
[![Open in Spaces](https://img.shields.io/badge/%F0%9F%A4%97-Open%20in%20Spaces-blue)]()
</div>
## Demo Video
![Download demo video](https://drive.google.com/file/d/1W0iJPHNPA6ENLJrPef0vtQytboSubxXe/view?usp=sharing)
## Features
- Accessibility: Music Copilot dynamically selects the most appropriate methods for each music-related task.
- Unity: Music Copilot unifies a wide array of tools into a single system, incorporating Huggingface models, GitHub projects, and Web APIs.
- Modularity: Music Copilot offers high modularity, allowing users to effortlessly enhance its capabilities by integrating new functions.
## Installation
### Docker (Recommended)
To be created.
### Conda / Pip
#### Install Dependencies
To set up the system from source, follow the steps below:
```bash
# Make sure git-lfs is installed
sudo apt-get update
sudo apt-get install -y git-lfs
# Install music-related libs
sudo apt-get install -y libsndfile1-dev
sudo apt-get install -y fluidsynth
sudo apt-get install -y ffmpeg
# Clone the repository from TODO
git clone https://github.com/TODO
cd DIR
```
Next, install the dependent libraries. There might be some conflicts, but they should not affect the functionality of the system.
```bash
pip install --upgrade pip
pip install -r requirements.txt
pip install semantic-kernel
pip install numpy==1.23.0
pip install protobuf==3.20.3
```
By following these steps, you will be able to successfully set up the system from the provided source.
#### Download Huggingface / Github Parameters
```bash
cd models/ # Or your custom folder for tools
bash download.sh
```
P.S. Download Github parameters according to your own need:
To use [muzic/roc](https://github.com/microsoft/muzic/tree/main/roc), follow these steps:
```bash
cd YOUR_MODEL_DIR # models/ by default
cd muzic/roc
```
1. Download the checkpoint and database from the following [link](https://drive.google.com/drive/folders/1TpWOMlRAaUL-R6CRLWfZK1ZeE1VCaubp).
2. Place the downloaded checkpoint file in the *music-ckpt* folder.
3. Create a folder named *database* to store the downloaded database files.
To use [DiffSinger](https://github.com/MoonInTheRiver/DiffSinger), follow these steps:
```bash
cd YOUR_MODEL_DIR
cd DiffSinger
```
1. Down the checkpoint and config from the following [link](https://github.com/MoonInTheRiver/DiffSinger/releases/download/pretrain-model/0228_opencpop_ds100_rel.zip) and unzip it in *checkpoints* folder.
2. You can find other DiffSinger checkpoints in its [docs](https://github.com/MoonInTheRiver/DiffSinger/blob/master/docs/README-SVS.md)
To use [DDSP](https://github.com/magenta/ddsp/tree/main), follow these steps:
```bash
cd YOUR_MODEL_DIR
mkdir ddsp
cd ddsp
pip install gsutil
mkdir violin; gsutil cp gs://ddsp/models/timbre_transfer_colab/2021-07-08/solo_violin_ckpt/* violin/
mkdir flute; gsutil cp gs://ddsp/models/timbre_transfer_colab/2021-07-08/solo_flute_ckpt/* flute/
```
To use audio synthesis, please download [MS Basic.sf3](https://github.com/musescore/MuseScore/tree/master/share/sound) and place it in the main folder.
## Usage
Change the *config.yaml* file to ensure that it is suitable for your application scenario.
```yaml
# optional tools
huggingface:
token: YOUR_HF_TOKEN
spotify:
client_id: YOUR_CLIENT_ID
client_secret: YOUR_CLIENT_SECRET
google:
api_key: YOUR_API_KEY
custom_search_engine_id: YOUR_SEARCH_ENGINE_ID
```
- Set your [Hugging Face token](https://huggingface.co/settings/tokens).
- Set your [Spotify Client ID and Secret](https://developer.spotify.com/dashboard), according to the [doc](https://developer.spotify.com/documentation/web-api).
- Set your [Google API key](https://console.cloud.google.com/apis/dashboard) and [Google Custom Search Engine ID](https://programmablesearchengine.google.com/controlpanel/create)
### CLI
fill the .env
```bash
OPENAI_API_KEY=""
OPENAI_ORG_ID=""
# optional
AZURE_OPENAI_DEPLOYMENT_NAME=""
AZURE_OPENAI_ENDPOINT=""
AZURE_OPENAI_API_KEY=""
```
If you use Azure OpenAI, please pay attention to change *use_azure_openai* in *config.yaml*.
And now you can run the agent by:
```bash
python agent.py --config config.yaml
```
### Gradio
We also provide gradio interface
```bash
python gradio_agent.py --config config.yaml
```
No .env file setup is required for Gradio interaction selection, but it does support only the OpenAI key.
## Citation
If you use this code, please cite it as:
```
To be published
```

351
copilot/agent.py Normal file
Просмотреть файл

@ -0,0 +1,351 @@
""" Agent for CLI or APPs"""
import io
import os
import sys
import time
import re
import json
import logging
import yaml
import threading
import argparse
import pdb
import semantic_kernel as sk
from semantic_kernel.connectors.ai.open_ai import AzureTextCompletion, OpenAITextCompletion
from model_utils import lyric_format
from plugins import get_task_map, init_plugins
class MusicCoplilotAgent:
"""
Attributes:
config_path: A path to a YAML file, referring to the example config.yaml
mode: Supports "cli" or "gradio", determining when to load the LLM backend.
"""
def __init__(
self,
config_path: str,
mode: str = "cli",
):
self.config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader)
os.makedirs("logs", exist_ok=True)
self.src_fold = self.config["src_fold"]
os.makedirs(self.src_fold, exist_ok=True)
self._init_logger()
self.kernel = sk.Kernel()
self.task_map = get_task_map()
self.pipes = init_plugins(self.config)
if mode == "cli":
self._init_backend_from_env()
def _init_logger(self):
self.logger = logging.getLogger(__name__)
self.logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
if not self.config["debug"]:
handler.setLevel(logging.CRITICAL)
self.logger.addHandler(handler)
log_file = self.config["log_file"]
if log_file:
filehandler = logging.FileHandler(log_file)
filehandler.setLevel(logging.DEBUG)
filehandler.setFormatter(formatter)
self.logger.addHandler(filehandler)
def _init_semantic_kernel(self):
skills_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "skills")
copilot_funcs = self.kernel.import_semantic_skill_from_directory(skills_directory, "MusicCopilot")
# task planning
self.task_planner = copilot_funcs["TaskPlanner"]
self.task_context = self.kernel.create_new_context()
self.task_context["history"] = ""
# model selection
self.tool_selector = copilot_funcs["ToolSelector"]
self.tool_context = self.kernel.create_new_context()
self.tool_context["history"] = ""
self.tool_context["tools"] = ""
# response
self.responder = copilot_funcs["Responder"]
self.response_context = self.kernel.create_new_context()
self.response_context["history"] = ""
self.response_context["processes"] = ""
# chat
self.chatbot = copilot_funcs["ChatBot"]
self.chat_context = self.kernel.create_new_context()
self.chat_context["history"] = ""
def clear_history(self):
self.task_context["history"] = ""
self.tool_context["history"] = ""
self.response_context["history"] = ""
self.chat_context["history"] = ""
def _init_backend_from_env(self):
# Configure AI service used by the kernel
if self.config["use_azure_openai"]:
deployment, api_key, endpoint = sk.azure_openai_settings_from_dot_env()
self.kernel.add_text_completion_service("dv", AzureTextCompletion(deployment, endpoint, api_key))
else:
api_key, org_id = sk.openai_settings_from_dot_env()
self.kernel.add_text_completion_service("dv", OpenAITextCompletion(self.config["model"], api_key, org_id))
self._init_semantic_kernel()
self._init_task_context()
self._init_tool_context()
def _init_backend_from_input(self, api_key):
# Only OpenAI api is supported in Gradio demo
self.kernel.add_text_completion_service("dv", OpenAITextCompletion(self.config["model"], api_key, ""))
self._init_semantic_kernel()
self._init_task_context()
self._init_tool_context()
def _init_task_context(self):
self.task_context["tasks"] = json.dumps(list(self.task_map.keys()))
def _init_tool_context(self):
self.tool_context["tools"] = json.dumps(
[{"id": pipe.id, "attr": pipe.get_attributes()} for pipe in self.pipes.values()]
)
def update_tool_attributes(self, pipe_id, **kwargs):
self.pipes[pipe_id].update_attributes(kwargs)
self._init_tool_context()
def model_inference(self, model_id, command, device="cpu"):
output = self.pipes[model_id].inference(command["args"], command["task"], device)
locals = []
for result in output:
if "audio" in result or "sheet_music" in result:
locals.append(result)
if len(locals) > 0:
self.task_context["history"] += f"In this task, <GENERATED>-{command['id']}: {json.dumps(locals)}. "
return output
def skillchat(self, input_text, chat_function, context):
context["input"] = input_text
answer = chat_function.invoke(context=context)
answer = str(answer).strip()
context["history"] += f"\nuser: {input_text}\nassistant: {answer}\n"
# Manage history
context["history"] = ' '.join(context["history"].split()[-self.config["history_len"]:])
return answer
def fix_depth(self, tasks):
for task in tasks:
task["dep"] = list(set(re.findall(r"<GENERATED>-([0-9]+)", json.dumps(task))))
task["dep"] = [int(d) for d in task["dep"]]
if len(task["dep"]) == 0:
task["dep"] = [-1]
return tasks
def collect_result(self, command, choose, inference_result):
result = {"task": command}
result["inference result"] = inference_result
result["choose model result"] = choose
self.logger.debug(f"inference result: {inference_result}")
return result
def run_task(self, input_text, command, results):
id = command["id"]
args = command["args"]
task = command["task"]
deps = command["dep"]
if deps[0] != -1:
dep_tasks = [results[dep] for dep in deps]
else:
dep_tasks = []
self.logger.debug(f"Run task: {id} - {task}")
self.logger.debug("Deps: " + json.dumps(dep_tasks))
inst_args = []
for arg in args:
for key in arg:
if isinstance(arg[key], str):
if "<GENERATED>" in arg[key]:
dep_id = int(arg[key].split("-")[1])
for result in results[dep_id]["inference result"]:
if key in result:
tmp_arg = arg.copy()
tmp_arg[key] = result[key]
inst_args.append(tmp_arg)
else:
tmp_arg = arg.copy()
inst_args.append(tmp_arg)
elif isinstance(arg[key], list):
tmp_arg = arg.copy()
for t in range(len(tmp_arg[key])):
item = tmp_arg[key][t]
if "<GENERATED>" in item:
dep_id = int(item.split("-")[1])
for result in results[dep_id]["inference result"]:
if key in result:
tmp_arg[key][t] = result[key]
break
inst_args.append(tmp_arg)
for arg in inst_args:
for resource in ["audio", "sheet_music"]:
if resource in arg:
if not arg[resource].startswith(self.config["src_fold"]) and not arg[resource].startswith("http") and len(arg[resource]) > 0:
arg[resource] = f"{self.config['src_fold']}/{arg[resource]}"
command["args"] = inst_args
self.logger.debug(f"parsed task: {command}")
if task in ["lyric-generation"]: # ChatGPT Can do
best_model_id = "ChatGPT"
reason = "ChatGPT performs well on some NLP tasks as well."
choose = {"id": best_model_id, "reason": reason}
inference_result = []
for arg in command["args"]:
chat_input = f"[{input_text}] contains a task in JSON format {command}. Now you are a {command['task']} system, the arguments are {arg}. Just help me do {command['task']} and give me the resultwithout any additional description. The result must be in text form without any urls."
response = self.skillchat(chat_input, self.chatbot, self.chat_context)
inference_result.append({"lyric":lyric_format(response)})
else:
if task not in self.task_map:
self.logger.warning(f"no available models on {task} task.")
inference_result = [{"error": f"{command['task']} not found in available tasks."}]
results[id] = self.collect_result(command, "", inference_result)
return False
candidates = [pipe_id for pipe_id in self.task_map[task] if pipe_id in self.pipes]
candidates = candidates[:self.config["candidate_tools"]]
self.logger.debug(f"avaliable models on {command['task']}: {candidates}")
if len(candidates) == 0:
self.logger.warning(f"unloaded models on {task} task.")
inference_result = [{"error": f"models for {command['task']} are not loaded."}]
results[id] = self.collect_result(command, "", inference_result)
return False
if len(candidates) == 1:
best_model_id = candidates[0]
reason = "Only one model available."
choose = {"id": best_model_id, "reason": reason}
self.logger.debug(f"chosen model: {choose}")
else:
self.tool_context["available"] = ', '.join([cand.id for cand in candidates])
choose_str = self.skillchat(input_text, self.tool_selector, self.tool_context)
self.logger.debug(f"chosen model: {choose_str}")
choose = json.loads(choose_str)
reason = choose["reason"]
best_model_id = choose["id"]
inference_result = self.model_inference(best_model_id, command, device=self.config["device"])
results[id] = self.collect_result(command, choose, inference_result)
return True
def chat(self, input_text):
start = time.time()
self.logger.info(f"input: {input_text}")
task_str = self.skillchat(input_text, self.task_planner, self.task_context)
self.logger.info(f"plans: {task_str}")
try:
tasks = json.loads(task_str)
except Exception as e:
self.logger.debug(e)
response = self.skillchat(input_text, self.chatbot, self.chat_context)
return response
if len(tasks) == 0:
response = self.skillchat(input_text, self.chatbot, self.chat_context)
return response
tasks = self.fix_depth(tasks)
results = {}
threads = []
d = dict()
retry = 0
while True:
num_thread = len(threads)
for task in tasks:
# logger.debug(f"d.keys(): {d.keys()}, dep: {dep}")
for dep_id in task["dep"]:
if dep_id >= task["id"]:
task["dep"] = [-1]
break
dep = task["dep"]
if dep[0] == -1 or len(list(set(dep).intersection(d.keys()))) == len(dep):
tasks.remove(task)
thread = threading.Thread(target=self.run_task, args=(input_text, task, d))
thread.start()
threads.append(thread)
if num_thread == len(threads):
time.sleep(0.5)
retry += 1
if retry > 120:
self.logger.debug("User has waited too long, Loop break.")
break
if len(tasks) == 0:
break
for thread in threads:
thread.join()
results = d.copy()
self.logger.debug("results: ", results)
self.response_context["processes"] = str(results)
response = self.skillchat(input_text, self.responder, self.response_context)
end = time.time()
during = end - start
self.logger.info(f"time: {during}s")
return response
def parse_args():
parser = argparse.ArgumentParser(description="A path to a YAML file")
parser.add_argument("--config", type=str, help="a YAML file path.")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
agent = MusicCoplilotAgent(args.config, mode="cli")
print("Input exit or quit to stop the agent.")
while True:
message = input("Send a message: ")
if message in ["exit", "quit"]:
break
print(agent.chat(message))

Просмотреть файл

@ -0,0 +1,775 @@
import sqlite3
import random
import copy
import numpy as np
import argparse
import re
from math import ceil
from utils.lyrics_match import Lyrics_match
from midiutil.MidiFile import MIDIFile
from fairseq.models.transformer_lm import TransformerLanguageModel
from cnsenti import Sentiment
import miditoolkit
from midi2dict import midi_to_lyrics
from textblob import TextBlob
_CHORD_KIND_PITCHES = {
'': [0, 4, 7],
'm': [0, 3, 7],
'+': [0, 4, 8],
'dim': [0, 3, 6],
'7': [0, 4, 7, 10],
'maj7': [0, 4, 7, 11],
'm7': [0, 3, 7, 10],
'm7b5': [0, 3, 6, 10],
}
# custom_lm = TransformerLanguageModel.from_pretrained('music-ckps/', 'checkpoint_best.pt', tokenizer='space',
# batch_size=8192).cuda()
def select_melody(c, is_maj, is_chorus, length, last_bar, chord, chord_ptr, is_last_sentence):
cursor = c.execute(
"SELECT DISTINCT NOTES, CHORDS from MELOLIB where LENGTH = '{}' and CHORUS = '{}' and MAJOR = '{}' ".format(
length, is_chorus, is_maj)) # and MAJOR = '{}'
candidates_bars = []
if is_debug:
print("Retrive melody...")
for row in cursor:
notes = row[0]
cd_ = row[1]
candidates_bars.append((notes, cd_))
# Filter by chords.
chord_list_ = chord.strip().split(' ')
chord_list_ = chord_list_[chord_ptr:] + chord_list_[:chord_ptr]
re_str = ''
if not is_last_sentence:
key = ''
else:
if is_maj:
key = ' C:'
else:
key = ' A:m'
# For the given chord progression, we generate a regex like:
# A:m F: G: C: -> ^A:m( A:m)*( F:)+( G:)+( C:)*$|^A:m( A:m)*( F:)+( G:)*$|^A:m( A:m)*( F:)*$|^A:m( A:m)*$
# Given the regex, we find matched pieces.
# We design the regex like this because alternations in regular expressions are evaluated from left to right,
# the piece with the most various chords will be selected, if there's any.
for j in range(len(chord_list_), 0, -1):
re_str += '^({}( {})*'.format(chord_list_[0], chord_list_[0])
for idx in range(1, j):
re_str += '( {})+'.format(chord_list_[idx])
re_str = re_str[:-1]
re_str += '*{})$|'.format(key)
re_str = re_str[:-1]
tmp_candidates = []
for row in candidates_bars:
if re.match(r'{}'.format(re_str), row[1]):
tmp_candidates.append(row)
if len(tmp_candidates) == 0:
re_str = '^{}( {})*$'.format(chord_list_[-1], chord_list_[-1])
for row in candidates_bars:
if re.match(r'{}'.format(re_str), row[1]):
tmp_candidates.append(row)
if len(tmp_candidates) > 0:
candidates_bars = tmp_candidates
else:
if is_maj:
re_str = '^C:( C:)*$'
else:
re_str = '^A:m( A:m)*$'
for row in candidates_bars:
if re.match(r'{}'.format(re_str), row[1]):
tmp_candidates.append(row)
if len(tmp_candidates) > 0:
candidates_bars = tmp_candidates
candidates_cnt = len(candidates_bars)
if candidates_cnt == 0:
if is_debug:
print('No Matched Rhythm as {}'.format(length))
return []
if last_bar == None: # we are at the begining of a song, random select bars.
if is_debug:
print('Start a song...')
def not_too_high(bar):
notes = bar.split(' ')[:-1][3::5]
notes = [int(x[6:]) for x in notes]
for i in notes:
if 57 > i or i > 66:
return False
return True
tmp = []
for bar in candidates_bars:
if not_too_high(bar[0]):
tmp.append(bar)
return tmp
else:
last_note = int(last_bar.split(' ')[-3][6:])
# tendency
selected_bars = []
prefer_note = None
# Major C
if is_maj:
if last_note % 12 == 2 or last_note % 12 == 9:
prefer_note = last_note - 2
elif last_note % 12 == 5:
prefer_note = last_note - 1
elif last_note % 12 == 11:
prefer_note = last_note + 1
# Minor A
else:
if last_note % 12 == 11 or last_note % 12 == 2: # 2 -> 1, 4 -> 3
prefer_note = last_note - 2
elif last_note % 12 == 6: # 6 -> 5
prefer_note = last_note - 1
elif last_note % 12 == 7: # 7-> 1
prefer_note = last_note + 2
if prefer_note is not None:
for x in candidates_bars:
if x[0][0] == prefer_note:
selected_bars.append(x)
if len(selected_bars) > 0:
if is_debug:
print('Filter by tendency...')
candidates_bars = selected_bars
selected_bars = []
for bar in candidates_bars:
first_pitch = int(bar[0].split(' ')[3][6:])
if (first_pitch > last_note - 8 and first_pitch < last_note + 8):
selected_bars.append(bar)
if len(selected_bars) > 0:
if is_debug:
print('Filter by pitch range...')
return selected_bars
# No candidates yet? randomly return some.
if is_debug:
print("Randomly selected...")
return candidates_bars
def lm_score(custom_lm, bars, note_string, bar_idx):
tmp_string = []
n = ' '.join(note_string.split(' ')[-100:])
for sbar in bars:
sbar_, _ = fill_template(sbar[0], bar_idx)
tmp_string.append(n + sbar_)
score = [x['score'].item() for x in custom_lm.score(tmp_string)]
assert len(score) == len(tmp_string)
tmp = list(zip(bars, score))
tmp.sort(key=lambda x: x[1], reverse=True)
tmp = tmp[:30]
best_score = tmp[0][1]
res = []
for x in tmp:
if best_score - x[1] < 0.1:
res.append(x[0])
return res
def get_chorus(chorus_start, chorus_length, lyrics):
return range(chorus_start, chorus_start + chorus_length)
def save_demo(notes_str, select_chords, name, lang, sentence, word_counter):
pitch_dict = {'C': 0, 'C#': 1, 'D': 2, 'Eb': 3, 'E': 4, 'F': 5, 'F#': 6, 'G': 7, 'Ab': 8, 'A': 9, 'Bb': 10, 'B': 11}
_CHORD_KIND_PITCHES = {
'': [0, 4, 7],
'm': [0, 3, 7],
'+': [0, 4, 8],
'dim': [0, 3, 6],
'7': [0, 4, 7, 10],
'maj7': [0, 4, 7, 11],
'm7': [0, 3, 7, 10],
'm7b5': [0, 3, 6, 10],
}
print('Save the melody to {}.mid'.format(name))
mf = MIDIFile(2) # only 1 track
melody_track = 0 # the only track
chord_track = 1
time = 0 # start at the beginning
channel = 0
mf.addTrackName(melody_track, time, "melody")
mf.addTrackName(chord_track, time, "chord")
mf.addTempo(melody_track, time, 120)
mf.addTempo(chord_track, time, 120)
notes = notes_str.split(' ')
cnt = 0
sen_idx = 0
chord_time = []
for i in range(len(notes) // 5):
if is_debug:
print('wirting idx: ', i)
# cadence = notes[5 * i]
bar = int(notes[5 * i + 1][4:])
pos = int(notes[5 * i + 2][4:]) # // pos_resolution
pitch = int(notes[5 * i + 3][6:])
dur = int(notes[5 * i + 4][4:]) / 4
time = bar * 4 + pos / 4 # + delta
# if cadence == 'HALF':
# delta += 2
# if cadence == 'AUT':
# delta += 4
mf.addNote(melody_track, channel, pitch, time, dur, 100)
# fill all chords into bars before writing notes
if cnt == 0:
cds = select_chords[sen_idx].split(' ')
t = time - time % 2
if len(chord_time) > 0:
blank_dur = t - chord_time[-1] - 2
insert_num = int(blank_dur / 2)
if is_debug:
print('Chords:', cds[0].split(':'))
root, cd_type = cds[0].split(':')
root = pitch_dict[root]
for i in range(insert_num):
for shift in _CHORD_KIND_PITCHES[cd_type]:
mf.addNote(chord_track, channel, 36 + root + shift, chord_time[-1] + 2, 2, 75)
chord_time.append(chord_time[-1] + 2)
if is_debug:
print('begin sentence:', sen_idx)
for cd in cds:
root, cd_type = cd.split(':')
root = pitch_dict[root]
# mf.addNote(chord_track, channel, 36+root, t, 2, 75) # 36 is C3
for shift in _CHORD_KIND_PITCHES[cd_type]:
mf.addNote(chord_track, channel, 36 + root + shift, t, 2, 75)
chord_time.append(t)
t += 2
cnt += 1
if cnt == word_counter[sen_idx]:
cnt = 0
sen_idx += 1
name += '.mid'
with open(name, 'wb') as outf:
mf.writeFile(outf)
midi_obj = miditoolkit.midi.parser.MidiFile(name)
if lang == 'zh':
lyrics = ''.join(sentence)
else:
print(sentence)
lyrics = ' '.join(sentence).split(' ')
print(lyrics)
word_idx = 0
for idx, word in enumerate(lyrics):
if word not in [',', '.', '']:
note = midi_obj.instruments[0].notes[word_idx]
midi_obj.lyrics.append(
miditoolkit.Lyric(text=word, time=note.start))
word_idx += 1
else:
midi_obj.lyrics[-1].text += word
# print(midi_obj.lyrics)
midi_obj.dump(f'{name}', charset='utf-8')
return midi_to_lyrics(midi_obj)
def fill_template(s_bar, bar_idx):
notes = s_bar.split(' ')
tmp = []
last_bar_idx = notes[1][4:]
for i in range(len(notes)):
if i % 5 == 1:
if notes[i][4:] != last_bar_idx:
bar_idx += 1
last_bar_idx = notes[i][4:]
tmp.append('bar_' + str(bar_idx))
else:
tmp.append(notes[i])
return ' '.join(tmp), bar_idx + 1
def splice(bar1, bar2):
"""
Cancatenate bar1 and bar2
In bar1, bar index is replaced while in bar2 X or Y is remained like 'X {} {} {} {} '.format(pos, pitch,dur,cadence)
"""
if bar1 == '':
return bar2
if bar2 == '':
return bar1
assert bar1[-1] == ' ' # For the ease of concatenation, there's a space at the end of bar
assert bar2[-1] == ' '
notes1 = bar1.split(' ')[:-1]
notes2 = bar2.split(' ')[:-1]
bar_cnt = len(set(notes1[1::5]))
# If the last note ending time in bar1 is not far from the begining time of the first note in bar2, just return bar1 + bar2
# Calculate the note intervals in bars. If interval between two bars <= the average interval inside a bar, then it is regarded as 'not far away'.
def get_interval(notes):
begin = []
dur = []
if notes[1][4:] != 'X' and notes[1][4:] != 'Y':
start_bar = int(notes[1][4:])
else:
start_bar = 0
for idx in range(len(notes) // 5):
if notes[5 * idx + 1][4:] == 'X':
bar_idx_ = 0
elif notes[5 * idx + 1][4:] == 'Y':
bar_idx_ = 1
else:
bar_idx_ = int(notes[5 * idx + 1][4:])
begin.append(16 * (bar_idx_ - start_bar) + int(notes[5 * idx + 2][4:]))
dur.append(int(notes[5 * idx + 4][4:]))
end = list(np.array(begin) + np.array(dur))
return list(np.array(begin[1:]) - np.array(end[:-1])), begin[0], end[-1] - 16 if end[-1] > 16 else end[-1]
inter1, _, end1 = get_interval(notes1)
inter2, begin2, _ = get_interval(notes2)
def avg(notes):
return sum(notes) / len(notes)
avg_interval = avg(inter1 + inter2)
last_bar1_idx = int(notes1[-4][4:])
bar2, _ = fill_template(bar2, last_bar1_idx + 1)
if avg_interval < (16 - end1 + begin2):
# If interval between two bars is big, shift the second bar forward.
notes2 = bar2.split(' ')[:-1]
tmp = ''
for idx in range(len(notes2) // 5):
pos = int(notes2[5 * idx + 2][4:]) - (16 - end1 + begin2)
bar_idx_ = int(notes2[5 * idx + 1][4:])
if pos < 0:
bar_idx_ += pos // 16
pos = pos % 16
tmp += '{} bar_{} Pos_{} {} {} '.format(notes2[5 * idx], bar_idx_, pos, notes2[5 * idx + 3],
notes2[5 * idx + 4])
return bar1 + tmp
else:
return bar1 + bar2
def not_mono(bar):
"""
Filter monotonous pieces.
"""
notes = bar.split(' ')[:-1][3::5]
notes = [int(x[6:]) for x in notes]
tmp = [0] * 128
for idx in range(len(notes)):
tmp[int(notes[idx])] = 1
if (1 < len(notes) <= 3 and sum(tmp) == 1) or (len(notes) >= 4 and sum(tmp) < 3):
return False
return True
def not_duplicate(bar1, bar2):
"""
De-duplication, only care about the pitch.
"""
notes1 = bar1.split(' ')[:-1][3::5] # For the ease of concatenation, there's a space at the end of bar
notes2 = bar2.split(' ')[:-1][3::5]
return notes1 != notes2
def no_keep_trend(bars):
def is_sorted(a):
return all([a[i] <= a[i + 1] for i in range(len(a) - 1)])
candidates_bars = []
for bar_and_chord in bars:
bar = bar_and_chord[0]
notes = bar.split(' ')[:-1][3::5]
notes = [int(x[6:]) for x in notes]
if not is_sorted(notes):
candidates_bars.append(bar_and_chord)
return candidates_bars
def polish(bar, last_note_end, iscopy=False):
"""
Three fuctions:
1. Avoid bars overlapping.
2. Make the first note in all bars start at the position 0.
3. Remove rest and cadence in a bar.
"""
notes = bar.strip().split(' ')
tmp = ''
first_note_start = 0
is_tuned = False
for idx in range(len(notes) // 5):
pos = int(notes[5 * idx + 2][4:])
bar_idx_ = int(notes[5 * idx + 1][4:])
dur = int(notes[5 * idx + 4][4:])
this_note_start = 16 * bar_idx_ + pos
cadence = 'NOT'
if idx == 0:
first_note_start = this_note_start
blank_after_last_note = 16 - last_note_end % 16
threshold = blank_after_last_note
else:
threshold = 0
if dur == 1: # the minimum granularity is a 1/8 note.
dur = 2
if dur > 8: # the maximum granularity is a 1/2 note.
dur = 8
# Function 3:
if this_note_start - last_note_end != threshold:
pos += (last_note_end + threshold - this_note_start)
bar_idx_ += pos // 16
pos = pos % 16
if idx == len(notes) // 5 - 2:
if 12 < pos + dur <= 16 or len(notes) // 5 <= 4:
dur = 16 - pos
is_tuned = True
if idx == len(notes) // 5 - 1:
if is_tuned:
pos = 0
else:
if 12 < pos + dur <= 16:
dur += 6
cadence = 'HALF' # just for the ease of model scoring
last_note_end = 16 * bar_idx_ + pos + dur
assert pos <= 16
tmp += '{} bar_{} Pos_{} {} Dur_{} '.format(cadence, bar_idx_, pos, notes[5 * idx + 3], dur)
return tmp, bar_idx_ + 1, last_note_end
def chord_truc(bar, schord):
"""
Given a bar string, remove redundant chords.
"""
schord_list = schord.split(' ')
notes = bar.strip().split(' ')
start_pos = 16 * int(notes[1][4:]) + int(notes[2][4:])
end_pos = 16 * int(notes[-4][4:]) + int(notes[-3][4:]) + int(notes[-1][4:])
duration = end_pos - start_pos
chord_num = ceil(duration / 8)
assert chord_num >= 1, 'bar:{},chord:{}'.format(bar, schord)
if len(schord_list) >= chord_num:
schord_list = schord_list[:chord_num]
else:
tmp = []
for i in schord_list:
tmp.append(i)
tmp.append(i)
schord_list = tmp[:chord_num]
return schord_list
def polish_chord(bar, schord, chord, chord_ptr):
"""
Align chords and the bar. When this function is called, the bar index is already replaced by the true index instead of X or Y.
In our setting, there's 2 chords in a bar. Therefore for any position % 8==0, we write a chord.
Of course, you can modify this setting as needed.
"""
schord_list = chord_truc(bar, schord)
last_chord = schord_list[-1]
schord = ' '.join(schord_list)
chord_list = chord.split(' ')
if last_chord not in chord_list:
chord_ptr = (chord_ptr + 1) % len(chord_list)
else:
chord_ptr = (chord_list.index(last_chord) + 1) % len(chord_list)
return schord, chord_ptr
# if __name__ == '__main__':
def main(
custom_lm,
lyrics_corpus=['明月几时有 把酒问青天 不知天上宫阙 今夕是何年'],
chord_corpus=['zh C: G: A:m E: D:m G:'],
output_file_name='generated',
db_path='database/ROC.db'
):
use_sentiment = False
global is_debug
is_debug = False
conn = sqlite3.connect(db_path)
# global c
c = conn.cursor()
print("Database connected")
for lyrics, chord in zip(lyrics_corpus, chord_corpus):
lang = chord[:2]
assert lang in ['zh', 'en'] # Note that ROC is not language-sensitive, you can extend this.
chord = chord[2:].strip()
print('CHORD:', chord)
chord_ptr = 0
is_maj = 1
if lang == 'zh':
sentence = lyrics.strip().split(' ') # The list of lyric sentences
name = sentence[0]
if use_sentiment:
senti = Sentiment()
pos = 0
neg = 0
for s in sentence:
result = senti.sentiment_calculate(s)
pos += result['pos']
neg += result['neg']
if neg < 0 and pos >= 0:
is_maj = 1
elif pos < 0 and neg >= 0:
is_maj = 0
else:
if pos / neg < 1:
is_maj = 0
else:
is_maj = 1
elif lang == 'en':
sentence = lyrics.strip().split('[sep]')
name = sentence[0]
sentence = [len(x.strip().split(' ')) * '_' for x in sentence]
if use_sentiment:
sent = '.'.join(sentence)
blob = TextBlob(sent)
polarity = 0
for s in blob.sentences:
polarity += s.sentiment.polarity
if polarity >= 0:
is_maj = 1
else:
is_maj = 0
print('Tonality:', is_maj)
# structure recognition
parent, chorus_start, chorus_length = Lyrics_match(
sentence) # The last element must be -1, because the chord should go back to tonic
if is_debug:
print('Struct Array: ', parent)
chorus_range = get_chorus(chorus_start, chorus_length, lyrics)
if is_debug:
print('Recognized Chorus: ', chorus_start, chorus_length)
select_notes = [] # selected_melodies
select_chords = [] # selected chords
is_chorus = 0 # is a chorus?
note_string = '' # the 'melody context' mentioned in the paper.
bar_idx = 0 # current bar index. it is used to replace bar index in retrieved pieces.
last_note_end = -16
# is_1smn = 0 # Does 1 Syllable align with Multi Notes? In the future, we will explore better methods to realize this function. Here by default, we disable it.
for i in range(len(sentence)):
if lang == 'zh':
print('Lyrics: ', sentence[i])
else:
print('Lyrics: ', lyrics.strip().split('[sep]')[i])
is_last_sentence = (i == len(sentence) - 1)
if i in chorus_range:
is_chorus = 1
else:
is_chorus = 0
cnt = len(sentence[i])
if cnt <= 2 and parent[i] == -2: # if length is too short, do not partially share
parent[i] = -1
# Following codes correspond to 'Retrieval and Re-ranking' in Section 3.2.
# parent[i] is the 'struct value' in the paper.
if parent[i] == -1:
if is_debug:
print('No sharing.')
# one_syllable_multi_notes_probabilty = random.randint(1,100)
# if one_syllable_multi_notes_probabilty == 1:
# is_1smn = 1
# connect_notes = random.randint(1,2)
# cnt += connect_notes
# connect_start = random.randint(1,cnt)
# print('One Syllable Multi Notes range:',connect_start, connect_start + connect_notes)
if len(select_notes) == 0: # The first sentence of a song.
last_bar = None
else:
last_bar = select_notes[-1]
selected_bars = select_melody(c, is_maj, is_chorus, cnt, last_bar, chord, chord_ptr, is_last_sentence)
if cnt < 9 and len(selected_bars) > 0:
selected_bars = lm_score(custom_lm, selected_bars, note_string, bar_idx)
# selected_bars = no_keep_trend(selected_bars)
bar_chord = selected_bars[random.randint(0, len(selected_bars) - 1)]
s_bar = bar_chord[0]
s_chord = bar_chord[1]
s_bar, bar_idx = fill_template(s_bar,
bar_idx) # The returned bar index is the first bar index which should be in the next sentence, that is s_bar + 1.
else: # If no pieces is retrieved or there are too many syllables in a lyric.
if is_debug:
print('No pieces is retrieved or there are too many syllables in a lyric. Split the lyric.')
s_bar = ''
s_chord = ''
origin_cnt = cnt
error = 0
while cnt > 0:
l = max(origin_cnt // 3, 5)
r = max(origin_cnt // 2, 7) # Better to use long pieces, for better coherency.
split_len = random.randint(l, r)
if split_len > cnt:
split_len = cnt
if is_debug:
print('Split at ', split_len)
selected_bars = select_melody(c, is_maj, is_chorus, split_len, last_bar, chord, chord_ptr,
is_last_sentence)
if len(selected_bars) > 0:
selected_bars = lm_score(custom_lm, selected_bars, note_string + s_bar, bar_idx)
bar_chord = selected_bars[random.randint(0, len(selected_bars) - 1)]
last_bar = bar_chord[0]
last_chord = bar_chord[1]
s_bar = splice(s_bar, last_bar)
s_chord += ' ' + last_chord
# Explanation: if this condition is true, i.e., the length of s_bar + last_bar == the length of last_bar,
# then the only possibility is that we are in the first step of this while loop. We need to replace the bar index in retrieved pieces with the true bar index.
# In the following steps, there is no need to do so because there is a implicit 'fill_template' in 'splice'.
if len(s_bar) == len(last_bar):
s_bar, bar_idx = fill_template(s_bar, bar_idx)
s_chord, chord_ptr = polish_chord(s_bar, s_chord, chord, chord_ptr)
last_bar = s_bar
cnt -= split_len
else:
error += 1
if error >= 10:
print('Database has not enough pieces to support ROC.')
exit()
s_chord = s_chord[1:]
s_bar, bar_idx, last_note_end = polish(s_bar, last_note_end)
s_chord, chord_ptr = polish_chord(s_bar, s_chord, chord, chord_ptr)
note_string += s_bar
select_notes.append(s_bar)
select_chords.append(s_chord)
if is_debug:
print('Selected notes: ', s_bar)
print('Chords: ', s_chord)
elif parent[i] == -2:
if is_debug:
print('Share partial melody from the previous lyric.')
l = min(cnt // 3,
3) # As mentioned in 'Concatenation and Polish' Section, for adjacents lyrics having the same syllabels number,
r = min(cnt // 2, 5) # we 'polish their melodies to sound similar'
# modify some notes then share.
replace_len = random.randint(l, r)
last_bar = ' '.join(select_notes[-1].split(' ')[:- replace_len * 5 - 1]) + ' '
tail = select_notes[-1].split(' ')[- replace_len * 5 - 1:]
last_chord = ' '.join(chord_truc(last_bar, select_chords[-1]))
selected_bars = select_melody(c, is_maj, is_chorus, replace_len, last_bar, chord, chord_ptr,
is_last_sentence)
selected_bars = lm_score(custom_lm, selected_bars, note_string + last_bar, bar_idx)
for bar_chord in selected_bars:
bar = bar_chord[0]
s_chord = bar_chord[1]
s_bar = splice(last_bar, bar)
if not_mono(s_bar) and not_duplicate(s_bar, select_notes[-1]):
s_chord = last_chord + ' ' + s_chord
break
s_bar, bar_idx = fill_template(s_bar, bar_idx)
s_bar = s_bar.split(' ')
for i in range(2, len(tail)): # Modify duration
if i % 5 == 2 or i % 5 == 1: # dur and cadence
s_bar[-i] = tail[-i]
s_bar = ' '.join(s_bar)
s_bar, bar_idx, last_note_end = polish(s_bar, last_note_end, True)
s_chord, chord_ptr = polish_chord(s_bar, s_chord, chord, chord_ptr)
note_string += s_bar
select_notes.append(s_bar)
select_chords.append(s_chord)
if is_debug:
print('Modified notes: ', s_bar)
print('chords: ', s_chord)
else:
# 'struct value is postive' as mentioned in the paper, we directly share melodies.
if is_debug:
print('Share notes with sentence No.', parent[i])
s_bar = copy.deepcopy(select_notes[parent[i]])
s_chord = copy.deepcopy(select_chords[parent[i]])
s_bar, bar_idx = fill_template(s_bar, bar_idx)
s_bar, bar_idx, last_note_end = polish(s_bar, last_note_end, True)
s_chord, chord_ptr = polish_chord(s_bar, s_chord, chord, chord_ptr)
note_string += s_bar
select_notes.append(s_bar)
select_chords.append(s_chord)
if is_debug:
print(
'----------------------------------------------------------------------------------------------------------')
if is_debug:
print(select_chords)
print(select_notes)
output = save_demo(note_string, select_chords, output_file_name, lang,
lyrics.strip().split('[sep]') if lang == 'en' else sentence, [len(i) for i in sentence])
print(output)
print(
'--------------------------------------------A song is composed.--------------------------------------------')
conn.close()
return output
if __name__ == '__main__':
model = TransformerLanguageModel.from_pretrained('music-ckps/', 'checkpoint_best.pt', tokenizer='space',
batch_size=8192).cuda()
main(model)

20
copilot/config.yaml Normal file
Просмотреть файл

@ -0,0 +1,20 @@
# optional tools
huggingface:
token:
spotify:
client_id:
client_secret:
google:
api_key:
custom_search_engine_id:
# agent settings
debug: false
use_azure_openai: false
model: text-davinci-003
device: cuda:0
local_fold: models
log_file: logs/debug.log
src_fold: public/audios
disabled_tools:
history_len: 200
candidate_tools: 5

142
copilot/gradio_agent.py Normal file
Просмотреть файл

@ -0,0 +1,142 @@
import uuid
import gradio as gr
import re
import requests
from agent import MusicCoplilotAgent
import soundfile
import argparse
all_messages = []
OPENAI_KEY = ""
def add_message(content, role):
message = {"role": role, "content": content}
all_messages.append(message)
def extract_medias(message):
audio_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?\.(flac|wav|mp3)")
audio_urls = []
for match in audio_pattern.finditer(message):
if match.group(0) not in audio_urls:
audio_urls.append(match.group(0))
return audio_urls
def set_openai_key(openai_key):
global OPENAI_KEY
OPENAI_KEY = openai_key
agent._init_backend_from_input(openai_key)
return OPENAI_KEY, gr.update(visible=True), gr.update(visible=True)
def add_text(messages, message):
if len(OPENAI_KEY) == 0 or not OPENAI_KEY.startswith("sk-"):
return messages, "Please set your OpenAI API key first."
add_message(message, "user")
messages = messages + [(message, None)]
audio_urls = extract_medias(message)
for audio_url in audio_urls:
if audio_url.startswith("http"):
ext = audio_url.split(".")[-1]
name = f"{str(uuid.uuid4()[:4])}.{ext}"
response = requests.get(audio_url)
with open(f"{agent.config['src_fold']}/{name}", "wb") as f:
f.write(response.content)
messages = messages + [((f"{audio_url} is saved as {name}",), None)]
return messages, ""
def upload_audio(file, messages):
file_name = str(uuid.uuid4())[:4]
audio_load, sr = soundfile.read(file.name)
soundfile.write(f"{agent.config['src_fold']}/{file_name}.wav", audio_load, samplerate=sr)
messages = messages + [(None, f"** {file_name}.wav ** uploaded")]
return messages
def bot(messages):
if len(OPENAI_KEY) == 0 or not OPENAI_KEY.startswith("sk-"):
return messages
message = agent.chat(messages[-1][0])
audio_urls = extract_medias(message)
add_message(message, "assistant")
messages[-1][1] = message
for audio_url in audio_urls:
if not audio_url.startswith("http") and not audio_url.startswith(f"{agent.config['src_fold']}"):
audio_url = f"{agent.config['src_fold']}/{audio_url}"
messages = messages + [((None, (audio_url,)))]
return messages
def clear_all_history(messages):
agent.clear_history()
messages = messages + [(("All histories of LLM are cleared", None))]
return messages
def parse_args():
parser = argparse.ArgumentParser(description="A path to a YAML file")
parser.add_argument("--config", type=str, help="a YAML file path.")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
agent = MusicCoplilotAgent(args.config, mode="gradio")
with gr.Blocks() as demo:
gr.Markdown("<h2><center>Music Copilot (Dev)</center></h2>")
with gr.Row():
openai_api_key = gr.Textbox(
show_label=False,
placeholder="Set your OpenAI API key here and press Enter",
lines=1,
type="password",
)
state = gr.State([])
chatbot = gr.Chatbot([], elem_id="chatbot", label="music_copilot", visible=False).style(height=500)
with gr.Row(visible=False) as text_input_raws:
with gr.Column(scale=0.8):
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter or click Run button").style(container=False)
# with gr.Column(scale=0.1, min_width=0):
# run = gr.Button("🏃Run")
with gr.Column(scale=0.1, min_width=0):
clear_txt = gr.Button("🔄Clear")
with gr.Column(scale=0.1, min_width=0):
btn = gr.UploadButton("🖼Upload", file_types=["audio"])
openai_api_key.submit(set_openai_key, [openai_api_key], [openai_api_key, chatbot, text_input_raws])
clear_txt.click(clear_all_history, [chatbot], [chatbot])
btn.upload(upload_audio, [btn, chatbot], [chatbot])
txt.submit(add_text, [chatbot, txt], [chatbot, txt]).then(
bot, chatbot, chatbot
)
gr.Examples(
examples=["Write a piece of lyric about the recent World Cup.",
"生成一首古风歌词的中文歌",
"Download a song by Jay Zhou for me and separate the vocals and the accompanies.",
"Convert the vocals in /b.wav to a violin sound.",
"Give me the sheet music and lyrics in the song /a.wav",
"近一个月流行的音乐类型",
"把c.wav中的人声搭配合适的旋律变成一首歌"
],
inputs=txt
)
demo.launch(share=True)

361
copilot/model_utils.py Normal file
Просмотреть файл

@ -0,0 +1,361 @@
import os
import requests
import urllib.parse
import librosa
import soundfile as sf
import re
import numpy as np
import ddsp
import ddsp.training
import pickle
import gin
from ddsp.training.postprocessing import (
detect_notes, fit_quantile_transform
)
# from ddsp.colab.colab_utils import (
# auto_tune, get_tuning_factor
# )
import tensorflow.compat.v2 as tf
import pdb
import torch
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.models.wav2vec2.modeling_wav2vec2 import (
Wav2Vec2PreTrainedModel,
Wav2Vec2Model
)
from transformers.models.hubert.modeling_hubert import (
HubertPreTrainedModel,
HubertModel
)
from dataclasses import dataclass
from typing import Optional, Tuple
from transformers.file_utils import ModelOutput
@dataclass
class SpeechClassifierOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class Wav2Vec2ClassificationHead(nn.Module):
"""Head for wav2vec classification task."""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.final_dropout)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, features, **kwargs):
x = features
x = self.dropout(x)
x = self.dense(x)
x = torch.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class Wav2Vec2ForSpeechClassification(Wav2Vec2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.pooling_mode = config.pooling_mode
self.config = config
self.wav2vec2 = Wav2Vec2Model(config)
self.classifier = Wav2Vec2ClassificationHead(config)
self.init_weights()
def freeze_feature_extractor(self):
self.wav2vec2.feature_extractor._freeze_parameters()
def merged_strategy(
self,
hidden_states,
mode="mean"
):
if mode == "mean":
outputs = torch.mean(hidden_states, dim=1)
elif mode == "sum":
outputs = torch.sum(hidden_states, dim=1)
elif mode == "max":
outputs = torch.max(hidden_states, dim=1)[0]
else:
raise Exception(
"The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']")
return outputs
def forward(
self,
input_values,
attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.wav2vec2(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = self.merged_strategy(hidden_states, mode=self.pooling_mode)
logits = self.classifier(hidden_states)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SpeechClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def shift_ld(audio_features, ld_shift=0.0):
"""Shift loudness by a number of ocatves."""
audio_features['loudness_db'] += ld_shift
return audio_features
def shift_f0(audio_features, pitch_shift=0.0):
"""Shift f0 by a number of ocatves."""
audio_features['f0_hz'] *= 2.0 ** (pitch_shift)
audio_features['f0_hz'] = np.clip(audio_features['f0_hz'],
0.0,
librosa.midi_to_hz(110.0))
return audio_features
def timbre_transfer(filename, out_path, instrument="violin", sample_rate=16000):
audio, _ = librosa.load(filename, sr=sample_rate)
audio = audio[np.newaxis, :]
# Setup the session.
ddsp.spectral_ops.reset_crepe()
audio_features = ddsp.training.metrics.compute_audio_features(audio)
model_dir = f"models/ddsp/{instrument}"
gin_file = os.path.join(model_dir, 'operative_config-0.gin')
DATASET_STATS = None
dataset_stats_file = os.path.join(model_dir, 'dataset_statistics.pkl')
if tf.io.gfile.exists(dataset_stats_file):
with tf.io.gfile.GFile(dataset_stats_file, 'rb') as f:
DATASET_STATS = pickle.load(f)
with gin.unlock_config():
gin.parse_config_file(gin_file, skip_unknown=True)
ckpt_files = [f for f in tf.io.gfile.listdir(model_dir) if 'ckpt' in f]
ckpt_name = ckpt_files[0].split('.')[0]
ckpt = os.path.join(model_dir, ckpt_name)
time_steps_train = gin.query_parameter('F0LoudnessPreprocessor.time_steps')
n_samples_train = gin.query_parameter('Harmonic.n_samples')
hop_size = int(n_samples_train / time_steps_train)
time_steps = int(audio.shape[1] / hop_size)
n_samples = time_steps * hop_size
gin_params = [
'Harmonic.n_samples = {}'.format(n_samples),
'FilteredNoise.n_samples = {}'.format(n_samples),
'F0LoudnessPreprocessor.time_steps = {}'.format(time_steps),
'oscillator_bank.use_angular_cumsum = True', # Avoids cumsum accumulation errors.
]
with gin.unlock_config():
gin.parse_config(gin_params)
for key in ['f0_hz', 'f0_confidence', 'loudness_db']:
audio_features[key] = audio_features[key][:time_steps]
audio_features['audio'] = audio_features['audio'][:, :n_samples]
model = ddsp.training.models.Autoencoder()
model.restore(ckpt)
threshold = 1
ADJUST = True
quiet = 20
autotune = 0
pitch_shift = 0
loudness_shift = 0
audio_features_mod = {k: v.copy() for k, v in audio_features.items()}
mask_on = None
if ADJUST and DATASET_STATS is not None:
# Detect sections that are "on".
mask_on, note_on_value = detect_notes(audio_features['loudness_db'],
audio_features['f0_confidence'],
threshold)
if np.any(mask_on):
# Shift the pitch register.
target_mean_pitch = DATASET_STATS['mean_pitch']
pitch = ddsp.core.hz_to_midi(audio_features['f0_hz'])
mean_pitch = np.mean(pitch[mask_on])
p_diff = target_mean_pitch - mean_pitch
p_diff_octave = p_diff / 12.0
round_fn = np.floor if p_diff_octave > 1.5 else np.ceil
p_diff_octave = round_fn(p_diff_octave)
audio_features_mod = shift_f0(audio_features_mod, p_diff_octave)
# Quantile shift the note_on parts.
_, loudness_norm = fit_quantile_transform(
audio_features['loudness_db'],
mask_on,
inv_quantile=DATASET_STATS['quantile_transform'])
# Turn down the note_off parts.
mask_off = np.logical_not(mask_on)
loudness_norm[mask_off] -= quiet * (1.0 - note_on_value[mask_off][:, np.newaxis])
loudness_norm = np.reshape(loudness_norm, audio_features['loudness_db'].shape)
audio_features_mod['loudness_db'] = loudness_norm
# Auto-tune.
if autotune:
f0_midi = np.array(ddsp.core.hz_to_midi(audio_features_mod['f0_hz']))
tuning_factor = get_tuning_factor(f0_midi, audio_features_mod['f0_confidence'], mask_on)
f0_midi_at = auto_tune(f0_midi, tuning_factor, mask_on, amount=autotune)
audio_features_mod['f0_hz'] = ddsp.core.midi_to_hz(f0_midi_at)
else:
print('\nSkipping auto-adjust (no notes detected or ADJUST box empty).')
af = audio_features if audio_features_mod is None else audio_features_mod
outputs = model(af, training=False)
audio_gen = model.get_audio_from_outputs(outputs)
sf.write(out_path, audio_gen[0], sample_rate)
def pad_wave_mixing(file_name1, file_name2, out_path='mixed_audio.wav', sr=16000):
audio1, _ = librosa.load(file_name1, sr=sr)
audio2, _ = librosa.load(file_name2, sr=sr)
max_len = max(len(audio1), len(audio2))
audio1 = librosa.util.fix_length(audio1, size=max_len)
audio2 = librosa.util.fix_length(audio2, size=max_len)
mixed_audio = audio1 + audio2
sf.write(out_path, mixed_audio, sr)
def spotify_search(src, tgt, output_file_name, client_id, client_secret):
# request API access token
url = "https://accounts.spotify.com/api/token"
headers = {
"Content-Type": "application/x-www-form-urlencoded"
}
data = {
"grant_type": "client_credentials",
"client_id": client_id,
"client_secret": client_secret
}
response = requests.post(url, headers=headers, data=data)
if response.status_code == 200:
token_data = response.json()
access_token = token_data["access_token"]
print("Access Token:", access_token)
else:
print("Error:", response.status_code)
# POST query
query = ["remaster"]
for key in src:
if key in ["track", "album", "artist", "genre"]:
value = " ".join(src[key])
query.append(f"{key}:{value}")
if tgt == "playlist":
query[0] = src["description"][0]
query = " ".join(query).replace(" ", "%20")
query = urllib.parse.quote(query)
url = f"https://api.spotify.com/v1/search?query={query}&type={tgt}"
headers = {"Authorization": f"Bearer {access_token}"}
response = requests.get(url, headers=headers)
if response.status_code == 200:
data = response.json()[tgt + "s"]["items"][0]
text = dict()
spotify_id = data["id"]
text[tgt] = [data["name"]]
if tgt == "track":
url = data["preview_url"]
with open(output_file_name, "wb") as f:
f.write(requests.get(url).content)
text["album"] = [data["album"]["name"]]
text["artist"] = [d["name"] for d in data["artists"]]
if tgt == "album":
text["date"] = [data["release_date"]]
text["artist"] = [d["name"] for d in data["artists"]]
url = f"https://api.spotify.com/v1/albums/{spotify_id}"
album = requests.get(url, headers=headers).json()
if len(album["genres"]) > 0:
text["genre"] = album["genres"]
text["track"] = [d["name"] for d in album["tracks"]["items"]]
if tgt == "playlist":
url = f"https://api.spotify.com/v1/playlists/{spotify_id}"
album = requests.get(url, headers=headers).json()
text["track"] = [d["track"]["name"] for d in album["tracks"]["items"]]
if tgt == "artist":
if len(data["genres"]) > 0:
text["genre"] = data["genres"]
return text
else:
print('Response Failed: ', response.status_code)
return None
def lyric_format(text):
text = text.split('\n\n')
delimiters = "\n|,.;?!,。;、?!"
text = [re.split("["+delimiters+"]", chap) for chap in text]
i = 0
while i < len(text):
if len(text[i]) == 1:
text.pop(i)
continue
if len(text[i]) > 4:
text[i] = text[i][1:]
i += 1
return ' '.join([' '.join(chap) for chap in text]).split()

Просмотреть файл

@ -0,0 +1,33 @@
#!/bin/bash
# Set models to download
models=(
"m3hrdadfi/wav2vec2-base-100k-gtzan-music-genres"
"sander-wood/text-to-music"
"jonatasgrosman/whisper-large-zh-cv11"
)
# Set the current directory
CURRENT_DIR=$(pwd)
# Download models
for model in "${models[@]}"; do
echo "----- Downloading from https://huggingface.co/${model} -----"
if [ -d "${model}" ]; then
(cd "${model}" && git pull && git lfs pull)
else
git clone --recurse-submodules "https://huggingface.co/${model}" "${model}"
fi
done
# Set Git project to download
libs=(
"microsoft/muzic"
"MoonInTheRiver/DiffSinger"
)
for lib in "${libs[@]}"; do
echo "----- Downloading from https://github.com/${lib}.git -----"
git clone "https://github.com/${lib}.git"
done
cp -r ../auxiliary/* ./

615
copilot/plugins.py Normal file
Просмотреть файл

@ -0,0 +1,615 @@
""" Models and APIs"""
import uuid
import numpy as np
import importlib
from transformers import pipeline, AutoConfig, Wav2Vec2FeatureExtractor
from model_utils import Wav2Vec2ForSpeechClassification, timbre_transfer
from pydub import AudioSegment
import requests
import urllib
import librosa
import torch
import torch.nn.functional as F
# import torchaudio
from fairseq.models.transformer_lm import TransformerLanguageModel
import soundfile as sf
import os
import sys
import json
import pdb
def get_task_map():
task_map = {
"text-to-sheet-music": [
"sander-wood/text-to-music"
],
"music-classification": [
"m3hrdadfi/wav2vec2-base-100k-gtzan-music-genres"
],
"lyric-to-melody": [
"muzic/roc"
],
"lyric-to-audio": [
"DiffSinger"
],
"web-search": [
"google-search"
],
"artist-search": [
"spotify"
],
"track-search": [
"spotify"
],
"album-search": [
"spotify"
],
"playlist-search": [
"spotify"
],
"separate-track": [
"demucs"
],
"lyric-recognition": [
"jonatasgrosman/whisper-large-zh-cv11"
],
"score-transcription": [
"basic-pitch"
],
"timbre-transfer": [
"ddsp"
],
"accompaniment": [
"getmusic"
],
"audio-mixing": [
"basic-merge"
],
"audio-crop": [
"basic-crop"
],
"audio-splice": [
"basic-splice"
],
"web-search": [
"google-search"
],
}
return task_map
def init_plugins(config):
if config["disabled_tools"] is not None:
disabled = [tool.strip() for tool in config["disabled_tools"].split(",")]
else:
disabled = []
pipes = {}
if "muzic/roc" not in disabled:
pipes["muzic/roc"] = MuzicROC(config)
if "DiffSinger" not in disabled:
pipes["DiffSinger"] = DiffSinger(config)
if "m3hrdadfi/wav2vec2-base-100k-gtzan-music-genres" not in disabled:
pipes["m3hrdadfi/wav2vec2-base-100k-gtzan-music-genres"] = Wav2Vec2Base(config)
if "jonatasgrosman/whisper-large-zh-cv11" not in disabled:
pipes["jonatasgrosman/whisper-large-zh-cv11"] = WhisperZh(config)
if "spotify" not in disabled:
pipes["spotify"] = Spotify(config)
if "ddsp" not in disabled:
pipes["ddsp"] = DDSP(config)
if "demucs" not in disabled:
pipes["demucs"] = Demucs(config)
if "basic-merge" not in disabled:
pipes["basic-merge"] = BasicMerge(config)
if "basic-crop" not in disabled:
pipes["basic-crop"] = BasicCrop(config)
if "basic-splice" not in disabled:
pipes["basic-splice"] = BasicSplice(config)
if "basic-pitch" not in disabled:
pipes["basic-pitch"] = BasicPitch(config)
if "google-search" not in disabled:
pipes["google-search"] = GoogleSearch(config)
return pipes
class BaseToolkit:
def __init__(self, config):
self.local_fold = config["local_fold"]
self.id = "basic toolkit"
self.attributes = {}
def get_attributes(self):
return json.dumps(self.attributes)
def update_attributes(self, **kwargs):
for key in kwargs:
self.attributes[key] = kwargs[key]
class MuzicROC(BaseToolkit):
def __init__(self, config):
super().__init__(config)
self.id = "muzic/roc"
self.attributes = {
"description": "ROC is a new paradigm for lyric-to-melody generation"
}
self._init_toolkit(config)
def _init_toolkit(self, config):
sys.path.append(os.path.join(os.getcwd(), f"{self.local_fold}/muzic/roc"))
from main import main as roc_processer
self.processer = roc_processer
self.model = TransformerLanguageModel.from_pretrained(os.path.join(os.getcwd(), f"{self.local_fold}/muzic/roc/music-ckps/"), "checkpoint_best.pt", tokenizer="space",
batch_size=8192)
sys.path.remove(os.path.join(os.getcwd(), f"{self.local_fold}/muzic/roc"))
def inference(self, args, task, device="cpu"):
results = []
self.model.to(device)
for arg in args:
if "lyric" in arg:
prompt = arg["lyric"]
prompt = " ".join(prompt)
file_name = str(uuid.uuid4())[:4]
outputs = self.processer(
self.model,
[prompt],
output_file_name=f"public/audios/{file_name}",
db_path=f"{self.local_fold}/muzic/roc/database/ROC.db"
)
os.system(f"fluidsynth -l -ni -a file -z 2048 -F public/audios/{file_name}.wav 'MS Basic.sf3' public/audios/{file_name}.mid")
results.append(
{
"score": str(outputs),
"audio": f"{file_name}.wav",
"sheet_music": f"{file_name}.mid"
}
)
self.model.to("cpu")
return results
class DiffSinger(BaseToolkit):
def __init__(self, config):
super().__init__(config)
self.id = "DiffSinger"
self.attributes = {
"description": "Singing Voice Synthesis via Shallow Diffusion Mechanism",
"star": 3496
}
self._init_toolkit(config)
def _init_toolkit(self, config):
sys.path.append(os.path.join(os.getcwd(), f"{self.local_fold}/DiffSinger"))
import utils
importlib.reload(utils)
from inference.svs.ds_e2e import DiffSingerE2EInfer
from utils.hparams import hparams, set_hparams
from utils.audio import save_wav
work_dir = os.getcwd()
os.chdir(os.path.join(os.getcwd(), f"{self.local_fold}/DiffSinger"))
set_hparams('usr/configs/midi/e2e/opencpop/ds100_adj_rel.yaml',
'0228_opencpop_ds100_rel', print_hparams=False)
self.processer = save_wav
self.model = DiffSingerE2EInfer(hparams, device="cuda:0")
self.model.model.to("cpu")
self.model.vocoder.to("cpu")
os.chdir(work_dir)
sys.path.remove(os.path.join(os.getcwd(), f"{self.local_fold}/DiffSinger"))
def inference(self, args, task, device="cpu"):
results = []
self.model.model.to(device)
self.model.vocoder.to(device)
self.model.device = device
for arg in args:
if "score" in arg:
prompt = arg["score"]
prompt = eval(prompt)
wav = self.model.infer_once(prompt)
file_name = str(uuid.uuid4())[:4]
self.processer(wav, f"public/audios/{file_name}.wav", sr=16000)
results.append({"audio": f"{file_name}.wav"})
self.model.model.to("cpu")
self.model.vocoder.to("cpu")
self.model.device = "cpu"
return results
class Wav2Vec2Base(BaseToolkit):
def __init__(self, config):
super().__init__(config)
self.id = "m3hrdadfi/wav2vec2-base-100k-gtzan-music-genres"
self.attributes = {
"description": "Music Genre Classification using Wav2Vec 2.0"
}
self._init_toolkit(config)
def _init_toolkit(self, config):
self.processer = Wav2Vec2FeatureExtractor.from_pretrained(f"{self.local_fold}/m3hrdadfi/wav2vec2-base-100k-gtzan-music-genres")
self.model = Wav2Vec2ForSpeechClassification.from_pretrained(f"{self.local_fold}/m3hrdadfi/wav2vec2-base-100k-gtzan-music-genres")
self.config = AutoConfig.from_pretrained(f"{self.local_fold}/m3hrdadfi/wav2vec2-base-100k-gtzan-music-genres")
def inference(self, args, task, device="cpu"):
results = []
self.model.to(device)
for arg in args:
if "audio" in arg:
prompt = arg["audio"]
sampling_rate = self.processer.sampling_rate
#speech_array, _sampling_rate = torchaudio.load(prompt)
#resampler = torchaudio.transforms.Resample(_sampling_rate, sampling_rate)
#speech = resampler(speech_array).squeeze().numpy()
speech, _ = librosa.load(prompt, sr=sampling_rate)
inputs = self.processer(speech, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
logits = self.model(**inputs).logits
scores = F.softmax(logits, dim=1).detach().cpu().numpy()[0]
genre = self.config.id2label[np.argmax(scores)]
# outputs = [{"Label": pipes[pipe_id]["config"].id2label[i], "Score": f"{round(score * 100, 3):.1f}%"} for i, score in enumerate(scores)]
results.append({"genre": genre})
self.model.to("cpu")
return results
class WhisperZh(BaseToolkit):
def __init__(self, config):
super().__init__(config)
self.id = "jonatasgrosman/whisper-large-zh-cv11"
self.attributes = {
"description": "a fine-tuned version of openai/whisper-large-v2 on Chinese (Mandarin)"
}
self._init_toolkit(config)
def _init_toolkit(self, config):
self.model = pipeline("automatic-speech-recognition", model=f"{self.local_fold}/jonatasgrosman/whisper-large-zh-cv11", device="cuda:0")
self.model.model.to("cpu")
def inference(self, args, task, device="cpu"):
results = []
self.model.model.to(device)
for arg in args:
if "audio" in arg:
prompt = arg["audio"]
self.model.model.config.forced_decoder_ids = (
self.model.tokenizer.get_decoder_prompt_ids(
language="zh",
task="transcribe"
)
)
results.append({"lyric": self.model(prompt)})
self.model.model.to("cpu")
return results
class Spotify(BaseToolkit):
def __init__(self, config):
super().__init__(config)
self.id = "spotify"
self.attributes = {
"description": "Spotify is a digital music service that gives you access to millions of songs."
}
self._init_toolkit(config)
def _init_toolkit(self, config):
client_id = config["spotify"]["client_id"]
client_secret = config["spotify"]["client_secret"]
url = "https://accounts.spotify.com/api/token"
headers = {
"Content-Type": "application/x-www-form-urlencoded"
}
data = {
"grant_type": "client_credentials",
"client_id": client_id,
"client_secret": client_secret
}
response = requests.post(url, headers=headers, data=data)
if response.status_code == 200:
token_data = response.json()
self.access_token = token_data["access_token"]
print("Access Token:", self.access_token)
else:
print("Error:", response.status_code)
def inference(self, args, task, device="cpu"):
results = []
for arg in args:
tgt = task.split("-")[0]
query = ["remaster"]
for key in arg:
if key in ["track", "album", "artist", "genre"]:
if isinstance(arg[key], list):
value = " ".join(arg[key])
else:
value = arg[key]
query.append(f"{key}:{value}")
if tgt == "playlist":
query[0] = arg["description"]
query = " ".join(query).replace(" ", "%20")
query = urllib.parse.quote(query)
url = f"https://api.spotify.com/v1/search?query={query}&type={tgt}"
headers = {"Authorization": f"Bearer {self.access_token}"}
response = requests.get(url, headers=headers)
if response.status_code == 200:
data = response.json()[tgt + "s"]["items"][0]
text = dict()
spotify_id = data["id"]
text[tgt] = data["name"]
if tgt == "track":
if "preview_url" in data and len(data["preview_url"]) > 0:
url = data["preview_url"]
file_name = str(uuid.uuid4())[:4]
with open(f"public/audios/{file_name}.mp3", "wb") as f:
f.write(requests.get(url).content)
text["audio"] = f"{file_name}.mp3"
text["album"] = data["album"]["name"]
text["artist"] = [d["name"] for d in data["artists"]]
if tgt == "album":
text["date"] = data["release_date"]
text["artist"] = [d["name"] for d in data["artists"]]
url = f"https://api.spotify.com/v1/albums/{spotify_id}"
album = requests.get(url, headers=headers).json()
if len(album["genres"]) > 0:
text["genre"] = album["genres"]
text["track"] = [d["name"] for d in album["tracks"]["items"]]
if tgt == "playlist":
url = f"https://api.spotify.com/v1/playlists/{spotify_id}"
album = requests.get(url, headers=headers).json()
text["track"] = [d["track"]["name"] for d in album["tracks"]["items"]]
if tgt == "artist":
if len(data["genres"]) > 0:
text["genre"] = data["genres"]
results.append(text)
else:
results.append({"error": "No corresponding song found."})
return results
class GoogleSearch(BaseToolkit):
def __init__(self, config):
super().__init__(config)
self.id = "google"
self.attributes = {
"description": "Google Custom Search Engine."
}
self._init_toolkit(config)
def _init_toolkit(self, config):
api_key = config["google"]["api_key"]
custom_search_engine_id = config["google"]["custom_search_engine_id"]
self.url = "https://www.googleapis.com/customsearch/v1"
self.params = {
"key": api_key,
"cx": custom_search_engine_id,
"max_results": 5
}
def inference(self, args, task, device="cpu"):
results = []
for arg in args:
if "description" in arg:
self.params["q"] = arg["description"]
response = requests.get(self.url, self.params)
if response.status_code == 200:
data = response.json()
items = data.get("items")
descriptions = []
for item in items:
descriptions.append(
{
"title": item.get("title"),
"snippet": item.get("snippet")
}
)
results.append(
{
"description": json.dumps(descriptions)
}
)
return results
class Demucs(BaseToolkit):
def __init__(self, config):
super().__init__(config)
self.id = "demucs"
self.attributes = {
"description": "Demucs Music Source Separation"
}
def inference(self, args, task, device="cpu"):
results = []
for arg in args:
if "audio" in arg:
prompt = arg["audio"]
file_name = str(uuid.uuid4())[:4]
os.system(f"python -m demucs --two-stems=vocals -o public/audios/ {prompt}")
os.system(f"cp public/audios/htdemucs/{prompt.split('/')[-1].split('.')[0]}/no_vocals.wav public/audios/{file_name}.wav")
results.append({"audio": f"{file_name}.wav"})
return results
class BasicMerge(BaseToolkit):
def __init__(self, config):
super().__init__(config)
self.id = "pad-wave-merge"
self.attributes = {
"description": "Merge audios."
}
def inference(self, args, task, device="cpu"):
audios = []
sr=16000
for arg in args:
if "audio" in arg:
audio, _ = librosa.load(arg["audio"], sr=sr)
audios.append(audio)
max_len = max([len(audio) for audio in audios])
audios = [librosa.util.fix_length(audio, size=max_len) for audio in audios]
mixed_audio = sum(audios)
file_name = str(uuid.uuid4())[:4]
sf.write(f"public/audios/{file_name}.wav", mixed_audio, sr)
results = [{"audio": f"{file_name}.wav"}]
return results
class DDSP(BaseToolkit):
def __init__(self, config):
super().__init__(config)
self.id = "ddsp"
self.attributes = {
"description": "Convert audio between sound sources with pretrained models."
}
def inference(self, args, task, device="cpu"):
results = []
for arg in args:
if "audio" in arg:
prompt = arg["audio"]
file_name = str(uuid.uuid4())[:4]
timbre_transfer(prompt, f"public/audios/{file_name}.wav", instrument="violin")
results.append({"audio": f"{file_name}.wav"})
return results
class BasicPitch(BaseToolkit):
def __init__(self, config):
super().__init__(config)
self.id = "basic-pitch"
self.attributes = {
"description": "Demucs Music Source Separation"
}
def inference(self, args, task, device="cpu"):
results = []
for arg in args:
if "audio" in arg:
prompt = arg["audio"]
file_name = str(uuid.uuid4())[:4]
os.system(f"basic-pitch public/audios/ {prompt}")
os.system(f"cp public/audios/{prompt.split('/')[-1].split('.')[0]}_basic_pitch.mid public/audios/{file_name}.mid")
results.append({"sheet music": f"{file_name}.mid"})
return results
class BasicCrop(BaseToolkit):
def __init__(self, config):
super().__init__(config)
self.id = "audio-crop"
self.attributes = {
"description": "Trim audio based on time"
}
def inference(self, args, task, device="cpu"):
results = []
for arg in args:
if "audio" in arg and "time" in arg:
prompt = arg["audio"]
time = arg["time"]
file_name = str(uuid.uuid4())[:4]
audio = AudioSegment.from_file(prompt)
start_ms = int(float(time[0]) * 1000)
end_ms = int(float(time[1]) * 1000)
if start_ms < 0:
start_ms += len(audio)
if end_ms < 0:
end_ms += len(audio)
start_ms = max(start_ms, len(audio))
end_ms = max(end_ms, len(audio))
if start_ms > end_ms:
continue
trimmed_audio = audio[start_ms:end_ms]
trimmed_audio.export(f"public/audios/{file_name}.wav", format="wav")
results.append({"audio": f"{file_name}.wav"})
return results
class BasicSplice(BaseToolkit):
def __init__(self, config):
super().__init__(config)
self.id = "audio-splice"
self.attributes = {
"description": "Basic audio splice"
}
def inference(self, args, task, device="cpu"):
audios = []
results = []
for arg in args:
if "audio" in arg:
audios.append(arg["audio"])
audio = AudioSegment.from_file(audios[0])
for i in range(1, len(audios)):
audio = audio + AudioSegment.from_file(audios[i])
file_name = str(uuid.uuid4())[:4]
audio.export(f"public/audios/{file_name}.wav", format="wav")
results.append({"audio": f"{file_name}.wav"})
return results

44
copilot/requirements.txt Normal file
Просмотреть файл

@ -0,0 +1,44 @@
transformers==4.29.2
tiktoken==0.3.3
google-search-results==2.4.2
flask==2.2.3
flask_cors==3.0.10
waitress==2.1.2
speechbrain==0.5.14
datasets==2.11.0
timm==0.6.13
typeguard==2.13.3
accelerate==0.18.0
pytesseract==0.3.10
gradio==3.26.0
librosa==0.8.0
tensorflow==2.11.0
tensorflow_probability==0.19.0
tensorboardX==2.6
pyloudnorm==0.1.1
g2p_en==2.1.0
inflect==5.3.0
pydantic==1.8.2
webrtcvad==2.0.10
scikit-learn==0.24.1
scikit-image==0.16.2
textgrid==1.5
jiwer==3.0.1
pycwt==0.3.0a22
praat-parselmouth==0.3.3
jieba==0.42.1
einops==0.6.1
pretty-midi==0.2.9
h5py==3.1.0
pypinyin==0.39.0
g2pM==0.1.2.5
cnsenti==0.0.7
midiutil==1.2.1
miditoolkit==0.1.16
textblob==0.17.1
basic-pitch==0.2.5
hydra-core==1.0.7
ddsp==1.6.0
torch==1.12.1
fairseq==0.12.0
demucs==4.0.0

Просмотреть файл

@ -0,0 +1,15 @@
{
"schema": 1,
"description": "",
"type": "completion",
"completion": {
"max_tokens": 800,
"temperature": 0.9,
"top_p": 0.0,
"presence_penalty": 0.6,
"frequency_penalty": 0.0,
"stop_sequences": [
"[Done]"
]
}
}

Просмотреть файл

@ -0,0 +1,7 @@
This is a conversation.
[CONTEXT]
user: I have a question. Can you help?
assistant: Of course. Go on!
[END CONTEXT]
user: {{$input}}
assistant:

Просмотреть файл

@ -0,0 +1,15 @@
{
"schema": 1,
"description": "",
"type": "completion",
"completion": {
"max_tokens": 800,
"temperature": 0.9,
"top_p": 0.0,
"presence_penalty": 0.6,
"frequency_penalty": 0.0,
"stop_sequences": [
"[Done]"
]
}
}

Просмотреть файл

@ -0,0 +1,8 @@
Response Generation Stage: With the task execution logs, the AI assistant needs to describe the process and inference results.
[DONE]
{{$history}}
[DONE]
user: {{$input}}.
assistant: I want to introduce my workflow for your request, which is shown in the following JSON data: {{$processes}}.
user: Please first think carefully and directly answer my request based on the inference results. Then detail your workflow including the used methods and inference results for my request in your friendly tone. Please filter out information that is not relevant to my request. Tell me the complete path or urls of files in inference results. If there is nothing in the results, please tell me you can't make it. Answer in the language of {{$input}}.
assistant:

Просмотреть файл

@ -0,0 +1,15 @@
{
"schema": 1,
"description": "",
"type": "completion",
"completion": {
"max_tokens": 800,
"temperature": 0.9,
"top_p": 0.0,
"presence_penalty": 0.6,
"frequency_penalty": 0.0,
"stop_sequences": [
"[Done]"
]
}
}

Просмотреть файл

@ -0,0 +1,28 @@
Task Planning Stage: The AI assistant can parse user input to several tasks: [{"task": task, "id": task_id, "dep": dependency_task_id, "args": list of {"track", "album", "artist", "description", "lyric", "score", "date", "genre", "language", and the value is text or <GENERATED>-dep_id, "audio": audio_url or <GENERATED>-dep_id, "sheet_music": midi_url or <GENERATED>-dep_id}}]. The special tag "<GENERATED>-dep_id" refer to the one generated text/audio/midi in the dependency task. Consider whether the dependency task generates resources of this type. and "dep_id" must be in "dep" list. The "args" field must in ["text", "audio", "midi"]. The task MUST be selected from: {{$tasks}}. Think step by step about all the tasks needed to resolve the user's request. Parse out as few tasks as possible. If the user input can't be parsed, you need to reply empty JSON [].
[CONTEXT]
user: 生成一首古风歌词的中文歌
assistant: [{"task": "lyric-generation", "id": 0, "dep": [-1], "args": [{"description": "古风歌词的中文歌"}]}, {"task": "lyric-to-melody", "id": 1, "dep": [0], "args": [{"lyric": "<GENERATED>-0"}]}, {"task": "lyric-to-audio", "id": 2, "dep": [1], "args": [{"score": "<GENERATED>-1"}]}, {"task": "audio-mixing", "id": 3, "dep": [1,2], "args": [{"audio": "<GENERATED>-1"}, {"audio": "<GENERATED>-2"}]}]
user: Give me the sheet music and lyrics in the song /a.wav
assistant: [{"task": "separate-track", "id": 0, "dep": [-1], "args": [{"audio": "/a.wav"}]}, {"task": "score-transcription", "id": 1, "dep": [0], "args": [{"audio": "<GENERATED>-0"}]}, {"task": "lyric-recognition", "id": 2, "dep": [0], "args": [{"audio": "<GENERATED>-0"}]}]
user: I want a rock song by Imagine Dragons.
assistant: [{"task": "track-search", "id": 0, "dep": [-1], "args": [{"artist": ["Imagine Dragons"], "genre": "rock"}]}]
user: 分析近期热门音乐风格趋势.
assistant: [{"task": "playlist-search", "id": 0, "dep": [-1], "args": [{"description": "近期热门音乐"}]}, {"task": "music-classification", "id": 1, "dep": [0], "args": [{"audio": "<GENERATED>-0"}]}]
user: 请给我周杰伦的千里之外的伴奏.
assistant: [{"task": "track-search", "id": 0, "dep": [-1], "args": [{"track": "千里之外", "artist": ["周杰伦"], "language": "zh"}]}, {"task": "separate-track", "id": 1, "dep": [0], "args": [{"audio": "<GENERATED>-0"}]}]
user: Rewind the music in /c.wav.
assistant: [{"task": "separate-track", "id": 0, "dep": [-1], "args": [{"audio": "/c.wav"}]}, {"task": "score-transcription", "id": 1, "dep": [0], "args": [{"audio": ["<GENERATED>-0"]}]}, {"task": "reverse-music", "id": 2, "dep": [1], "args": [{"sheet_music": "<GENERATED>-1"}]}]
user: Convert the vocals in /b.wav to a violin sound.
assistant: [{"task": "separate-track", "id": 0, "dep": [-1], "args": [{"audio": "/b.wav"}]}, {"task": "timbre-transfer", "id": 1, "dep": [0], "args": [{"audio": "<GENERATED>-0"}]}]
user: Splice the first 10s of /a.wav with the last 10s of /b.wav
assistant: [{"task": "audio-crop", "id": 0, "dep": [-1], "args": [{"time": ["0.0", "10.0"], "audio": "/a.wav"}]}, {"task": "audio-crop", "id": 1, "dep": [-1], "args": [{"time": ["-10.00", "-0.00"], "audio": "/b.wav"}]}, {"task": "audio-splice", "id": 2, "dep": [0,1], "args": [{"audio": "<GENERATED>-0"}, {"audio": "<GENERATED>-1"}]}]
user: Turning speech in /a.wav into music.
assistant: [{"task": "lyric-recognition", "id": 0, "dep": [-1], "args": [{"audio": "/a.wav"}]}, {"task": "lyric-to-melody", "id": 1, "dep": [0], "args": [{"lyric": "<GENERATED>-0"}]}, {"task": "lyric-to-audio", "id": 2, "dep": [1], "args": [{"score": "<GENERATED>-1"}]}, {"task": "audio-mixing", "id": 3, "dep": [1,2], "args": [{"audio": "<GENERATED>-1"}, {"audio": "<GENERATED>-2"}]}]
user: Write a piece of lyric about the World Cup.
assistant: [{"task": "web-search", "id": 0, "dep": [-1], "args": [{"description": "World Cup"}]}, {"task": "lyric-generation", "id": 1, "dep": [0], "args": [{"description": "<GENERATED>-0"}]}]
[END CONTEXT]
[DONE]
{{$history}}
[DONE]
user: {{$input}}
assistant:

Просмотреть файл

@ -0,0 +1,15 @@
{
"schema": 1,
"description": "",
"type": "completion",
"completion": {
"max_tokens": 800,
"temperature": 0.9,
"top_p": 0.0,
"presence_penalty": 0.6,
"frequency_penalty": 0.0,
"stop_sequences": [
"[Done]"
]
}
}

Просмотреть файл

@ -0,0 +1,7 @@
Given the user request and the parsed tasks, the AI assistant helps the user to select a suitable model from a list of models to process the user request.
[CONTEXT]
user: Determine the style of music /a.wav.
assistant: {"id": "", "reason": ""}
[END CONTEXT]
user: Please choose the most suitable tool from {{$tools}} for the task {{$task}} with the input [{{$input}}]. Please focus on the {{$focus}}. The output must be in a strict JSON format: {"id": "id", "reason": "your detail reasons for the choice"}.
assistant: