lovelyai999's picture
Update app.py
2ea45f5 verified
import os
import sys
import time, random
from random import choice
from typing import List, Dict
import torch
import torch.nn as nn
import torch.optim as optim
import music21
import numpy as np
from sklearn.preprocessing import MultiLabelBinarizer
from tqdm import tqdm
import wave
import struct
import ffmpeg
import tempfile
from pydub import AudioSegment
from moviepy.editor import VideoFileClip, AudioFileClip
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import DataLoader, SubsetRandomSampler
from torch.utils.tensorboard import SummaryWriter
from torch.nn.utils import clip_grad_norm_
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.tensorboard import SummaryWriter
# 设置基础路径
Gbase = "./"
cache_dir = "./hf/"
try:
import google.colab
from google.colab import drive
IN_COLAB = True
drive.mount('/gdrive', force_remount=True)
Gbase = "/gdrive/MyDrive/generate/"
cache_dir = "/gdrive/MyDrive/hf/"
sys.path.append(Gbase)
except:
IN_COLAB = False
Gbase = "./"
cache_dir = "./hf/"
# 定义模型保存路径
ModelPath = os.path.join(Gbase, 'music_generation_model.pth')
OptimizerPath = os.path.join(Gbase, 'optimizer_state.pth')
DiscriminatorModelPath = os.path.join(Gbase, 'discriminator_model.pth')
DiscriminatorOptimizerPath = os.path.join(Gbase, 'discriminator_optimizer_state.pth')
EvaluatorPath = os.path.join(Gbase, 'music_tag_evaluator.pkl')
# 定义音乐标签
MUSIC_TAGS = {
'emotions': ['Happy', 'Sad', 'Angry', 'Peaceful', 'Neutral'],
'genres': ['Classical', 'Jazz', 'Rock', 'Electronic'],
'tempo': ['Slow', 'Medium', 'Fast'],
'instrumentation': ['Piano', 'Guitar', 'Synthesizer'],
'harmony': ['Consonant', 'Dissonant', 'Complex', 'Simple'],
'dynamics': ['Dynamic', 'Static'],
'rhythm': ['Simple', 'Complex']
}
def randomMusicTags():
return {k: choice(MUSIC_TAGS[k]) for k in MUSIC_TAGS.keys()}
print("随机生成的音乐标签:", randomMusicTags())
def get_scale_notes(key_str: str, octave_range=(2, 6)) -> List[int]:
"""
根据调性返回所属音阶的 MIDI 音高列表。
"""
key = music21.key.Key(key_str)
scale_notes = []
for octave in range(octave_range[0], octave_range[1] + 1):
pitches = key.getScale().getPitches(f"{key_str}{octave}")
for pitch in pitches:
scale_notes.append(pitch.midi)
return scale_notes
def composer_from_features(features: np.ndarray, key_str: str) -> music21.stream.Stream:
"""
将特征转换为 music21.stream.Stream 对象,并确保音符遵循指定音阶。
"""
s = music21.stream.Stream()
# 设置节奏(BPM),默认 120 BPM
tempo = music21.tempo.MetronomeMark(number=120)
s.append(tempo)
# 设置调性
tonality = music21.key.Key(key_str)
s.append(tonality)
# 获取音阶音符
scale_notes = get_scale_notes(key_str)
# 定义可接受的时值
acceptable_durations = [0.25, 0.333, 0.5, 0.666, 0.75, 1.0, 1.5, 2.0, 3.0, 4.0]
for feature in features:
pitch = int(round(feature[0]))
duration = feature[1]
volume = feature[2]
# 将时值量化为最近的可接受值
duration = min(acceptable_durations, key=lambda x: abs(x - duration))
# 确保音高在 21 (A0) 到 108 (C8) 之间
pitch = max(21, min(108, pitch))
# 将音高映射到最近的音阶音符
if pitch not in scale_notes:
pitch = min(scale_notes, key=lambda x: abs(x - pitch))
# 确保音量在 0 到 127 之间
volume = max(0, min(127, volume))
if pitch == 0:
# 休止符
r = music21.note.Rest(quarterLength=duration)
s.append(r)
else:
n = music21.note.Note(midi=pitch, quarterLength=duration)
n.volume.velocity = volume
s.append(n)
return s
import pickle
class MusicTagEvaluator:
def __init__(self):
# 定义所有标签
self.MUSIC_TAGS = MUSIC_TAGS
# 展平成所有标签并移除重复项
all_tags = []
for category in self.MUSIC_TAGS:
all_tags.extend(self.MUSIC_TAGS[category])
self.all_tags = list(set(all_tags)) # 移除重复的标签
self.mlb = MultiLabelBinarizer()
self.mlb.fit([self.all_tags])
def save(self, path):
with open(path, 'wb') as f:
pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL)
print(f"评估器已保存至 '{path}'。")
@staticmethod
def load(path):
if os.path.exists(path):
with open(path, 'rb') as f:
evaluator = pickle.load(f)
print(f"评估器已从 '{path}' 加载。")
return evaluator
else:
print(f"评估器文件 '{path}' 不存在,将创建新的评估器。")
return MusicTagEvaluator()
def evaluate_tags_from_features(self, features: np.ndarray) -> List[str]:
"""
根据特征评估标签。
"""
# 随机选择一个调性以生成音乐
key_str = choice(['C', 'G', 'D', 'A', 'E', 'B', 'F#', 'C#', 'F', 'Bb', 'Eb', 'Ab', 'Db', 'Gb', 'Cb'])
s = composer_from_features(features, key_str)
tag_scores = self.evaluate_tags(s)
tags = []
# 根据评分分配标签
for category in self.MUSIC_TAGS:
tag = tag_scores.get(category)
if tag in self.MUSIC_TAGS[category]:
tags.append(tag)
return tags
def evaluate_tags(self, generated_music):
"""
根据生成的音乐评估标签。
"""
tag_scores = {}
# 音高范围计算
pitch_values = [note.pitch.midi for note in generated_music.recurse().notes if isinstance(note, music21.note.Note)]
pitch_range = max(pitch_values) - min(pitch_values) if pitch_values else 0
# 单独评估各项
harmony_tag = self._evaluate_harmony(generated_music)
rhythm_tag = self._evaluate_rhythm(generated_music)
dynamics_tag = self._evaluate_dynamics(generated_music)
tempo_tag = self._evaluate_tempo(generated_music)
emotion_tag = self._evaluate_emotion(harmony_tag, rhythm_tag, dynamics_tag, tempo_tag)
# 标签集合
tag_scores['emotions'] = emotion_tag
tag_scores['harmony'] = harmony_tag
tag_scores['rhythm'] = rhythm_tag
tag_scores['dynamics'] = dynamics_tag
tag_scores['tempo'] = tempo_tag
return tag_scores
def _evaluate_harmony(self, stream):
# 将音乐流和弦化
chords = stream.chordify()
chord_types = []
for element in chords.recurse():
if isinstance(element, music21.chord.Chord):
chord_types.append(element.commonName)
# 根据和弦种类评估和声复杂度
if any('diminished' in str(ct) or 'augmented' in str(ct) for ct in chord_types):
harmony_tag = 'Complex'
elif any('major' in str(ct) or 'minor' in str(ct) for ct in chord_types):
harmony_tag = 'Consonant'
else:
harmony_tag = 'Simple'
return harmony_tag
def _evaluate_rhythm(self, stream):
durations = [note.quarterLength for note in stream.flat.notes]
# 计算节奏复杂度,如时值种类的数量
unique_durations = len(set(durations))
if unique_durations > 5:
rhythm_tag = 'Complex'
else:
rhythm_tag = 'Simple'
return rhythm_tag
def _evaluate_dynamics(self, stream):
volumes = [note.volume.velocity for note in stream.flat.notes if note.volume.velocity is not None]
if not volumes:
dynamics_tag = 'Static'
else:
dynamics_range = max(volumes) - min(volumes)
if dynamics_range > 40:
dynamics_tag = 'Dynamic'
else:
dynamics_tag = 'Static'
return dynamics_tag
def _evaluate_tempo(self, stream):
tempos = [metronome.number for metronome in stream.recurse() if isinstance(metronome, music21.tempo.MetronomeMark)]
bpm = tempos[0] if tempos else 120 # 默认 BPM 为 120
if bpm < 60:
return 'Slow'
elif 60 <= bpm < 120:
return 'Medium'
else:
return 'Fast'
def _evaluate_emotion(self, harmony_tag, rhythm_tag, dynamics_tag, tempo_tag):
# 根据和声、节奏、动态和节奏进行情感评估
if harmony_tag == 'Complex' and rhythm_tag == 'Complex':
emotion = 'Angry'
elif harmony_tag == 'Consonant' and dynamics_tag == 'Dynamic' and tempo_tag == 'Fast':
emotion = 'Happy'
elif harmony_tag == 'Simple' and dynamics_tag == 'Static' and tempo_tag == 'Slow':
emotion = 'Peaceful'
elif harmony_tag == 'Consonant' and dynamics_tag == 'Static' and tempo_tag == 'Medium':
emotion = 'Neutral'
else:
emotion = 'Sad'
return emotion
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super(PositionalEncoding, self).__init__()
pe = torch.zeros(max_len, d_model) # [max_len, d_model]
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # [max_len, 1]
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) # [d_model/2]
pe[:, 0::2] = torch.sin(position * div_term) # even indices
pe[:, 1::2] = torch.cos(position * div_term) # odd indices
pe = pe.unsqueeze(0) # [1, max_len, d_model]
self.register_buffer('pe', pe)
def forward(self, x):
"""
x: [batch_size, seq_len, d_model]
"""
x = x + self.pe[:, :x.size(1), :]
return x
class MusicGenerationModel(nn.Module):
def __init__(self, input_dim, d_model, nhead, num_encoder_layers, dim_feedforward, output_dim, num_tags, max_seq_length=500):
super(MusicGenerationModel, self).__init__()
self.d_model = d_model
self.input_linear = nn.Linear(input_dim, d_model)
self.positional_encoding = PositionalEncoding(d_model, max_len=max_seq_length)
encoder_layers = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=0.1)
self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_encoder_layers)
self.fc_music = nn.Linear(d_model, output_dim)
self.fc_tags = nn.Linear(d_model, num_tags)
self.sigmoid = nn.Sigmoid()
self.dropout = nn.Dropout(0.1)
def forward(self, src, src_mask=None, src_key_padding_mask=None):
"""
src: [batch_size, seq_len, input_dim]
"""
src = self.input_linear(src) * np.sqrt(self.d_model) # [batch_size, seq_len, d_model]
src = self.positional_encoding(src) # [batch_size, seq_len, d_model]
src = src.transpose(0, 1) # [seq_len, batch_size, d_model]
memory = self.transformer_encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask) # [seq_len, batch_size, d_model]
memory = memory.transpose(0, 1) # [batch_size, seq_len, d_model]
memory = self.dropout(memory)
music_output = self.fc_music(memory) # [batch_size, seq_len, output_dim]
tag_probabilities = self.sigmoid(self.fc_tags(memory)) # [batch_size, seq_len, num_tags]
return music_output, tag_probabilities
class Discriminator(nn.Module):
def __init__(self, input_dim, d_model, nhead, num_layers, dim_feedforward):
super(Discriminator, self).__init__()
self.input_linear = nn.Linear(input_dim, d_model)
self.positional_encoding = PositionalEncoding(d_model)
encoder_layers = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward)
self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)
self.fc = nn.Linear(d_model, 1)
self.sigmoid = nn.Sigmoid()
self.dropout = nn.Dropout(0.1)
def forward(self, src, src_mask=None, src_key_padding_mask=None):
src = self.input_linear(src) * np.sqrt(self.input_linear.out_features)
src = self.positional_encoding(src)
src = src.transpose(0, 1) # [seq_len, batch_size, d_model]
output = self.transformer_encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
output = output.transpose(0, 1) # [batch_size, seq_len, d_model]
output = self.dropout(output)
# 取序列最后一个时间步作为判断依据,也可以选择取平均或其他方式
output = self.fc(output[:, -1, :])
output = self.sigmoid(output)
return output
class MidiDataset(Dataset):
def __init__(self, midi_files: List[str], max_length: int, dataset_path: str, evaluator: MusicTagEvaluator):
self.max_length = max_length
self.dataset_path = dataset_path
self.evaluator = evaluator
# 检查数据集文件是否存在
if os.path.exists(self.dataset_path):
# 加载已预处理的数据集
print(f"从 '{self.dataset_path}' 加载数据集")
try:
saved_data = torch.load(self.dataset_path)
self.features = saved_data['features']
self.labels = saved_data['labels']
print(f"成功加载数据集,共有 {len(self.features)} 个样本。")
except Exception as e:
print(f"加载数据集时出错: {e}")
self._process_midi_files(midi_files)
else:
# 处理 MIDI 文件并保存数据集
self._process_midi_files(midi_files)
def __len__(self):
return len(self.features)
def getAug(self, idx):
feature = self.features[idx] # [seq_len, input_dim]
label = self.labels[idx] # [num_tags]
# 应用数据增强
feature_aug, label_aug =self._augment_data(feature, label)
# 返回张量
return torch.tensor(feature_aug, dtype=torch.float32), torch.tensor(label_aug, dtype=torch.float32)
def __getitem__(self, idx):
feature = self.features[idx] # [seq_len, input_dim]
label = self.labels[idx] # [num_tags]
# 应用数据增强
feature_aug, label_aug =feature, label
# 返回张量
return torch.tensor(feature_aug, dtype=torch.float32), torch.tensor(label_aug, dtype=torch.float32)
def _process_midi_files(self, midi_files):
print("处理 MIDI 文件以创建数据集...")
features_list = []
labels_list = []
for midi_file in midi_files:
try:
stream = music21.converter.parse(midi_file)
# 将音轨转换为特征
features = self.midi_to_features(stream)
if len(features) < self.max_length:
# 跳过长度不足的样本
continue
else:
# 将特征分割成长度为 max_length 的片段
num_segments = len(features) // self.max_length
for i in range(num_segments):
segment = features[i*self.max_length : (i+1)*self.max_length]
if len(segment) < self.max_length:
continue # 跳过不完整的片段
# 使用评估器为每个片段分配标签
tags = self.evaluator.evaluate_tags_from_features(segment)
# 二值化标签
tag_binarized = self.evaluator.mlb.transform([tags])[0]
features_list.append(segment)
labels_list.append(tag_binarized)
except Exception as e:
print(f"处理 {midi_file} 时出错: {e}")
self.features = features_list
self.labels = labels_list
# 保存数据集
try:
torch.save({'features': self.features, 'labels': self.labels}, self.dataset_path)
print(f"数据集已保存至 '{self.dataset_path}',共有 {len(self.features)} 个样本。")
except Exception as e:
print(f"保存数据集时出错: {e}")
def midi_to_features(self, stream) -> np.ndarray:
"""
将 music21 流对象转换为特征序列。
"""
features = []
for note in stream.flat.notesAndRests:
if isinstance(note, music21.note.Note):
pitch = note.pitch.midi
duration = note.quarterLength
volume = note.volume.velocity if note.volume.velocity else 64 # 默认音量
elif isinstance(note, music21.note.Rest):
pitch = 0 # 休止符音高设为 0
duration = note.quarterLength
volume = 0
else:
continue
features.append([pitch, duration, volume])
return np.array(features, dtype=np.float32)
def _augment_data(self, feature, label):
# 实现数据增强:随机抽取、拼接、动态和快慢变化
# 例如,随机调整动态和节奏
feature_aug = np.copy(feature)
label_aug = np.copy(label)
# 随机调整音量(动态)
volume_change = np.random.uniform(0.8, 1.2)
feature_aug[:, 2] *= volume_change
feature_aug[:, 2] = np.clip(feature_aug[:, 2], 0, 127)
# 随机调整时值(节奏变化)
duration_change = np.random.uniform(0.9, 1.1)
feature_aug[:, 1] *= duration_change
# 根据变化调整标签
# 例如,如果节奏变化显著,调整 'tempo' 标签
if duration_change > 1.05:
# 更快的节奏
tempo_tags = ['Fast']
elif duration_change < 0.95:
# 更慢的节奏
tempo_tags = ['Slow']
else:
tempo_tags = ['Medium']
# 更新 'tempo' 标签
for tempo in ['Slow', 'Medium', 'Fast']:
label_aug[self.evaluator.all_tags.index(tempo)] = 0
tempo_index = self.evaluator.all_tags.index(tempo_tags[0])
label_aug[tempo_index] = 1
return feature_aug, label_aug
class MidiDatasetAug(Dataset):
def __init__(self, midi_files: List[str], max_length: int, dataset_path: str, evaluator: MusicTagEvaluator):
self.max_length = max_length
self.dataset_path = dataset_path
self.evaluator = evaluator
# 检查数据集文件是否存在
if os.path.exists(self.dataset_path):
# 加载已预处理的数据集
print(f"从 '{self.dataset_path}' 加载数据集")
try:
saved_data = torch.load(self.dataset_path)
self.features = saved_data['features']
self.labels = saved_data['labels']
print(f"成功加载数据集,共有 {len(self.features)} 个样本。")
except Exception as e:
print(f"加载数据集时出错: {e}")
self._process_midi_files(midi_files)
else:
# 处理 MIDI 文件并保存数据集
self._process_midi_files(midi_files)
def __len__(self):
return len(self.features)
def __getitem__(self, idx):
feature = self.features[idx] # [seq_len, input_dim]
label = self.labels[idx] # [num_tags]
# 应用数据增强
feature_aug, label_aug =self._augment_data(feature, label)
# 返回张量
return torch.tensor(feature_aug, dtype=torch.float32), torch.tensor(label_aug, dtype=torch.float32)
torch.tensor(feature_aug, dtype=torch.float32), torch.tensor(label_aug, dtype=torch.float32)
def _process_midi_files(self, midi_files):
print("处理 MIDI 文件以创建数据集...")
features_list = []
labels_list = []
for midi_file in midi_files:
try:
stream = music21.converter.parse(midi_file)
# 将音轨转换为特征
features = self.midi_to_features(stream)
if len(features) < self.max_length:
# 跳过长度不足的样本
continue
else:
# 将特征分割成长度为 max_length 的片段
num_segments = len(features) // self.max_length
for i in range(num_segments):
segment = features[i*self.max_length : (i+1)*self.max_length]
if len(segment) < self.max_length:
continue # 跳过不完整的片段
# 使用评估器为每个片段分配标签
tags = self.evaluator.evaluate_tags_from_features(segment)
# 二值化标签
tag_binarized = self.evaluator.mlb.transform([tags])[0]
features_list.append(segment)
labels_list.append(tag_binarized)
except Exception as e:
print(f"处理 {midi_file} 时出错: {e}")
self.features = features_list
self.labels = labels_list
# 保存数据集
try:
torch.save({'features': self.features, 'labels': self.labels}, self.dataset_path)
print(f"数据集已保存至 '{self.dataset_path}',共有 {len(self.features)} 个样本。")
except Exception as e:
print(f"保存数据集时出错: {e}")
def midi_to_features(self, stream) -> np.ndarray:
"""
将 music21 流对象转换为特征序列。
"""
features = []
for note in stream.flat.notesAndRests:
if isinstance(note, music21.note.Note):
pitch = note.pitch.midi
duration = note.quarterLength
volume = note.volume.velocity if note.volume.velocity else 64 # 默认音量
elif isinstance(note, music21.note.Rest):
pitch = 0 # 休止符音高设为 0
duration = note.quarterLength
volume = 0
else:
continue
features.append([pitch, duration, volume])
return np.array(features, dtype=np.float32)
def _augment_data(self, feature, label):
# 实现数据增强:随机抽取、拼接、动态和快慢变化
# 例如,随机调整动态和节奏
feature_aug = np.copy(feature)
label_aug = np.copy(label)
# 随机调整音量(动态)
volume_change = np.random.uniform(0.8, 1.2)
feature_aug[:, 2] *= volume_change
feature_aug[:, 2] = np.clip(feature_aug[:, 2], 0, 127)
# 随机调整时值(节奏变化)
duration_change = np.random.uniform(0.9, 1.1)
feature_aug[:, 1] *= duration_change
# 根据变化调整标签
# 例如,如果节奏变化显著,调整 'tempo' 标签
if duration_change > 1.05:
# 更快的节奏
tempo_tags = ['Fast']
elif duration_change < 0.95:
# 更慢的节奏
tempo_tags = ['Slow']
else:
tempo_tags = ['Medium']
# 更新 'tempo' 标签
for tempo in ['Slow', 'Medium', 'Fast']:
label_aug[self.evaluator.all_tags.index(tempo)] = 0
tempo_index = self.evaluator.all_tags.index(tempo_tags[0])
label_aug[tempo_index] = 1
return feature_aug, label_aug
class RandomDataset(Dataset):
def __init__(self, size: int, max_length: int):
"""
随机生成数据集。
参数:
size (int): 数据集大小。
max_length (int): 每个样本的序列长度。
"""
self.size = size
self.max_length = max_length
def __len__(self):
return self.size
def __getitem__(self, idx):
# 随机音高范围在21(A0)到108(C8)之间
pitch = np.random.randint(21, 109, size=(self.max_length, 1)).astype(np.float32)
# 随机选择可接受的时值
acceptable_durations = [0.25, 0.333, 0.5, 0.666, 0.75, 1.0, 1.5, 2.0, 3.0, 4.0]
duration = np.random.choice(acceptable_durations, size=(self.max_length, 1)).astype(np.float32)
# 随机音量在60到100之间
volume = np.random.randint(40, 70, size=(self.max_length, 1)).astype(np.float32)
features = np.concatenate([pitch, duration, volume], axis=-1) # [max_length, 3]
return torch.tensor(features, dtype=torch.float32)
class MusicGenerator:
def __init__(self, model: nn.Module, evaluator, device: torch.device, model_path: str, optimizer=None, optimizer_path: str=None, writer: SummaryWriter=None):
self.model = model.to(device)
self.evaluator = evaluator
self.device = device
self.model_path = model_path
self.optimizer = optimizer
self.optimizer_path = optimizer_path
self.writer = writer
self._load_model()
# 定义归一化和反归一化参数
self.min_pitch = 21
self.max_pitch = 108
self.min_duration = 0.15
self.max_duration = 1.5
self.min_volume = 40
self.max_volume = 85
def _load_model(self):
"""自动载入已存在的模型权重,如果存在的话。"""
if os.path.exists(self.model_path):
try:
self.model.load_state_dict(torch.load(self.model_path, map_location=self.device))
self.model.to(self.device)
self.model.eval()
print(f"已成功载入模型权重从 '{self.model_path}'。")
except Exception as e:
print(f"载入模型权重时出错: {e},将初始化新模型。")
else:
print("未找到已保存的模型,将初始化新模型。")
# 加载优化器状态
if self.optimizer and self.optimizer_path and os.path.exists(self.optimizer_path):
try:
self.optimizer.load_state_dict(torch.load(self.optimizer_path, map_location=self.device))
print(f"已成功载入优化器状态从 '{self.optimizer_path}'。")
except Exception as e:
print(f"载入优化器状态时出错: {e},将初始化新优化器。")
else:
if self.optimizer and self.optimizer_path:
print("未找到已保存的优化器状态,将初始化新优化器。")
def save_model(self, epoch: int, loss: float):
"""保存当前模型的权重和优化器状态。"""
try:
torch.save(self.model.state_dict(), self.model_path, _use_new_zipfile_serialization=False)
if self.optimizer and self.optimizer_path:
torch.save(self.optimizer.state_dict(), self.optimizer_path, _use_new_zipfile_serialization=False)
print(f"模型和优化器已保存至 '{self.model_path}' 和 '{self.optimizer_path}'。")
if self.writer:
self.writer.add_scalar('Loss/Save', loss, epoch)
except Exception as e:
print(f"保存模型或优化器时出错: {e}")
def train_epoch(self, dataloader: DataLoader, optimizer, criterion_music, criterion_tags, epoch: int):
"""
训练一个 epoch。
"""
self.model.train()
total_loss = 0.0
for batch_idx, (batch_features, batch_labels) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch}", leave=False)):
batch_features = batch_features.to(self.device) # [batch_size, seq_len, input_dim]
batch_labels = batch_labels.to(self.device) # [batch_size, num_tags]
inputs = batch_features[:, :-1, :] # [batch_size, seq_len-1, input_dim]
targets = batch_features[:, -1, :] # [batch_size, input_dim]
optimizer.zero_grad()
music_output, tag_probabilities = self.model(inputs) # 音乐输出: [batch, seq_len-1, output_dim]
# 只对最后一个时间步的输出进行损失计算
loss_music = criterion_music(music_output[:, -1, :], targets)
# 使用数据集中的标签
loss_tags = criterion_tags(tag_probabilities[:, -1, :], batch_labels)
# 总损失
loss = loss_music + loss_tags
loss.backward()
# 梯度裁剪
clip_grad_norm_(self.model.parameters(), max_norm=1.0)
optimizer.step()
total_loss += loss.item()
if self.writer:
self.writer.add_scalar('Loss/Train', loss.item(), epoch * len(dataloader) + batch_idx)
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch} 平均损失: {avg_loss:.4f}")
return avg_loss
def train_epoch_gan(self, dataloader, optimizer_generator, optimizer_discriminator, criterion_music, criterion_tags, criterion_discriminator, discriminator, epoch):
"""
使用对抗训练的方法训练一个 epoch。
"""
self.model.train()
discriminator.train()
total_loss = 0.0
for batch_idx, (batch_features, batch_labels) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch}", leave=False)):
batch_features = batch_features.to(self.device) # [batch_size, seq_len, input_dim]
batch_labels = batch_labels.to(self.device) # [batch_size, num_tags]
batch_size = batch_features.size(0)
seq_len = batch_features.size(1)
# ---------------------
# 训练判别器
# ---------------------
# 使用真实数据
real_data = batch_features # [batch_size, seq_len, input_dim]
real_labels = torch.ones(batch_size, 1).to(self.device)
# 使用生成器生成假数据
noise = torch.rand(batch_size, seq_len, 3).to(self.device) # 随机噪声在 [0,1],与归一化后的特征一致
generated_features = torch.zeros_like(batch_features).to(self.device)
for i in range(seq_len):
input_noise = noise[:, :i+1, :]
fake_data, _ = self.model(input_noise)
generated_features[:, i, :] = fake_data[:, -1, :]
fake_data = generated_features.detach() # [batch_size, seq_len, input_dim]
fake_labels = torch.zeros(batch_size, 1).to(self.device)
# 计算判别器在真实数据上的损失
optimizer_discriminator.zero_grad()
output_real = discriminator(real_data)
loss_real = criterion_discriminator(output_real, real_labels)
# 计算判别器在假数据上的损失
output_fake = discriminator(fake_data)
loss_fake = criterion_discriminator(output_fake, fake_labels)
# 总损失并反向传播
loss_discriminator = (loss_real + loss_fake) / 2
loss_discriminator.backward()
optimizer_discriminator.step()
# ---------------------
# 训练生成器
# ---------------------
optimizer_generator.zero_grad()
# 生成假数据并计算生成器的损失,目标是让判别器相信这些数据是真实的
output_fake_for_generator = discriminator(fake_data)
loss_generator_adv = criterion_discriminator(output_fake_for_generator, real_labels) # 生成器的对抗损失
# 计算生成器的音乐特征和标签损失
music_output, tag_probabilities = self.model(noise)
targets = batch_features[:, -1, :] # 真实的最后一个特征
loss_music = criterion_music(music_output[:, -1, :], targets)
# 使用数据集中的标签
loss_tags = criterion_tags(tag_probabilities[:, -1, :], batch_labels)
# 总损失
loss_generator = loss_generator_adv + loss_music + loss_tags
loss_generator.backward()
# 梯度裁剪
clip_grad_norm_(self.model.parameters(), max_norm=1.0)
optimizer_generator.step()
total_loss += loss_generator.item()
if self.writer:
#self.writer.add_scalar('Loss/Generator', loss_generator.item(), epoch * len(dataloader) + batch_idx)
#self.writer.add_scalar('Loss/Discriminator', loss_discriminator.item(), epoch * len(dataloader) + batch_idx)
pass
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch} 平均生成器损失: {avg_loss:.4f}")
return avg_loss
def generate_music(self, tag_conditions: dict={
'emotions': 'Neutral',
'genres': 'Classical',
'tempo': 'Medium',
'instrumentation': 'Piano',
'harmony': 'Simple',
'dynamics': 'Dynamic',
'rhythm': 'Simple' # 或 'Complex'
}, max_length=100, temperature=1.0) -> music21.stream.Stream:
"""
根据标签生成音乐。
"""
self.model.eval()
acceptable_durations = [0.25, 0.333, 0.5, 0.666, 0.75, 1.0, 1.5, 2.0, 3.0, 4.0]
generated_features = []
with torch.no_grad():
# 随机选择一个调性
key_str = choice(['C', 'G', 'D', 'A', 'E', 'B', 'F#', 'C#', 'F', 'Bb', 'Eb', 'Ab', 'Db', 'Gb', 'Cb'])
scale_notes = get_scale_notes(key_str)
# 初始输入(随机特征)
input_feature = torch.zeros(1, 1, 3).to(self.device) # [batch_size=1, seq_len=1, input_dim=3]
for _ in range(max_length):
music_output, tag_probabilities = self.model(input_feature) # [1, seq_len, 3] and [1, seq_len, num_tags]
music_output_np = music_output.cpu().numpy()[0, -1]
# 应用温度控制
music_output_np = music_output_np / temperature
# 使用概率分布进行采样
pitch = int(round(music_output_np[0]))
duration = music_output_np[1]
volume = int(round(music_output_np[2]))
# 增加随机变动
pitch += int(np.random.uniform(-2, 2))
pitch = max(21, min(108, pitch)) # 限制在钢琴键范围内
# 将音高映射到最近的音阶音符
if pitch not in scale_notes:
pitch = min(scale_notes, key=lambda x: abs(x - pitch))
duration += np.random.uniform(-0.1, 0.1)
try:
duration = min(acceptable_durations, key=lambda x: abs(x - duration))
except ValueError:
duration = 1.0 # 默认时值
volume += int(np.random.uniform(-10, 10))
volume = max(70, min(100, volume)) # 限制音量范围
# 保存特征
generated_features.append([pitch, duration, volume])
# 更新输入
next_input = torch.tensor([[pitch, duration, volume]], dtype=torch.float32).to(self.device).unsqueeze(0) # [1, 1, 3]
input_feature = torch.cat((input_feature, next_input), dim=1) # 增加序列长度
# 转换为 numpy 数组
generated_features_array = np.array(generated_features, dtype=np.float32)
generated_stream = composer_from_features(generated_features_array, key_str)
# 评估标签
tag_scores = self.evaluator.evaluate_tags(generated_stream)
print("生成的音乐标签:", tag_scores)
# 根据情感进行判断并保存
high_score_emotions = ['Happy', 'Peaceful']
if tag_scores.get('emotions') in high_score_emotions:
# 将生成的 MIDI 转换为 WAV
midi_filename = f'high_score_{int(time.time())}.mid'
generated_stream.write('midi', fp=os.path.join(Gbase, midi_filename))
wav_file = self.custom_midi_to_wav(generated_stream, os.path.join(Gbase, f'high_score_{int(time.time())}.wav'))
print(f"高评分音乐已保存为 WAV 文件: '{wav_file}'")
return generated_stream
def addMusicToVideo(self, videoPath, tagConditions={
'emotions': 'Neutral',
'genres': 'Classical',
'tempo': 'Medium',
'instrumentation': 'Piano',
'harmony': 'Simple',
'dynamics': 'Dynamic',
'rhythm': 'Simple' # 或 'Complex'
}, outputPath=None):
"""
根据指定的标签条件生成音乐,并将其附加到视频中,确保音乐的长度与视频一致。
参数:
videoPath (str): 输入视频的路径。
tagConditions (dict): 用于生成音乐的标签条件。
outputPath (str, optional): 输出视频的路径。如果未指定,将在原路径基础上添加 '_with_music'。
返回:
str: 输出的视频路径。
"""
# 1. 获取视频时长
try:
video = VideoFileClip(videoPath)
duration = video.duration
print(f"视频时长: {duration} 秒。")
except Exception as e:
print(f"无法载入视频: {e}")
return None
if not outputPath:
base, ext = os.path.splitext(videoPath)
outputPath = f"{base}_with_music{ext}"
if os.path.exists (outputPath):return outputPath
# 2. 初始化音频拼接
combined_audio = AudioSegment.silent(duration=0) # 初始化为空音频
total_generated_duration = 0 # 总生成时长(毫秒)
chunk_duration_seconds = 10 # 每次生成音讯的预估时长(秒),根据需要调整
crossfade_duration = 500 # 淡入淡出持续时间(毫秒)
# 3. 逐段生成音频
print("逐段生成音乐中...")
while total_generated_duration < duration * 1000: # pydub 使用毫秒
# 根据剩余时长生成音乐,确保不生成过多
remaining_duration_ms = duration * 1000 - total_generated_duration
remaining_duration_seconds = remaining_duration_ms / 1000.0
current_chunk_length = min(chunk_duration_seconds, remaining_duration_seconds)
# 计算所需的音符数量,假设每个音符平均约0.5秒
estimated_max_length = int(current_chunk_length / 0.5) * 2 # 调整因子根据实际情况
# 生成音乐流
generated_stream = self.generate_music(max_length=100)
# 转换为 WAV 文件
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as wav_temp:
wav_filename = wav_temp.name
wav_path = self.custom_midi_to_wav(generated_stream, wav_filename)
print(f"生成的 WAV 已保存为 '{wav_path}'。")
# 加载生成的音频
try:
generated_audio = AudioSegment.from_wav(wav_path)
except Exception as e:
print(f"加载生成的音频时出错: {e}")
os.remove(wav_path)
#break
# 拼接音频,应用淡入淡出效果
if len(combined_audio) == 0:
# 第一段音频,仅应用淡入
generated_audio = generated_audio.fade_in(crossfade_duration)
combined_audio += generated_audio
else:
# 之后的音频段,应用淡出和淡入,并设置 crossfade
generated_audio = generated_audio.fade_in(crossfade_duration)
combined_audio = combined_audio.append(generated_audio, crossfade=crossfade_duration)
total_generated_duration = len(combined_audio)
# 删除临时 WAV 文件
try:
os.remove(wav_path)
print(f"已删除临时 WAV 文件 '{wav_path}'。")
except Exception as e:
print(f"删除临时 WAV 文件时出错: {e}")
# 4. 剪切音频以匹配视频时长
final_audio = combined_audio[:int(duration * 1000)] # pydub 使用毫秒为单位
final_wav_path = tempfile.mktemp(suffix='.wav')
final_audio.export(final_wav_path, format="wav")
print(f"最终剪切后的 WAV 已保存为 '{final_wav_path}'。")
# 5. 定义输出视频路径
if not outputPath:
base, ext = os.path.splitext(videoPath)
outputPath = f"{base}_with_music{ext}"
# 6. 使用 moviepy 将音频与视频结合
try:
# 载入视频和音频
video_clip = VideoFileClip(videoPath)
audio_clip = AudioFileClip(final_wav_path)
# 设置音频,确保音频长度与视频一致
audio_clip = audio_clip.set_duration(video_clip.duration)
# 将音频附加到视频
video_with_audio = video_clip.set_audio(audio_clip)
# 输出最终视频
video_with_audio.write_videofile(outputPath, codec='libx264', audio_codec='aac', verbose=False, logger=None)
print(f"输出视频已保存为 '{outputPath}'。")
except Exception as e:
print(f"结合视频和音频时出错: {e}")
return None
finally:
# 清理 moviepy 生成的资源
if 'video_clip' in locals():
video_clip.close()
if 'audio_clip' in locals():
audio_clip.close()
if 'video_with_audio' in locals():
video_with_audio.close()
# 7. 清理临时文件
try:
os.remove(final_wav_path)
print("最终临时 WAV 文件已删除。")
except Exception as e:
print(f"删除最终临时 WAV 文件时出错: {e}")
return outputPath
def custom_midi_to_wav(self, stream: music21.stream.Stream, wav_filename: str, sample_rate=44100) -> str:
"""
自定义的 MIDI 到 WAV 转换函数,使用数学公式生成高质量的音频。
改进后:声音更加悦耳,符合音符、音阶、乐器的基本要求。
"""
import math
# 合成参数
envelope_attack = 0.01 # 攻击时间
envelope_decay = 0.1 # 衰减时间
envelope_sustain = 0.8 # 持续水平
envelope_release = 0.2 # 释放时间
# 获取节奏信息
metronome_marks = list(stream.metronomeMarkBoundaries())
bpm = 120 # 默认 BPM
if metronome_marks:
# 检查是否存在 MetronomeMark 对象
for mark in metronome_marks:
if isinstance(mark[2], music21.tempo.MetronomeMark) and mark[2].number:
bpm = mark[2].number
break
# 生成时间轴
notes = list(stream.flat.getElementsByClass(['Note', 'Chord', 'Rest']))
if not notes:
print("没有音符可生成音频。")
return ""
# 计算整体时长
total_duration = stream.duration.quarterLength * 60 / bpm
total_samples = int(total_duration * sample_rate) + 1
audio = np.zeros(total_samples)
current_time = 0
# 定义乐器的谐波系数,模拟钢琴的谐波
harmonic_coeffs = [1.0, 0.5, 0.25, 0.1, 0.05]
for element in notes:
if isinstance(element, music21.note.Rest):
# 休止符,更新当前时间
duration = element.quarterLength * 60 / bpm # 秒
current_time += duration
continue
elif isinstance(element, music21.note.Note):
frequencies = [element.pitch.frequency]
elif isinstance(element, music21.chord.Chord):
frequencies = [p.frequency for p in element.pitches]
else:
continue
duration = element.quarterLength * 60 / bpm # 秒
# 音量固定为70%
volume = 0.6
# 生成波形时间轴
t = np.linspace(0, duration, int(duration * sample_rate), False)
waveform = np.zeros_like(t)
for freq in frequencies:
note_waveform = np.zeros_like(t)
for idx, coeff in enumerate(harmonic_coeffs):
harmonic_freq = freq * (idx + 1)
note_waveform += coeff * np.sin(2 * np.pi * harmonic_freq * t)
waveform += note_waveform
# 归一化振幅(避免多个频率叠加导致音量过高)
waveform /= len(frequencies) * sum(harmonic_coeffs)
# 添加 ADSR 包络
attack_samples = int(envelope_attack * sample_rate)
decay_samples = int(envelope_decay * sample_rate)
release_samples = int(envelope_release * sample_rate)
sustain_samples = len(waveform) - attack_samples - decay_samples - release_samples
if sustain_samples < 0:
# 调整 ADSR 以适应短音符
total_envelope = envelope_attack + envelope_decay + envelope_release
attack_ratio = envelope_attack / total_envelope
decay_ratio = envelope_decay / total_envelope
release_ratio = envelope_release / total_envelope
attack_samples = int(len(waveform) * attack_ratio)
decay_samples = int(len(waveform) * decay_ratio)
release_samples = len(waveform) - attack_samples - decay_samples
sustain_samples = 0
envelope = np.concatenate([
np.linspace(0, 1, attack_samples, False),
np.linspace(1, envelope_sustain, decay_samples, False),
np.full(sustain_samples, envelope_sustain),
np.linspace(envelope_sustain, 0, release_samples, False)
])
# 调整 envelope 长度
envelope = envelope[:len(waveform)]
waveform *= envelope
waveform *= volume
# 计算样本索引
start_sample = int(current_time * sample_rate)
end_sample = start_sample + len(waveform)
if end_sample > total_samples:
end_sample = total_samples
waveform = waveform[:end_sample - start_sample]
# 合成音频
audio[start_sample:end_sample] += waveform
# 更新当前时间
current_time += duration
# 防止削波
max_val = np.max(np.abs(audio))
if max_val > 1:
audio /= max_val
# 将音频转换为16位整数
audio_int16 = np.int16(audio * 32767)
# 写入 WAV 文件
wav_path = os.path.join(os.getcwd(), wav_filename)
with wave.open(wav_path, 'w') as wav_file:
n_channels = 2
sampwidth = 2 # 2 bytes for int16
framerate = sample_rate
n_frames = len(audio_int16)
comptype = "NONE"
compname = "not compressed"
wav_file.setparams((n_channels, sampwidth, framerate, n_frames, comptype, compname))
wav_file.writeframes(audio_int16.tobytes())
return wav_path
class AdvancedMusicGenerator(MusicGenerator):
def __init__(self, model: nn.Module, evaluator, device: torch.device, model_path: str, optimizer=None, optimizer_path: str=None, writer: SummaryWriter=None):
super().__init__(model, evaluator, device, model_path, optimizer, optimizer_path, writer)
# 可以在此添加更多的初始化参数或方法
# 这里可以覆盖或新增更多方法以进一步增强功能
def trainModel():
# 初始化 TensorBoard
writer = SummaryWriter(log_dir=os.path.join(Gbase, 'runs'))
# 初始化标签评估器
evaluator = MusicTagEvaluator.load(EvaluatorPath)
# 获取唯一的标签数量
num_tags = len(evaluator.all_tags)
# 定义模型参数
input_dim = 3 # 音高、时值和音量
d_model = 512 # 增加 Transformer 模型维度
nhead = 8 # 多头注意力头数
num_encoder_layers = 8 # 增加 Transformer 编码器层数
dim_feedforward = 2048 # 增加前馈层维度
output_dim = 3 # 预测音高、时值和音量
# 初始化模型
model = MusicGenerationModel(input_dim, d_model, nhead, num_encoder_layers, dim_feedforward, output_dim, num_tags)
# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
print(f"使用设备: {device}")
# 加载 MIDI 文件
midi_directory = os.path.join(Gbase, 'generateMIDI')
midi_files = []
if os.path.exists(midi_directory):
midi_files = [os.path.join(midi_directory, f) for f in os.listdir(midi_directory) if f.endswith('.mid') or f.endswith('.midi')]
print(f"在目录 '{midi_directory}' 中找到 {len(midi_files)} 个 MIDI 文件用于训练。")
else:
print(f"MIDI 文件目录 '{midi_directory}' 不存在,请确保该目录存在并包含 MIDI 文件。")
return # 退出函数
# 创建数据集和数据加载器
max_length = 100 # 根据需求调整
dataset_path = os.path.join(Gbase, 'mymusic.dataset')
dataset = MidiDataset(midi_files, max_length, dataset_path, evaluator)
datasetAug = MidiDatasetAug(midi_files, max_length, dataset_path, evaluator)
# 定义要采样的样本数量
sample_size = 30000 if torch.cuda.is_available() else 15000
sample_size1 = int(sample_size/10)
sample_size2 = int(sample_size/300)
total_samples = len(dataset)
if total_samples < sample_size:
print(f"数据集中只有 {total_samples} 个样本,无法采样 {sample_size} 个。请检查数据集。")
return
# 定义训练周期和学习率
epochs = 4 # 根据需要调整
learning_rate = 0.001
batch_size= 16 if torch.cuda.is_available() else 4
# 初始化生成器
optimizer_generator = optim.AdamW(model.parameters(), lr=learning_rate * 0.1)
generator = MusicGenerator(model, evaluator, device, model_path=ModelPath, optimizer=optimizer_generator, optimizer_path=OptimizerPath, writer=writer)
# 初始化判别器
discriminator = Discriminator(input_dim, d_model, nhead, num_encoder_layers, dim_feedforward).to(device)
optimizer_discriminator = optim.AdamW(discriminator.parameters(), lr=learning_rate)
criterion_discriminator = nn.BCELoss()
# 尝试加载判别器模型和优化器状态
if os.path.exists(DiscriminatorModelPath):
discriminator.load_state_dict(torch.load(DiscriminatorModelPath, map_location=device))
print(f"已成功载入判别器模型权重从 '{DiscriminatorModelPath}'。")
if os.path.exists(DiscriminatorOptimizerPath):
optimizer_discriminator.load_state_dict(torch.load(DiscriminatorOptimizerPath, map_location=device))
print(f"已成功载入判别器优化器状态从 '{DiscriminatorOptimizerPath}'。")
indices = list(range(total_samples))
random_indices = random.sample(indices, sample_size)
random_indices1 = random.sample(indices, sample_size1)
random_indices2 = random.sample(indices, sample_size2)
random_indicesAug= random.sample(indices, sample_size)
sampler = SubsetRandomSampler(random_indices)
dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, num_workers=2)
sampler1 = SubsetRandomSampler(random_indices1)
dataloaderAug = DataLoader(datasetAug, batch_size=batch_size, sampler=sampler, num_workers=2)
dataloader1 = DataLoader(datasetAug, batch_size=8, sampler=sampler1, num_workers=2)
sampler2 = SubsetRandomSampler(random_indices2)
dataloader2 = DataLoader(dataset, batch_size=batch_size, sampler=sampler2, num_workers=2)
sampler3 = SubsetRandomSampler(random_indices2)
dataloader3 = DataLoader(datasetAug, batch_size=batch_size, sampler=sampler2, num_workers=2)
# 开始对抗训练
print("開始訓練...")
for epoch in range(1, epochs + 1):
try:
avg_loss = generator.train_epoch(
dataloader,
optimizer_generator,
nn.MSELoss(),
nn.BCELoss(),
epoch
)
# 保存判别器模型和优化器
"""
generator.save_model(epoch, avg_loss)
torch.save(discriminator.state_dict(), DiscriminatorModelPath)
torch.save(optimizer_discriminator.state_dict(), DiscriminatorOptimizerPath)
print(f"判别器模型和优化器已保存至 '{DiscriminatorModelPath}' 和 '{DiscriminatorOptimizerPath}'。")
"""
except KeyboardInterrupt:
print("训练过程被手动中断。")
break
except Exception as e:
print(f"在训练 epoch {epoch} 时发生错误: {e}")
if epoch!=4:continue
print("開始強化訓練...")
try:
avg_loss = generator.train_epoch(
dataloaderAug,
optimizer_generator,
nn.MSELoss(),
nn.BCELoss(),
epoch
)
# 保存判别器模型和优化器
#"""
generator.save_model(epoch, avg_loss)
torch.save(discriminator.state_dict(), DiscriminatorModelPath)
torch.save(optimizer_discriminator.state_dict(), DiscriminatorOptimizerPath)
print(f"判别器模型和优化器已保存至 '{DiscriminatorModelPath}' 和 '{DiscriminatorOptimizerPath}'。")
# 保存评估器
#evaluator.save(EvaluatorPath)
#"""
except KeyboardInterrupt:
print("训练过程被手动中断。")
break
except Exception as e:
print(f"在训练 epoch {epoch} 时发生错误: {e}")
print("開始對抗訓練...")
try:
avg_loss = generator.train_epoch_gan(
dataloader1,
optimizer_generator,
optimizer_discriminator,
nn.MSELoss(),
nn.BCELoss(),
criterion_discriminator,
discriminator,
epoch
)
"""
generator.save_model(epoch, avg_loss)
# 保存判别器模型和优化器
torch.save(discriminator.state_dict(), DiscriminatorModelPath)
torch.save(optimizer_discriminator.state_dict(), DiscriminatorOptimizerPath)
print(f"判别器模型和优化器已保存至 '{DiscriminatorModelPath}' 和 '{DiscriminatorOptimizerPath}'。")
# 保存评估器
#evaluator.save(EvaluatorPath)
#"""
except KeyboardInterrupt:
print("训练过程被手动中断。")
break
except Exception as e:
print(f"在训练 epoch {epoch} 时发生错误: {e}")
continue # 继续下一个 epoch
# 关闭 TensorBoard writer
writer.close()
def loadMusicGenerator():
# 初始化 TensorBoard
writer = SummaryWriter(log_dir=os.path.join(Gbase, 'runs'))
# 加载标签评估器
evaluator = MusicTagEvaluator()
#.load(EvaluatorPath)
# 获取唯一的标签数量
num_tags = len(evaluator.all_tags)
# 定义模型参数
input_dim = 3 # 音高、时值和音量
d_model = 512 # 必须与训练时的模型参数一致
nhead = 8
num_encoder_layers = 8
dim_feedforward = 2048
output_dim = 3
# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
# 初始化模型
model = MusicGenerationModel(input_dim, d_model, nhead, num_encoder_layers, dim_feedforward, output_dim, num_tags).to(device)
# 初始化生成器
generator = AdvancedMusicGenerator(model, evaluator, device, model_path=ModelPath, writer=writer)
return generator, evaluator
MyMusicGenerator, MyMusicTagEvaluator = loadMusicGenerator()
import gradio as gr
import numpy as np
import time
import os
# Assuming your existing functions and setup are defined above
def generate_music(*tags, use_random=False):
if use_random:
tags_dict = randomMusicTags()
else:
# Assuming the order of tags matches with MUSIC_TAGS.keys()
tags_dict = dict(zip(MUSIC_TAGS.keys(), tags))
# Generate music using your existing function (which should return a path to a wav file)
generated_stream = MyMusicGenerator.generate_music(tag_conditions=tags_dict, max_length=130, temperature=np.random.uniform(0.7, 1.1))
# Save the generated stream as a MIDI file
midi_filename = f"music_{int(time.time())}.mid"
mid_path = os.path.join(Gbase, midi_filename)
generated_stream.write('midi', fp=mid_path)
# Convert MIDI to WAV (make sure this function exists)
wav_file = MyMusicGenerator.custom_midi_to_wav(generated_stream, os.path.join(Gbase, f"{midi_filename[:-4]}.wav"))
return wav_file, tags_dict
# Define the interface
with gr.Blocks() as demo:
gr.Markdown("# Music Generation with Tags")
with gr.Row():
with gr.Column():
# List comprehension to create dropdowns for each tag category
tag_inputs = [
gr.Dropdown(value=MUSIC_TAGS[category][0] ,choices=MUSIC_TAGS[category], label=category.capitalize())
for category in MUSIC_TAGS.keys()
]
with gr.Column():
use_random = gr.Checkbox(label="Use Random Tags")
generate_btn = gr.Button("Generate Music")
output_audio = gr.Audio(label="Generated Music")
output_tags = gr.JSON(label="Generated Tags")
# Pass the list of dropdowns directly instead of using gr.Group
generate_btn.click(
fn=generate_music,
inputs=[*tag_inputs, use_random],
outputs=[output_audio, output_tags]
)
# Launch the interface
demo.launch()