Ashaar / poetry_diacritizer /config_manager.py
Ababababababbababa's picture
Duplicate from arbml/Ashaar
6faf7e7
from enum import Enum
import os
from pathlib import Path
import shutil
import subprocess
from typing import Any, Dict
import ruamel.yaml
import torch
from poetry_diacritizer.models.baseline import BaseLineModel
from poetry_diacritizer.models.cbhg import CBHGModel
from poetry_diacritizer.models.gpt import GPTModel
from poetry_diacritizer.models.seq2seq import Decoder as Seq2SeqDecoder, Encoder as Seq2SeqEncoder, Seq2Seq
from poetry_diacritizer.models.tacotron_based import (
Decoder as TacotronDecoder,
Encoder as TacotronEncoder,
Tacotron,
)
from poetry_diacritizer.options import AttentionType, LossType, OptimizerType
from poetry_diacritizer.util.text_encoders import (
ArabicEncoderWithStartSymbol,
BasicArabicEncoder,
TextEncoder,
)
class ConfigManager:
"""Co/home/almodhfer/Projects/daicritization/temp_results/CA_MSA/cbhg-new/model-10.ptnfig Manager"""
def __init__(self, config_path: str, model_kind: str):
available_models = ["baseline", "cbhg", "seq2seq", "tacotron_based", "gpt"]
if model_kind not in available_models:
raise TypeError(f"model_kind must be in {available_models}")
self.config_path = Path(config_path)
self.model_kind = model_kind
self.yaml = ruamel.yaml.YAML()
self.config: Dict[str, Any] = self._load_config()
self.git_hash = self._get_git_hash()
self.session_name = ".".join(
[
self.config["data_type"],
self.config["session_name"],
f"{model_kind}",
]
)
self.data_dir = Path(
os.path.join(self.config["data_directory"], self.config["data_type"])
)
self.base_dir = Path(
os.path.join(self.config["log_directory"], self.session_name)
)
self.log_dir = Path(os.path.join(self.base_dir, "logs"))
self.prediction_dir = Path(os.path.join(self.base_dir, "predictions"))
self.plot_dir = Path(os.path.join(self.base_dir, "plots"))
self.models_dir = Path(os.path.join(self.base_dir, "models"))
if "sp_model_path" in self.config:
self.sp_model_path = self.config["sp_model_path"]
else:
self.sp_model_path = None
self.text_encoder: TextEncoder = self.get_text_encoder()
self.config["len_input_symbols"] = len(self.text_encoder.input_symbols)
self.config["len_target_symbols"] = len(self.text_encoder.target_symbols)
if self.model_kind in ["seq2seq", "tacotron_based"]:
self.config["attention_type"] = AttentionType[self.config["attention_type"]]
self.config["optimizer"] = OptimizerType[self.config["optimizer_type"]]
def _load_config(self):
with open(self.config_path, "rb") as model_yaml:
_config = self.yaml.load(model_yaml)
return _config
@staticmethod
def _get_git_hash():
try:
return (
subprocess.check_output(["git", "describe", "--always"])
.strip()
.decode()
)
except Exception as e:
print(f"WARNING: could not retrieve git hash. {e}")
def _check_hash(self):
try:
git_hash = (
subprocess.check_output(["git", "describe", "--always"])
.strip()
.decode()
)
if self.config["git_hash"] != git_hash:
print(
f"""WARNING: git hash mismatch. Current: {git_hash}.
Config hash: {self.config['git_hash']}"""
)
except Exception as e:
print(f"WARNING: could not check git hash. {e}")
@staticmethod
def _print_dict_values(values, key_name, level=0, tab_size=2):
tab = level * tab_size * " "
print(tab + "-", key_name, ":", values)
def _print_dictionary(self, dictionary, recursion_level=0):
for key in dictionary.keys():
if isinstance(key, dict):
recursion_level += 1
self._print_dictionary(dictionary[key], recursion_level)
else:
self._print_dict_values(
dictionary[key], key_name=key, level=recursion_level
)
def print_config(self):
print("\nCONFIGURATION", self.session_name)
self._print_dictionary(self.config)
def update_config(self):
self.config["git_hash"] = self._get_git_hash()
def dump_config(self):
self.update_config()
_config = {}
for key, val in self.config.items():
if isinstance(val, Enum):
_config[key] = val.name
else:
_config[key] = val
with open(self.base_dir / "config.yml", "w") as model_yaml:
self.yaml.dump(_config, model_yaml)
def create_remove_dirs(
self,
clear_dir: bool = False,
clear_logs: bool = False,
clear_weights: bool = False,
clear_all: bool = False,
):
self.base_dir.mkdir(exist_ok=True, parents=True)
self.plot_dir.mkdir(exist_ok=True)
self.prediction_dir.mkdir(exist_ok=True)
if clear_dir:
delete = input(f"Delete {self.log_dir} AND {self.models_dir}? (y/[n])")
if delete == "y":
shutil.rmtree(self.log_dir, ignore_errors=True)
shutil.rmtree(self.models_dir, ignore_errors=True)
if clear_logs:
delete = input(f"Delete {self.log_dir}? (y/[n])")
if delete == "y":
shutil.rmtree(self.log_dir, ignore_errors=True)
if clear_weights:
delete = input(f"Delete {self.models_dir}? (y/[n])")
if delete == "y":
shutil.rmtree(self.models_dir, ignore_errors=True)
self.log_dir.mkdir(exist_ok=True)
self.models_dir.mkdir(exist_ok=True)
def get_last_model_path(self):
"""
Given a checkpoint, get the last save model name
Args:
checkpoint (str): the path where models are saved
"""
models = os.listdir(self.models_dir)
models = [model for model in models if model[-3:] == ".pt"]
if len(models) == 0:
return None
_max = max(int(m.split(".")[0].split("-")[0]) for m in models)
model_name = f"{_max}-snapshot.pt"
last_model_path = os.path.join(self.models_dir, model_name)
return last_model_path
def load_model(self, model_path: str = None):
"""
loading a model from path
Args:
checkpoint (str): the path to the model
name (str): the name of the model, which is in the path
model (Tacotron): the model to load its save state
optimizer: the optimizer to load its saved state
"""
model = self.get_model()
with open(self.base_dir / f"{self.model_kind}_network.txt", "w") as file:
file.write(str(model))
if model_path is None:
last_model_path = self.get_last_model_path()
if last_model_path is None:
return model, 1
else:
last_model_path = model_path
saved_model = torch.load(last_model_path)
out = model.load_state_dict(saved_model["model_state_dict"])
print(out)
global_step = saved_model["global_step"] + 1
return model, global_step
def get_model(self, ignore_hash=False):
if not ignore_hash:
self._check_hash()
if self.model_kind == "cbhg":
return self.get_cbhg()
elif self.model_kind == "seq2seq":
return self.get_seq2seq()
elif self.model_kind == "tacotron_based":
return self.get_tacotron_based()
elif self.model_kind == "baseline":
return self.get_baseline()
elif self.model_kind == "gpt":
return self.get_gpt()
def get_gpt(self):
model = GPTModel(
self.config["base_model_path"],
freeze=self.config["freeze"],
n_layer=self.config["n_layer"],
use_lstm=self.config["use_lstm"],
)
return model
def get_baseline(self):
model = BaseLineModel(
embedding_dim=self.config["embedding_dim"],
inp_vocab_size=self.config["len_input_symbols"],
targ_vocab_size=self.config["len_target_symbols"],
layers_units=self.config["layers_units"],
use_batch_norm=self.config["use_batch_norm"],
)
return model
def get_cbhg(self):
model = CBHGModel(
embedding_dim=self.config["embedding_dim"],
inp_vocab_size=self.config["len_input_symbols"],
targ_vocab_size=self.config["len_target_symbols"],
use_prenet=self.config["use_prenet"],
prenet_sizes=self.config["prenet_sizes"],
cbhg_gru_units=self.config["cbhg_gru_units"],
cbhg_filters=self.config["cbhg_filters"],
cbhg_projections=self.config["cbhg_projections"],
post_cbhg_layers_units=self.config["post_cbhg_layers_units"],
post_cbhg_use_batch_norm=self.config["post_cbhg_use_batch_norm"],
)
return model
def get_seq2seq(self):
encoder = Seq2SeqEncoder(
embedding_dim=self.config["encoder_embedding_dim"],
inp_vocab_size=self.config["len_input_symbols"],
layers_units=self.config["encoder_units"],
use_batch_norm=self.config["use_batch_norm"],
)
decoder = TacotronDecoder(
self.config["len_target_symbols"],
start_symbol_id=self.text_encoder.start_symbol_id,
embedding_dim=self.config["decoder_embedding_dim"],
encoder_dim=self.config["encoder_dim"],
decoder_units=self.config["decoder_units"],
decoder_layers=self.config["decoder_layers"],
attention_type=self.config["attention_type"],
attention_units=self.config["attention_units"],
is_attention_accumulative=self.config["is_attention_accumulative"],
use_prenet=self.config["use_decoder_prenet"],
prenet_depth=self.config["decoder_prenet_depth"],
teacher_forcing_probability=self.config["teacher_forcing_probability"],
)
model = Tacotron(encoder=encoder, decoder=decoder)
return model
def get_tacotron_based(self):
encoder = TacotronEncoder(
embedding_dim=self.config["encoder_embedding_dim"],
inp_vocab_size=self.config["len_input_symbols"],
prenet_sizes=self.config["prenet_sizes"],
use_prenet=self.config["use_encoder_prenet"],
cbhg_gru_units=self.config["cbhg_gru_units"],
cbhg_filters=self.config["cbhg_filters"],
cbhg_projections=self.config["cbhg_projections"],
)
decoder = TacotronDecoder(
self.config["len_target_symbols"],
start_symbol_id=self.text_encoder.start_symbol_id,
embedding_dim=self.config["decoder_embedding_dim"],
encoder_dim=self.config["encoder_dim"],
decoder_units=self.config["decoder_units"],
decoder_layers=self.config["decoder_layers"],
attention_type=self.config["attention_type"],
attention_units=self.config["attention_units"],
is_attention_accumulative=self.config["is_attention_accumulative"],
use_prenet=self.config["use_decoder_prenet"],
prenet_depth=self.config["decoder_prenet_depth"],
teacher_forcing_probability=self.config["teacher_forcing_probability"],
)
model = Tacotron(encoder=encoder, decoder=decoder)
return model
def get_text_encoder(self):
"""Getting the class of TextEncoder from config"""
if self.config["text_cleaner"] not in [
"basic_cleaners",
"valid_arabic_cleaners",
None,
]:
raise Exception(f"cleaner is not known {self.config['text_cleaner']}")
if self.config["text_encoder"] == "BasicArabicEncoder":
text_encoder = BasicArabicEncoder(
cleaner_fn=self.config["text_cleaner"], sp_model_path=self.sp_model_path
)
elif self.config["text_encoder"] == "ArabicEncoderWithStartSymbol":
text_encoder = ArabicEncoderWithStartSymbol(
cleaner_fn=self.config["text_cleaner"], sp_model_path=self.sp_model_path
)
else:
raise Exception(
f"the text encoder is not found {self.config['text_encoder']}"
)
return text_encoder
def get_loss_type(self):
try:
loss_type = LossType[self.config["loss_type"]]
except:
raise Exception(f"The loss type is not correct {self.config['loss_type']}")
return loss_type
if __name__ == "__main__":
config_path = "config/tacotron-base-config.yml"
model_kind = "tacotron"
config = ConfigManager(config_path=config_path, model_kind=model_kind)