Gregniuki commited on
Commit
29441b6
1 Parent(s): 7726dc4

Delete model/trainer.py

Browse files
Files changed (1) hide show
  1. model/trainer.py +0 -250
model/trainer.py DELETED
@@ -1,250 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import os
4
- import gc
5
- from tqdm import tqdm
6
- import wandb
7
-
8
- import torch
9
- from torch.optim import AdamW
10
- from torch.utils.data import DataLoader, Dataset, SequentialSampler
11
- from torch.optim.lr_scheduler import LinearLR, SequentialLR
12
-
13
- from einops import rearrange
14
-
15
- from accelerate import Accelerator
16
- from accelerate.utils import DistributedDataParallelKwargs
17
-
18
- from ema_pytorch import EMA
19
-
20
- from model import CFM
21
- from model.utils import exists, default
22
- from model.dataset import DynamicBatchSampler, collate_fn
23
-
24
-
25
- # trainer
26
-
27
- class Trainer:
28
- def __init__(
29
- self,
30
- model: CFM,
31
- epochs,
32
- learning_rate,
33
- num_warmup_updates = 20000,
34
- save_per_updates = 1000,
35
- checkpoint_path = None,
36
- batch_size = 32,
37
- batch_size_type: str = "sample",
38
- max_samples = 32,
39
- grad_accumulation_steps = 1,
40
- max_grad_norm = 1.0,
41
- noise_scheduler: str | None = None,
42
- duration_predictor: torch.nn.Module | None = None,
43
- wandb_project = "test_e2-tts",
44
- wandb_run_name = "test_run",
45
- wandb_resume_id: str = None,
46
- last_per_steps = None,
47
- accelerate_kwargs: dict = dict(),
48
- ema_kwargs: dict = dict()
49
- ):
50
-
51
- ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
52
-
53
- self.accelerator = Accelerator(
54
- log_with = "wandb",
55
- kwargs_handlers = [ddp_kwargs],
56
- gradient_accumulation_steps = grad_accumulation_steps,
57
- **accelerate_kwargs
58
- )
59
-
60
- if exists(wandb_resume_id):
61
- init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name, 'id': wandb_resume_id}}
62
- else:
63
- init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name}}
64
- self.accelerator.init_trackers(
65
- project_name = wandb_project,
66
- init_kwargs=init_kwargs,
67
- config={"epochs": epochs,
68
- "learning_rate": learning_rate,
69
- "num_warmup_updates": num_warmup_updates,
70
- "batch_size": batch_size,
71
- "batch_size_type": batch_size_type,
72
- "max_samples": max_samples,
73
- "grad_accumulation_steps": grad_accumulation_steps,
74
- "max_grad_norm": max_grad_norm,
75
- "gpus": self.accelerator.num_processes,
76
- "noise_scheduler": noise_scheduler}
77
- )
78
-
79
- self.model = model
80
-
81
- if self.is_main:
82
- self.ema_model = EMA(
83
- model,
84
- include_online_model = False,
85
- **ema_kwargs
86
- )
87
-
88
- self.ema_model.to(self.accelerator.device)
89
-
90
- self.epochs = epochs
91
- self.num_warmup_updates = num_warmup_updates
92
- self.save_per_updates = save_per_updates
93
- self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
94
- self.checkpoint_path = default(checkpoint_path, 'ckpts/test_e2-tts')
95
-
96
- self.batch_size = batch_size
97
- self.batch_size_type = batch_size_type
98
- self.max_samples = max_samples
99
- self.grad_accumulation_steps = grad_accumulation_steps
100
- self.max_grad_norm = max_grad_norm
101
-
102
- self.noise_scheduler = noise_scheduler
103
-
104
- self.duration_predictor = duration_predictor
105
-
106
- self.optimizer = AdamW(model.parameters(), lr=learning_rate)
107
- self.model, self.optimizer = self.accelerator.prepare(
108
- self.model, self.optimizer
109
- )
110
-
111
- @property
112
- def is_main(self):
113
- return self.accelerator.is_main_process
114
-
115
- def save_checkpoint(self, step, last=False):
116
- self.accelerator.wait_for_everyone()
117
- if self.is_main:
118
- checkpoint = dict(
119
- model_state_dict = self.accelerator.unwrap_model(self.model).state_dict(),
120
- optimizer_state_dict = self.accelerator.unwrap_model(self.optimizer).state_dict(),
121
- ema_model_state_dict = self.ema_model.state_dict(),
122
- scheduler_state_dict = self.scheduler.state_dict(),
123
- step = step
124
- )
125
- if not os.path.exists(self.checkpoint_path):
126
- os.makedirs(self.checkpoint_path)
127
- if last == True:
128
- self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
129
- print(f"Saved last checkpoint at step {step}")
130
- else:
131
- self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
132
-
133
- def load_checkpoint(self):
134
- if not exists(self.checkpoint_path) or not os.path.exists(self.checkpoint_path) or not os.listdir(self.checkpoint_path):
135
- return 0
136
-
137
- self.accelerator.wait_for_everyone()
138
- if "model_last.pt" in os.listdir(self.checkpoint_path):
139
- latest_checkpoint = "model_last.pt"
140
- else:
141
- latest_checkpoint = sorted([f for f in os.listdir(self.checkpoint_path) if f.endswith('.pt')], key=lambda x: int(''.join(filter(str.isdigit, x))))[-1]
142
- # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
143
- checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
144
-
145
- if self.is_main:
146
- self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
147
-
148
- if 'step' in checkpoint:
149
- self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
150
- self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint['optimizer_state_dict'])
151
- if self.scheduler:
152
- self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
153
- step = checkpoint['step']
154
- else:
155
- checkpoint['model_state_dict'] = {k.replace("ema_model.", ""): v for k, v in checkpoint['ema_model_state_dict'].items() if k not in ["initted", "step"]}
156
- self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
157
- step = 0
158
-
159
- del checkpoint; gc.collect()
160
- return step
161
-
162
- def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
163
-
164
- if exists(resumable_with_seed):
165
- generator = torch.Generator()
166
- generator.manual_seed(resumable_with_seed)
167
- else:
168
- generator = None
169
-
170
- if self.batch_size_type == "sample":
171
- train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, persistent_workers=True,
172
- batch_size=self.batch_size, shuffle=True, generator=generator)
173
- elif self.batch_size_type == "frame":
174
- self.accelerator.even_batches = False
175
- sampler = SequentialSampler(train_dataset)
176
- batch_sampler = DynamicBatchSampler(sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False)
177
- train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, persistent_workers=True,
178
- batch_sampler=batch_sampler)
179
- else:
180
- raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")
181
-
182
- # accelerator.prepare() dispatches batches to devices;
183
- # which means the length of dataloader calculated before, should consider the number of devices
184
- warmup_steps = self.num_warmup_updates * self.accelerator.num_processes # consider a fixed warmup steps while using accelerate multi-gpu ddp
185
- # otherwise by default with split_batches=False, warmup steps change with num_processes
186
- total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
187
- decay_steps = total_steps - warmup_steps
188
- warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
189
- decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
190
- self.scheduler = SequentialLR(self.optimizer,
191
- schedulers=[warmup_scheduler, decay_scheduler],
192
- milestones=[warmup_steps])
193
- train_dataloader, self.scheduler = self.accelerator.prepare(train_dataloader, self.scheduler) # actual steps = 1 gpu steps / gpus
194
- start_step = self.load_checkpoint()
195
- global_step = start_step
196
-
197
- if exists(resumable_with_seed):
198
- orig_epoch_step = len(train_dataloader)
199
- skipped_epoch = int(start_step // orig_epoch_step)
200
- skipped_batch = start_step % orig_epoch_step
201
- skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
202
- else:
203
- skipped_epoch = 0
204
-
205
- for epoch in range(skipped_epoch, self.epochs):
206
- self.model.train()
207
- if exists(resumable_with_seed) and epoch == skipped_epoch:
208
- progress_bar = tqdm(skipped_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process,
209
- initial=skipped_batch, total=orig_epoch_step)
210
- else:
211
- progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process)
212
-
213
- for batch in progress_bar:
214
- with self.accelerator.accumulate(self.model):
215
- text_inputs = batch['text']
216
- mel_spec = rearrange(batch['mel'], 'b d n -> b n d')
217
- mel_lengths = batch["mel_lengths"]
218
-
219
- # TODO. add duration predictor training
220
- if self.duration_predictor is not None and self.accelerator.is_local_main_process:
221
- dur_loss = self.duration_predictor(mel_spec, lens=batch.get('durations'))
222
- self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
223
-
224
- loss, cond, pred = self.model(mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler)
225
- self.accelerator.backward(loss)
226
-
227
- if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
228
- self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
229
-
230
- self.optimizer.step()
231
- self.scheduler.step()
232
- self.optimizer.zero_grad()
233
-
234
- if self.is_main:
235
- self.ema_model.update()
236
-
237
- global_step += 1
238
-
239
- if self.accelerator.is_local_main_process:
240
- self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
241
-
242
- progress_bar.set_postfix(step=str(global_step), loss=loss.item())
243
-
244
- if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
245
- self.save_checkpoint(global_step)
246
-
247
- if global_step % self.last_per_steps == 0:
248
- self.save_checkpoint(global_step, last=True)
249
-
250
- self.accelerator.end_training()