This commit is contained in:
UranusYu 2023-10-16 13:09:28 +08:00
Родитель efa2d4c0bc
Коммит 6f23864a27
20 изменённых файлов: 506 добавлений и 261 удалений

Просмотреть файл

Просмотреть файл

@ -1,4 +1,4 @@
<!-- <p align="center"> <b> Music Pilot </b> </p> -->
<!-- <p align="center"> <b> Music Agent </b> </p> -->
<div align="center">
@ -9,13 +9,16 @@
## Demo Video
![Download demo video](https://drive.google.com/file/d/1W0iJPHNPA6ENLJrPef0vtQytboSubxXe/view?usp=sharing)
[![Watch the video](https://img.youtube.com/vi/tpNynjdcBqA/maxresdefault.jpg)](https://youtu.be/tpNynjdcBqA)
## Features
- Accessibility: Music Pilot dynamically selects the most appropriate methods for each music-related task.
- Unity: Music Pilot unifies a wide array of tools into a single system, incorporating Huggingface models, GitHub projects, and Web APIs.
- Modularity: Music Pilot offers high modularity, allowing users to effortlessly enhance its capabilities by integrating new functions.
- Accessibility: Music Agent dynamically selects the most appropriate methods for each music-related task.
- Unity: Music Agent unifies a wide array of tools into a single system, incorporating Huggingface models, GitHub projects, and Web APIs.
- Modularity: Music Agent offers high modularity, allowing users to effortlessly enhance its capabilities by integrating new functions.
## Skills
## Installation
@ -38,10 +41,11 @@ sudo apt-get install -y git-lfs
sudo apt-get install -y libsndfile1-dev
sudo apt-get install -y fluidsynth
sudo apt-get install -y ffmpeg
sudo apt-get install -y lilypond
# Clone the repository from TODO
git clone https://github.com/TODO
cd DIR
# Clone the repository from muzic
git clone https://github.com/muzic
cd muzic/agent
```
Next, install the dependent libraries. There might be some conflicts, but they should not affect the functionality of the system.
@ -49,8 +53,8 @@ Next, install the dependent libraries. There might be some conflicts, but they s
```bash
pip install --upgrade pip
pip install -r requirements.txt
pip install semantic-kernel
pip install -r requirements.txt
pip install numpy==1.23.0
pip install protobuf==3.20.3
```

Просмотреть файл

@ -18,7 +18,7 @@ from semantic_kernel.connectors.ai.open_ai import AzureTextCompletion, OpenAITex
from model_utils import lyric_format
from plugins import get_task_map, init_plugins
class MusicPilotAgent:
class MusicAgent:
"""
Attributes:
config_path: A path to a YAML file, referring to the example config.yaml
@ -64,7 +64,7 @@ class MusicPilotAgent:
def _init_semantic_kernel(self):
skills_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "skills")
pilot_funcs = self.kernel.import_semantic_skill_from_directory(skills_directory, "MusicPilot")
pilot_funcs = self.kernel.import_semantic_skill_from_directory(skills_directory, "MusicAgent")
# task planning
self.task_planner = pilot_funcs["TaskPlanner"]
@ -168,6 +168,9 @@ class MusicPilotAgent:
return result
def run_task(self, input_text, command, results):
if self.error_event.is_set():
return
id = command["id"]
args = command["args"]
task = command["task"]
@ -226,7 +229,7 @@ class MusicPilotAgent:
inference_result = []
for arg in command["args"]:
chat_input = f"[{input_text}] contains a task in JSON format {command}. Now you are a {command['task']} system, the arguments are {arg}. Just help me do {command['task']} and give me the resultwithout any additional description. The result must be in text form without any urls."
chat_input = f"[{input_text}] contains a task in JSON format {command}. Now you are a {command['task']} system, the arguments are {arg}. Just help me do {command['task']} and give me the result without any additional description."
response = self.skillchat(chat_input, self.chatbot, self.chat_context)
inference_result.append({"lyric":lyric_format(response)})
@ -263,7 +266,12 @@ class MusicPilotAgent:
inference_result = self.model_inference(best_model_id, command, device=self.config["device"])
results[id] = self.collect_result(command, choose, inference_result)
return True
for result in inference_result:
if "error" in result:
self.error_event.set()
break
return
def chat(self, input_text):
start = time.time()
@ -277,19 +285,22 @@ class MusicPilotAgent:
except Exception as e:
self.logger.debug(e)
response = self.skillchat(input_text, self.chatbot, self.chat_context)
return response
return response, {"0": "Task parsing error, reply using ChatGPT."}
if len(tasks) == 0:
response = self.skillchat(input_text, self.chatbot, self.chat_context)
return response
return response, {"0": "No task detected, reply using ChatGPT."}
tasks = self.fix_depth(tasks)
results = {}
threads = []
d = dict()
retry = 0
self.error_event = threading.Event()
while True:
num_thread = len(threads)
if self.error_event.is_set():
break
for task in tasks:
# logger.debug(f"d.keys(): {d.keys()}, dep: {dep}")
for dep_id in task["dep"]:
@ -326,10 +337,10 @@ class MusicPilotAgent:
end = time.time()
during = end - start
self.logger.info(f"time: {during}s")
return response
return response, results
def parse_args():
parser = argparse.ArgumentParser(description="A path to a YAML file")
parser = argparse.ArgumentParser(description="music agent config")
parser.add_argument("--config", type=str, help="a YAML file path.")
args = parser.parse_args()
@ -337,10 +348,10 @@ def parse_args():
if __name__ == "__main__":
args = parse_args()
agent = MusicPilotAgent(args.config, mode="cli")
agent = MusicAgent(args.config, mode="cli")
print("Input exit or quit to stop the agent.")
while True:
message = input("Send a message: ")
message = input("User input: ")
if message in ["exit", "quit"]:
break

Просмотреть файл

Просмотреть файл

205
agent/gradio_agent.py Normal file
Просмотреть файл

@ -0,0 +1,205 @@
import uuid
import os
import gradio as gr
import re
import requests
from agent import MusicAgent
import soundfile
import argparse
all_messages = []
OPENAI_KEY = ""
def add_message(content, role):
message = {"role": role, "content": content}
all_messages.append(message)
def extract_medias(message):
# audio_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?\.(flac|wav|mp3)")
audio_pattern = re.compile(r"(http(s?):|\/)?[a-zA-Z0-9\/.:-]*\.(flac|wav|mp3)")
symbolic_button = re.compile(r"(http(s?):|\/)?[a-zA-Z0-9\/.:-]*\.(mid)")
audio_urls = []
for match in audio_pattern.finditer(message):
if match.group(0) not in audio_urls:
audio_urls.append(match.group(0))
symbolic_urls = []
for match in symbolic_button.finditer(message):
if match.group(0) not in symbolic_urls:
symbolic_urls.append(match.group(0))
return list(set(audio_urls)), list(set(symbolic_urls))
def set_openai_key(openai_key):
global OPENAI_KEY
OPENAI_KEY = openai_key
agent._init_backend_from_input(openai_key)
if not OPENAI_KEY.startswith("sk-"):
return "OpenAI API Key starts with sk-", gr.update(visible=False)
return OPENAI_KEY, gr.update(visible=True)
def add_text(messages, message):
add_message(message, "user")
messages = messages + [(message, None)]
audio_urls, _ = extract_medias(message)
for audio_url in audio_urls:
if audio_url.startswith("http"):
ext = audio_url.split(".")[-1]
name = f"{str(uuid.uuid4()[:4])}.{ext}"
response = requests.get(audio_url)
with open(f"{agent.config['src_fold']}/{name}", "wb") as f:
f.write(response.content)
messages = messages + [(None, f"{audio_url} is saved as {name}")]
return messages, ""
def upload_audio(file, messages):
file_name = str(uuid.uuid4())[:4]
audio_load, sr = soundfile.read(file.name)
soundfile.write(f"{agent.config['src_fold']}/{file_name}.wav", audio_load, samplerate=sr)
messages = messages + [(None, f"Audio is stored in wav format as ** {file_name}.wav **"),
(None, (f"{agent.config['src_fold']}/{file_name}.wav",))]
return messages
def bot(messages):
message, results = agent.chat(messages[-1][0])
audio_urls, symbolic_urls = extract_medias(message)
add_message(message, "assistant")
messages[-1][1] = message
for audio_url in audio_urls:
if not audio_url.startswith("http") and not audio_url.startswith(agent.config['src_fold']):
audio_url = os.path.join(agent.config['src_fold'], audio_url)
messages = messages + [(None, f"** {audio_url.split('/')[-1]} **"),
(None, (audio_url,))]
for symbolic_url in symbolic_urls:
if not symbolic_url.startswith(agent.config['src_fold']):
symbolic_url = os.path.join(agent.config['src_fold'], symbolic_url)
try:
os.system(f"midi2ly {symbolic_url} -o {symbolic_url}.ly; lilypond -f png -o {symbolic_url} {symbolic_url}.ly")
except:
continue
messages = messages + [(None, f"** {symbolic_url.split('/')[-1]} **")]
if os.path.exists(f"{symbolic_url}.png"):
messages = messages + [ (None, (f"{symbolic_url}.png",))]
else:
s_page = 1
while os.path.exists(f"{symbolic_url}-page{s_page}.png"):
messages = messages + [ (None, (f"{symbolic_url}-page{s_page}.png",))]
s_page += 1
def truncate_strings(obj, max_length=128):
if isinstance(obj, str):
if len(obj) > max_length:
return obj[:max_length] + "..."
else:
return obj
elif isinstance(obj, dict):
return {key: truncate_strings(value, max_length) for key, value in obj.items()}
elif isinstance(obj, list):
return [truncate_strings(item, max_length) for item in obj]
else:
return obj
results = truncate_strings(results)
results = sorted(results.items(), key=lambda x: int(x[0]))
response = [(None, "\n\n".join([f"Subtask {r[0]}:\n{r[1]}" for r in results]))]
return messages, response
def clear_all_history(messages):
agent.clear_history()
messages = messages + [((None, "All LLM history cleared"))]
return messages
def parse_args():
parser = argparse.ArgumentParser(description="music agent config")
parser.add_argument("-c", "--config", type=str, help="a YAML file path.")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
agent = MusicAgent(args.config, mode="gradio")
with gr.Blocks() as demo:
gr.HTML("""
<h1 align="center" style=" display: flex; flex-direction: row; justify-content: center; font-size: 25pt; ">🎧 Music Agent</h1>
<h3>This is a demo page for Music Agent, a project that uses LLM to integrate music tools. For specific functions, please refer to the examples given below, or refer to the instructions in Github.</h3>
<h3>Make sure the uploaded audio resource is in flac|wav|mp3 format.</h3>
<h3>Due to RPM limitations, Music Agent requires an OpenAI key for the paid version.</h3>
<div style="display: flex;"><a href='https://github.com/microsoft/muzic/tree/main/copilot'><img src='https://img.shields.io/badge/Github-Code-blue'></a></div>
""")
with gr.Row():
openai_api_key = gr.Textbox(
show_label=False,
placeholder="Set your OpenAI API key here and press Enter",
lines=1,
type="password",
)
state = gr.State([])
with gr.Row(visible=False) as interact_window:
with gr.Column(scale=0.7, min_width=500):
chatbot = gr.Chatbot([], elem_id="chatbot", label="Music-Agent Chatbot").style(height=500)
with gr.Tab("User Input"):
with gr.Row(scale=1):
with gr.Column(scale=0.6):
txt = gr.Textbox(show_label=False, placeholder="Press ENTER or click the Run button. You can start by asking 'What can you do?'").style(container=False)
with gr.Column(scale=0.1, min_width=0):
run = gr.Button("🏃Run")
with gr.Column(scale=0.1, min_width=0):
clear_txt = gr.Button("🔄Clear")
with gr.Column(scale=0.2, min_width=0):
btn = gr.UploadButton("Upload Audio", file_types=["audio"])
with gr.Column(scale=0.3, min_width=300):
with gr.Tab("Intermediate Results"):
response = gr.Chatbot([], label="Current Progress").style(height=400)
openai_api_key.submit(set_openai_key, [openai_api_key], [openai_api_key, interact_window])
clear_txt.click(clear_all_history, [chatbot], [chatbot])
btn.upload(upload_audio, [btn, chatbot], [chatbot])
run.click(add_text, [chatbot, txt], [chatbot, txt]).then(
bot, chatbot, [chatbot, response]
)
txt.submit(add_text, [chatbot, txt], [chatbot, txt]).then(
bot, chatbot, [chatbot, response]
)
gr.Examples(
examples=["What can you do?",
"Write a piece of lyric about the recent World Cup.",
"生成一首古风歌词的中文歌",
"Download a song by Jay Chou for me and separate the vocals and the accompanies.",
"Convert the vocals in /b.wav to a violin sound.",
"Give me the sheet music and lyrics in the song /a.wav",
"近一个月流行的音乐类型",
"把c.wav中的人声搭配合适的旋律变成一首歌"
],
inputs=txt
)
demo.launch(share=True)

Просмотреть файл

Просмотреть файл

@ -3,8 +3,11 @@
# Set models to download
models=(
"m3hrdadfi/wav2vec2-base-100k-gtzan-music-genres"
"lewtun/distilhubert-finetuned-music-genres"
"dima806/music_genres_classification"
"sander-wood/text-to-music"
"jonatasgrosman/whisper-large-zh-cv11"
"cvssp/audioldm-m-full"
)
# Set the current directory

Просмотреть файл

@ -8,11 +8,12 @@ from pydub import AudioSegment
import requests
import urllib
import librosa
import re
import torch
import torch.nn.functional as F
# import torchaudio
from fairseq.models.transformer_lm import TransformerLanguageModel
from diffusers import AudioLDMPipeline
import soundfile as sf
import os
import sys
@ -25,11 +26,16 @@ def get_task_map():
"text-to-sheet-music": [
"sander-wood/text-to-music"
],
"text-to-audio": [
"cvssp/audioldm-m-full"
],
"music-classification": [
"m3hrdadfi/wav2vec2-base-100k-gtzan-music-genres"
"lewtun/distilhubert-finetuned-music-genres",
"dima806/music_genres_classification"
],
"lyric-to-melody": [
"muzic/roc"
"muzic/roc",
"muzic/telemelody"
],
"lyric-to-audio": [
"DiffSinger"
@ -90,10 +96,16 @@ def init_plugins(config):
pipes = {}
if "muzic/roc" not in disabled:
pipes["muzic/roc"] = MuzicROC(config)
if "cvssp/audioldm-m-full" not in disabled:
pipes["cvssp/audioldm-m-full"] = AudioLDM(config)
if "DiffSinger" not in disabled:
pipes["DiffSinger"] = DiffSinger(config)
if "m3hrdadfi/wav2vec2-base-100k-gtzan-music-genres" not in disabled:
pipes["m3hrdadfi/wav2vec2-base-100k-gtzan-music-genres"] = Wav2Vec2Base(config)
# if "m3hrdadfi/wav2vec2-base-100k-gtzan-music-genres" not in disabled:
# pipes["m3hrdadfi/wav2vec2-base-100k-gtzan-music-genres"] = Wav2Vec2Base(config)
if "dima806/music_genres_classification" not in disabled:
pipes["dima806/music_genres_classification"] = Text2AudioDima(config)
if "lewtun/distilhubert-finetuned-music-genres" not in disabled:
pipes["lewtun/distilhubert-finetuned-music-genres"] = Text2AudioLewtun(config)
if "jonatasgrosman/whisper-large-zh-cv11" not in disabled:
pipes["jonatasgrosman/whisper-large-zh-cv11"] = WhisperZh(config)
if "spotify" not in disabled:
@ -129,6 +141,22 @@ class BaseToolkit:
for key in kwargs:
self.attributes[key] = kwargs[key]
def mount_model(self, model, device):
try:
model.to(device)
except:
model.device = torch.device(device)
model.model.to(device)
def detach_model(self, model):
try:
model.to("cpu")
torch.cuda.empty_cache()
except:
model.device = torch.device("cpu")
model.model.to("cpu")
torch.cuda.empty_cache()
class MuzicROC(BaseToolkit):
def __init__(self, config):
@ -155,9 +183,10 @@ class MuzicROC(BaseToolkit):
if "lyric" in arg:
prompt = arg["lyric"]
prompt = " ".join(prompt)
prompt = re.sub("[^\u4e00-\u9fa5]", "", prompt)
file_name = str(uuid.uuid4())[:4]
try:
outputs = self.processer(
self.model,
[prompt],
@ -165,6 +194,10 @@ class MuzicROC(BaseToolkit):
db_path=f"{self.local_fold}/muzic/roc/database/ROC.db"
)
os.system(f"fluidsynth -l -ni -a file -z 2048 -F public/audios/{file_name}.wav 'MS Basic.sf3' public/audios/{file_name}.mid")
except:
results.append({"error": "Lyric-to-melody Error"})
continue
results.append(
{
"score": str(outputs),
@ -178,6 +211,78 @@ class MuzicROC(BaseToolkit):
return results
class Text2AudioDima(BaseToolkit):
def __init__(self, config):
super().__init__(config)
self.id = "dima806/music_genres_classification"
self.attributes = {
"description": "The model is trained based on publicly available dataset of labeled music data — GTZAN Dataset",
"genres": "blues,classical,country,disco,hip-hop,jazz,metal,pop,reggae,rock.",
"downloads": 60
}
def _init_toolkit(self, config):
self.pipe = pipeline("audio-classification", model=os.path.join(self.local_fold, self.id), device="cpu")
def inference(self, args, task, device="cpu"):
self.mount_model(self.pipe, device)
results = []
for arg in args:
if "audio" in arg:
prompt = arg["audio"]
sampling_rate = self.pipe.feature_extractor.sampling_rate
audio, _ = librosa.load(prompt, sr=sampling_rate)
try:
output = self.pipe(audio)
genre = output[0]["label"]
except:
results.append({"error": "Genres Classification Error"})
continue
results.append({"genre": genre})
self.detach_model(self.pipe)
return results
class Text2AudioLewtun(BaseToolkit):
def __init__(self, config):
super().__init__(config)
self.id = "lewtun/distilhubert-finetuned-music-genres"
self.attributes = {
"description": "This model is a fine-tuned version of ntu-spml/distilhubert on the None dataset.",
"genres": "Pop,Classical,International,Ambient Electronic,Folk",
"downloads": 202
}
def _init_toolkit(self, config):
self.pipe = pipeline("audio-classification", model=os.path.join(self.local_fold, self.id), device="cpu")
def inference(self, args, task, device="cpu"):
self.mount_model(self.pipe, device)
results = []
for arg in args:
if "audio" in arg:
prompt = arg["audio"]
sampling_rate = self.pipe.feature_extractor.sampling_rate
audio, _ = librosa.load(prompt, sr=sampling_rate)
try:
output = self.pipe(audio)
genre = output[0]["label"]
except:
results.append({"error": "Genres Classification Error"})
continue
results.append({"genre": genre})
self.detach_model(self.pipe)
return results
class DiffSinger(BaseToolkit):
def __init__(self, config):
super().__init__(config)
@ -217,9 +322,13 @@ class DiffSinger(BaseToolkit):
prompt = arg["score"]
prompt = eval(prompt)
try:
wav = self.model.infer_once(prompt)
file_name = str(uuid.uuid4())[:4]
self.processer(wav, f"public/audios/{file_name}.wav", sr=16000)
except:
results.append({"error": "Singing Voice Synthesis Error"})
continue
results.append({"audio": f"{file_name}.wav"})
@ -229,46 +338,46 @@ class DiffSinger(BaseToolkit):
return results
class Wav2Vec2Base(BaseToolkit):
def __init__(self, config):
super().__init__(config)
self.id = "m3hrdadfi/wav2vec2-base-100k-gtzan-music-genres"
self.attributes = {
"description": "Music Genre Classification using Wav2Vec 2.0"
}
self._init_toolkit(config)
# class Wav2Vec2Base(BaseToolkit):
# def __init__(self, config):
# super().__init__(config)
# self.id = "m3hrdadfi/wav2vec2-base-100k-gtzan-music-genres"
# self.attributes = {
# "description": "Music Genre Classification using Wav2Vec 2.0"
# }
# self._init_toolkit(config)
def _init_toolkit(self, config):
self.processer = Wav2Vec2FeatureExtractor.from_pretrained(f"{self.local_fold}/m3hrdadfi/wav2vec2-base-100k-gtzan-music-genres")
self.model = Wav2Vec2ForSpeechClassification.from_pretrained(f"{self.local_fold}/m3hrdadfi/wav2vec2-base-100k-gtzan-music-genres")
self.config = AutoConfig.from_pretrained(f"{self.local_fold}/m3hrdadfi/wav2vec2-base-100k-gtzan-music-genres")
# def _init_toolkit(self, config):
# self.processer = Wav2Vec2FeatureExtractor.from_pretrained(f"{self.local_fold}/m3hrdadfi/wav2vec2-base-100k-gtzan-music-genres")
# self.model = Wav2Vec2ForSpeechClassification.from_pretrained(f"{self.local_fold}/m3hrdadfi/wav2vec2-base-100k-gtzan-music-genres")
# self.config = AutoConfig.from_pretrained(f"{self.local_fold}/m3hrdadfi/wav2vec2-base-100k-gtzan-music-genres")
def inference(self, args, task, device="cpu"):
results = []
self.model.to(device)
# def inference(self, args, task, device="cpu"):
# results = []
# self.model.to(device)
for arg in args:
if "audio" in arg:
prompt = arg["audio"]
# for arg in args:
# if "audio" in arg:
# prompt = arg["audio"]
sampling_rate = self.processer.sampling_rate
#speech_array, _sampling_rate = torchaudio.load(prompt)
#resampler = torchaudio.transforms.Resample(_sampling_rate, sampling_rate)
#speech = resampler(speech_array).squeeze().numpy()
speech, _ = librosa.load(prompt, sr=sampling_rate)
inputs = self.processer(speech, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
# sampling_rate = self.processer.sampling_rate
# #speech_array, _sampling_rate = torchaudio.load(prompt)
# #resampler = torchaudio.transforms.Resample(_sampling_rate, sampling_rate)
# #speech = resampler(speech_array).squeeze().numpy()
# speech, _ = librosa.load(prompt, sr=sampling_rate)
# inputs = self.processer(speech, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
# inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
logits = self.model(**inputs).logits
# with torch.no_grad():
# logits = self.model(**inputs).logits
scores = F.softmax(logits, dim=1).detach().cpu().numpy()[0]
genre = self.config.id2label[np.argmax(scores)]
# outputs = [{"Label": pipes[pipe_id]["config"].id2label[i], "Score": f"{round(score * 100, 3):.1f}%"} for i, score in enumerate(scores)]
results.append({"genre": genre})
# scores = F.softmax(logits, dim=1).detach().cpu().numpy()[0]
# genre = self.config.id2label[np.argmax(scores)]
# # outputs = [{"Label": pipes[pipe_id]["config"].id2label[i], "Score": f"{round(score * 100, 3):.1f}%"} for i, score in enumerate(scores)]
# results.append({"genre": genre})
self.model.to("cpu")
return results
# self.model.to("cpu")
# return results
class WhisperZh(BaseToolkit):
@ -292,6 +401,7 @@ class WhisperZh(BaseToolkit):
if "audio" in arg:
prompt = arg["audio"]
try:
self.model.model.config.forced_decoder_ids = (
self.model.tokenizer.get_decoder_prompt_ids(
language="zh",
@ -299,6 +409,9 @@ class WhisperZh(BaseToolkit):
)
)
results.append({"lyric": self.model(prompt)})
except:
results.append({"error": "Lyric-recognition Error"})
continue
self.model.model.to("cpu")
return results
@ -340,26 +453,26 @@ class Spotify(BaseToolkit):
for arg in args:
tgt = task.split("-")[0]
query = ["remaster"]
for key in arg:
if key in ["track", "album", "artist", "genre"]:
if isinstance(arg[key], list):
value = " ".join(arg[key])
else:
value = arg[key]
query.append(f"{key}:{value}")
url = "https://api.spotify.com/v1/"
endpoint = "search"
if tgt == "playlist":
query[0] = arg["description"]
data = {
"q": " ".join([f"{key}:{value}" for key, value in arg.items() if key in ["track", "album", "artist", "genre", "description"]]),
"type" : [tgt]
}
query = " ".join(query).replace(" ", "%20")
query = urllib.parse.quote(query)
url = f"https://api.spotify.com/v1/search?query={query}&type={tgt}"
headers = {"Authorization": f"Bearer {self.access_token}"}
response = requests.get(url, headers=headers)
headers = {
"Authorization": f"Bearer {self.access_token}",
"Accept" : "application/json",
"Content-Type" : "application/json"
}
response = requests.get(url=url+endpoint, params=data, headers=headers)
if response.status_code == 200:
data = response.json()[tgt + "s"]["items"][0]
data = response.json()
if not (tgt + "s" in data):
results.append({"error": "No corresponding song found."})
continue
data = data[tgt + "s"]["items"][0]
text = dict()
spotify_id = data["id"]
text[tgt] = data["name"]
@ -466,12 +579,20 @@ class Demucs(BaseToolkit):
for arg in args:
if "audio" in arg:
prompt = arg["audio"]
try:
os.system(f"python -m demucs --two-stems=vocals -o public/audios/ {prompt}")
except:
results.append({"error": "Music Source Separation Error"})
continue
file_name = str(uuid.uuid4())[:4]
os.system(f"python -m demucs --two-stems=vocals -o public/audios/ {prompt}")
os.system(f"cp public/audios/htdemucs/{prompt.split('/')[-1].split('.')[0]}/no_vocals.wav public/audios/{file_name}.wav")
results.append({"audio": f"{file_name}.wav", "instrument": "accompaniment"})
file_name = str(uuid.uuid4())[:4]
os.system(f"cp public/audios/htdemucs/{prompt.split('/')[-1].split('.')[0]}/vocals.wav public/audios/{file_name}.wav")
results.append({"audio": f"{file_name}.wav", "instrument": "vocal"})
results.append({"audio": f"{file_name}.wav"})
return results
@ -521,7 +642,12 @@ class DDSP(BaseToolkit):
if "audio" in arg:
prompt = arg["audio"]
file_name = str(uuid.uuid4())[:4]
try:
timbre_transfer(prompt, f"public/audios/{file_name}.wav", instrument="violin")
except:
results.append({"error": "Convert Style Error"})
continue
results.append({"audio": f"{file_name}.wav"})
return results
@ -532,7 +658,7 @@ class BasicPitch(BaseToolkit):
super().__init__(config)
self.id = "basic-pitch"
self.attributes = {
"description": "Demucs Music Source Separation"
"description": "A Python library for Automatic Music Transcription (AMT), using lightweight neural network developed by Spotify's Audio Intelligence Lab."
}
def inference(self, args, task, device="cpu"):
@ -542,7 +668,12 @@ class BasicPitch(BaseToolkit):
prompt = arg["audio"]
file_name = str(uuid.uuid4())[:4]
try:
os.system(f"basic-pitch public/audios/ {prompt}")
except:
results.append({"error": "Music Transcription Error"})
continue
os.system(f"cp public/audios/{prompt.split('/')[-1].split('.')[0]}_basic_pitch.mid public/audios/{file_name}.mid")
results.append({"sheet music": f"{file_name}.mid"})
@ -613,3 +744,38 @@ class BasicSplice(BaseToolkit):
results.append({"audio": f"{file_name}.wav"})
return results
class AudioLDM(BaseToolkit):
def __init__(self, config):
super().__init__(config)
self.id = "cvssp/audioldm-m-full"
self.attributes = {
"description": "Text-to-Audio Generation: Generate audio given text input. "
}
self._init_toolkit(config)
def _init_toolkit(self, config):
repo_id = "models/cvssp/audioldm-m-full"
self.pipe = AudioLDMPipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
def inference(self, args, task, device="cpu"):
results = []
self.mount_model(self.pipe, device)
for arg in args:
if "description" in arg:
prompt = arg["description"]
try:
audio = self.pipe(prompt, num_inference_steps=10, audio_length_in_s=5.0).audios[0]
except:
results.append({"error": "Text-to-Audio Error"})
continue
file_name = str(uuid.uuid4())[:4]
sf.write(f"public/audios/{file_name}.wav", audio, 16000)
results.append({"audio": f"{file_name}.wav"})
self.detach_model(self.pipe)
return results

Просмотреть файл

@ -1,14 +1,9 @@
transformers==4.29.2
tiktoken==0.3.3
transformers==4.33.2
google-search-results==2.4.2
flask==2.2.3
flask_cors==3.0.10
waitress==2.1.2
speechbrain==0.5.14
datasets==2.11.0
timm==0.6.13
typeguard==2.13.3
accelerate==0.18.0
accelerate==0.23.0
pytesseract==0.3.10
gradio==3.26.0
librosa==0.8.0
@ -18,7 +13,6 @@ tensorboardX==2.6
pyloudnorm==0.1.1
g2p_en==2.1.0
inflect==5.3.0
pydantic==1.8.2
webrtcvad==2.0.10
scikit-learn==0.24.1
scikit-image==0.16.2
@ -39,6 +33,6 @@ textblob==0.17.1
basic-pitch==0.2.5
hydra-core==1.0.7
ddsp==1.6.0
torch==1.12.1
fairseq==0.12.0
demucs==4.0.0
diffusers==0.21.2

Просмотреть файл

@ -0,0 +1,10 @@
You are an AI music assistant called Music Agent, which handles text, symbolic music and audio-related tasks, including:
songwriting, accompaniment generation, audio synthesis, track separation, lyrics and music transcription, music classification, etc.
On this basis, users can freely combine to achieve more complex functions.
[CONTEXT]
user: I have a question. Can you help?
assistant: Of course. Go on!
[END CONTEXT]
{{$history}}
user: {{$input}}
assistant:

Просмотреть файл

@ -1,7 +1,5 @@
Response Generation Stage: With the task execution logs, the AI assistant needs to describe the process and inference results.
[DONE]
{{$history}}
[DONE]
user: {{$input}}.
assistant: I want to introduce my workflow for your request, which is shown in the following JSON data: {{$processes}}.
user: Please first think carefully and directly answer my request based on the inference results. Then detail your workflow including the used methods and inference results for my request in your friendly tone. Please filter out information that is not relevant to my request. Tell me the complete path or urls of files in inference results. If there is nothing in the results, please tell me you can't make it. Answer in the language of {{$input}}.

Просмотреть файл

@ -8,8 +8,10 @@ user: I want a rock song by Imagine Dragons.
assistant: [{"task": "track-search", "id": 0, "dep": [-1], "args": [{"artist": ["Imagine Dragons"], "genre": "rock"}]}]
user: 分析近期热门音乐风格趋势.
assistant: [{"task": "playlist-search", "id": 0, "dep": [-1], "args": [{"description": "近期热门音乐"}]}, {"task": "music-classification", "id": 1, "dep": [0], "args": [{"audio": "<GENERATED>-0"}]}]
user: 请给我周杰伦的千里之外的伴奏.
assistant: [{"task": "track-search", "id": 0, "dep": [-1], "args": [{"track": "千里之外", "artist": ["周杰伦"], "language": "zh"}]}, {"task": "separate-track", "id": 1, "dep": [0], "args": [{"audio": "<GENERATED>-0"}]}]
user: transcribe /e.wav to music score (sheet music).
assistant: [{"task": "score-transcription", "id": 1, "dep": [0], "args": [{"audio": "/e.wav"}]]
user: transcribe /e.wav to lyrics.
assistant: [{"task": "lyric-recognition", "id": 1, "dep": [0], "args": [{"audio": "/e.wav"}]]
user: Rewind the music in /c.wav.
assistant: [{"task": "separate-track", "id": 0, "dep": [-1], "args": [{"audio": "/c.wav"}]}, {"task": "score-transcription", "id": 1, "dep": [0], "args": [{"audio": ["<GENERATED>-0"]}]}, {"task": "reverse-music", "id": 2, "dep": [1], "args": [{"sheet_music": "<GENERATED>-1"}]}]
user: Convert the vocals in /b.wav to a violin sound.
@ -20,9 +22,9 @@ user: Turning speech in /a.wav into music.
assistant: [{"task": "lyric-recognition", "id": 0, "dep": [-1], "args": [{"audio": "/a.wav"}]}, {"task": "lyric-to-melody", "id": 1, "dep": [0], "args": [{"lyric": "<GENERATED>-0"}]}, {"task": "lyric-to-audio", "id": 2, "dep": [1], "args": [{"score": "<GENERATED>-1"}]}, {"task": "audio-mixing", "id": 3, "dep": [1,2], "args": [{"audio": "<GENERATED>-1"}, {"audio": "<GENERATED>-2"}]}]
user: Write a piece of lyric about the World Cup.
assistant: [{"task": "web-search", "id": 0, "dep": [-1], "args": [{"description": "World Cup"}]}, {"task": "lyric-generation", "id": 1, "dep": [0], "args": [{"description": "<GENERATED>-0"}]}]
user: What can you do?
assistant: []
[END CONTEXT]
[DONE]
{{$history}}
[DONE]
user: {{$input}}
assistant:

Просмотреть файл

@ -3,5 +3,6 @@ Given the user request and the parsed tasks, the AI assistant helps the user to
user: Determine the style of music /a.wav.
assistant: {"id": "", "reason": ""}
[END CONTEXT]
{{$history}}
user: Please choose the most suitable tool from {{$tools}} for the task {{$task}} with the input [{{$input}}]. Please focus on the {{$focus}}. The output must be in a strict JSON format: {"id": "id", "reason": "your detail reasons for the choice"}.
assistant:

Просмотреть файл

@ -1,142 +0,0 @@
import uuid
import gradio as gr
import re
import requests
from agent import MusicCoplilotAgent
import soundfile
import argparse
all_messages = []
OPENAI_KEY = ""
def add_message(content, role):
message = {"role": role, "content": content}
all_messages.append(message)
def extract_medias(message):
audio_pattern = re.compile(r"(http(s?):|\/)?([\.\/_\w:-])*?\.(flac|wav|mp3)")
audio_urls = []
for match in audio_pattern.finditer(message):
if match.group(0) not in audio_urls:
audio_urls.append(match.group(0))
return audio_urls
def set_openai_key(openai_key):
global OPENAI_KEY
OPENAI_KEY = openai_key
agent._init_backend_from_input(openai_key)
return OPENAI_KEY, gr.update(visible=True), gr.update(visible=True)
def add_text(messages, message):
if len(OPENAI_KEY) == 0 or not OPENAI_KEY.startswith("sk-"):
return messages, "Please set your OpenAI API key first."
add_message(message, "user")
messages = messages + [(message, None)]
audio_urls = extract_medias(message)
for audio_url in audio_urls:
if audio_url.startswith("http"):
ext = audio_url.split(".")[-1]
name = f"{str(uuid.uuid4()[:4])}.{ext}"
response = requests.get(audio_url)
with open(f"{agent.config['src_fold']}/{name}", "wb") as f:
f.write(response.content)
messages = messages + [((f"{audio_url} is saved as {name}",), None)]
return messages, ""
def upload_audio(file, messages):
file_name = str(uuid.uuid4())[:4]
audio_load, sr = soundfile.read(file.name)
soundfile.write(f"{agent.config['src_fold']}/{file_name}.wav", audio_load, samplerate=sr)
messages = messages + [(None, f"** {file_name}.wav ** uploaded")]
return messages
def bot(messages):
if len(OPENAI_KEY) == 0 or not OPENAI_KEY.startswith("sk-"):
return messages
message = agent.chat(messages[-1][0])
audio_urls = extract_medias(message)
add_message(message, "assistant")
messages[-1][1] = message
for audio_url in audio_urls:
if not audio_url.startswith("http") and not audio_url.startswith(f"{agent.config['src_fold']}"):
audio_url = f"{agent.config['src_fold']}/{audio_url}"
messages = messages + [((None, (audio_url,)))]
return messages
def clear_all_history(messages):
agent.clear_history()
messages = messages + [(("All histories of LLM are cleared", None))]
return messages
def parse_args():
parser = argparse.ArgumentParser(description="A path to a YAML file")
parser.add_argument("--config", type=str, help="a YAML file path.")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
agent = MusicCoplilotAgent(args.config, mode="gradio")
with gr.Blocks() as demo:
gr.Markdown("<h2><center>Music Pilot (Dev)</center></h2>")
with gr.Row():
openai_api_key = gr.Textbox(
show_label=False,
placeholder="Set your OpenAI API key here and press Enter",
lines=1,
type="password",
)
state = gr.State([])
chatbot = gr.Chatbot([], elem_id="chatbot", label="music_pilot", visible=False).style(height=500)
with gr.Row(visible=False) as text_input_raws:
with gr.Column(scale=0.8):
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter or click Run button").style(container=False)
# with gr.Column(scale=0.1, min_width=0):
# run = gr.Button("🏃Run")
with gr.Column(scale=0.1, min_width=0):
clear_txt = gr.Button("🔄Clear")
with gr.Column(scale=0.1, min_width=0):
btn = gr.UploadButton("🖼Upload", file_types=["audio"])
openai_api_key.submit(set_openai_key, [openai_api_key], [openai_api_key, chatbot, text_input_raws])
clear_txt.click(clear_all_history, [chatbot], [chatbot])
btn.upload(upload_audio, [btn, chatbot], [chatbot])
txt.submit(add_text, [chatbot, txt], [chatbot, txt]).then(
bot, chatbot, chatbot
)
gr.Examples(
examples=["Write a piece of lyric about the recent World Cup.",
"生成一首古风歌词的中文歌",
"Download a song by Jay Zhou for me and separate the vocals and the accompanies.",
"Convert the vocals in /b.wav to a violin sound.",
"Give me the sheet music and lyrics in the song /a.wav",
"近一个月流行的音乐类型",
"把c.wav中的人声搭配合适的旋律变成一首歌"
],
inputs=txt
)
demo.launch(share=True)

Просмотреть файл

@ -1,7 +0,0 @@
This is a conversation.
[CONTEXT]
user: I have a question. Can you help?
assistant: Of course. Go on!
[END CONTEXT]
user: {{$input}}
assistant: