Spaces:
Running
Running
from .config_manager import ConfigManager | |
import os | |
from typing import Dict | |
from torch import nn | |
from tqdm import tqdm | |
from tqdm import trange | |
from dataset import load_iterators | |
from trainer import GeneralTrainer | |
class DiacritizationTester(GeneralTrainer): | |
def __init__(self, config_path: str, model_kind: str) -> None: | |
self.config_path = config_path | |
self.model_kind = model_kind | |
self.config_manager = ConfigManager( | |
config_path=config_path, model_kind=model_kind | |
) | |
self.config = self.config_manager.config | |
self.pad_idx = 0 | |
self.criterion = nn.CrossEntropyLoss(ignore_index=self.pad_idx) | |
self.set_device() | |
self.text_encoder = self.config_manager.text_encoder | |
self.start_symbol_id = self.text_encoder.start_symbol_id | |
self.model = self.config_manager.get_model() | |
self.model = self.model.to(self.device) | |
self.load_model(model_path=self.config["test_model_path"], load_optimizer=False) | |
self.load_diacritizer() | |
self.diacritizer.set_model(self.model) | |
self.initialize_model() | |
self.print_config() | |
def run(self): | |
self.config_manager.config["load_training_data"] = False | |
self.config_manager.config["load_validation_data"] = False | |
self.config_manager.config["load_test_data"] = True | |
_, test_iterator, _ = load_iterators(self.config_manager) | |
tqdm_eval = trange(0, len(test_iterator), leave=True) | |
tqdm_error_rates = trange(0, len(test_iterator), leave=True) | |
loss, acc = self.evaluate(test_iterator, tqdm_eval, log = False) | |
error_rates, _ = self.evaluate_with_error_rates(test_iterator, tqdm_error_rates, log = False) | |
tqdm_eval.close() | |
tqdm_error_rates.close() | |
WER = error_rates["WER"] | |
DER = error_rates["DER"] | |
DER1 = error_rates["DER*"] | |
WER1 = error_rates["WER*"] | |
error_rates = f"DER: {DER}, WER: {WER}, DER*: {DER1}, WER*: {WER1}" | |
print(f"global step : {self.global_step}") | |
print(f"Evaluate {self.global_step}: accuracy, {acc}, loss: {loss}") | |
print(f"WER/DER {self.global_step}: {error_rates}") | |