Spaces:
Running
Running
import os | |
import sys | |
import time, random | |
from random import choice | |
from typing import List, Dict | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import music21 | |
import numpy as np | |
from sklearn.preprocessing import MultiLabelBinarizer | |
from tqdm import tqdm | |
import wave | |
import struct | |
import ffmpeg | |
import tempfile | |
from pydub import AudioSegment | |
from moviepy.editor import VideoFileClip, AudioFileClip | |
from torch.utils.data import Dataset, DataLoader | |
from torch.utils.data import DataLoader, SubsetRandomSampler | |
from torch.utils.tensorboard import SummaryWriter | |
from torch.nn.utils import clip_grad_norm_ | |
from torch.optim.lr_scheduler import ReduceLROnPlateau | |
from torch.utils.tensorboard import SummaryWriter | |
# 设置基础路径 | |
Gbase = "./" | |
cache_dir = "./hf/" | |
try: | |
import google.colab | |
from google.colab import drive | |
IN_COLAB = True | |
drive.mount('/gdrive', force_remount=True) | |
Gbase = "/gdrive/MyDrive/generate/" | |
cache_dir = "/gdrive/MyDrive/hf/" | |
sys.path.append(Gbase) | |
except: | |
IN_COLAB = False | |
Gbase = "./" | |
cache_dir = "./hf/" | |
# 定义模型保存路径 | |
ModelPath = os.path.join(Gbase, 'music_generation_model.pth') | |
OptimizerPath = os.path.join(Gbase, 'optimizer_state.pth') | |
DiscriminatorModelPath = os.path.join(Gbase, 'discriminator_model.pth') | |
DiscriminatorOptimizerPath = os.path.join(Gbase, 'discriminator_optimizer_state.pth') | |
EvaluatorPath = os.path.join(Gbase, 'music_tag_evaluator.pkl') | |
# 定义音乐标签 | |
MUSIC_TAGS = { | |
'emotions': ['Happy', 'Sad', 'Angry', 'Peaceful', 'Neutral'], | |
'genres': ['Classical', 'Jazz', 'Rock', 'Electronic'], | |
'tempo': ['Slow', 'Medium', 'Fast'], | |
'instrumentation': ['Piano', 'Guitar', 'Synthesizer'], | |
'harmony': ['Consonant', 'Dissonant', 'Complex', 'Simple'], | |
'dynamics': ['Dynamic', 'Static'], | |
'rhythm': ['Simple', 'Complex'] | |
} | |
def randomMusicTags(): | |
return {k: choice(MUSIC_TAGS[k]) for k in MUSIC_TAGS.keys()} | |
print("随机生成的音乐标签:", randomMusicTags()) | |
def get_scale_notes(key_str: str, octave_range=(2, 6)) -> List[int]: | |
""" | |
根据调性返回所属音阶的 MIDI 音高列表。 | |
""" | |
key = music21.key.Key(key_str) | |
scale_notes = [] | |
for octave in range(octave_range[0], octave_range[1] + 1): | |
pitches = key.getScale().getPitches(f"{key_str}{octave}") | |
for pitch in pitches: | |
scale_notes.append(pitch.midi) | |
return scale_notes | |
def composer_from_features(features: np.ndarray, key_str: str) -> music21.stream.Stream: | |
""" | |
将特征转换为 music21.stream.Stream 对象,并确保音符遵循指定音阶。 | |
""" | |
s = music21.stream.Stream() | |
# 设置节奏(BPM),默认 120 BPM | |
tempo = music21.tempo.MetronomeMark(number=120) | |
s.append(tempo) | |
# 设置调性 | |
tonality = music21.key.Key(key_str) | |
s.append(tonality) | |
# 获取音阶音符 | |
scale_notes = get_scale_notes(key_str) | |
# 定义可接受的时值 | |
acceptable_durations = [0.25, 0.333, 0.5, 0.666, 0.75, 1.0, 1.5, 2.0, 3.0, 4.0] | |
for feature in features: | |
pitch = int(round(feature[0])) | |
duration = feature[1] | |
volume = feature[2] | |
# 将时值量化为最近的可接受值 | |
duration = min(acceptable_durations, key=lambda x: abs(x - duration)) | |
# 确保音高在 21 (A0) 到 108 (C8) 之间 | |
pitch = max(21, min(108, pitch)) | |
# 将音高映射到最近的音阶音符 | |
if pitch not in scale_notes: | |
pitch = min(scale_notes, key=lambda x: abs(x - pitch)) | |
# 确保音量在 0 到 127 之间 | |
volume = max(0, min(127, volume)) | |
if pitch == 0: | |
# 休止符 | |
r = music21.note.Rest(quarterLength=duration) | |
s.append(r) | |
else: | |
n = music21.note.Note(midi=pitch, quarterLength=duration) | |
n.volume.velocity = volume | |
s.append(n) | |
return s | |
import pickle | |
class MusicTagEvaluator: | |
def __init__(self): | |
# 定义所有标签 | |
self.MUSIC_TAGS = MUSIC_TAGS | |
# 展平成所有标签并移除重复项 | |
all_tags = [] | |
for category in self.MUSIC_TAGS: | |
all_tags.extend(self.MUSIC_TAGS[category]) | |
self.all_tags = list(set(all_tags)) # 移除重复的标签 | |
self.mlb = MultiLabelBinarizer() | |
self.mlb.fit([self.all_tags]) | |
def save(self, path): | |
with open(path, 'wb') as f: | |
pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL) | |
print(f"评估器已保存至 '{path}'。") | |
def load(path): | |
if os.path.exists(path): | |
with open(path, 'rb') as f: | |
evaluator = pickle.load(f) | |
print(f"评估器已从 '{path}' 加载。") | |
return evaluator | |
else: | |
print(f"评估器文件 '{path}' 不存在,将创建新的评估器。") | |
return MusicTagEvaluator() | |
def evaluate_tags_from_features(self, features: np.ndarray) -> List[str]: | |
""" | |
根据特征评估标签。 | |
""" | |
# 随机选择一个调性以生成音乐 | |
key_str = choice(['C', 'G', 'D', 'A', 'E', 'B', 'F#', 'C#', 'F', 'Bb', 'Eb', 'Ab', 'Db', 'Gb', 'Cb']) | |
s = composer_from_features(features, key_str) | |
tag_scores = self.evaluate_tags(s) | |
tags = [] | |
# 根据评分分配标签 | |
for category in self.MUSIC_TAGS: | |
tag = tag_scores.get(category) | |
if tag in self.MUSIC_TAGS[category]: | |
tags.append(tag) | |
return tags | |
def evaluate_tags(self, generated_music): | |
""" | |
根据生成的音乐评估标签。 | |
""" | |
tag_scores = {} | |
# 音高范围计算 | |
pitch_values = [note.pitch.midi for note in generated_music.recurse().notes if isinstance(note, music21.note.Note)] | |
pitch_range = max(pitch_values) - min(pitch_values) if pitch_values else 0 | |
# 单独评估各项 | |
harmony_tag = self._evaluate_harmony(generated_music) | |
rhythm_tag = self._evaluate_rhythm(generated_music) | |
dynamics_tag = self._evaluate_dynamics(generated_music) | |
tempo_tag = self._evaluate_tempo(generated_music) | |
emotion_tag = self._evaluate_emotion(harmony_tag, rhythm_tag, dynamics_tag, tempo_tag) | |
# 标签集合 | |
tag_scores['emotions'] = emotion_tag | |
tag_scores['harmony'] = harmony_tag | |
tag_scores['rhythm'] = rhythm_tag | |
tag_scores['dynamics'] = dynamics_tag | |
tag_scores['tempo'] = tempo_tag | |
return tag_scores | |
def _evaluate_harmony(self, stream): | |
# 将音乐流和弦化 | |
chords = stream.chordify() | |
chord_types = [] | |
for element in chords.recurse(): | |
if isinstance(element, music21.chord.Chord): | |
chord_types.append(element.commonName) | |
# 根据和弦种类评估和声复杂度 | |
if any('diminished' in str(ct) or 'augmented' in str(ct) for ct in chord_types): | |
harmony_tag = 'Complex' | |
elif any('major' in str(ct) or 'minor' in str(ct) for ct in chord_types): | |
harmony_tag = 'Consonant' | |
else: | |
harmony_tag = 'Simple' | |
return harmony_tag | |
def _evaluate_rhythm(self, stream): | |
durations = [note.quarterLength for note in stream.flat.notes] | |
# 计算节奏复杂度,如时值种类的数量 | |
unique_durations = len(set(durations)) | |
if unique_durations > 5: | |
rhythm_tag = 'Complex' | |
else: | |
rhythm_tag = 'Simple' | |
return rhythm_tag | |
def _evaluate_dynamics(self, stream): | |
volumes = [note.volume.velocity for note in stream.flat.notes if note.volume.velocity is not None] | |
if not volumes: | |
dynamics_tag = 'Static' | |
else: | |
dynamics_range = max(volumes) - min(volumes) | |
if dynamics_range > 40: | |
dynamics_tag = 'Dynamic' | |
else: | |
dynamics_tag = 'Static' | |
return dynamics_tag | |
def _evaluate_tempo(self, stream): | |
tempos = [metronome.number for metronome in stream.recurse() if isinstance(metronome, music21.tempo.MetronomeMark)] | |
bpm = tempos[0] if tempos else 120 # 默认 BPM 为 120 | |
if bpm < 60: | |
return 'Slow' | |
elif 60 <= bpm < 120: | |
return 'Medium' | |
else: | |
return 'Fast' | |
def _evaluate_emotion(self, harmony_tag, rhythm_tag, dynamics_tag, tempo_tag): | |
# 根据和声、节奏、动态和节奏进行情感评估 | |
if harmony_tag == 'Complex' and rhythm_tag == 'Complex': | |
emotion = 'Angry' | |
elif harmony_tag == 'Consonant' and dynamics_tag == 'Dynamic' and tempo_tag == 'Fast': | |
emotion = 'Happy' | |
elif harmony_tag == 'Simple' and dynamics_tag == 'Static' and tempo_tag == 'Slow': | |
emotion = 'Peaceful' | |
elif harmony_tag == 'Consonant' and dynamics_tag == 'Static' and tempo_tag == 'Medium': | |
emotion = 'Neutral' | |
else: | |
emotion = 'Sad' | |
return emotion | |
class PositionalEncoding(nn.Module): | |
def __init__(self, d_model, max_len=5000): | |
super(PositionalEncoding, self).__init__() | |
pe = torch.zeros(max_len, d_model) # [max_len, d_model] | |
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # [max_len, 1] | |
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) # [d_model/2] | |
pe[:, 0::2] = torch.sin(position * div_term) # even indices | |
pe[:, 1::2] = torch.cos(position * div_term) # odd indices | |
pe = pe.unsqueeze(0) # [1, max_len, d_model] | |
self.register_buffer('pe', pe) | |
def forward(self, x): | |
""" | |
x: [batch_size, seq_len, d_model] | |
""" | |
x = x + self.pe[:, :x.size(1), :] | |
return x | |
class MusicGenerationModel(nn.Module): | |
def __init__(self, input_dim, d_model, nhead, num_encoder_layers, dim_feedforward, output_dim, num_tags, max_seq_length=500): | |
super(MusicGenerationModel, self).__init__() | |
self.d_model = d_model | |
self.input_linear = nn.Linear(input_dim, d_model) | |
self.positional_encoding = PositionalEncoding(d_model, max_len=max_seq_length) | |
encoder_layers = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=0.1) | |
self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_encoder_layers) | |
self.fc_music = nn.Linear(d_model, output_dim) | |
self.fc_tags = nn.Linear(d_model, num_tags) | |
self.sigmoid = nn.Sigmoid() | |
self.dropout = nn.Dropout(0.1) | |
def forward(self, src, src_mask=None, src_key_padding_mask=None): | |
""" | |
src: [batch_size, seq_len, input_dim] | |
""" | |
src = self.input_linear(src) * np.sqrt(self.d_model) # [batch_size, seq_len, d_model] | |
src = self.positional_encoding(src) # [batch_size, seq_len, d_model] | |
src = src.transpose(0, 1) # [seq_len, batch_size, d_model] | |
memory = self.transformer_encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask) # [seq_len, batch_size, d_model] | |
memory = memory.transpose(0, 1) # [batch_size, seq_len, d_model] | |
memory = self.dropout(memory) | |
music_output = self.fc_music(memory) # [batch_size, seq_len, output_dim] | |
tag_probabilities = self.sigmoid(self.fc_tags(memory)) # [batch_size, seq_len, num_tags] | |
return music_output, tag_probabilities | |
class Discriminator(nn.Module): | |
def __init__(self, input_dim, d_model, nhead, num_layers, dim_feedforward): | |
super(Discriminator, self).__init__() | |
self.input_linear = nn.Linear(input_dim, d_model) | |
self.positional_encoding = PositionalEncoding(d_model) | |
encoder_layers = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward) | |
self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers) | |
self.fc = nn.Linear(d_model, 1) | |
self.sigmoid = nn.Sigmoid() | |
self.dropout = nn.Dropout(0.1) | |
def forward(self, src, src_mask=None, src_key_padding_mask=None): | |
src = self.input_linear(src) * np.sqrt(self.input_linear.out_features) | |
src = self.positional_encoding(src) | |
src = src.transpose(0, 1) # [seq_len, batch_size, d_model] | |
output = self.transformer_encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask) | |
output = output.transpose(0, 1) # [batch_size, seq_len, d_model] | |
output = self.dropout(output) | |
# 取序列最后一个时间步作为判断依据,也可以选择取平均或其他方式 | |
output = self.fc(output[:, -1, :]) | |
output = self.sigmoid(output) | |
return output | |
class MidiDataset(Dataset): | |
def __init__(self, midi_files: List[str], max_length: int, dataset_path: str, evaluator: MusicTagEvaluator): | |
self.max_length = max_length | |
self.dataset_path = dataset_path | |
self.evaluator = evaluator | |
# 检查数据集文件是否存在 | |
if os.path.exists(self.dataset_path): | |
# 加载已预处理的数据集 | |
print(f"从 '{self.dataset_path}' 加载数据集") | |
try: | |
saved_data = torch.load(self.dataset_path) | |
self.features = saved_data['features'] | |
self.labels = saved_data['labels'] | |
print(f"成功加载数据集,共有 {len(self.features)} 个样本。") | |
except Exception as e: | |
print(f"加载数据集时出错: {e}") | |
self._process_midi_files(midi_files) | |
else: | |
# 处理 MIDI 文件并保存数据集 | |
self._process_midi_files(midi_files) | |
def __len__(self): | |
return len(self.features) | |
def getAug(self, idx): | |
feature = self.features[idx] # [seq_len, input_dim] | |
label = self.labels[idx] # [num_tags] | |
# 应用数据增强 | |
feature_aug, label_aug =self._augment_data(feature, label) | |
# 返回张量 | |
return torch.tensor(feature_aug, dtype=torch.float32), torch.tensor(label_aug, dtype=torch.float32) | |
def __getitem__(self, idx): | |
feature = self.features[idx] # [seq_len, input_dim] | |
label = self.labels[idx] # [num_tags] | |
# 应用数据增强 | |
feature_aug, label_aug =feature, label | |
# 返回张量 | |
return torch.tensor(feature_aug, dtype=torch.float32), torch.tensor(label_aug, dtype=torch.float32) | |
def _process_midi_files(self, midi_files): | |
print("处理 MIDI 文件以创建数据集...") | |
features_list = [] | |
labels_list = [] | |
for midi_file in midi_files: | |
try: | |
stream = music21.converter.parse(midi_file) | |
# 将音轨转换为特征 | |
features = self.midi_to_features(stream) | |
if len(features) < self.max_length: | |
# 跳过长度不足的样本 | |
continue | |
else: | |
# 将特征分割成长度为 max_length 的片段 | |
num_segments = len(features) // self.max_length | |
for i in range(num_segments): | |
segment = features[i*self.max_length : (i+1)*self.max_length] | |
if len(segment) < self.max_length: | |
continue # 跳过不完整的片段 | |
# 使用评估器为每个片段分配标签 | |
tags = self.evaluator.evaluate_tags_from_features(segment) | |
# 二值化标签 | |
tag_binarized = self.evaluator.mlb.transform([tags])[0] | |
features_list.append(segment) | |
labels_list.append(tag_binarized) | |
except Exception as e: | |
print(f"处理 {midi_file} 时出错: {e}") | |
self.features = features_list | |
self.labels = labels_list | |
# 保存数据集 | |
try: | |
torch.save({'features': self.features, 'labels': self.labels}, self.dataset_path) | |
print(f"数据集已保存至 '{self.dataset_path}',共有 {len(self.features)} 个样本。") | |
except Exception as e: | |
print(f"保存数据集时出错: {e}") | |
def midi_to_features(self, stream) -> np.ndarray: | |
""" | |
将 music21 流对象转换为特征序列。 | |
""" | |
features = [] | |
for note in stream.flat.notesAndRests: | |
if isinstance(note, music21.note.Note): | |
pitch = note.pitch.midi | |
duration = note.quarterLength | |
volume = note.volume.velocity if note.volume.velocity else 64 # 默认音量 | |
elif isinstance(note, music21.note.Rest): | |
pitch = 0 # 休止符音高设为 0 | |
duration = note.quarterLength | |
volume = 0 | |
else: | |
continue | |
features.append([pitch, duration, volume]) | |
return np.array(features, dtype=np.float32) | |
def _augment_data(self, feature, label): | |
# 实现数据增强:随机抽取、拼接、动态和快慢变化 | |
# 例如,随机调整动态和节奏 | |
feature_aug = np.copy(feature) | |
label_aug = np.copy(label) | |
# 随机调整音量(动态) | |
volume_change = np.random.uniform(0.8, 1.2) | |
feature_aug[:, 2] *= volume_change | |
feature_aug[:, 2] = np.clip(feature_aug[:, 2], 0, 127) | |
# 随机调整时值(节奏变化) | |
duration_change = np.random.uniform(0.9, 1.1) | |
feature_aug[:, 1] *= duration_change | |
# 根据变化调整标签 | |
# 例如,如果节奏变化显著,调整 'tempo' 标签 | |
if duration_change > 1.05: | |
# 更快的节奏 | |
tempo_tags = ['Fast'] | |
elif duration_change < 0.95: | |
# 更慢的节奏 | |
tempo_tags = ['Slow'] | |
else: | |
tempo_tags = ['Medium'] | |
# 更新 'tempo' 标签 | |
for tempo in ['Slow', 'Medium', 'Fast']: | |
label_aug[self.evaluator.all_tags.index(tempo)] = 0 | |
tempo_index = self.evaluator.all_tags.index(tempo_tags[0]) | |
label_aug[tempo_index] = 1 | |
return feature_aug, label_aug | |
class MidiDatasetAug(Dataset): | |
def __init__(self, midi_files: List[str], max_length: int, dataset_path: str, evaluator: MusicTagEvaluator): | |
self.max_length = max_length | |
self.dataset_path = dataset_path | |
self.evaluator = evaluator | |
# 检查数据集文件是否存在 | |
if os.path.exists(self.dataset_path): | |
# 加载已预处理的数据集 | |
print(f"从 '{self.dataset_path}' 加载数据集") | |
try: | |
saved_data = torch.load(self.dataset_path) | |
self.features = saved_data['features'] | |
self.labels = saved_data['labels'] | |
print(f"成功加载数据集,共有 {len(self.features)} 个样本。") | |
except Exception as e: | |
print(f"加载数据集时出错: {e}") | |
self._process_midi_files(midi_files) | |
else: | |
# 处理 MIDI 文件并保存数据集 | |
self._process_midi_files(midi_files) | |
def __len__(self): | |
return len(self.features) | |
def __getitem__(self, idx): | |
feature = self.features[idx] # [seq_len, input_dim] | |
label = self.labels[idx] # [num_tags] | |
# 应用数据增强 | |
feature_aug, label_aug =self._augment_data(feature, label) | |
# 返回张量 | |
return torch.tensor(feature_aug, dtype=torch.float32), torch.tensor(label_aug, dtype=torch.float32) | |
torch.tensor(feature_aug, dtype=torch.float32), torch.tensor(label_aug, dtype=torch.float32) | |
def _process_midi_files(self, midi_files): | |
print("处理 MIDI 文件以创建数据集...") | |
features_list = [] | |
labels_list = [] | |
for midi_file in midi_files: | |
try: | |
stream = music21.converter.parse(midi_file) | |
# 将音轨转换为特征 | |
features = self.midi_to_features(stream) | |
if len(features) < self.max_length: | |
# 跳过长度不足的样本 | |
continue | |
else: | |
# 将特征分割成长度为 max_length 的片段 | |
num_segments = len(features) // self.max_length | |
for i in range(num_segments): | |
segment = features[i*self.max_length : (i+1)*self.max_length] | |
if len(segment) < self.max_length: | |
continue # 跳过不完整的片段 | |
# 使用评估器为每个片段分配标签 | |
tags = self.evaluator.evaluate_tags_from_features(segment) | |
# 二值化标签 | |
tag_binarized = self.evaluator.mlb.transform([tags])[0] | |
features_list.append(segment) | |
labels_list.append(tag_binarized) | |
except Exception as e: | |
print(f"处理 {midi_file} 时出错: {e}") | |
self.features = features_list | |
self.labels = labels_list | |
# 保存数据集 | |
try: | |
torch.save({'features': self.features, 'labels': self.labels}, self.dataset_path) | |
print(f"数据集已保存至 '{self.dataset_path}',共有 {len(self.features)} 个样本。") | |
except Exception as e: | |
print(f"保存数据集时出错: {e}") | |
def midi_to_features(self, stream) -> np.ndarray: | |
""" | |
将 music21 流对象转换为特征序列。 | |
""" | |
features = [] | |
for note in stream.flat.notesAndRests: | |
if isinstance(note, music21.note.Note): | |
pitch = note.pitch.midi | |
duration = note.quarterLength | |
volume = note.volume.velocity if note.volume.velocity else 64 # 默认音量 | |
elif isinstance(note, music21.note.Rest): | |
pitch = 0 # 休止符音高设为 0 | |
duration = note.quarterLength | |
volume = 0 | |
else: | |
continue | |
features.append([pitch, duration, volume]) | |
return np.array(features, dtype=np.float32) | |
def _augment_data(self, feature, label): | |
# 实现数据增强:随机抽取、拼接、动态和快慢变化 | |
# 例如,随机调整动态和节奏 | |
feature_aug = np.copy(feature) | |
label_aug = np.copy(label) | |
# 随机调整音量(动态) | |
volume_change = np.random.uniform(0.8, 1.2) | |
feature_aug[:, 2] *= volume_change | |
feature_aug[:, 2] = np.clip(feature_aug[:, 2], 0, 127) | |
# 随机调整时值(节奏变化) | |
duration_change = np.random.uniform(0.9, 1.1) | |
feature_aug[:, 1] *= duration_change | |
# 根据变化调整标签 | |
# 例如,如果节奏变化显著,调整 'tempo' 标签 | |
if duration_change > 1.05: | |
# 更快的节奏 | |
tempo_tags = ['Fast'] | |
elif duration_change < 0.95: | |
# 更慢的节奏 | |
tempo_tags = ['Slow'] | |
else: | |
tempo_tags = ['Medium'] | |
# 更新 'tempo' 标签 | |
for tempo in ['Slow', 'Medium', 'Fast']: | |
label_aug[self.evaluator.all_tags.index(tempo)] = 0 | |
tempo_index = self.evaluator.all_tags.index(tempo_tags[0]) | |
label_aug[tempo_index] = 1 | |
return feature_aug, label_aug | |
class RandomDataset(Dataset): | |
def __init__(self, size: int, max_length: int): | |
""" | |
随机生成数据集。 | |
参数: | |
size (int): 数据集大小。 | |
max_length (int): 每个样本的序列长度。 | |
""" | |
self.size = size | |
self.max_length = max_length | |
def __len__(self): | |
return self.size | |
def __getitem__(self, idx): | |
# 随机音高范围在21(A0)到108(C8)之间 | |
pitch = np.random.randint(21, 109, size=(self.max_length, 1)).astype(np.float32) | |
# 随机选择可接受的时值 | |
acceptable_durations = [0.25, 0.333, 0.5, 0.666, 0.75, 1.0, 1.5, 2.0, 3.0, 4.0] | |
duration = np.random.choice(acceptable_durations, size=(self.max_length, 1)).astype(np.float32) | |
# 随机音量在60到100之间 | |
volume = np.random.randint(40, 70, size=(self.max_length, 1)).astype(np.float32) | |
features = np.concatenate([pitch, duration, volume], axis=-1) # [max_length, 3] | |
return torch.tensor(features, dtype=torch.float32) | |
class MusicGenerator: | |
def __init__(self, model: nn.Module, evaluator, device: torch.device, model_path: str, optimizer=None, optimizer_path: str=None, writer: SummaryWriter=None): | |
self.model = model.to(device) | |
self.evaluator = evaluator | |
self.device = device | |
self.model_path = model_path | |
self.optimizer = optimizer | |
self.optimizer_path = optimizer_path | |
self.writer = writer | |
self._load_model() | |
# 定义归一化和反归一化参数 | |
self.min_pitch = 21 | |
self.max_pitch = 108 | |
self.min_duration = 0.15 | |
self.max_duration = 1.5 | |
self.min_volume = 40 | |
self.max_volume = 85 | |
def _load_model(self): | |
"""自动载入已存在的模型权重,如果存在的话。""" | |
if os.path.exists(self.model_path): | |
try: | |
self.model.load_state_dict(torch.load(self.model_path, map_location=self.device)) | |
self.model.to(self.device) | |
self.model.eval() | |
print(f"已成功载入模型权重从 '{self.model_path}'。") | |
except Exception as e: | |
print(f"载入模型权重时出错: {e},将初始化新模型。") | |
else: | |
print("未找到已保存的模型,将初始化新模型。") | |
# 加载优化器状态 | |
if self.optimizer and self.optimizer_path and os.path.exists(self.optimizer_path): | |
try: | |
self.optimizer.load_state_dict(torch.load(self.optimizer_path, map_location=self.device)) | |
print(f"已成功载入优化器状态从 '{self.optimizer_path}'。") | |
except Exception as e: | |
print(f"载入优化器状态时出错: {e},将初始化新优化器。") | |
else: | |
if self.optimizer and self.optimizer_path: | |
print("未找到已保存的优化器状态,将初始化新优化器。") | |
def save_model(self, epoch: int, loss: float): | |
"""保存当前模型的权重和优化器状态。""" | |
try: | |
torch.save(self.model.state_dict(), self.model_path, _use_new_zipfile_serialization=False) | |
if self.optimizer and self.optimizer_path: | |
torch.save(self.optimizer.state_dict(), self.optimizer_path, _use_new_zipfile_serialization=False) | |
print(f"模型和优化器已保存至 '{self.model_path}' 和 '{self.optimizer_path}'。") | |
if self.writer: | |
self.writer.add_scalar('Loss/Save', loss, epoch) | |
except Exception as e: | |
print(f"保存模型或优化器时出错: {e}") | |
def train_epoch(self, dataloader: DataLoader, optimizer, criterion_music, criterion_tags, epoch: int): | |
""" | |
训练一个 epoch。 | |
""" | |
self.model.train() | |
total_loss = 0.0 | |
for batch_idx, (batch_features, batch_labels) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch}", leave=False)): | |
batch_features = batch_features.to(self.device) # [batch_size, seq_len, input_dim] | |
batch_labels = batch_labels.to(self.device) # [batch_size, num_tags] | |
inputs = batch_features[:, :-1, :] # [batch_size, seq_len-1, input_dim] | |
targets = batch_features[:, -1, :] # [batch_size, input_dim] | |
optimizer.zero_grad() | |
music_output, tag_probabilities = self.model(inputs) # 音乐输出: [batch, seq_len-1, output_dim] | |
# 只对最后一个时间步的输出进行损失计算 | |
loss_music = criterion_music(music_output[:, -1, :], targets) | |
# 使用数据集中的标签 | |
loss_tags = criterion_tags(tag_probabilities[:, -1, :], batch_labels) | |
# 总损失 | |
loss = loss_music + loss_tags | |
loss.backward() | |
# 梯度裁剪 | |
clip_grad_norm_(self.model.parameters(), max_norm=1.0) | |
optimizer.step() | |
total_loss += loss.item() | |
if self.writer: | |
self.writer.add_scalar('Loss/Train', loss.item(), epoch * len(dataloader) + batch_idx) | |
avg_loss = total_loss / len(dataloader) | |
print(f"Epoch {epoch} 平均损失: {avg_loss:.4f}") | |
return avg_loss | |
def train_epoch_gan(self, dataloader, optimizer_generator, optimizer_discriminator, criterion_music, criterion_tags, criterion_discriminator, discriminator, epoch): | |
""" | |
使用对抗训练的方法训练一个 epoch。 | |
""" | |
self.model.train() | |
discriminator.train() | |
total_loss = 0.0 | |
for batch_idx, (batch_features, batch_labels) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch}", leave=False)): | |
batch_features = batch_features.to(self.device) # [batch_size, seq_len, input_dim] | |
batch_labels = batch_labels.to(self.device) # [batch_size, num_tags] | |
batch_size = batch_features.size(0) | |
seq_len = batch_features.size(1) | |
# --------------------- | |
# 训练判别器 | |
# --------------------- | |
# 使用真实数据 | |
real_data = batch_features # [batch_size, seq_len, input_dim] | |
real_labels = torch.ones(batch_size, 1).to(self.device) | |
# 使用生成器生成假数据 | |
noise = torch.rand(batch_size, seq_len, 3).to(self.device) # 随机噪声在 [0,1],与归一化后的特征一致 | |
generated_features = torch.zeros_like(batch_features).to(self.device) | |
for i in range(seq_len): | |
input_noise = noise[:, :i+1, :] | |
fake_data, _ = self.model(input_noise) | |
generated_features[:, i, :] = fake_data[:, -1, :] | |
fake_data = generated_features.detach() # [batch_size, seq_len, input_dim] | |
fake_labels = torch.zeros(batch_size, 1).to(self.device) | |
# 计算判别器在真实数据上的损失 | |
optimizer_discriminator.zero_grad() | |
output_real = discriminator(real_data) | |
loss_real = criterion_discriminator(output_real, real_labels) | |
# 计算判别器在假数据上的损失 | |
output_fake = discriminator(fake_data) | |
loss_fake = criterion_discriminator(output_fake, fake_labels) | |
# 总损失并反向传播 | |
loss_discriminator = (loss_real + loss_fake) / 2 | |
loss_discriminator.backward() | |
optimizer_discriminator.step() | |
# --------------------- | |
# 训练生成器 | |
# --------------------- | |
optimizer_generator.zero_grad() | |
# 生成假数据并计算生成器的损失,目标是让判别器相信这些数据是真实的 | |
output_fake_for_generator = discriminator(fake_data) | |
loss_generator_adv = criterion_discriminator(output_fake_for_generator, real_labels) # 生成器的对抗损失 | |
# 计算生成器的音乐特征和标签损失 | |
music_output, tag_probabilities = self.model(noise) | |
targets = batch_features[:, -1, :] # 真实的最后一个特征 | |
loss_music = criterion_music(music_output[:, -1, :], targets) | |
# 使用数据集中的标签 | |
loss_tags = criterion_tags(tag_probabilities[:, -1, :], batch_labels) | |
# 总损失 | |
loss_generator = loss_generator_adv + loss_music + loss_tags | |
loss_generator.backward() | |
# 梯度裁剪 | |
clip_grad_norm_(self.model.parameters(), max_norm=1.0) | |
optimizer_generator.step() | |
total_loss += loss_generator.item() | |
if self.writer: | |
#self.writer.add_scalar('Loss/Generator', loss_generator.item(), epoch * len(dataloader) + batch_idx) | |
#self.writer.add_scalar('Loss/Discriminator', loss_discriminator.item(), epoch * len(dataloader) + batch_idx) | |
pass | |
avg_loss = total_loss / len(dataloader) | |
print(f"Epoch {epoch} 平均生成器损失: {avg_loss:.4f}") | |
return avg_loss | |
def generate_music(self, tag_conditions: dict={ | |
'emotions': 'Neutral', | |
'genres': 'Classical', | |
'tempo': 'Medium', | |
'instrumentation': 'Piano', | |
'harmony': 'Simple', | |
'dynamics': 'Dynamic', | |
'rhythm': 'Simple' # 或 'Complex' | |
}, max_length=100, temperature=1.0) -> music21.stream.Stream: | |
""" | |
根据标签生成音乐。 | |
""" | |
self.model.eval() | |
acceptable_durations = [0.25, 0.333, 0.5, 0.666, 0.75, 1.0, 1.5, 2.0, 3.0, 4.0] | |
generated_features = [] | |
with torch.no_grad(): | |
# 随机选择一个调性 | |
key_str = choice(['C', 'G', 'D', 'A', 'E', 'B', 'F#', 'C#', 'F', 'Bb', 'Eb', 'Ab', 'Db', 'Gb', 'Cb']) | |
scale_notes = get_scale_notes(key_str) | |
# 初始输入(随机特征) | |
input_feature = torch.zeros(1, 1, 3).to(self.device) # [batch_size=1, seq_len=1, input_dim=3] | |
for _ in range(max_length): | |
music_output, tag_probabilities = self.model(input_feature) # [1, seq_len, 3] and [1, seq_len, num_tags] | |
music_output_np = music_output.cpu().numpy()[0, -1] | |
# 应用温度控制 | |
music_output_np = music_output_np / temperature | |
# 使用概率分布进行采样 | |
pitch = int(round(music_output_np[0])) | |
duration = music_output_np[1] | |
volume = int(round(music_output_np[2])) | |
# 增加随机变动 | |
pitch += int(np.random.uniform(-2, 2)) | |
pitch = max(21, min(108, pitch)) # 限制在钢琴键范围内 | |
# 将音高映射到最近的音阶音符 | |
if pitch not in scale_notes: | |
pitch = min(scale_notes, key=lambda x: abs(x - pitch)) | |
duration += np.random.uniform(-0.1, 0.1) | |
try: | |
duration = min(acceptable_durations, key=lambda x: abs(x - duration)) | |
except ValueError: | |
duration = 1.0 # 默认时值 | |
volume += int(np.random.uniform(-10, 10)) | |
volume = max(70, min(100, volume)) # 限制音量范围 | |
# 保存特征 | |
generated_features.append([pitch, duration, volume]) | |
# 更新输入 | |
next_input = torch.tensor([[pitch, duration, volume]], dtype=torch.float32).to(self.device).unsqueeze(0) # [1, 1, 3] | |
input_feature = torch.cat((input_feature, next_input), dim=1) # 增加序列长度 | |
# 转换为 numpy 数组 | |
generated_features_array = np.array(generated_features, dtype=np.float32) | |
generated_stream = composer_from_features(generated_features_array, key_str) | |
# 评估标签 | |
tag_scores = self.evaluator.evaluate_tags(generated_stream) | |
print("生成的音乐标签:", tag_scores) | |
# 根据情感进行判断并保存 | |
high_score_emotions = ['Happy', 'Peaceful'] | |
if tag_scores.get('emotions') in high_score_emotions: | |
# 将生成的 MIDI 转换为 WAV | |
midi_filename = f'high_score_{int(time.time())}.mid' | |
generated_stream.write('midi', fp=os.path.join(Gbase, midi_filename)) | |
wav_file = self.custom_midi_to_wav(generated_stream, os.path.join(Gbase, f'high_score_{int(time.time())}.wav')) | |
print(f"高评分音乐已保存为 WAV 文件: '{wav_file}'") | |
return generated_stream | |
def addMusicToVideo(self, videoPath, tagConditions={ | |
'emotions': 'Neutral', | |
'genres': 'Classical', | |
'tempo': 'Medium', | |
'instrumentation': 'Piano', | |
'harmony': 'Simple', | |
'dynamics': 'Dynamic', | |
'rhythm': 'Simple' # 或 'Complex' | |
}, outputPath=None): | |
""" | |
根据指定的标签条件生成音乐,并将其附加到视频中,确保音乐的长度与视频一致。 | |
参数: | |
videoPath (str): 输入视频的路径。 | |
tagConditions (dict): 用于生成音乐的标签条件。 | |
outputPath (str, optional): 输出视频的路径。如果未指定,将在原路径基础上添加 '_with_music'。 | |
返回: | |
str: 输出的视频路径。 | |
""" | |
# 1. 获取视频时长 | |
try: | |
video = VideoFileClip(videoPath) | |
duration = video.duration | |
print(f"视频时长: {duration} 秒。") | |
except Exception as e: | |
print(f"无法载入视频: {e}") | |
return None | |
if not outputPath: | |
base, ext = os.path.splitext(videoPath) | |
outputPath = f"{base}_with_music{ext}" | |
if os.path.exists (outputPath):return outputPath | |
# 2. 初始化音频拼接 | |
combined_audio = AudioSegment.silent(duration=0) # 初始化为空音频 | |
total_generated_duration = 0 # 总生成时长(毫秒) | |
chunk_duration_seconds = 10 # 每次生成音讯的预估时长(秒),根据需要调整 | |
crossfade_duration = 500 # 淡入淡出持续时间(毫秒) | |
# 3. 逐段生成音频 | |
print("逐段生成音乐中...") | |
while total_generated_duration < duration * 1000: # pydub 使用毫秒 | |
# 根据剩余时长生成音乐,确保不生成过多 | |
remaining_duration_ms = duration * 1000 - total_generated_duration | |
remaining_duration_seconds = remaining_duration_ms / 1000.0 | |
current_chunk_length = min(chunk_duration_seconds, remaining_duration_seconds) | |
# 计算所需的音符数量,假设每个音符平均约0.5秒 | |
estimated_max_length = int(current_chunk_length / 0.5) * 2 # 调整因子根据实际情况 | |
# 生成音乐流 | |
generated_stream = self.generate_music(max_length=100) | |
# 转换为 WAV 文件 | |
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as wav_temp: | |
wav_filename = wav_temp.name | |
wav_path = self.custom_midi_to_wav(generated_stream, wav_filename) | |
print(f"生成的 WAV 已保存为 '{wav_path}'。") | |
# 加载生成的音频 | |
try: | |
generated_audio = AudioSegment.from_wav(wav_path) | |
except Exception as e: | |
print(f"加载生成的音频时出错: {e}") | |
os.remove(wav_path) | |
#break | |
# 拼接音频,应用淡入淡出效果 | |
if len(combined_audio) == 0: | |
# 第一段音频,仅应用淡入 | |
generated_audio = generated_audio.fade_in(crossfade_duration) | |
combined_audio += generated_audio | |
else: | |
# 之后的音频段,应用淡出和淡入,并设置 crossfade | |
generated_audio = generated_audio.fade_in(crossfade_duration) | |
combined_audio = combined_audio.append(generated_audio, crossfade=crossfade_duration) | |
total_generated_duration = len(combined_audio) | |
# 删除临时 WAV 文件 | |
try: | |
os.remove(wav_path) | |
print(f"已删除临时 WAV 文件 '{wav_path}'。") | |
except Exception as e: | |
print(f"删除临时 WAV 文件时出错: {e}") | |
# 4. 剪切音频以匹配视频时长 | |
final_audio = combined_audio[:int(duration * 1000)] # pydub 使用毫秒为单位 | |
final_wav_path = tempfile.mktemp(suffix='.wav') | |
final_audio.export(final_wav_path, format="wav") | |
print(f"最终剪切后的 WAV 已保存为 '{final_wav_path}'。") | |
# 5. 定义输出视频路径 | |
if not outputPath: | |
base, ext = os.path.splitext(videoPath) | |
outputPath = f"{base}_with_music{ext}" | |
# 6. 使用 moviepy 将音频与视频结合 | |
try: | |
# 载入视频和音频 | |
video_clip = VideoFileClip(videoPath) | |
audio_clip = AudioFileClip(final_wav_path) | |
# 设置音频,确保音频长度与视频一致 | |
audio_clip = audio_clip.set_duration(video_clip.duration) | |
# 将音频附加到视频 | |
video_with_audio = video_clip.set_audio(audio_clip) | |
# 输出最终视频 | |
video_with_audio.write_videofile(outputPath, codec='libx264', audio_codec='aac', verbose=False, logger=None) | |
print(f"输出视频已保存为 '{outputPath}'。") | |
except Exception as e: | |
print(f"结合视频和音频时出错: {e}") | |
return None | |
finally: | |
# 清理 moviepy 生成的资源 | |
if 'video_clip' in locals(): | |
video_clip.close() | |
if 'audio_clip' in locals(): | |
audio_clip.close() | |
if 'video_with_audio' in locals(): | |
video_with_audio.close() | |
# 7. 清理临时文件 | |
try: | |
os.remove(final_wav_path) | |
print("最终临时 WAV 文件已删除。") | |
except Exception as e: | |
print(f"删除最终临时 WAV 文件时出错: {e}") | |
return outputPath | |
def custom_midi_to_wav(self, stream: music21.stream.Stream, wav_filename: str, sample_rate=44100) -> str: | |
""" | |
自定义的 MIDI 到 WAV 转换函数,使用数学公式生成高质量的音频。 | |
改进后:声音更加悦耳,符合音符、音阶、乐器的基本要求。 | |
""" | |
import math | |
# 合成参数 | |
envelope_attack = 0.01 # 攻击时间 | |
envelope_decay = 0.1 # 衰减时间 | |
envelope_sustain = 0.8 # 持续水平 | |
envelope_release = 0.2 # 释放时间 | |
# 获取节奏信息 | |
metronome_marks = list(stream.metronomeMarkBoundaries()) | |
bpm = 120 # 默认 BPM | |
if metronome_marks: | |
# 检查是否存在 MetronomeMark 对象 | |
for mark in metronome_marks: | |
if isinstance(mark[2], music21.tempo.MetronomeMark) and mark[2].number: | |
bpm = mark[2].number | |
break | |
# 生成时间轴 | |
notes = list(stream.flat.getElementsByClass(['Note', 'Chord', 'Rest'])) | |
if not notes: | |
print("没有音符可生成音频。") | |
return "" | |
# 计算整体时长 | |
total_duration = stream.duration.quarterLength * 60 / bpm | |
total_samples = int(total_duration * sample_rate) + 1 | |
audio = np.zeros(total_samples) | |
current_time = 0 | |
# 定义乐器的谐波系数,模拟钢琴的谐波 | |
harmonic_coeffs = [1.0, 0.5, 0.25, 0.1, 0.05] | |
for element in notes: | |
if isinstance(element, music21.note.Rest): | |
# 休止符,更新当前时间 | |
duration = element.quarterLength * 60 / bpm # 秒 | |
current_time += duration | |
continue | |
elif isinstance(element, music21.note.Note): | |
frequencies = [element.pitch.frequency] | |
elif isinstance(element, music21.chord.Chord): | |
frequencies = [p.frequency for p in element.pitches] | |
else: | |
continue | |
duration = element.quarterLength * 60 / bpm # 秒 | |
# 音量固定为70% | |
volume = 0.6 | |
# 生成波形时间轴 | |
t = np.linspace(0, duration, int(duration * sample_rate), False) | |
waveform = np.zeros_like(t) | |
for freq in frequencies: | |
note_waveform = np.zeros_like(t) | |
for idx, coeff in enumerate(harmonic_coeffs): | |
harmonic_freq = freq * (idx + 1) | |
note_waveform += coeff * np.sin(2 * np.pi * harmonic_freq * t) | |
waveform += note_waveform | |
# 归一化振幅(避免多个频率叠加导致音量过高) | |
waveform /= len(frequencies) * sum(harmonic_coeffs) | |
# 添加 ADSR 包络 | |
attack_samples = int(envelope_attack * sample_rate) | |
decay_samples = int(envelope_decay * sample_rate) | |
release_samples = int(envelope_release * sample_rate) | |
sustain_samples = len(waveform) - attack_samples - decay_samples - release_samples | |
if sustain_samples < 0: | |
# 调整 ADSR 以适应短音符 | |
total_envelope = envelope_attack + envelope_decay + envelope_release | |
attack_ratio = envelope_attack / total_envelope | |
decay_ratio = envelope_decay / total_envelope | |
release_ratio = envelope_release / total_envelope | |
attack_samples = int(len(waveform) * attack_ratio) | |
decay_samples = int(len(waveform) * decay_ratio) | |
release_samples = len(waveform) - attack_samples - decay_samples | |
sustain_samples = 0 | |
envelope = np.concatenate([ | |
np.linspace(0, 1, attack_samples, False), | |
np.linspace(1, envelope_sustain, decay_samples, False), | |
np.full(sustain_samples, envelope_sustain), | |
np.linspace(envelope_sustain, 0, release_samples, False) | |
]) | |
# 调整 envelope 长度 | |
envelope = envelope[:len(waveform)] | |
waveform *= envelope | |
waveform *= volume | |
# 计算样本索引 | |
start_sample = int(current_time * sample_rate) | |
end_sample = start_sample + len(waveform) | |
if end_sample > total_samples: | |
end_sample = total_samples | |
waveform = waveform[:end_sample - start_sample] | |
# 合成音频 | |
audio[start_sample:end_sample] += waveform | |
# 更新当前时间 | |
current_time += duration | |
# 防止削波 | |
max_val = np.max(np.abs(audio)) | |
if max_val > 1: | |
audio /= max_val | |
# 将音频转换为16位整数 | |
audio_int16 = np.int16(audio * 32767) | |
# 写入 WAV 文件 | |
wav_path = os.path.join(os.getcwd(), wav_filename) | |
with wave.open(wav_path, 'w') as wav_file: | |
n_channels = 2 | |
sampwidth = 2 # 2 bytes for int16 | |
framerate = sample_rate | |
n_frames = len(audio_int16) | |
comptype = "NONE" | |
compname = "not compressed" | |
wav_file.setparams((n_channels, sampwidth, framerate, n_frames, comptype, compname)) | |
wav_file.writeframes(audio_int16.tobytes()) | |
return wav_path | |
class AdvancedMusicGenerator(MusicGenerator): | |
def __init__(self, model: nn.Module, evaluator, device: torch.device, model_path: str, optimizer=None, optimizer_path: str=None, writer: SummaryWriter=None): | |
super().__init__(model, evaluator, device, model_path, optimizer, optimizer_path, writer) | |
# 可以在此添加更多的初始化参数或方法 | |
# 这里可以覆盖或新增更多方法以进一步增强功能 | |
def trainModel(): | |
# 初始化 TensorBoard | |
writer = SummaryWriter(log_dir=os.path.join(Gbase, 'runs')) | |
# 初始化标签评估器 | |
evaluator = MusicTagEvaluator.load(EvaluatorPath) | |
# 获取唯一的标签数量 | |
num_tags = len(evaluator.all_tags) | |
# 定义模型参数 | |
input_dim = 3 # 音高、时值和音量 | |
d_model = 512 # 增加 Transformer 模型维度 | |
nhead = 8 # 多头注意力头数 | |
num_encoder_layers = 8 # 增加 Transformer 编码器层数 | |
dim_feedforward = 2048 # 增加前馈层维度 | |
output_dim = 3 # 预测音高、时值和音量 | |
# 初始化模型 | |
model = MusicGenerationModel(input_dim, d_model, nhead, num_encoder_layers, dim_feedforward, output_dim, num_tags) | |
# 设置设备 | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model.to(device) | |
print(f"使用设备: {device}") | |
# 加载 MIDI 文件 | |
midi_directory = os.path.join(Gbase, 'generateMIDI') | |
midi_files = [] | |
if os.path.exists(midi_directory): | |
midi_files = [os.path.join(midi_directory, f) for f in os.listdir(midi_directory) if f.endswith('.mid') or f.endswith('.midi')] | |
print(f"在目录 '{midi_directory}' 中找到 {len(midi_files)} 个 MIDI 文件用于训练。") | |
else: | |
print(f"MIDI 文件目录 '{midi_directory}' 不存在,请确保该目录存在并包含 MIDI 文件。") | |
return # 退出函数 | |
# 创建数据集和数据加载器 | |
max_length = 100 # 根据需求调整 | |
dataset_path = os.path.join(Gbase, 'mymusic.dataset') | |
dataset = MidiDataset(midi_files, max_length, dataset_path, evaluator) | |
datasetAug = MidiDatasetAug(midi_files, max_length, dataset_path, evaluator) | |
# 定义要采样的样本数量 | |
sample_size = 30000 if torch.cuda.is_available() else 15000 | |
sample_size1 = int(sample_size/10) | |
sample_size2 = int(sample_size/300) | |
total_samples = len(dataset) | |
if total_samples < sample_size: | |
print(f"数据集中只有 {total_samples} 个样本,无法采样 {sample_size} 个。请检查数据集。") | |
return | |
# 定义训练周期和学习率 | |
epochs = 4 # 根据需要调整 | |
learning_rate = 0.001 | |
batch_size= 16 if torch.cuda.is_available() else 4 | |
# 初始化生成器 | |
optimizer_generator = optim.AdamW(model.parameters(), lr=learning_rate * 0.1) | |
generator = MusicGenerator(model, evaluator, device, model_path=ModelPath, optimizer=optimizer_generator, optimizer_path=OptimizerPath, writer=writer) | |
# 初始化判别器 | |
discriminator = Discriminator(input_dim, d_model, nhead, num_encoder_layers, dim_feedforward).to(device) | |
optimizer_discriminator = optim.AdamW(discriminator.parameters(), lr=learning_rate) | |
criterion_discriminator = nn.BCELoss() | |
# 尝试加载判别器模型和优化器状态 | |
if os.path.exists(DiscriminatorModelPath): | |
discriminator.load_state_dict(torch.load(DiscriminatorModelPath, map_location=device)) | |
print(f"已成功载入判别器模型权重从 '{DiscriminatorModelPath}'。") | |
if os.path.exists(DiscriminatorOptimizerPath): | |
optimizer_discriminator.load_state_dict(torch.load(DiscriminatorOptimizerPath, map_location=device)) | |
print(f"已成功载入判别器优化器状态从 '{DiscriminatorOptimizerPath}'。") | |
indices = list(range(total_samples)) | |
random_indices = random.sample(indices, sample_size) | |
random_indices1 = random.sample(indices, sample_size1) | |
random_indices2 = random.sample(indices, sample_size2) | |
random_indicesAug= random.sample(indices, sample_size) | |
sampler = SubsetRandomSampler(random_indices) | |
dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, num_workers=2) | |
sampler1 = SubsetRandomSampler(random_indices1) | |
dataloaderAug = DataLoader(datasetAug, batch_size=batch_size, sampler=sampler, num_workers=2) | |
dataloader1 = DataLoader(datasetAug, batch_size=8, sampler=sampler1, num_workers=2) | |
sampler2 = SubsetRandomSampler(random_indices2) | |
dataloader2 = DataLoader(dataset, batch_size=batch_size, sampler=sampler2, num_workers=2) | |
sampler3 = SubsetRandomSampler(random_indices2) | |
dataloader3 = DataLoader(datasetAug, batch_size=batch_size, sampler=sampler2, num_workers=2) | |
# 开始对抗训练 | |
print("開始訓練...") | |
for epoch in range(1, epochs + 1): | |
try: | |
avg_loss = generator.train_epoch( | |
dataloader, | |
optimizer_generator, | |
nn.MSELoss(), | |
nn.BCELoss(), | |
epoch | |
) | |
# 保存判别器模型和优化器 | |
""" | |
generator.save_model(epoch, avg_loss) | |
torch.save(discriminator.state_dict(), DiscriminatorModelPath) | |
torch.save(optimizer_discriminator.state_dict(), DiscriminatorOptimizerPath) | |
print(f"判别器模型和优化器已保存至 '{DiscriminatorModelPath}' 和 '{DiscriminatorOptimizerPath}'。") | |
""" | |
except KeyboardInterrupt: | |
print("训练过程被手动中断。") | |
break | |
except Exception as e: | |
print(f"在训练 epoch {epoch} 时发生错误: {e}") | |
if epoch!=4:continue | |
print("開始強化訓練...") | |
try: | |
avg_loss = generator.train_epoch( | |
dataloaderAug, | |
optimizer_generator, | |
nn.MSELoss(), | |
nn.BCELoss(), | |
epoch | |
) | |
# 保存判别器模型和优化器 | |
#""" | |
generator.save_model(epoch, avg_loss) | |
torch.save(discriminator.state_dict(), DiscriminatorModelPath) | |
torch.save(optimizer_discriminator.state_dict(), DiscriminatorOptimizerPath) | |
print(f"判别器模型和优化器已保存至 '{DiscriminatorModelPath}' 和 '{DiscriminatorOptimizerPath}'。") | |
# 保存评估器 | |
#evaluator.save(EvaluatorPath) | |
#""" | |
except KeyboardInterrupt: | |
print("训练过程被手动中断。") | |
break | |
except Exception as e: | |
print(f"在训练 epoch {epoch} 时发生错误: {e}") | |
print("開始對抗訓練...") | |
try: | |
avg_loss = generator.train_epoch_gan( | |
dataloader1, | |
optimizer_generator, | |
optimizer_discriminator, | |
nn.MSELoss(), | |
nn.BCELoss(), | |
criterion_discriminator, | |
discriminator, | |
epoch | |
) | |
""" | |
generator.save_model(epoch, avg_loss) | |
# 保存判别器模型和优化器 | |
torch.save(discriminator.state_dict(), DiscriminatorModelPath) | |
torch.save(optimizer_discriminator.state_dict(), DiscriminatorOptimizerPath) | |
print(f"判别器模型和优化器已保存至 '{DiscriminatorModelPath}' 和 '{DiscriminatorOptimizerPath}'。") | |
# 保存评估器 | |
#evaluator.save(EvaluatorPath) | |
#""" | |
except KeyboardInterrupt: | |
print("训练过程被手动中断。") | |
break | |
except Exception as e: | |
print(f"在训练 epoch {epoch} 时发生错误: {e}") | |
continue # 继续下一个 epoch | |
# 关闭 TensorBoard writer | |
writer.close() | |
def loadMusicGenerator(): | |
# 初始化 TensorBoard | |
writer = SummaryWriter(log_dir=os.path.join(Gbase, 'runs')) | |
# 加载标签评估器 | |
evaluator = MusicTagEvaluator() | |
#.load(EvaluatorPath) | |
# 获取唯一的标签数量 | |
num_tags = len(evaluator.all_tags) | |
# 定义模型参数 | |
input_dim = 3 # 音高、时值和音量 | |
d_model = 512 # 必须与训练时的模型参数一致 | |
nhead = 8 | |
num_encoder_layers = 8 | |
dim_feedforward = 2048 | |
output_dim = 3 | |
# 设置设备 | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print(f"使用设备: {device}") | |
# 初始化模型 | |
model = MusicGenerationModel(input_dim, d_model, nhead, num_encoder_layers, dim_feedforward, output_dim, num_tags).to(device) | |
# 初始化生成器 | |
generator = AdvancedMusicGenerator(model, evaluator, device, model_path=ModelPath, writer=writer) | |
return generator, evaluator | |
MyMusicGenerator, MyMusicTagEvaluator = loadMusicGenerator() | |
import gradio as gr | |
import numpy as np | |
import time | |
import os | |
# Assuming your existing functions and setup are defined above | |
def generate_music(*tags, use_random=False): | |
if use_random: | |
tags_dict = randomMusicTags() | |
else: | |
# Assuming the order of tags matches with MUSIC_TAGS.keys() | |
tags_dict = dict(zip(MUSIC_TAGS.keys(), tags)) | |
# Generate music using your existing function (which should return a path to a wav file) | |
generated_stream = MyMusicGenerator.generate_music(tag_conditions=tags_dict, max_length=130, temperature=np.random.uniform(0.7, 1.1)) | |
# Save the generated stream as a MIDI file | |
midi_filename = f"music_{int(time.time())}.mid" | |
mid_path = os.path.join(Gbase, midi_filename) | |
generated_stream.write('midi', fp=mid_path) | |
# Convert MIDI to WAV (make sure this function exists) | |
wav_file = MyMusicGenerator.custom_midi_to_wav(generated_stream, os.path.join(Gbase, f"{midi_filename[:-4]}.wav")) | |
return wav_file, tags_dict | |
# Define the interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Music Generation with Tags") | |
with gr.Row(): | |
with gr.Column(): | |
# List comprehension to create dropdowns for each tag category | |
tag_inputs = [ | |
gr.Dropdown(value=MUSIC_TAGS[category][0] ,choices=MUSIC_TAGS[category], label=category.capitalize()) | |
for category in MUSIC_TAGS.keys() | |
] | |
with gr.Column(): | |
use_random = gr.Checkbox(label="Use Random Tags") | |
generate_btn = gr.Button("Generate Music") | |
output_audio = gr.Audio(label="Generated Music") | |
output_tags = gr.JSON(label="Generated Tags") | |
# Pass the list of dropdowns directly instead of using gr.Group | |
generate_btn.click( | |
fn=generate_music, | |
inputs=[*tag_inputs, use_random], | |
outputs=[output_audio, output_tags] | |
) | |
# Launch the interface | |
demo.launch() | |