Spaces:
Running
Running
import os | |
from typing import Dict | |
from diacritization_evaluation import der, wer | |
import torch | |
from torch import nn | |
from torch import optim | |
from torch.cuda.amp import autocast | |
from torch.utils.tensorboard.writer import SummaryWriter | |
from tqdm import tqdm | |
from tqdm import trange | |
from .config_manager import ConfigManager | |
from dataset import load_iterators | |
from diacritizer import CBHGDiacritizer, Seq2SeqDiacritizer, GPTDiacritizer | |
from poetry_diacritizer.util.learning_rates import LearningRateDecay | |
from poetry_diacritizer.options import OptimizerType | |
from poetry_diacritizer.util.utils import ( | |
categorical_accuracy, | |
count_parameters, | |
initialize_weights, | |
plot_alignment, | |
repeater, | |
) | |
import wandb | |
wandb.login() | |
class Trainer: | |
def run(self): | |
raise NotImplementedError | |
class GeneralTrainer(Trainer): | |
def __init__(self, config_path: str, model_kind: str, model_desc: str) -> None: | |
self.config_path = config_path | |
self.model_kind = model_kind | |
self.config_manager = ConfigManager( | |
config_path=config_path, model_kind=model_kind | |
) | |
self.config = self.config_manager.config | |
self.losses = [] | |
self.lr = 0 | |
self.pad_idx = 0 | |
self.criterion = nn.CrossEntropyLoss(ignore_index=self.pad_idx) | |
self.set_device() | |
self.config_manager.create_remove_dirs() | |
self.text_encoder = self.config_manager.text_encoder | |
self.start_symbol_id = self.text_encoder.start_symbol_id | |
self.summary_manager = SummaryWriter(log_dir=self.config_manager.log_dir) | |
if model_desc == "": | |
model_desc = self.model_kind | |
wandb.init(project="diacratization", name=model_desc, config=self.config) | |
self.model = self.config_manager.get_model() | |
self.optimizer = self.get_optimizer() | |
self.model = self.model.to(self.device) | |
self.load_model(model_path=self.config.get("train_resume_model_path")) | |
self.load_diacritizer() | |
self.initialize_model() | |
self.print_config() | |
def set_device(self): | |
if self.config.get("device"): | |
self.device = self.config["device"] | |
else: | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
def print_config(self): | |
self.config_manager.dump_config() | |
self.config_manager.print_config() | |
if self.global_step > 1: | |
print(f"loaded form {self.global_step}") | |
parameters_count = count_parameters(self.model) | |
print(f"The model has {parameters_count} trainable parameters parameters") | |
def load_diacritizer(self): | |
if self.model_kind in ["cbhg", "baseline"]: | |
self.diacritizer = CBHGDiacritizer(self.config_path, self.model_kind) | |
elif self.model_kind in ["seq2seq", "tacotron_based"]: | |
self.diacritizer = Seq2SeqDiacritizer(self.config_path, self.model_kind) | |
elif self.model_kind in ["gpt"]: | |
self.diacritizer = GPTDiacritizer(self.config_path, self.model_kind) | |
def initialize_model(self): | |
if self.global_step > 1: | |
return | |
if self.model_kind == "transformer": | |
print("Initializing using xavier_uniform_") | |
self.model.apply(initialize_weights) | |
def print_losses(self, step_results, tqdm): | |
self.summary_manager.add_scalar( | |
"loss/loss", step_results["loss"], global_step=self.global_step | |
) | |
tqdm.display(f"loss: {step_results['loss']}", pos=3) | |
for pos, n_steps in enumerate(self.config["n_steps_avg_losses"]): | |
if len(self.losses) > n_steps: | |
self.summary_manager.add_scalar( | |
f"loss/loss-{n_steps}", | |
sum(self.losses[-n_steps:]) / n_steps, | |
global_step=self.global_step, | |
) | |
tqdm.display( | |
f"{n_steps}-steps average loss: {sum(self.losses[-n_steps:]) / n_steps}", | |
pos=pos + 4, | |
) | |
def evaluate(self, iterator, tqdm, use_target=True, log = True): | |
epoch_loss = 0 | |
epoch_acc = 0 | |
self.model.eval() | |
tqdm.set_description(f"Eval: {self.global_step}") | |
with torch.no_grad(): | |
for batch_inputs in iterator: | |
batch_inputs["src"] = batch_inputs["src"].to(self.device) | |
batch_inputs["lengths"] = batch_inputs["lengths"].to("cpu") | |
if use_target: | |
batch_inputs["target"] = batch_inputs["target"].to(self.device) | |
else: | |
batch_inputs["target"] = None | |
outputs = self.model( | |
src=batch_inputs["src"], | |
target=batch_inputs["target"], | |
lengths=batch_inputs["lengths"], | |
) | |
predictions = outputs["diacritics"] | |
predictions = predictions.view(-1, predictions.shape[-1]) | |
targets = batch_inputs["target"] | |
targets = targets.view(-1) | |
loss = self.criterion(predictions, targets.to(self.device)) | |
acc = categorical_accuracy( | |
predictions, targets.to(self.device), self.pad_idx | |
) | |
epoch_loss += loss.item() | |
epoch_acc += acc.item() | |
if log: | |
wandb.log({"evaluate_loss": loss.item(), "evaluate_acc": acc.item()}) | |
tqdm.update() | |
tqdm.reset() | |
return epoch_loss / len(iterator), epoch_acc / len(iterator) | |
def evaluate_with_error_rates(self, iterator, tqdm, log = True): | |
all_orig = [] | |
all_predicted = [] | |
results = {} | |
self.diacritizer.set_model(self.model) | |
evaluated_batches = 0 | |
tqdm.set_description(f"Calculating DER/WER {self.global_step}: ") | |
for i, batch in enumerate(iterator): | |
if evaluated_batches > int(self.config["error_rates_n_batches"]): | |
break | |
predicted = self.diacritizer.diacritize_batch(batch) | |
all_predicted += predicted | |
all_orig += batch["original"] | |
if i > self.config["max_eval_batches"]: | |
break | |
tqdm.update() | |
summary_texts = [] | |
orig_path = os.path.join(self.config_manager.prediction_dir, f"original.txt") | |
predicted_path = os.path.join( | |
self.config_manager.prediction_dir, f"predicted.txt" | |
) | |
table = wandb.Table(columns=["original", "predicted"]) | |
with open(orig_path, "w", encoding="utf8") as file: | |
for sentence in all_orig: | |
file.write(f"{sentence}\n") | |
with open(predicted_path, "w", encoding="utf8") as file: | |
for sentence in all_predicted: | |
file.write(f"{sentence}\n") | |
for i in range(int(self.config["n_predicted_text_tensorboard"])): | |
if i > len(all_predicted): | |
break | |
summary_texts.append( | |
(f"eval-text/{i}", f"{ all_orig[i]} |-> {all_predicted[i]}") | |
) | |
if i < 10: | |
table.add_data(all_orig[i], all_predicted[i]) | |
if log: | |
wandb.log({f"prediction_{self.global_step}": table}, commit=False) | |
results["DER"] = der.calculate_der_from_path(orig_path, predicted_path) | |
results["DER*"] = der.calculate_der_from_path( | |
orig_path, predicted_path, case_ending=False | |
) | |
results["WER"] = wer.calculate_wer_from_path(orig_path, predicted_path) | |
results["WER*"] = wer.calculate_wer_from_path( | |
orig_path, predicted_path, case_ending=False | |
) | |
if log: | |
wandb.log(results) | |
tqdm.reset() | |
return results, summary_texts | |
def run(self): | |
scaler = torch.cuda.amp.GradScaler() | |
train_iterator, _, validation_iterator = load_iterators(self.config_manager) | |
print("data loaded") | |
print("----------------------------------------------------------") | |
tqdm_eval = trange(0, len(validation_iterator), leave=True) | |
tqdm_error_rates = trange(0, len(validation_iterator), leave=True) | |
tqdm_eval.set_description("Eval") | |
tqdm_error_rates.set_description("WER/DER : ") | |
tqdm = trange(self.global_step, self.config["max_steps"] + 1, leave=True) | |
for batch_inputs in repeater(train_iterator): | |
tqdm.set_description(f"Global Step {self.global_step}") | |
if self.config["use_decay"]: | |
self.lr = self.adjust_learning_rate( | |
self.optimizer, global_step=self.global_step | |
) | |
self.optimizer.zero_grad() | |
if self.device == "cuda" and self.config["use_mixed_precision"]: | |
with autocast(): | |
step_results = self.run_one_step(batch_inputs) | |
scaler.scale(step_results["loss"]).backward() | |
scaler.unscale_(self.optimizer) | |
if self.config.get("CLIP"): | |
torch.nn.utils.clip_grad_norm_( | |
self.model.parameters(), self.config["CLIP"] | |
) | |
scaler.step(self.optimizer) | |
scaler.update() | |
else: | |
step_results = self.run_one_step(batch_inputs) | |
loss = step_results["loss"] | |
loss.backward() | |
if self.config.get("CLIP"): | |
torch.nn.utils.clip_grad_norm_( | |
self.model.parameters(), self.config["CLIP"] | |
) | |
self.optimizer.step() | |
self.losses.append(step_results["loss"].item()) | |
wandb.log({"train_loss": step_results["loss"].item()}) | |
self.print_losses(step_results, tqdm) | |
self.summary_manager.add_scalar( | |
"meta/learning_rate", self.lr, global_step=self.global_step | |
) | |
if self.global_step % self.config["model_save_frequency"] == 0: | |
torch.save( | |
{ | |
"global_step": self.global_step, | |
"model_state_dict": self.model.state_dict(), | |
"optimizer_state_dict": self.optimizer.state_dict(), | |
}, | |
os.path.join( | |
self.config_manager.models_dir, | |
f"{self.global_step}-snapshot.pt", | |
), | |
) | |
if self.global_step % self.config["evaluate_frequency"] == 0: | |
loss, acc = self.evaluate(validation_iterator, tqdm_eval) | |
self.summary_manager.add_scalar( | |
"evaluate/loss", loss, global_step=self.global_step | |
) | |
self.summary_manager.add_scalar( | |
"evaluate/acc", acc, global_step=self.global_step | |
) | |
tqdm.display( | |
f"Evaluate {self.global_step}: accuracy, {acc}, loss: {loss}", pos=8 | |
) | |
self.model.train() | |
if ( | |
self.global_step % self.config["evaluate_with_error_rates_frequency"] | |
== 0 | |
): | |
error_rates, summery_texts = self.evaluate_with_error_rates( | |
validation_iterator, tqdm_error_rates | |
) | |
if error_rates: | |
WER = error_rates["WER"] | |
DER = error_rates["DER"] | |
DER1 = error_rates["DER*"] | |
WER1 = error_rates["WER*"] | |
self.summary_manager.add_scalar( | |
"error_rates/WER", | |
WER / 100, | |
global_step=self.global_step, | |
) | |
self.summary_manager.add_scalar( | |
"error_rates/DER", | |
DER / 100, | |
global_step=self.global_step, | |
) | |
self.summary_manager.add_scalar( | |
"error_rates/DER*", | |
DER1 / 100, | |
global_step=self.global_step, | |
) | |
self.summary_manager.add_scalar( | |
"error_rates/WER*", | |
WER1 / 100, | |
global_step=self.global_step, | |
) | |
error_rates = f"DER: {DER}, WER: {WER}, DER*: {DER1}, WER*: {WER1}" | |
tqdm.display(f"WER/DER {self.global_step}: {error_rates}", pos=9) | |
for tag, text in summery_texts: | |
self.summary_manager.add_text(tag, text) | |
self.model.train() | |
if self.global_step % self.config["train_plotting_frequency"] == 0: | |
self.plot_attention(step_results) | |
self.report(step_results, tqdm) | |
self.global_step += 1 | |
if self.global_step > self.config["max_steps"]: | |
print("Training Done.") | |
return | |
tqdm.update() | |
def run_one_step(self, batch_inputs: Dict[str, torch.Tensor]): | |
batch_inputs["src"] = batch_inputs["src"].to(self.device) | |
batch_inputs["lengths"] = batch_inputs["lengths"].to("cpu") | |
batch_inputs["target"] = batch_inputs["target"].to(self.device) | |
outputs = self.model( | |
src=batch_inputs["src"], | |
target=batch_inputs["target"], | |
lengths=batch_inputs["lengths"], | |
) | |
predictions = outputs["diacritics"].contiguous() | |
targets = batch_inputs["target"].contiguous() | |
predictions = predictions.view(-1, predictions.shape[-1]) | |
targets = targets.view(-1) | |
loss = self.criterion(predictions.to(self.device), targets.to(self.device)) | |
outputs.update({"loss": loss}) | |
return outputs | |
def predict(self, iterator): | |
pass | |
def load_model(self, model_path: str = None, load_optimizer: bool = True): | |
with open( | |
self.config_manager.base_dir / f"{self.model_kind}_network.txt", "w" | |
) as file: | |
file.write(str(self.model)) | |
if model_path is None: | |
last_model_path = self.config_manager.get_last_model_path() | |
if last_model_path is None: | |
self.global_step = 1 | |
return | |
else: | |
last_model_path = model_path | |
print(f"loading from {last_model_path}") | |
saved_model = torch.load(last_model_path) | |
self.model.load_state_dict(saved_model["model_state_dict"]) | |
if load_optimizer: | |
self.optimizer.load_state_dict(saved_model["optimizer_state_dict"]) | |
self.global_step = saved_model["global_step"] + 1 | |
def get_optimizer(self): | |
if self.config["optimizer"] == OptimizerType.Adam: | |
optimizer = optim.Adam( | |
self.model.parameters(), | |
lr=self.config["learning_rate"], | |
betas=(self.config["adam_beta1"], self.config["adam_beta2"]), | |
weight_decay=self.config["weight_decay"], | |
) | |
elif self.config["optimizer"] == OptimizerType.SGD: | |
optimizer = optim.SGD( | |
self.model.parameters(), lr=self.config["learning_rate"], momentum=0.9 | |
) | |
else: | |
raise ValueError("Optimizer option is not valid") | |
return optimizer | |
def get_learning_rate(self): | |
return LearningRateDecay( | |
lr=self.config["learning_rate"], | |
warmup_steps=self.config.get("warmup_steps", 4000.0), | |
) | |
def adjust_learning_rate(self, optimizer, global_step): | |
learning_rate = self.get_learning_rate()(global_step=global_step) | |
for param_group in optimizer.param_groups: | |
param_group["lr"] = learning_rate | |
return learning_rate | |
def plot_attention(self, results): | |
pass | |
def report(self, results, tqdm): | |
pass | |
class Seq2SeqTrainer(GeneralTrainer): | |
def plot_attention(self, results): | |
plot_alignment( | |
results["attention"][0], | |
str(self.config_manager.plot_dir), | |
self.global_step, | |
) | |
self.summary_manager.add_image( | |
"Train/attention", | |
results["attention"][0].unsqueeze(0), | |
global_step=self.global_step, | |
) | |
class GPTTrainer(GeneralTrainer): | |
pass | |
class CBHGTrainer(GeneralTrainer): | |
pass | |