muzic/musicagent/plugins.py

781 строка
26 KiB
Python
Исходник Постоянная ссылка Обычный вид История

2023-08-06 16:22:29 +03:00
""" Models and APIs"""
import uuid
import numpy as np
import importlib
from transformers import pipeline, AutoConfig, Wav2Vec2FeatureExtractor
from model_utils import Wav2Vec2ForSpeechClassification, timbre_transfer
from pydub import AudioSegment
import requests
import urllib
import librosa
2023-10-16 08:09:28 +03:00
import re
2023-08-06 16:22:29 +03:00
import torch
import torch.nn.functional as F
# import torchaudio
from fairseq.models.transformer_lm import TransformerLanguageModel
2023-10-16 08:09:28 +03:00
from diffusers import AudioLDMPipeline
2023-08-06 16:22:29 +03:00
import soundfile as sf
import os
import sys
import json
import pdb
def get_task_map():
task_map = {
"text-to-sheet-music": [
"sander-wood/text-to-music"
],
2023-10-16 08:09:28 +03:00
"text-to-audio": [
"cvssp/audioldm-m-full"
],
2023-08-06 16:22:29 +03:00
"music-classification": [
2023-10-16 08:09:28 +03:00
"lewtun/distilhubert-finetuned-music-genres",
"dima806/music_genres_classification"
2023-08-06 16:22:29 +03:00
],
"lyric-to-melody": [
2023-10-16 08:09:28 +03:00
"muzic/roc",
"muzic/telemelody"
2023-08-06 16:22:29 +03:00
],
"lyric-to-audio": [
"DiffSinger"
],
"web-search": [
"google-search"
],
"artist-search": [
"spotify"
],
"track-search": [
"spotify"
],
"album-search": [
"spotify"
],
"playlist-search": [
"spotify"
],
"separate-track": [
"demucs"
],
"lyric-recognition": [
"jonatasgrosman/whisper-large-zh-cv11"
],
"score-transcription": [
"basic-pitch"
],
"timbre-transfer": [
"ddsp"
],
"accompaniment": [
"getmusic"
],
"audio-mixing": [
"basic-merge"
],
"audio-crop": [
"basic-crop"
],
"audio-splice": [
"basic-splice"
],
"web-search": [
"google-search"
],
}
return task_map
def init_plugins(config):
if config["disabled_tools"] is not None:
disabled = [tool.strip() for tool in config["disabled_tools"].split(",")]
else:
disabled = []
pipes = {}
if "muzic/roc" not in disabled:
pipes["muzic/roc"] = MuzicROC(config)
2023-10-16 08:09:28 +03:00
if "cvssp/audioldm-m-full" not in disabled:
pipes["cvssp/audioldm-m-full"] = AudioLDM(config)
2023-08-06 16:22:29 +03:00
if "DiffSinger" not in disabled:
pipes["DiffSinger"] = DiffSinger(config)
2023-10-16 08:09:28 +03:00
# if "m3hrdadfi/wav2vec2-base-100k-gtzan-music-genres" not in disabled:
# pipes["m3hrdadfi/wav2vec2-base-100k-gtzan-music-genres"] = Wav2Vec2Base(config)
if "dima806/music_genres_classification" not in disabled:
pipes["dima806/music_genres_classification"] = Text2AudioDima(config)
if "lewtun/distilhubert-finetuned-music-genres" not in disabled:
pipes["lewtun/distilhubert-finetuned-music-genres"] = Text2AudioLewtun(config)
2023-08-06 16:22:29 +03:00
if "jonatasgrosman/whisper-large-zh-cv11" not in disabled:
pipes["jonatasgrosman/whisper-large-zh-cv11"] = WhisperZh(config)
if "spotify" not in disabled:
pipes["spotify"] = Spotify(config)
if "ddsp" not in disabled:
pipes["ddsp"] = DDSP(config)
if "demucs" not in disabled:
pipes["demucs"] = Demucs(config)
if "basic-merge" not in disabled:
pipes["basic-merge"] = BasicMerge(config)
if "basic-crop" not in disabled:
pipes["basic-crop"] = BasicCrop(config)
if "basic-splice" not in disabled:
pipes["basic-splice"] = BasicSplice(config)
if "basic-pitch" not in disabled:
pipes["basic-pitch"] = BasicPitch(config)
if "google-search" not in disabled:
pipes["google-search"] = GoogleSearch(config)
return pipes
class BaseToolkit:
def __init__(self, config):
self.local_fold = config["local_fold"]
self.id = "basic toolkit"
self.attributes = {}
def get_attributes(self):
return json.dumps(self.attributes)
def update_attributes(self, **kwargs):
for key in kwargs:
self.attributes[key] = kwargs[key]
2023-10-16 08:09:28 +03:00
def mount_model(self, model, device):
try:
model.to(device)
except:
model.device = torch.device(device)
model.model.to(device)
def detach_model(self, model):
try:
model.to("cpu")
torch.cuda.empty_cache()
except:
model.device = torch.device("cpu")
model.model.to("cpu")
torch.cuda.empty_cache()
2023-08-06 16:22:29 +03:00
class MuzicROC(BaseToolkit):
def __init__(self, config):
super().__init__(config)
self.id = "muzic/roc"
self.attributes = {
"description": "ROC is a new paradigm for lyric-to-melody generation"
}
self._init_toolkit(config)
def _init_toolkit(self, config):
sys.path.append(os.path.join(os.getcwd(), f"{self.local_fold}/muzic/roc"))
from main import main as roc_processer
self.processer = roc_processer
self.model = TransformerLanguageModel.from_pretrained(os.path.join(os.getcwd(), f"{self.local_fold}/muzic/roc/music-ckps/"), "checkpoint_best.pt", tokenizer="space",
batch_size=8192)
sys.path.remove(os.path.join(os.getcwd(), f"{self.local_fold}/muzic/roc"))
def inference(self, args, task, device="cpu"):
results = []
self.model.to(device)
for arg in args:
if "lyric" in arg:
prompt = arg["lyric"]
prompt = " ".join(prompt)
2023-10-16 08:09:28 +03:00
prompt = re.sub("[^\u4e00-\u9fa5]", "", prompt)
2023-08-06 16:22:29 +03:00
file_name = str(uuid.uuid4())[:4]
2023-10-16 08:09:28 +03:00
try:
outputs = self.processer(
self.model,
[prompt],
output_file_name=f"public/audios/{file_name}",
db_path=f"{self.local_fold}/muzic/roc/database/ROC.db"
)
os.system(f"fluidsynth -l -ni -a file -z 2048 -F public/audios/{file_name}.wav 'MS Basic.sf3' public/audios/{file_name}.mid")
except:
results.append({"error": "Lyric-to-melody Error"})
continue
2023-08-06 16:22:29 +03:00
results.append(
{
"score": str(outputs),
"audio": f"{file_name}.wav",
"sheet_music": f"{file_name}.mid"
}
)
self.model.to("cpu")
return results
2023-10-16 08:09:28 +03:00
class Text2AudioDima(BaseToolkit):
def __init__(self, config):
super().__init__(config)
self.id = "dima806/music_genres_classification"
self.attributes = {
"description": "The model is trained based on publicly available dataset of labeled music data — GTZAN Dataset",
"genres": "blues,classical,country,disco,hip-hop,jazz,metal,pop,reggae,rock.",
"downloads": 60
}
def _init_toolkit(self, config):
self.pipe = pipeline("audio-classification", model=os.path.join(self.local_fold, self.id), device="cpu")
def inference(self, args, task, device="cpu"):
self.mount_model(self.pipe, device)
results = []
for arg in args:
if "audio" in arg:
prompt = arg["audio"]
sampling_rate = self.pipe.feature_extractor.sampling_rate
audio, _ = librosa.load(prompt, sr=sampling_rate)
try:
output = self.pipe(audio)
genre = output[0]["label"]
except:
results.append({"error": "Genres Classification Error"})
continue
results.append({"genre": genre})
self.detach_model(self.pipe)
return results
class Text2AudioLewtun(BaseToolkit):
def __init__(self, config):
super().__init__(config)
self.id = "lewtun/distilhubert-finetuned-music-genres"
self.attributes = {
"description": "This model is a fine-tuned version of ntu-spml/distilhubert on the None dataset.",
"genres": "Pop,Classical,International,Ambient Electronic,Folk",
"downloads": 202
}
def _init_toolkit(self, config):
self.pipe = pipeline("audio-classification", model=os.path.join(self.local_fold, self.id), device="cpu")
def inference(self, args, task, device="cpu"):
self.mount_model(self.pipe, device)
results = []
for arg in args:
if "audio" in arg:
prompt = arg["audio"]
sampling_rate = self.pipe.feature_extractor.sampling_rate
audio, _ = librosa.load(prompt, sr=sampling_rate)
try:
output = self.pipe(audio)
genre = output[0]["label"]
except:
results.append({"error": "Genres Classification Error"})
continue
results.append({"genre": genre})
self.detach_model(self.pipe)
return results
2023-08-06 16:22:29 +03:00
class DiffSinger(BaseToolkit):
def __init__(self, config):
super().__init__(config)
self.id = "DiffSinger"
self.attributes = {
"description": "Singing Voice Synthesis via Shallow Diffusion Mechanism",
"star": 3496
}
self._init_toolkit(config)
def _init_toolkit(self, config):
sys.path.append(os.path.join(os.getcwd(), f"{self.local_fold}/DiffSinger"))
import utils
importlib.reload(utils)
from inference.svs.ds_e2e import DiffSingerE2EInfer
from utils.hparams import hparams, set_hparams
from utils.audio import save_wav
work_dir = os.getcwd()
os.chdir(os.path.join(os.getcwd(), f"{self.local_fold}/DiffSinger"))
set_hparams('usr/configs/midi/e2e/opencpop/ds100_adj_rel.yaml',
'0228_opencpop_ds100_rel', print_hparams=False)
self.processer = save_wav
self.model = DiffSingerE2EInfer(hparams, device="cuda:0")
self.model.model.to("cpu")
self.model.vocoder.to("cpu")
os.chdir(work_dir)
sys.path.remove(os.path.join(os.getcwd(), f"{self.local_fold}/DiffSinger"))
def inference(self, args, task, device="cpu"):
results = []
self.model.model.to(device)
self.model.vocoder.to(device)
self.model.device = device
for arg in args:
if "score" in arg:
prompt = arg["score"]
prompt = eval(prompt)
2023-10-16 08:09:28 +03:00
try:
wav = self.model.infer_once(prompt)
file_name = str(uuid.uuid4())[:4]
self.processer(wav, f"public/audios/{file_name}.wav", sr=16000)
except:
results.append({"error": "Singing Voice Synthesis Error"})
continue
2023-08-06 16:22:29 +03:00
results.append({"audio": f"{file_name}.wav"})
self.model.model.to("cpu")
self.model.vocoder.to("cpu")
self.model.device = "cpu"
return results
2023-10-16 08:09:28 +03:00
# class Wav2Vec2Base(BaseToolkit):
# def __init__(self, config):
# super().__init__(config)
# self.id = "m3hrdadfi/wav2vec2-base-100k-gtzan-music-genres"
# self.attributes = {
# "description": "Music Genre Classification using Wav2Vec 2.0"
# }
# self._init_toolkit(config)
# def _init_toolkit(self, config):
# self.processer = Wav2Vec2FeatureExtractor.from_pretrained(f"{self.local_fold}/m3hrdadfi/wav2vec2-base-100k-gtzan-music-genres")
# self.model = Wav2Vec2ForSpeechClassification.from_pretrained(f"{self.local_fold}/m3hrdadfi/wav2vec2-base-100k-gtzan-music-genres")
# self.config = AutoConfig.from_pretrained(f"{self.local_fold}/m3hrdadfi/wav2vec2-base-100k-gtzan-music-genres")
# def inference(self, args, task, device="cpu"):
# results = []
# self.model.to(device)
# for arg in args:
# if "audio" in arg:
# prompt = arg["audio"]
# sampling_rate = self.processer.sampling_rate
# #speech_array, _sampling_rate = torchaudio.load(prompt)
# #resampler = torchaudio.transforms.Resample(_sampling_rate, sampling_rate)
# #speech = resampler(speech_array).squeeze().numpy()
# speech, _ = librosa.load(prompt, sr=sampling_rate)
# inputs = self.processer(speech, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
# inputs = {k: v.to(device) for k, v in inputs.items()}
2023-08-06 16:22:29 +03:00
2023-10-16 08:09:28 +03:00
# with torch.no_grad():
# logits = self.model(**inputs).logits
2023-08-06 16:22:29 +03:00
2023-10-16 08:09:28 +03:00
# scores = F.softmax(logits, dim=1).detach().cpu().numpy()[0]
# genre = self.config.id2label[np.argmax(scores)]
# # outputs = [{"Label": pipes[pipe_id]["config"].id2label[i], "Score": f"{round(score * 100, 3):.1f}%"} for i, score in enumerate(scores)]
# results.append({"genre": genre})
2023-08-06 16:22:29 +03:00
2023-10-16 08:09:28 +03:00
# self.model.to("cpu")
# return results
2023-08-06 16:22:29 +03:00
class WhisperZh(BaseToolkit):
def __init__(self, config):
super().__init__(config)
self.id = "jonatasgrosman/whisper-large-zh-cv11"
self.attributes = {
"description": "a fine-tuned version of openai/whisper-large-v2 on Chinese (Mandarin)"
}
self._init_toolkit(config)
def _init_toolkit(self, config):
self.model = pipeline("automatic-speech-recognition", model=f"{self.local_fold}/jonatasgrosman/whisper-large-zh-cv11", device="cuda:0")
self.model.model.to("cpu")
def inference(self, args, task, device="cpu"):
results = []
self.model.model.to(device)
for arg in args:
if "audio" in arg:
prompt = arg["audio"]
2023-10-16 08:09:28 +03:00
try:
self.model.model.config.forced_decoder_ids = (
self.model.tokenizer.get_decoder_prompt_ids(
language="zh",
task="transcribe"
)
2023-08-06 16:22:29 +03:00
)
2023-10-16 08:09:28 +03:00
results.append({"lyric": self.model(prompt)})
except:
results.append({"error": "Lyric-recognition Error"})
continue
2023-08-06 16:22:29 +03:00
self.model.model.to("cpu")
return results
class Spotify(BaseToolkit):
def __init__(self, config):
super().__init__(config)
self.id = "spotify"
self.attributes = {
"description": "Spotify is a digital music service that gives you access to millions of songs."
}
self._init_toolkit(config)
def _init_toolkit(self, config):
client_id = config["spotify"]["client_id"]
client_secret = config["spotify"]["client_secret"]
url = "https://accounts.spotify.com/api/token"
headers = {
"Content-Type": "application/x-www-form-urlencoded"
}
data = {
"grant_type": "client_credentials",
"client_id": client_id,
"client_secret": client_secret
}
response = requests.post(url, headers=headers, data=data)
if response.status_code == 200:
token_data = response.json()
self.access_token = token_data["access_token"]
print("Access Token:", self.access_token)
else:
print("Error:", response.status_code)
def inference(self, args, task, device="cpu"):
results = []
for arg in args:
tgt = task.split("-")[0]
2023-10-16 08:09:28 +03:00
url = "https://api.spotify.com/v1/"
endpoint = "search"
data = {
"q": " ".join([f"{key}:{value}" for key, value in arg.items() if key in ["track", "album", "artist", "genre", "description"]]),
"type" : [tgt]
}
headers = {
"Authorization": f"Bearer {self.access_token}",
"Accept" : "application/json",
"Content-Type" : "application/json"
}
response = requests.get(url=url+endpoint, params=data, headers=headers)
2023-08-06 16:22:29 +03:00
if response.status_code == 200:
2023-10-16 08:09:28 +03:00
data = response.json()
if not (tgt + "s" in data):
results.append({"error": "No corresponding song found."})
continue
data = data[tgt + "s"]["items"][0]
2023-08-06 16:22:29 +03:00
text = dict()
spotify_id = data["id"]
text[tgt] = data["name"]
if tgt == "track":
if "preview_url" in data and len(data["preview_url"]) > 0:
url = data["preview_url"]
file_name = str(uuid.uuid4())[:4]
with open(f"public/audios/{file_name}.mp3", "wb") as f:
f.write(requests.get(url).content)
text["audio"] = f"{file_name}.mp3"
text["album"] = data["album"]["name"]
text["artist"] = [d["name"] for d in data["artists"]]
if tgt == "album":
text["date"] = data["release_date"]
text["artist"] = [d["name"] for d in data["artists"]]
url = f"https://api.spotify.com/v1/albums/{spotify_id}"
album = requests.get(url, headers=headers).json()
if len(album["genres"]) > 0:
text["genre"] = album["genres"]
text["track"] = [d["name"] for d in album["tracks"]["items"]]
if tgt == "playlist":
url = f"https://api.spotify.com/v1/playlists/{spotify_id}"
album = requests.get(url, headers=headers).json()
text["track"] = [d["track"]["name"] for d in album["tracks"]["items"]]
if tgt == "artist":
if len(data["genres"]) > 0:
text["genre"] = data["genres"]
results.append(text)
else:
results.append({"error": "No corresponding song found."})
return results
class GoogleSearch(BaseToolkit):
def __init__(self, config):
super().__init__(config)
self.id = "google"
self.attributes = {
"description": "Google Custom Search Engine."
}
self._init_toolkit(config)
def _init_toolkit(self, config):
api_key = config["google"]["api_key"]
custom_search_engine_id = config["google"]["custom_search_engine_id"]
self.url = "https://www.googleapis.com/customsearch/v1"
self.params = {
"key": api_key,
"cx": custom_search_engine_id,
"max_results": 5
}
def inference(self, args, task, device="cpu"):
results = []
for arg in args:
if "description" in arg:
self.params["q"] = arg["description"]
response = requests.get(self.url, self.params)
if response.status_code == 200:
data = response.json()
items = data.get("items")
descriptions = []
for item in items:
descriptions.append(
{
"title": item.get("title"),
"snippet": item.get("snippet")
}
)
results.append(
{
"description": json.dumps(descriptions)
}
)
return results
class Demucs(BaseToolkit):
def __init__(self, config):
super().__init__(config)
self.id = "demucs"
self.attributes = {
"description": "Demucs Music Source Separation"
}
def inference(self, args, task, device="cpu"):
results = []
for arg in args:
if "audio" in arg:
prompt = arg["audio"]
2023-10-16 08:09:28 +03:00
try:
os.system(f"python -m demucs --two-stems=vocals -o public/audios/ {prompt}")
except:
results.append({"error": "Music Source Separation Error"})
continue
2023-08-06 16:22:29 +03:00
file_name = str(uuid.uuid4())[:4]
os.system(f"cp public/audios/htdemucs/{prompt.split('/')[-1].split('.')[0]}/no_vocals.wav public/audios/{file_name}.wav")
2023-10-16 08:09:28 +03:00
results.append({"audio": f"{file_name}.wav", "instrument": "accompaniment"})
file_name = str(uuid.uuid4())[:4]
os.system(f"cp public/audios/htdemucs/{prompt.split('/')[-1].split('.')[0]}/vocals.wav public/audios/{file_name}.wav")
results.append({"audio": f"{file_name}.wav", "instrument": "vocal"})
2023-08-06 16:22:29 +03:00
return results
class BasicMerge(BaseToolkit):
def __init__(self, config):
super().__init__(config)
self.id = "pad-wave-merge"
self.attributes = {
"description": "Merge audios."
}
def inference(self, args, task, device="cpu"):
audios = []
sr=16000
for arg in args:
if "audio" in arg:
audio, _ = librosa.load(arg["audio"], sr=sr)
audios.append(audio)
max_len = max([len(audio) for audio in audios])
audios = [librosa.util.fix_length(audio, size=max_len) for audio in audios]
mixed_audio = sum(audios)
file_name = str(uuid.uuid4())[:4]
sf.write(f"public/audios/{file_name}.wav", mixed_audio, sr)
results = [{"audio": f"{file_name}.wav"}]
return results
class DDSP(BaseToolkit):
def __init__(self, config):
super().__init__(config)
self.id = "ddsp"
self.attributes = {
"description": "Convert audio between sound sources with pretrained models."
}
def inference(self, args, task, device="cpu"):
results = []
for arg in args:
if "audio" in arg:
prompt = arg["audio"]
file_name = str(uuid.uuid4())[:4]
2023-10-16 08:09:28 +03:00
try:
timbre_transfer(prompt, f"public/audios/{file_name}.wav", instrument="violin")
except:
results.append({"error": "Convert Style Error"})
continue
2023-08-06 16:22:29 +03:00
results.append({"audio": f"{file_name}.wav"})
return results
class BasicPitch(BaseToolkit):
def __init__(self, config):
super().__init__(config)
self.id = "basic-pitch"
self.attributes = {
2023-10-16 08:09:28 +03:00
"description": "A Python library for Automatic Music Transcription (AMT), using lightweight neural network developed by Spotify's Audio Intelligence Lab."
2023-08-06 16:22:29 +03:00
}
def inference(self, args, task, device="cpu"):
results = []
for arg in args:
if "audio" in arg:
prompt = arg["audio"]
file_name = str(uuid.uuid4())[:4]
2023-10-16 08:09:28 +03:00
try:
os.system(f"basic-pitch public/audios/ {prompt}")
except:
results.append({"error": "Music Transcription Error"})
continue
2023-08-06 16:22:29 +03:00
os.system(f"cp public/audios/{prompt.split('/')[-1].split('.')[0]}_basic_pitch.mid public/audios/{file_name}.mid")
results.append({"sheet music": f"{file_name}.mid"})
return results
class BasicCrop(BaseToolkit):
def __init__(self, config):
super().__init__(config)
self.id = "audio-crop"
self.attributes = {
"description": "Trim audio based on time"
}
def inference(self, args, task, device="cpu"):
results = []
for arg in args:
if "audio" in arg and "time" in arg:
prompt = arg["audio"]
time = arg["time"]
file_name = str(uuid.uuid4())[:4]
audio = AudioSegment.from_file(prompt)
start_ms = int(float(time[0]) * 1000)
end_ms = int(float(time[1]) * 1000)
if start_ms < 0:
start_ms += len(audio)
if end_ms < 0:
end_ms += len(audio)
start_ms = max(start_ms, len(audio))
end_ms = max(end_ms, len(audio))
if start_ms > end_ms:
continue
trimmed_audio = audio[start_ms:end_ms]
trimmed_audio.export(f"public/audios/{file_name}.wav", format="wav")
results.append({"audio": f"{file_name}.wav"})
return results
class BasicSplice(BaseToolkit):
def __init__(self, config):
super().__init__(config)
self.id = "audio-splice"
self.attributes = {
"description": "Basic audio splice"
}
def inference(self, args, task, device="cpu"):
audios = []
results = []
for arg in args:
if "audio" in arg:
audios.append(arg["audio"])
audio = AudioSegment.from_file(audios[0])
for i in range(1, len(audios)):
audio = audio + AudioSegment.from_file(audios[i])
file_name = str(uuid.uuid4())[:4]
audio.export(f"public/audios/{file_name}.wav", format="wav")
results.append({"audio": f"{file_name}.wav"})
return results
2023-10-16 08:09:28 +03:00
class AudioLDM(BaseToolkit):
def __init__(self, config):
super().__init__(config)
self.id = "cvssp/audioldm-m-full"
self.attributes = {
"description": "Text-to-Audio Generation: Generate audio given text input. "
}
self._init_toolkit(config)
def _init_toolkit(self, config):
repo_id = "models/cvssp/audioldm-m-full"
self.pipe = AudioLDMPipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
def inference(self, args, task, device="cpu"):
results = []
self.mount_model(self.pipe, device)
for arg in args:
if "description" in arg:
prompt = arg["description"]
try:
audio = self.pipe(prompt, num_inference_steps=10, audio_length_in_s=5.0).audios[0]
except:
results.append({"error": "Text-to-Audio Error"})
continue
file_name = str(uuid.uuid4())[:4]
sf.write(f"public/audios/{file_name}.wav", audio, 16000)
results.append({"audio": f"{file_name}.wav"})
self.detach_model(self.pipe)
return results