|
import os |
|
import random |
|
import numpy as np |
|
import torch |
|
import math |
|
import time |
|
import datetime |
|
import json |
|
from json import encoder |
|
|
|
|
|
FORMAT_INFO = { |
|
"inchi": { |
|
"name": "InChI_text", |
|
"tokenizer": "tokenizer_inchi.json", |
|
"max_len": 300 |
|
}, |
|
"atomtok": { |
|
"name": "SMILES_atomtok", |
|
"tokenizer": "tokenizer_smiles_atomtok.json", |
|
"max_len": 256 |
|
}, |
|
"nodes": {"max_len": 384}, |
|
"atomtok_coords": {"max_len": 480}, |
|
"chartok_coords": {"max_len": 480} |
|
} |
|
|
|
|
|
def init_logger(log_file='train.log'): |
|
from logging import getLogger, INFO, FileHandler, Formatter, StreamHandler |
|
logger = getLogger(__name__) |
|
logger.setLevel(INFO) |
|
handler1 = StreamHandler() |
|
handler1.setFormatter(Formatter("%(message)s")) |
|
handler2 = FileHandler(filename=log_file) |
|
handler2.setFormatter(Formatter("%(message)s")) |
|
logger.addHandler(handler1) |
|
logger.addHandler(handler2) |
|
return logger |
|
|
|
|
|
def init_summary_writer(save_path): |
|
from tensorboardX import SummaryWriter |
|
summary = SummaryWriter(save_path) |
|
return summary |
|
|
|
|
|
def save_args(args): |
|
dt = datetime.datetime.strftime(datetime.datetime.now(), "%y%m%d-%H%M") |
|
path = os.path.join(args.save_path, f'train_{dt}.log') |
|
with open(path, 'w') as f: |
|
for k, v in vars(args).items(): |
|
f.write(f"**** {k} = *{v}*\n") |
|
return |
|
|
|
|
|
def seed_torch(seed=42): |
|
random.seed(seed) |
|
os.environ['PYTHONHASHSEED'] = str(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
torch.backends.cudnn.deterministic = True |
|
|
|
|
|
class AverageMeter(object): |
|
"""Computes and stores the average and current value""" |
|
def __init__(self): |
|
self.reset() |
|
|
|
def reset(self): |
|
self.val = 0 |
|
self.avg = 0 |
|
self.sum = 0 |
|
self.count = 0 |
|
|
|
def update(self, val, n=1): |
|
self.val = val |
|
self.sum += val * n |
|
self.count += n |
|
self.avg = self.sum / self.count |
|
|
|
|
|
class EpochMeter(AverageMeter): |
|
def __init__(self): |
|
super().__init__() |
|
self.epoch = AverageMeter() |
|
|
|
def update(self, val, n=1): |
|
super().update(val, n) |
|
self.epoch.update(val, n) |
|
|
|
|
|
class LossMeter(EpochMeter): |
|
def __init__(self): |
|
self.subs = {} |
|
super().__init__() |
|
|
|
def reset(self): |
|
super().reset() |
|
for k in self.subs: |
|
self.subs[k].reset() |
|
|
|
def update(self, loss, losses, n=1): |
|
loss = loss.item() |
|
super().update(loss, n) |
|
losses = {k: v.item() for k, v in losses.items()} |
|
for k, v in losses.items(): |
|
if k not in self.subs: |
|
self.subs[k] = EpochMeter() |
|
self.subs[k].update(v, n) |
|
|
|
|
|
def asMinutes(s): |
|
m = math.floor(s / 60) |
|
s -= m * 60 |
|
return '%dm %ds' % (m, s) |
|
|
|
|
|
def timeSince(since, percent): |
|
now = time.time() |
|
s = now - since |
|
es = s / (percent) |
|
rs = es - s |
|
return '%s (remain %s)' % (asMinutes(s), asMinutes(rs)) |
|
|
|
|
|
def print_rank_0(message): |
|
if torch.distributed.is_initialized(): |
|
if torch.distributed.get_rank() == 0: |
|
print(message, flush=True) |
|
else: |
|
print(message, flush=True) |
|
|
|
|
|
def to_device(data, device): |
|
if torch.is_tensor(data): |
|
return data.to(device) |
|
if type(data) is list: |
|
return [to_device(v, device) for v in data] |
|
if type(data) is dict: |
|
return {k: to_device(v, device) for k, v in data.items()} |
|
|
|
|
|
def round_floats(o): |
|
if isinstance(o, float): |
|
return round(o, 3) |
|
if isinstance(o, dict): |
|
return {k: round_floats(v) for k, v in o.items()} |
|
if isinstance(o, (list, tuple)): |
|
return [round_floats(x) for x in o] |
|
return o |
|
|
|
|
|
def format_df(df): |
|
def _dumps(obj): |
|
if obj is None: |
|
return obj |
|
return json.dumps(round_floats(obj)).replace(" ", "") |
|
for field in ['node_coords', 'node_symbols', 'edges']: |
|
if field in df.columns: |
|
df[field] = [_dumps(obj) for obj in df[field]] |
|
return df |
|
|