Najdi_TTS_Project / train.py
m6011's picture
Update train.py
771feaa verified
raw
history blame contribute delete
No virus
3.46 kB
import os
import shutil
from espnet2.bin.tts_train import TTSTrainer
from espnet2.tasks.tts import TTSTask
from datasets import load_dataset
import yaml
# تحميل بيانات sada2022 من Hugging Face Datasets
dataset = load_dataset("m6011/sada2022")
# تقسيم البيانات إلى تدريب وتحقق
train_data = dataset['train']
valid_data = dataset['test']
# إنشاء ملفات البيانات المطلوبة
os.makedirs('data/train', exist_ok=True)
os.makedirs('data/valid', exist_ok=True)
# حفظ البيانات في ملفات نصية (wav.scp و text)
with open('data/train/wav.scp', 'w', encoding='utf-8') as wav_scp, \
open('data/train/text', 'w', encoding='utf-8') as text_file:
for idx, sample in enumerate(train_data):
wav_path = sample['audio']['path']
transcription = sample['ProcessedText']
utt_id = f'train_{idx}'
wav_scp.write(f'{utt_id} {wav_path}\n')
text_file.write(f'{utt_id} {transcription}\n')
with open('data/valid/wav.scp', 'w', encoding='utf-8') as wav_scp, \
open('data/valid/text', 'w', encoding='utf-8') as text_file:
for idx, sample in enumerate(valid_data):
wav_path = sample['audio']['path']
transcription = sample['ProcessedText']
utt_id = f'valid_{idx}'
wav_scp.write(f'{utt_id} {wav_path}\n')
text_file.write(f'{utt_id} {transcription}\n')
# تحميل إعدادات التدريب الافتراضية من ESPnet
config_path = 'conf/train.yaml'
os.makedirs('conf', exist_ok=True)
# إعدادات التدريب
config = {
'output_dir': 'exp/tts_fastspeech2',
'token_type': 'char',
'fs': 16000,
'lang': 'ar', # تحديد اللغة العربية
'train_data_path_and_name_and_type': [
('data/train/wav.scp', 'speech', 'sound'),
('data/train/text', 'text', 'text')
],
'valid_data_path_and_name_and_type': [
('data/valid/wav.scp', 'speech', 'sound'),
('data/valid/text', 'text', 'text')
],
'token_list': 'tokens.txt',
'init_param': None,
# يمكنك إضافة المزيد من الإعدادات هنا
}
with open(config_path, 'w', encoding='utf-8') as f:
yaml.dump(config, f, allow_unicode=True)
# توليد قائمة التوكينات (الأحرف) من البيانات
def generate_token_list(text_files, output_file):
tokens = set()
for text_file in text_files:
with open(text_file, 'r', encoding='utf-8') as f:
for line in f:
_, text = line.strip().split(' ', 1)
tokens.update(list(text))
tokens = sorted(tokens)
with open(output_file, 'w', encoding='utf-8') as f:
for token in tokens:
f.write(f'{token}\n')
# توليد قائمة التوكينات (tokens.txt)
generate_token_list(['data/train/text', 'data/valid/text'], 'tokens.txt')
# إعداد التدريب
train_args = [
'--config', 'conf/train.yaml',
'--use_preprocessor', 'true',
'--token_type', 'char',
'--bpemodel', None,
'--train_data_path_and_name_and_type', 'data/train/wav.scp,speech,sound',
'--train_data_path_and_name_and_type', 'data/train/text,text,text',
'--valid_data_path_and_name_and_type', 'data/valid/wav.scp,speech,sound',
'--valid_data_path_and_name_and_type', 'data/valid/text,text,text',
'--output_dir', 'exp/tts_fastspeech2',
]
# بدء عملية التدريب
TTSTask.main(train_args)