m6011 commited on
Commit
aa29130
1 Parent(s): 739b974

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +12 -19
train.py CHANGED
@@ -1,26 +1,19 @@
 
 
1
  from datasets import load_dataset
2
- from transformers import FastSpeechForConditionalGeneration, Trainer, TrainingArguments
3
 
4
  # تحميل البيانات للهجة النجدية
5
  dataset = load_dataset("m6011/sada2022")
6
  najdi_data = dataset.filter(lambda example: example['SpeakerDialect'] == 'Najdi')
7
 
8
- # إعداد النموذج
9
- model = FastSpeechForConditionalGeneration.from_pretrained("facebook/fastspeech2-en-ljspeech")
 
 
 
 
 
 
10
 
11
- # إعداد المدرب
12
- training_args = TrainingArguments(
13
- output_dir="./results",
14
- per_device_train_batch_size=4,
15
- num_train_epochs=5,
16
- )
17
-
18
- trainer = Trainer(
19
- model=model,
20
- args=training_args,
21
- train_dataset=najdi_data['train'],
22
- eval_dataset=najdi_data['test']
23
- )
24
-
25
- # بدء التدريب
26
- trainer.train()
 
1
+ from espnet2.bin.tts_train import train
2
+ from espnet2.tasks.tts import TTSTask
3
  from datasets import load_dataset
 
4
 
5
  # تحميل البيانات للهجة النجدية
6
  dataset = load_dataset("m6011/sada2022")
7
  najdi_data = dataset.filter(lambda example: example['SpeakerDialect'] == 'Najdi')
8
 
9
+ # إعداد التدريب
10
+ train_config = {
11
+ 'output_dir': './results',
12
+ 'train_data_path_and_name_and_type': najdi_data['train'],
13
+ 'valid_data_path_and_name_and_type': najdi_data['test'],
14
+ 'train_batch_size': 8,
15
+ 'epochs': 10,
16
+ }
17
 
18
+ # بدء عملية التدريب باستخدام ESPnet
19
+ TTSTask.main(**train_config)