Spaces:
Sleeping
Sleeping
from enum import Enum | |
import os | |
from pathlib import Path | |
import shutil | |
import subprocess | |
from typing import Any, Dict | |
import ruamel.yaml | |
import torch | |
from poetry_diacritizer.models.baseline import BaseLineModel | |
from poetry_diacritizer.models.cbhg import CBHGModel | |
from poetry_diacritizer.models.gpt import GPTModel | |
from poetry_diacritizer.models.seq2seq import Decoder as Seq2SeqDecoder, Encoder as Seq2SeqEncoder, Seq2Seq | |
from poetry_diacritizer.models.tacotron_based import ( | |
Decoder as TacotronDecoder, | |
Encoder as TacotronEncoder, | |
Tacotron, | |
) | |
from poetry_diacritizer.options import AttentionType, LossType, OptimizerType | |
from poetry_diacritizer.util.text_encoders import ( | |
ArabicEncoderWithStartSymbol, | |
BasicArabicEncoder, | |
TextEncoder, | |
) | |
class ConfigManager: | |
"""Co/home/almodhfer/Projects/daicritization/temp_results/CA_MSA/cbhg-new/model-10.ptnfig Manager""" | |
def __init__(self, config_path: str, model_kind: str): | |
available_models = ["baseline", "cbhg", "seq2seq", "tacotron_based", "gpt"] | |
if model_kind not in available_models: | |
raise TypeError(f"model_kind must be in {available_models}") | |
self.config_path = Path(config_path) | |
self.model_kind = model_kind | |
self.yaml = ruamel.yaml.YAML() | |
self.config: Dict[str, Any] = self._load_config() | |
self.git_hash = self._get_git_hash() | |
self.session_name = ".".join( | |
[ | |
self.config["data_type"], | |
self.config["session_name"], | |
f"{model_kind}", | |
] | |
) | |
self.data_dir = Path( | |
os.path.join(self.config["data_directory"], self.config["data_type"]) | |
) | |
self.base_dir = Path( | |
os.path.join(self.config["log_directory"], self.session_name) | |
) | |
self.log_dir = Path(os.path.join(self.base_dir, "logs")) | |
self.prediction_dir = Path(os.path.join(self.base_dir, "predictions")) | |
self.plot_dir = Path(os.path.join(self.base_dir, "plots")) | |
self.models_dir = Path(os.path.join(self.base_dir, "models")) | |
if "sp_model_path" in self.config: | |
self.sp_model_path = self.config["sp_model_path"] | |
else: | |
self.sp_model_path = None | |
self.text_encoder: TextEncoder = self.get_text_encoder() | |
self.config["len_input_symbols"] = len(self.text_encoder.input_symbols) | |
self.config["len_target_symbols"] = len(self.text_encoder.target_symbols) | |
if self.model_kind in ["seq2seq", "tacotron_based"]: | |
self.config["attention_type"] = AttentionType[self.config["attention_type"]] | |
self.config["optimizer"] = OptimizerType[self.config["optimizer_type"]] | |
def _load_config(self): | |
with open(self.config_path, "rb") as model_yaml: | |
_config = self.yaml.load(model_yaml) | |
return _config | |
def _get_git_hash(): | |
try: | |
return ( | |
subprocess.check_output(["git", "describe", "--always"]) | |
.strip() | |
.decode() | |
) | |
except Exception as e: | |
print(f"WARNING: could not retrieve git hash. {e}") | |
def _check_hash(self): | |
try: | |
git_hash = ( | |
subprocess.check_output(["git", "describe", "--always"]) | |
.strip() | |
.decode() | |
) | |
if self.config["git_hash"] != git_hash: | |
print( | |
f"""WARNING: git hash mismatch. Current: {git_hash}. | |
Config hash: {self.config['git_hash']}""" | |
) | |
except Exception as e: | |
print(f"WARNING: could not check git hash. {e}") | |
def _print_dict_values(values, key_name, level=0, tab_size=2): | |
tab = level * tab_size * " " | |
print(tab + "-", key_name, ":", values) | |
def _print_dictionary(self, dictionary, recursion_level=0): | |
for key in dictionary.keys(): | |
if isinstance(key, dict): | |
recursion_level += 1 | |
self._print_dictionary(dictionary[key], recursion_level) | |
else: | |
self._print_dict_values( | |
dictionary[key], key_name=key, level=recursion_level | |
) | |
def print_config(self): | |
print("\nCONFIGURATION", self.session_name) | |
self._print_dictionary(self.config) | |
def update_config(self): | |
self.config["git_hash"] = self._get_git_hash() | |
def dump_config(self): | |
self.update_config() | |
_config = {} | |
for key, val in self.config.items(): | |
if isinstance(val, Enum): | |
_config[key] = val.name | |
else: | |
_config[key] = val | |
with open(self.base_dir / "config.yml", "w") as model_yaml: | |
self.yaml.dump(_config, model_yaml) | |
def create_remove_dirs( | |
self, | |
clear_dir: bool = False, | |
clear_logs: bool = False, | |
clear_weights: bool = False, | |
clear_all: bool = False, | |
): | |
self.base_dir.mkdir(exist_ok=True, parents=True) | |
self.plot_dir.mkdir(exist_ok=True) | |
self.prediction_dir.mkdir(exist_ok=True) | |
if clear_dir: | |
delete = input(f"Delete {self.log_dir} AND {self.models_dir}? (y/[n])") | |
if delete == "y": | |
shutil.rmtree(self.log_dir, ignore_errors=True) | |
shutil.rmtree(self.models_dir, ignore_errors=True) | |
if clear_logs: | |
delete = input(f"Delete {self.log_dir}? (y/[n])") | |
if delete == "y": | |
shutil.rmtree(self.log_dir, ignore_errors=True) | |
if clear_weights: | |
delete = input(f"Delete {self.models_dir}? (y/[n])") | |
if delete == "y": | |
shutil.rmtree(self.models_dir, ignore_errors=True) | |
self.log_dir.mkdir(exist_ok=True) | |
self.models_dir.mkdir(exist_ok=True) | |
def get_last_model_path(self): | |
""" | |
Given a checkpoint, get the last save model name | |
Args: | |
checkpoint (str): the path where models are saved | |
""" | |
models = os.listdir(self.models_dir) | |
models = [model for model in models if model[-3:] == ".pt"] | |
if len(models) == 0: | |
return None | |
_max = max(int(m.split(".")[0].split("-")[0]) for m in models) | |
model_name = f"{_max}-snapshot.pt" | |
last_model_path = os.path.join(self.models_dir, model_name) | |
return last_model_path | |
def load_model(self, model_path: str = None): | |
""" | |
loading a model from path | |
Args: | |
checkpoint (str): the path to the model | |
name (str): the name of the model, which is in the path | |
model (Tacotron): the model to load its save state | |
optimizer: the optimizer to load its saved state | |
""" | |
model = self.get_model() | |
with open(self.base_dir / f"{self.model_kind}_network.txt", "w") as file: | |
file.write(str(model)) | |
if model_path is None: | |
last_model_path = self.get_last_model_path() | |
if last_model_path is None: | |
return model, 1 | |
else: | |
last_model_path = model_path | |
saved_model = torch.load(last_model_path) | |
out = model.load_state_dict(saved_model["model_state_dict"]) | |
print(out) | |
global_step = saved_model["global_step"] + 1 | |
return model, global_step | |
def get_model(self, ignore_hash=False): | |
if not ignore_hash: | |
self._check_hash() | |
if self.model_kind == "cbhg": | |
return self.get_cbhg() | |
elif self.model_kind == "seq2seq": | |
return self.get_seq2seq() | |
elif self.model_kind == "tacotron_based": | |
return self.get_tacotron_based() | |
elif self.model_kind == "baseline": | |
return self.get_baseline() | |
elif self.model_kind == "gpt": | |
return self.get_gpt() | |
def get_gpt(self): | |
model = GPTModel( | |
self.config["base_model_path"], | |
freeze=self.config["freeze"], | |
n_layer=self.config["n_layer"], | |
use_lstm=self.config["use_lstm"], | |
) | |
return model | |
def get_baseline(self): | |
model = BaseLineModel( | |
embedding_dim=self.config["embedding_dim"], | |
inp_vocab_size=self.config["len_input_symbols"], | |
targ_vocab_size=self.config["len_target_symbols"], | |
layers_units=self.config["layers_units"], | |
use_batch_norm=self.config["use_batch_norm"], | |
) | |
return model | |
def get_cbhg(self): | |
model = CBHGModel( | |
embedding_dim=self.config["embedding_dim"], | |
inp_vocab_size=self.config["len_input_symbols"], | |
targ_vocab_size=self.config["len_target_symbols"], | |
use_prenet=self.config["use_prenet"], | |
prenet_sizes=self.config["prenet_sizes"], | |
cbhg_gru_units=self.config["cbhg_gru_units"], | |
cbhg_filters=self.config["cbhg_filters"], | |
cbhg_projections=self.config["cbhg_projections"], | |
post_cbhg_layers_units=self.config["post_cbhg_layers_units"], | |
post_cbhg_use_batch_norm=self.config["post_cbhg_use_batch_norm"], | |
) | |
return model | |
def get_seq2seq(self): | |
encoder = Seq2SeqEncoder( | |
embedding_dim=self.config["encoder_embedding_dim"], | |
inp_vocab_size=self.config["len_input_symbols"], | |
layers_units=self.config["encoder_units"], | |
use_batch_norm=self.config["use_batch_norm"], | |
) | |
decoder = TacotronDecoder( | |
self.config["len_target_symbols"], | |
start_symbol_id=self.text_encoder.start_symbol_id, | |
embedding_dim=self.config["decoder_embedding_dim"], | |
encoder_dim=self.config["encoder_dim"], | |
decoder_units=self.config["decoder_units"], | |
decoder_layers=self.config["decoder_layers"], | |
attention_type=self.config["attention_type"], | |
attention_units=self.config["attention_units"], | |
is_attention_accumulative=self.config["is_attention_accumulative"], | |
use_prenet=self.config["use_decoder_prenet"], | |
prenet_depth=self.config["decoder_prenet_depth"], | |
teacher_forcing_probability=self.config["teacher_forcing_probability"], | |
) | |
model = Tacotron(encoder=encoder, decoder=decoder) | |
return model | |
def get_tacotron_based(self): | |
encoder = TacotronEncoder( | |
embedding_dim=self.config["encoder_embedding_dim"], | |
inp_vocab_size=self.config["len_input_symbols"], | |
prenet_sizes=self.config["prenet_sizes"], | |
use_prenet=self.config["use_encoder_prenet"], | |
cbhg_gru_units=self.config["cbhg_gru_units"], | |
cbhg_filters=self.config["cbhg_filters"], | |
cbhg_projections=self.config["cbhg_projections"], | |
) | |
decoder = TacotronDecoder( | |
self.config["len_target_symbols"], | |
start_symbol_id=self.text_encoder.start_symbol_id, | |
embedding_dim=self.config["decoder_embedding_dim"], | |
encoder_dim=self.config["encoder_dim"], | |
decoder_units=self.config["decoder_units"], | |
decoder_layers=self.config["decoder_layers"], | |
attention_type=self.config["attention_type"], | |
attention_units=self.config["attention_units"], | |
is_attention_accumulative=self.config["is_attention_accumulative"], | |
use_prenet=self.config["use_decoder_prenet"], | |
prenet_depth=self.config["decoder_prenet_depth"], | |
teacher_forcing_probability=self.config["teacher_forcing_probability"], | |
) | |
model = Tacotron(encoder=encoder, decoder=decoder) | |
return model | |
def get_text_encoder(self): | |
"""Getting the class of TextEncoder from config""" | |
if self.config["text_cleaner"] not in [ | |
"basic_cleaners", | |
"valid_arabic_cleaners", | |
None, | |
]: | |
raise Exception(f"cleaner is not known {self.config['text_cleaner']}") | |
if self.config["text_encoder"] == "BasicArabicEncoder": | |
text_encoder = BasicArabicEncoder( | |
cleaner_fn=self.config["text_cleaner"], sp_model_path=self.sp_model_path | |
) | |
elif self.config["text_encoder"] == "ArabicEncoderWithStartSymbol": | |
text_encoder = ArabicEncoderWithStartSymbol( | |
cleaner_fn=self.config["text_cleaner"], sp_model_path=self.sp_model_path | |
) | |
else: | |
raise Exception( | |
f"the text encoder is not found {self.config['text_encoder']}" | |
) | |
return text_encoder | |
def get_loss_type(self): | |
try: | |
loss_type = LossType[self.config["loss_type"]] | |
except: | |
raise Exception(f"The loss type is not correct {self.config['loss_type']}") | |
return loss_type | |
if __name__ == "__main__": | |
config_path = "config/tacotron-base-config.yml" | |
model_kind = "tacotron" | |
config = ConfigManager(config_path=config_path, model_kind=model_kind) | |