RxnIM / molscribe /utils.py
CYF200127's picture
Upload 116 files
5e9bd47 verified
import os
import random
import numpy as np
import torch
import math
import time
import datetime
import json
from json import encoder
FORMAT_INFO = {
"inchi": {
"name": "InChI_text",
"tokenizer": "tokenizer_inchi.json",
"max_len": 300
},
"atomtok": {
"name": "SMILES_atomtok",
"tokenizer": "tokenizer_smiles_atomtok.json",
"max_len": 256
},
"nodes": {"max_len": 384},
"atomtok_coords": {"max_len": 480},
"chartok_coords": {"max_len": 480}
}
def init_logger(log_file='train.log'):
from logging import getLogger, INFO, FileHandler, Formatter, StreamHandler
logger = getLogger(__name__)
logger.setLevel(INFO)
handler1 = StreamHandler()
handler1.setFormatter(Formatter("%(message)s"))
handler2 = FileHandler(filename=log_file)
handler2.setFormatter(Formatter("%(message)s"))
logger.addHandler(handler1)
logger.addHandler(handler2)
return logger
def init_summary_writer(save_path):
from tensorboardX import SummaryWriter
summary = SummaryWriter(save_path)
return summary
def save_args(args):
dt = datetime.datetime.strftime(datetime.datetime.now(), "%y%m%d-%H%M")
path = os.path.join(args.save_path, f'train_{dt}.log')
with open(path, 'w') as f:
for k, v in vars(args).items():
f.write(f"**** {k} = *{v}*\n")
return
def seed_torch(seed=42):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class EpochMeter(AverageMeter):
def __init__(self):
super().__init__()
self.epoch = AverageMeter()
def update(self, val, n=1):
super().update(val, n)
self.epoch.update(val, n)
class LossMeter(EpochMeter):
def __init__(self):
self.subs = {}
super().__init__()
def reset(self):
super().reset()
for k in self.subs:
self.subs[k].reset()
def update(self, loss, losses, n=1):
loss = loss.item()
super().update(loss, n)
losses = {k: v.item() for k, v in losses.items()}
for k, v in losses.items():
if k not in self.subs:
self.subs[k] = EpochMeter()
self.subs[k].update(v, n)
def asMinutes(s):
m = math.floor(s / 60)
s -= m * 60
return '%dm %ds' % (m, s)
def timeSince(since, percent):
now = time.time()
s = now - since
es = s / (percent)
rs = es - s
return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))
def print_rank_0(message):
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)
def to_device(data, device):
if torch.is_tensor(data):
return data.to(device)
if type(data) is list:
return [to_device(v, device) for v in data]
if type(data) is dict:
return {k: to_device(v, device) for k, v in data.items()}
def round_floats(o):
if isinstance(o, float):
return round(o, 3)
if isinstance(o, dict):
return {k: round_floats(v) for k, v in o.items()}
if isinstance(o, (list, tuple)):
return [round_floats(x) for x in o]
return o
def format_df(df):
def _dumps(obj):
if obj is None:
return obj
return json.dumps(round_floats(obj)).replace(" ", "")
for field in ['node_coords', 'node_symbols', 'edges']:
if field in df.columns:
df[field] = [_dumps(obj) for obj in df[field]]
return df