m6011 commited on
Commit
8a1a95a
1 Parent(s): 03e41c0

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +91 -50
train.py CHANGED
@@ -1,54 +1,95 @@
1
- import torch
2
- from torch.utils.data import DataLoader
 
 
 
 
 
3
  from datasets import load_dataset
4
- from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
5
- from transformers import Trainer, TrainingArguments
6
 
7
- # تحميل بيانات SADA
8
  dataset = load_dataset("m6011/sada2022")
9
 
10
- # تحميل نموذج Wav2Vec2 لتحويل الصوت إلى نص (يمكنك تغييره إذا كنت تود استخدام نموذج آخر)
11
- processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-xlsr-53")
12
- model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-xlsr-53")
13
-
14
- # معالجة البيانات - تحويل النص إلى رموز صوتية مناسبة (حسب النموذج المختار)
15
- def preprocess_data(batch):
16
- audio = batch["audio"]
17
- inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt", padding=True)
18
- batch["input_values"] = inputs.input_values[0]
19
- batch["attention_mask"] = inputs.attention_mask[0]
20
-
21
- # تحويل النص إلى رموز
22
- with processor.as_target_processor():
23
- batch["labels"] = processor(batch["ProcessedText"]).input_ids
24
- return batch
25
-
26
- # تطبيق المعالجة المسبقة على البيانات
27
- dataset = dataset.map(preprocess_data, remove_columns=["audio", "ProcessedText"])
28
-
29
- # إعدادات التدريب
30
- training_args = TrainingArguments(
31
- output_dir="./wav2vec2-saudi-tts",
32
- group_by_length=True,
33
- per_device_train_batch_size=4,
34
- evaluation_strategy="steps",
35
- num_train_epochs=3,
36
- save_steps=400,
37
- eval_steps=400,
38
- logging_steps=400,
39
- learning_rate=3e-4,
40
- warmup_steps=500,
41
- save_total_limit=2,
42
- )
43
-
44
- # إعداد المدرب
45
- trainer = Trainer(
46
- model=model,
47
- args=training_args,
48
- train_dataset=dataset["train"],
49
- eval_dataset=dataset["test"],
50
- tokenizer=processor.feature_extractor,
51
- )
52
-
53
- # بدء التدريب
54
- trainer.train()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train.py
2
+
3
+ import os
4
+ import shutil
5
+ from espnet2.bin.tts_train import TTSTrainer
6
+ from espnet2.tasks.tts import TTSTask
7
+ from espnet_model_zoo.downloader import ModelDownloader
8
  from datasets import load_dataset
9
+ import yaml
 
10
 
11
+ # تحميل بيانات sada2022
12
  dataset = load_dataset("m6011/sada2022")
13
 
14
+ # تقسيم البيانات إلى تدريب وتحقق
15
+ train_data = dataset['train']
16
+ valid_data = dataset['test']
17
+
18
+ # إنشاء ملفات البيانات المطلوبة
19
+ os.makedirs('data/train', exist_ok=True)
20
+ os.makedirs('data/valid', exist_ok=True)
21
+
22
+ # حفظ البيانات في ملفات نصية
23
+ with open('data/train/wav.scp', 'w', encoding='utf-8') as wav_scp, \
24
+ open('data/train/text', 'w', encoding='utf-8') as text_file:
25
+ for idx, sample in enumerate(train_data):
26
+ wav_path = sample['audio']['path']
27
+ transcription = sample['ProcessedText']
28
+ utt_id = f'train_{idx}'
29
+ wav_scp.write(f'{utt_id} {wav_path}\n')
30
+ text_file.write(f'{utt_id} {transcription}\n')
31
+
32
+ with open('data/valid/wav.scp', 'w', encoding='utf-8') as wav_scp, \
33
+ open('data/valid/text', 'w', encoding='utf-8') as text_file:
34
+ for idx, sample in enumerate(valid_data):
35
+ wav_path = sample['audio']['path']
36
+ transcription = sample['ProcessedText']
37
+ utt_id = f'valid_{idx}'
38
+ wav_scp.write(f'{utt_id} {wav_path}\n')
39
+ text_file.write(f'{utt_id} {transcription}\n')
40
+
41
+ # تحميل إعدادات التدريب الافتراضية من ESPnet
42
+ config_path = 'conf/train.yaml'
43
+ os.makedirs('conf', exist_ok=True)
44
+
45
+ # يمكنك تخصيص إعدادات التدريب هنا أو استخدام الإعدادات الافتراضية
46
+ config = {
47
+ 'output_dir': 'exp/tts_fastspeech2',
48
+ 'token_type': 'char',
49
+ 'fs': 16000,
50
+ 'lang': 'ar', # تحديد اللغة العربية
51
+ 'train_data_path_and_name_and_type': [
52
+ ('data/train/wav.scp', 'speech', 'sound'),
53
+ ('data/train/text', 'text', 'text')
54
+ ],
55
+ 'valid_data_path_and_name_and_type': [
56
+ ('data/valid/wav.scp', 'speech', 'sound'),
57
+ ('data/valid/text', 'text', 'text')
58
+ ],
59
+ 'token_list': 'tokens.txt',
60
+ 'init_param': None,
61
+ # يمكنك إضافة المزيد من الإعدادات هنا
62
+ }
63
+
64
+ with open(config_path, 'w', encoding='utf-8') as f:
65
+ yaml.dump(config, f, allow_unicode=True)
66
+
67
+ # توليد قائمة التوكينات (الأحرف) من البيانات
68
+ def generate_token_list(text_files, output_file):
69
+ tokens = set()
70
+ for text_file in text_files:
71
+ with open(text_file, 'r', encoding='utf-8') as f:
72
+ for line in f:
73
+ _, text = line.strip().split(' ', 1)
74
+ tokens.update(list(text))
75
+ tokens = sorted(tokens)
76
+ with open(output_file, 'w', encoding='utf-8') as f:
77
+ for token in tokens:
78
+ f.write(f'{token}\n')
79
+
80
+ generate_token_list(['data/train/text', 'data/valid/text'], 'tokens.txt')
81
+
82
+ # بدء عملية التدريب
83
+ train_args = [
84
+ '--config', 'conf/train.yaml',
85
+ '--use_preprocessor', 'true',
86
+ '--token_type', 'char',
87
+ '--bpemodel', None,
88
+ '--train_data_path_and_name_and_type', 'data/train/wav.scp,speech,sound',
89
+ '--train_data_path_and_name_and_type', 'data/train/text,text,text',
90
+ '--valid_data_path_and_name_and_type', 'data/valid/wav.scp,speech,sound',
91
+ '--valid_data_path_and_name_and_type', 'data/valid/text,text,text',
92
+ '--output_dir', 'exp/tts_fastspeech2',
93
+ ]
94
+
95
+ TTSTask.main(train_args)