|
import os |
|
from typing import Dict |
|
|
|
from diacritization_evaluation import der, wer |
|
import torch |
|
from torch import nn |
|
from torch import optim |
|
from torch.cuda.amp import autocast |
|
from torch.utils.tensorboard.writer import SummaryWriter |
|
from tqdm import tqdm |
|
from tqdm import trange |
|
|
|
from .config_manager import ConfigManager |
|
from dataset import load_iterators |
|
from diacritizer import CBHGDiacritizer, Seq2SeqDiacritizer, GPTDiacritizer |
|
from poetry_diacritizer.util.learning_rates import LearningRateDecay |
|
from poetry_diacritizer.options import OptimizerType |
|
from poetry_diacritizer.util.utils import ( |
|
categorical_accuracy, |
|
count_parameters, |
|
initialize_weights, |
|
plot_alignment, |
|
repeater, |
|
) |
|
|
|
import wandb |
|
|
|
wandb.login() |
|
|
|
|
|
class Trainer: |
|
def run(self): |
|
raise NotImplementedError |
|
|
|
|
|
class GeneralTrainer(Trainer): |
|
def __init__(self, config_path: str, model_kind: str, model_desc: str) -> None: |
|
self.config_path = config_path |
|
self.model_kind = model_kind |
|
self.config_manager = ConfigManager( |
|
config_path=config_path, model_kind=model_kind |
|
) |
|
self.config = self.config_manager.config |
|
self.losses = [] |
|
self.lr = 0 |
|
self.pad_idx = 0 |
|
self.criterion = nn.CrossEntropyLoss(ignore_index=self.pad_idx) |
|
self.set_device() |
|
|
|
self.config_manager.create_remove_dirs() |
|
self.text_encoder = self.config_manager.text_encoder |
|
self.start_symbol_id = self.text_encoder.start_symbol_id |
|
self.summary_manager = SummaryWriter(log_dir=self.config_manager.log_dir) |
|
if model_desc == "": |
|
model_desc = self.model_kind |
|
wandb.init(project="diacratization", name=model_desc, config=self.config) |
|
self.model = self.config_manager.get_model() |
|
|
|
self.optimizer = self.get_optimizer() |
|
self.model = self.model.to(self.device) |
|
|
|
self.load_model(model_path=self.config.get("train_resume_model_path")) |
|
self.load_diacritizer() |
|
|
|
self.initialize_model() |
|
|
|
self.print_config() |
|
|
|
def set_device(self): |
|
if self.config.get("device"): |
|
self.device = self.config["device"] |
|
else: |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
def print_config(self): |
|
self.config_manager.dump_config() |
|
self.config_manager.print_config() |
|
|
|
if self.global_step > 1: |
|
print(f"loaded form {self.global_step}") |
|
|
|
parameters_count = count_parameters(self.model) |
|
print(f"The model has {parameters_count} trainable parameters parameters") |
|
|
|
def load_diacritizer(self): |
|
if self.model_kind in ["cbhg", "baseline"]: |
|
self.diacritizer = CBHGDiacritizer(self.config_path, self.model_kind) |
|
elif self.model_kind in ["seq2seq", "tacotron_based"]: |
|
self.diacritizer = Seq2SeqDiacritizer(self.config_path, self.model_kind) |
|
elif self.model_kind in ["gpt"]: |
|
self.diacritizer = GPTDiacritizer(self.config_path, self.model_kind) |
|
|
|
def initialize_model(self): |
|
if self.global_step > 1: |
|
return |
|
if self.model_kind == "transformer": |
|
print("Initializing using xavier_uniform_") |
|
self.model.apply(initialize_weights) |
|
|
|
def print_losses(self, step_results, tqdm): |
|
self.summary_manager.add_scalar( |
|
"loss/loss", step_results["loss"], global_step=self.global_step |
|
) |
|
|
|
tqdm.display(f"loss: {step_results['loss']}", pos=3) |
|
for pos, n_steps in enumerate(self.config["n_steps_avg_losses"]): |
|
if len(self.losses) > n_steps: |
|
|
|
self.summary_manager.add_scalar( |
|
f"loss/loss-{n_steps}", |
|
sum(self.losses[-n_steps:]) / n_steps, |
|
global_step=self.global_step, |
|
) |
|
tqdm.display( |
|
f"{n_steps}-steps average loss: {sum(self.losses[-n_steps:]) / n_steps}", |
|
pos=pos + 4, |
|
) |
|
|
|
def evaluate(self, iterator, tqdm, use_target=True, log = True): |
|
epoch_loss = 0 |
|
epoch_acc = 0 |
|
self.model.eval() |
|
tqdm.set_description(f"Eval: {self.global_step}") |
|
with torch.no_grad(): |
|
for batch_inputs in iterator: |
|
batch_inputs["src"] = batch_inputs["src"].to(self.device) |
|
batch_inputs["lengths"] = batch_inputs["lengths"].to("cpu") |
|
if use_target: |
|
batch_inputs["target"] = batch_inputs["target"].to(self.device) |
|
else: |
|
batch_inputs["target"] = None |
|
|
|
outputs = self.model( |
|
src=batch_inputs["src"], |
|
target=batch_inputs["target"], |
|
lengths=batch_inputs["lengths"], |
|
) |
|
|
|
predictions = outputs["diacritics"] |
|
|
|
predictions = predictions.view(-1, predictions.shape[-1]) |
|
targets = batch_inputs["target"] |
|
targets = targets.view(-1) |
|
loss = self.criterion(predictions, targets.to(self.device)) |
|
acc = categorical_accuracy( |
|
predictions, targets.to(self.device), self.pad_idx |
|
) |
|
|
|
epoch_loss += loss.item() |
|
epoch_acc += acc.item() |
|
if log: |
|
wandb.log({"evaluate_loss": loss.item(), "evaluate_acc": acc.item()}) |
|
tqdm.update() |
|
|
|
tqdm.reset() |
|
return epoch_loss / len(iterator), epoch_acc / len(iterator) |
|
|
|
def evaluate_with_error_rates(self, iterator, tqdm, log = True): |
|
all_orig = [] |
|
all_predicted = [] |
|
results = {} |
|
self.diacritizer.set_model(self.model) |
|
evaluated_batches = 0 |
|
tqdm.set_description(f"Calculating DER/WER {self.global_step}: ") |
|
for i, batch in enumerate(iterator): |
|
if evaluated_batches > int(self.config["error_rates_n_batches"]): |
|
break |
|
|
|
predicted = self.diacritizer.diacritize_batch(batch) |
|
all_predicted += predicted |
|
all_orig += batch["original"] |
|
if i > self.config["max_eval_batches"]: |
|
break |
|
tqdm.update() |
|
|
|
summary_texts = [] |
|
orig_path = os.path.join(self.config_manager.prediction_dir, f"original.txt") |
|
predicted_path = os.path.join( |
|
self.config_manager.prediction_dir, f"predicted.txt" |
|
) |
|
|
|
table = wandb.Table(columns=["original", "predicted"]) |
|
with open(orig_path, "w", encoding="utf8") as file: |
|
for sentence in all_orig: |
|
file.write(f"{sentence}\n") |
|
|
|
with open(predicted_path, "w", encoding="utf8") as file: |
|
for sentence in all_predicted: |
|
file.write(f"{sentence}\n") |
|
|
|
for i in range(int(self.config["n_predicted_text_tensorboard"])): |
|
if i > len(all_predicted): |
|
break |
|
|
|
summary_texts.append( |
|
(f"eval-text/{i}", f"{ all_orig[i]} |-> {all_predicted[i]}") |
|
) |
|
if i < 10: |
|
table.add_data(all_orig[i], all_predicted[i]) |
|
|
|
if log: |
|
wandb.log({f"prediction_{self.global_step}": table}, commit=False) |
|
|
|
results["DER"] = der.calculate_der_from_path(orig_path, predicted_path) |
|
results["DER*"] = der.calculate_der_from_path( |
|
orig_path, predicted_path, case_ending=False |
|
) |
|
results["WER"] = wer.calculate_wer_from_path(orig_path, predicted_path) |
|
results["WER*"] = wer.calculate_wer_from_path( |
|
orig_path, predicted_path, case_ending=False |
|
) |
|
if log: |
|
wandb.log(results) |
|
tqdm.reset() |
|
return results, summary_texts |
|
|
|
def run(self): |
|
scaler = torch.cuda.amp.GradScaler() |
|
train_iterator, _, validation_iterator = load_iterators(self.config_manager) |
|
print("data loaded") |
|
print("----------------------------------------------------------") |
|
tqdm_eval = trange(0, len(validation_iterator), leave=True) |
|
tqdm_error_rates = trange(0, len(validation_iterator), leave=True) |
|
tqdm_eval.set_description("Eval") |
|
tqdm_error_rates.set_description("WER/DER : ") |
|
tqdm = trange(self.global_step, self.config["max_steps"] + 1, leave=True) |
|
|
|
for batch_inputs in repeater(train_iterator): |
|
tqdm.set_description(f"Global Step {self.global_step}") |
|
if self.config["use_decay"]: |
|
self.lr = self.adjust_learning_rate( |
|
self.optimizer, global_step=self.global_step |
|
) |
|
self.optimizer.zero_grad() |
|
if self.device == "cuda" and self.config["use_mixed_precision"]: |
|
with autocast(): |
|
step_results = self.run_one_step(batch_inputs) |
|
scaler.scale(step_results["loss"]).backward() |
|
scaler.unscale_(self.optimizer) |
|
if self.config.get("CLIP"): |
|
torch.nn.utils.clip_grad_norm_( |
|
self.model.parameters(), self.config["CLIP"] |
|
) |
|
|
|
scaler.step(self.optimizer) |
|
|
|
scaler.update() |
|
else: |
|
step_results = self.run_one_step(batch_inputs) |
|
|
|
loss = step_results["loss"] |
|
loss.backward() |
|
if self.config.get("CLIP"): |
|
torch.nn.utils.clip_grad_norm_( |
|
self.model.parameters(), self.config["CLIP"] |
|
) |
|
self.optimizer.step() |
|
|
|
self.losses.append(step_results["loss"].item()) |
|
wandb.log({"train_loss": step_results["loss"].item()}) |
|
|
|
self.print_losses(step_results, tqdm) |
|
|
|
self.summary_manager.add_scalar( |
|
"meta/learning_rate", self.lr, global_step=self.global_step |
|
) |
|
|
|
if self.global_step % self.config["model_save_frequency"] == 0: |
|
torch.save( |
|
{ |
|
"global_step": self.global_step, |
|
"model_state_dict": self.model.state_dict(), |
|
"optimizer_state_dict": self.optimizer.state_dict(), |
|
}, |
|
os.path.join( |
|
self.config_manager.models_dir, |
|
f"{self.global_step}-snapshot.pt", |
|
), |
|
) |
|
|
|
if self.global_step % self.config["evaluate_frequency"] == 0: |
|
loss, acc = self.evaluate(validation_iterator, tqdm_eval) |
|
self.summary_manager.add_scalar( |
|
"evaluate/loss", loss, global_step=self.global_step |
|
) |
|
self.summary_manager.add_scalar( |
|
"evaluate/acc", acc, global_step=self.global_step |
|
) |
|
tqdm.display( |
|
f"Evaluate {self.global_step}: accuracy, {acc}, loss: {loss}", pos=8 |
|
) |
|
self.model.train() |
|
|
|
if ( |
|
self.global_step % self.config["evaluate_with_error_rates_frequency"] |
|
== 0 |
|
): |
|
error_rates, summery_texts = self.evaluate_with_error_rates( |
|
validation_iterator, tqdm_error_rates |
|
) |
|
if error_rates: |
|
WER = error_rates["WER"] |
|
DER = error_rates["DER"] |
|
DER1 = error_rates["DER*"] |
|
WER1 = error_rates["WER*"] |
|
|
|
self.summary_manager.add_scalar( |
|
"error_rates/WER", |
|
WER / 100, |
|
global_step=self.global_step, |
|
) |
|
self.summary_manager.add_scalar( |
|
"error_rates/DER", |
|
DER / 100, |
|
global_step=self.global_step, |
|
) |
|
self.summary_manager.add_scalar( |
|
"error_rates/DER*", |
|
DER1 / 100, |
|
global_step=self.global_step, |
|
) |
|
self.summary_manager.add_scalar( |
|
"error_rates/WER*", |
|
WER1 / 100, |
|
global_step=self.global_step, |
|
) |
|
|
|
error_rates = f"DER: {DER}, WER: {WER}, DER*: {DER1}, WER*: {WER1}" |
|
tqdm.display(f"WER/DER {self.global_step}: {error_rates}", pos=9) |
|
|
|
for tag, text in summery_texts: |
|
self.summary_manager.add_text(tag, text) |
|
|
|
self.model.train() |
|
|
|
if self.global_step % self.config["train_plotting_frequency"] == 0: |
|
self.plot_attention(step_results) |
|
|
|
self.report(step_results, tqdm) |
|
|
|
self.global_step += 1 |
|
if self.global_step > self.config["max_steps"]: |
|
print("Training Done.") |
|
return |
|
|
|
tqdm.update() |
|
|
|
def run_one_step(self, batch_inputs: Dict[str, torch.Tensor]): |
|
batch_inputs["src"] = batch_inputs["src"].to(self.device) |
|
batch_inputs["lengths"] = batch_inputs["lengths"].to("cpu") |
|
batch_inputs["target"] = batch_inputs["target"].to(self.device) |
|
|
|
outputs = self.model( |
|
src=batch_inputs["src"], |
|
target=batch_inputs["target"], |
|
lengths=batch_inputs["lengths"], |
|
) |
|
|
|
predictions = outputs["diacritics"].contiguous() |
|
targets = batch_inputs["target"].contiguous() |
|
predictions = predictions.view(-1, predictions.shape[-1]) |
|
targets = targets.view(-1) |
|
loss = self.criterion(predictions.to(self.device), targets.to(self.device)) |
|
outputs.update({"loss": loss}) |
|
return outputs |
|
|
|
def predict(self, iterator): |
|
pass |
|
|
|
def load_model(self, model_path: str = None, load_optimizer: bool = True): |
|
with open( |
|
self.config_manager.base_dir / f"{self.model_kind}_network.txt", "w" |
|
) as file: |
|
file.write(str(self.model)) |
|
|
|
if model_path is None: |
|
last_model_path = self.config_manager.get_last_model_path() |
|
if last_model_path is None: |
|
self.global_step = 1 |
|
return |
|
else: |
|
last_model_path = model_path |
|
|
|
print(f"loading from {last_model_path}") |
|
saved_model = torch.load(last_model_path) |
|
self.model.load_state_dict(saved_model["model_state_dict"]) |
|
if load_optimizer: |
|
self.optimizer.load_state_dict(saved_model["optimizer_state_dict"]) |
|
self.global_step = saved_model["global_step"] + 1 |
|
|
|
def get_optimizer(self): |
|
if self.config["optimizer"] == OptimizerType.Adam: |
|
optimizer = optim.Adam( |
|
self.model.parameters(), |
|
lr=self.config["learning_rate"], |
|
betas=(self.config["adam_beta1"], self.config["adam_beta2"]), |
|
weight_decay=self.config["weight_decay"], |
|
) |
|
elif self.config["optimizer"] == OptimizerType.SGD: |
|
optimizer = optim.SGD( |
|
self.model.parameters(), lr=self.config["learning_rate"], momentum=0.9 |
|
) |
|
else: |
|
raise ValueError("Optimizer option is not valid") |
|
|
|
return optimizer |
|
|
|
def get_learning_rate(self): |
|
return LearningRateDecay( |
|
lr=self.config["learning_rate"], |
|
warmup_steps=self.config.get("warmup_steps", 4000.0), |
|
) |
|
|
|
def adjust_learning_rate(self, optimizer, global_step): |
|
learning_rate = self.get_learning_rate()(global_step=global_step) |
|
for param_group in optimizer.param_groups: |
|
param_group["lr"] = learning_rate |
|
return learning_rate |
|
|
|
def plot_attention(self, results): |
|
pass |
|
|
|
def report(self, results, tqdm): |
|
pass |
|
|
|
|
|
class Seq2SeqTrainer(GeneralTrainer): |
|
def plot_attention(self, results): |
|
plot_alignment( |
|
results["attention"][0], |
|
str(self.config_manager.plot_dir), |
|
self.global_step, |
|
) |
|
|
|
self.summary_manager.add_image( |
|
"Train/attention", |
|
results["attention"][0].unsqueeze(0), |
|
global_step=self.global_step, |
|
) |
|
|
|
|
|
class GPTTrainer(GeneralTrainer): |
|
pass |
|
|
|
|
|
class CBHGTrainer(GeneralTrainer): |
|
pass |
|
|