m6011 commited on
Commit
8c93826
1 Parent(s): 8188f68

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +17 -7
train.py CHANGED
@@ -1,16 +1,26 @@
1
  from datasets import load_dataset
2
- from transformers import Trainer, TrainingArguments, Tacotron2ForConditionalGeneration
3
 
4
- # تحميل البيانات من Hugging Face Datasets
5
  dataset = load_dataset("m6011/sada2022")
6
  najdi_data = dataset.filter(lambda example: example['SpeakerDialect'] == 'Najdi')
7
 
8
- # إعداد النموذج والمعالج
9
- model = Tacotron2ForConditionalGeneration.from_pretrained("facebook/tacotron2")
10
 
11
- # إعداد التدريب
12
- training_args = TrainingArguments(output_dir="./results", per_device_train_batch_size=16, num_train_epochs=3)
13
- trainer = Trainer(model=model, args=training_args, train_dataset=najdi_data)
 
 
 
 
 
 
 
 
 
 
14
 
15
  # بدء التدريب
16
  trainer.train()
 
1
  from datasets import load_dataset
2
+ from transformers import FastSpeechForConditionalGeneration, Trainer, TrainingArguments
3
 
4
+ # تحميل البيانات للهجة النجدية
5
  dataset = load_dataset("m6011/sada2022")
6
  najdi_data = dataset.filter(lambda example: example['SpeakerDialect'] == 'Najdi')
7
 
8
+ # إعداد النموذج
9
+ model = FastSpeechForConditionalGeneration.from_pretrained("facebook/fastspeech2-en-ljspeech")
10
 
11
+ # إعداد المدرب
12
+ training_args = TrainingArguments(
13
+ output_dir="./results",
14
+ per_device_train_batch_size=4,
15
+ num_train_epochs=5,
16
+ )
17
+
18
+ trainer = Trainer(
19
+ model=model,
20
+ args=training_args,
21
+ train_dataset=najdi_data['train'],
22
+ eval_dataset=najdi_data['test']
23
+ )
24
 
25
  # بدء التدريب
26
  trainer.train()