m6011 commited on
Commit
13cef53
1 Parent(s): 99d4e63

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +54 -0
train.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+ from datasets import load_dataset
4
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
5
+ from transformers import Trainer, TrainingArguments
6
+
7
+ # تحميل بيانات SADA
8
+ dataset = load_dataset("m6011/sada2022")
9
+
10
+ # تحميل نموذج Wav2Vec2 لتحويل الصوت إلى نص (يمكنك تغييره إذا كنت تود استخدام نموذج آخر)
11
+ processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-xlsr-53")
12
+ model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-xlsr-53")
13
+
14
+ # معالجة البيانات - تحويل النص إلى رموز صوتية مناسبة (حسب النموذج المختار)
15
+ def preprocess_data(batch):
16
+ audio = batch["audio"]
17
+ inputs = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt", padding=True)
18
+ batch["input_values"] = inputs.input_values[0]
19
+ batch["attention_mask"] = inputs.attention_mask[0]
20
+
21
+ # تحويل النص إلى رموز
22
+ with processor.as_target_processor():
23
+ batch["labels"] = processor(batch["ProcessedText"]).input_ids
24
+ return batch
25
+
26
+ # تطبيق المعالجة المسبقة على البيانات
27
+ dataset = dataset.map(preprocess_data, remove_columns=["audio", "ProcessedText"])
28
+
29
+ # إعدادات التدريب
30
+ training_args = TrainingArguments(
31
+ output_dir="./wav2vec2-saudi-tts",
32
+ group_by_length=True,
33
+ per_device_train_batch_size=4,
34
+ evaluation_strategy="steps",
35
+ num_train_epochs=3,
36
+ save_steps=400,
37
+ eval_steps=400,
38
+ logging_steps=400,
39
+ learning_rate=3e-4,
40
+ warmup_steps=500,
41
+ save_total_limit=2,
42
+ )
43
+
44
+ # إعداد المدرب
45
+ trainer = Trainer(
46
+ model=model,
47
+ args=training_args,
48
+ train_dataset=dataset["train"],
49
+ eval_dataset=dataset["test"],
50
+ tokenizer=processor.feature_extractor,
51
+ )
52
+
53
+ # بدء التدريب
54
+ trainer.train()