Najdi_TTS_Project / train.py
m6011's picture
Update train.py
8a1a95a verified
raw
history blame
3.49 kB
# train.py
import os
import shutil
from espnet2.bin.tts_train import TTSTrainer
from espnet2.tasks.tts import TTSTask
from espnet_model_zoo.downloader import ModelDownloader
from datasets import load_dataset
import yaml
# تحميل بيانات sada2022
dataset = load_dataset("m6011/sada2022")
# تقسيم البيانات إلى تدريب وتحقق
train_data = dataset['train']
valid_data = dataset['test']
# إنشاء ملفات البيانات المطلوبة
os.makedirs('data/train', exist_ok=True)
os.makedirs('data/valid', exist_ok=True)
# حفظ البيانات في ملفات نصية
with open('data/train/wav.scp', 'w', encoding='utf-8') as wav_scp, \
open('data/train/text', 'w', encoding='utf-8') as text_file:
for idx, sample in enumerate(train_data):
wav_path = sample['audio']['path']
transcription = sample['ProcessedText']
utt_id = f'train_{idx}'
wav_scp.write(f'{utt_id} {wav_path}\n')
text_file.write(f'{utt_id} {transcription}\n')
with open('data/valid/wav.scp', 'w', encoding='utf-8') as wav_scp, \
open('data/valid/text', 'w', encoding='utf-8') as text_file:
for idx, sample in enumerate(valid_data):
wav_path = sample['audio']['path']
transcription = sample['ProcessedText']
utt_id = f'valid_{idx}'
wav_scp.write(f'{utt_id} {wav_path}\n')
text_file.write(f'{utt_id} {transcription}\n')
# تحميل إعدادات التدريب الافتراضية من ESPnet
config_path = 'conf/train.yaml'
os.makedirs('conf', exist_ok=True)
# يمكنك تخصيص إعدادات التدريب هنا أو استخدام الإعدادات الافتراضية
config = {
'output_dir': 'exp/tts_fastspeech2',
'token_type': 'char',
'fs': 16000,
'lang': 'ar', # تحديد اللغة العربية
'train_data_path_and_name_and_type': [
('data/train/wav.scp', 'speech', 'sound'),
('data/train/text', 'text', 'text')
],
'valid_data_path_and_name_and_type': [
('data/valid/wav.scp', 'speech', 'sound'),
('data/valid/text', 'text', 'text')
],
'token_list': 'tokens.txt',
'init_param': None,
# يمكنك إضافة المزيد من الإعدادات هنا
}
with open(config_path, 'w', encoding='utf-8') as f:
yaml.dump(config, f, allow_unicode=True)
# توليد قائمة التوكينات (الأحرف) من البيانات
def generate_token_list(text_files, output_file):
tokens = set()
for text_file in text_files:
with open(text_file, 'r', encoding='utf-8') as f:
for line in f:
_, text = line.strip().split(' ', 1)
tokens.update(list(text))
tokens = sorted(tokens)
with open(output_file, 'w', encoding='utf-8') as f:
for token in tokens:
f.write(f'{token}\n')
generate_token_list(['data/train/text', 'data/valid/text'], 'tokens.txt')
# بدء عملية التدريب
train_args = [
'--config', 'conf/train.yaml',
'--use_preprocessor', 'true',
'--token_type', 'char',
'--bpemodel', None,
'--train_data_path_and_name_and_type', 'data/train/wav.scp,speech,sound',
'--train_data_path_and_name_and_type', 'data/train/text,text,text',
'--valid_data_path_and_name_and_type', 'data/valid/wav.scp,speech,sound',
'--valid_data_path_and_name_and_type', 'data/valid/text,text,text',
'--output_dir', 'exp/tts_fastspeech2',
]
TTSTask.main(train_args)