LlTRA-model / configuration.py
Esmail Atta Gumaan
Upload 7 files
cd89176 verified
raw
history blame
1.07 kB
from pathlib import Path
def Get_configuration():
return {
"batch_size": 8,
"num_epochs": 20,
"lr": 10**-4,
"sequence_length": 100,
"d_model": 512,
"datasource": 'opus_infopankki',
"source_language": "ar",
"target_language": "en",
"model_folder": "weights",
"model_basename": "tmodel_",
"preload": "latest",
"tokenizer_file": "tokenizer_{0}.json",
"experiment_name": "runs/tmodel"
}
def Get_weights_file_path(config, epoch: str):
model_folder = f"{config['datasource']}_{config['model_folder']}"
model_filename = f"{config['model_basename']}{epoch}.pt"
return str(Path('.') / model_folder / model_filename)
def latest_weights_file_path(config):
model_folder = f"{config['datasource']}_{config['model_folder']}"
model_filename = f"{config['model_basename']}*"
weights_files = list(Path(model_folder).glob(model_filename))
if len(weights_files) == 0:
return None
weights_files.sort()
return str(weights_files[-1])