import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import timm from .utils import FORMAT_INFO, to_device from .tokenizer import SOS_ID, EOS_ID, PAD_ID, MASK_ID from .inference import GreedySearch, BeamSearch from .transformer import TransformerDecoder, Embeddings class Encoder(nn.Module): def __init__(self, args, pretrained=False): super().__init__() model_name = args.encoder self.model_name = model_name if model_name.startswith('resnet'): self.model_type = 'resnet' self.cnn = timm.create_model(model_name, pretrained=pretrained) self.n_features = self.cnn.num_features # encoder_dim self.cnn.global_pool = nn.Identity() self.cnn.fc = nn.Identity() elif model_name.startswith('swin'): self.model_type = 'swin' self.transformer = timm.create_model(model_name, pretrained=pretrained, pretrained_strict=False, use_checkpoint=args.use_checkpoint) self.n_features = self.transformer.num_features self.transformer.head = nn.Identity() elif 'efficientnet' in model_name: self.model_type = 'efficientnet' self.cnn = timm.create_model(model_name, pretrained=pretrained) self.n_features = self.cnn.num_features self.cnn.global_pool = nn.Identity() self.cnn.classifier = nn.Identity() else: raise NotImplemented def swin_forward(self, transformer, x): x = transformer.patch_embed(x) if transformer.absolute_pos_embed is not None: x = x + transformer.absolute_pos_embed x = transformer.pos_drop(x) def layer_forward(layer, x, hiddens): for blk in layer.blocks: if not torch.jit.is_scripting() and layer.use_checkpoint: x = torch.utils.checkpoint.checkpoint(blk, x) else: x = blk(x) H, W = layer.input_resolution B, L, C = x.shape hiddens.append(x.view(B, H, W, C)) if layer.downsample is not None: x = layer.downsample(x) return x, hiddens hiddens = [] for layer in transformer.layers: x, hiddens = layer_forward(layer, x, hiddens) x = transformer.norm(x) # B L C hiddens[-1] = x.view_as(hiddens[-1]) return x, hiddens def forward(self, x, refs=None): if self.model_type in ['resnet', 'efficientnet']: features = self.cnn(x) features = features.permute(0, 2, 3, 1) hiddens = [] elif self.model_type == 'swin': if 'patch' in self.model_name: features, hiddens = self.swin_forward(self.transformer, x) else: features, hiddens = self.transformer(x) else: raise NotImplemented return features, hiddens class TransformerDecoderBase(nn.Module): def __init__(self, args): super().__init__() self.args = args self.enc_trans_layer = nn.Sequential( nn.Linear(args.encoder_dim, args.dec_hidden_size) # nn.LayerNorm(args.dec_hidden_size, eps=1e-6) ) self.enc_pos_emb = nn.Embedding(144, args.encoder_dim) if args.enc_pos_emb else None self.decoder = TransformerDecoder( num_layers=args.dec_num_layers, d_model=args.dec_hidden_size, heads=args.dec_attn_heads, d_ff=args.dec_hidden_size * 4, copy_attn=False, self_attn_type="scaled-dot", dropout=args.hidden_dropout, attention_dropout=args.attn_dropout, max_relative_positions=args.max_relative_positions, aan_useffn=False, full_context_alignment=False, alignment_layer=0, alignment_heads=0, pos_ffn_activation_fn='gelu' ) def enc_transform(self, encoder_out): batch_size = encoder_out.size(0) encoder_dim = encoder_out.size(-1) encoder_out = encoder_out.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim) max_len = encoder_out.size(1) device = encoder_out.device if self.enc_pos_emb: pos_emb = self.enc_pos_emb(torch.arange(max_len, device=device)).unsqueeze(0) encoder_out = encoder_out + pos_emb encoder_out = self.enc_trans_layer(encoder_out) return encoder_out class TransformerDecoderAR(TransformerDecoderBase): """Autoregressive Transformer Decoder""" def __init__(self, args, tokenizer): super().__init__(args) self.tokenizer = tokenizer self.vocab_size = len(self.tokenizer) self.output_layer = nn.Linear(args.dec_hidden_size, self.vocab_size, bias=True) self.embeddings = Embeddings( word_vec_size=args.dec_hidden_size, word_vocab_size=self.vocab_size, word_padding_idx=PAD_ID, position_encoding=True, dropout=args.hidden_dropout) def dec_embedding(self, tgt, step=None): pad_idx = self.embeddings.word_padding_idx tgt_pad_mask = tgt.data.eq(pad_idx).transpose(1, 2) # [B, 1, T_tgt] emb = self.embeddings(tgt, step=step) assert emb.dim() == 3 # batch x len x embedding_dim return emb, tgt_pad_mask def forward(self, encoder_out, labels, label_lengths): """Training mode""" batch_size, max_len, _ = encoder_out.size() memory_bank = self.enc_transform(encoder_out) tgt = labels.unsqueeze(-1) # (b, t, 1) tgt_emb, tgt_pad_mask = self.dec_embedding(tgt) dec_out, *_ = self.decoder(tgt_emb=tgt_emb, memory_bank=memory_bank, tgt_pad_mask=tgt_pad_mask) logits = self.output_layer(dec_out) # (b, t, h) -> (b, t, v) return logits[:, :-1], labels[:, 1:], dec_out def decode(self, encoder_out, beam_size: int, n_best: int, min_length: int = 1, max_length: int = 256, labels=None): """Inference mode. Autoregressively decode the sequence. Only greedy search is supported now. Beam search is out-dated. The labels is used for partial prediction, i.e. part of the sequence is given. In standard decoding, labels=None.""" batch_size, max_len, _ = encoder_out.size() memory_bank = self.enc_transform(encoder_out) orig_labels = labels if beam_size == 1: decode_strategy = GreedySearch( sampling_temp=0.0, keep_topk=1, batch_size=batch_size, min_length=min_length, max_length=max_length, pad=PAD_ID, bos=SOS_ID, eos=EOS_ID, return_attention=False, return_hidden=True) else: decode_strategy = BeamSearch( beam_size=beam_size, n_best=n_best, batch_size=batch_size, min_length=min_length, max_length=max_length, pad=PAD_ID, bos=SOS_ID, eos=EOS_ID, return_attention=False) # adapted from onmt.translate.translator results = { "predictions": None, "scores": None, "attention": None } # (2) prep decode_strategy. Possibly repeat src objects. _, memory_bank = decode_strategy.initialize(memory_bank=memory_bank) # (3) Begin decoding step by step: for step in range(decode_strategy.max_length): tgt = decode_strategy.current_predictions.view(-1, 1, 1) if labels is not None: label = labels[:, step].view(-1, 1, 1) mask = label.eq(MASK_ID).long() tgt = tgt * mask + label * (1 - mask) tgt_emb, tgt_pad_mask = self.dec_embedding(tgt) dec_out, dec_attn, *_ = self.decoder(tgt_emb=tgt_emb, memory_bank=memory_bank, tgt_pad_mask=tgt_pad_mask, step=step) attn = dec_attn.get("std", None) dec_logits = self.output_layer(dec_out) # [b, t, h] => [b, t, v] dec_logits = dec_logits.squeeze(1) log_probs = F.log_softmax(dec_logits, dim=-1) if self.tokenizer.output_constraint: output_mask = [self.tokenizer.get_output_mask(id) for id in tgt.view(-1).tolist()] output_mask = torch.tensor(output_mask, device=log_probs.device) log_probs.masked_fill_(output_mask, -10000) label = labels[:, step + 1] if labels is not None and step + 1 < labels.size(1) else None decode_strategy.advance(log_probs, attn, dec_out, label) any_finished = decode_strategy.is_finished.any() if any_finished: decode_strategy.update_finished() if decode_strategy.done: break select_indices = decode_strategy.select_indices if any_finished: # Reorder states. memory_bank = memory_bank.index_select(0, select_indices) if labels is not None: labels = labels.index_select(0, select_indices) self.map_state(lambda state, dim: state.index_select(dim, select_indices)) results["scores"] = decode_strategy.scores # fixed to be average of token scores results["token_scores"] = decode_strategy.token_scores results["predictions"] = decode_strategy.predictions results["attention"] = decode_strategy.attention results["hidden"] = decode_strategy.hidden if orig_labels is not None: for i in range(batch_size): pred = results["predictions"][i][0] label = orig_labels[i][1:len(pred) + 1] mask = label.eq(MASK_ID).long() pred = pred[:len(label)] results["predictions"][i][0] = pred * mask + label * (1 - mask) return results["predictions"], results['scores'], results["token_scores"], results["hidden"] # adapted from onmt.decoders.transformer def map_state(self, fn): def _recursive_map(struct, batch_dim=0): for k, v in struct.items(): if v is not None: if isinstance(v, dict): _recursive_map(v) else: struct[k] = fn(v, batch_dim) if self.decoder.state["cache"] is not None: _recursive_map(self.decoder.state["cache"]) class GraphPredictor(nn.Module): def __init__(self, decoder_dim, coords=False): super(GraphPredictor, self).__init__() self.coords = coords self.mlp = nn.Sequential( nn.Linear(decoder_dim * 2, decoder_dim), nn.GELU(), nn.Linear(decoder_dim, 7) ) if coords: self.coords_mlp = nn.Sequential( nn.Linear(decoder_dim, decoder_dim), nn.GELU(), nn.Linear(decoder_dim, 2) ) def forward(self, hidden, indices=None): b, l, dim = hidden.size() if indices is None: index = [i for i in range(3, l, 3)] hidden = hidden[:, index] else: batch_id = torch.arange(b).unsqueeze(1).expand_as(indices).reshape(-1) indices = indices.view(-1) hidden = hidden[batch_id, indices].view(b, -1, dim) b, l, dim = hidden.size() results = {} hh = torch.cat([hidden.unsqueeze(2).expand(b, l, l, dim), hidden.unsqueeze(1).expand(b, l, l, dim)], dim=3) results['edges'] = self.mlp(hh).permute(0, 3, 1, 2) if self.coords: results['coords'] = self.coords_mlp(hidden) return results def get_edge_prediction(edge_prob): if not edge_prob: return [], [] n = len(edge_prob) if n == 0: return [], [] for i in range(n): for j in range(i + 1, n): for k in range(5): edge_prob[i][j][k] = (edge_prob[i][j][k] + edge_prob[j][i][k]) / 2 edge_prob[j][i][k] = edge_prob[i][j][k] edge_prob[i][j][5] = (edge_prob[i][j][5] + edge_prob[j][i][6]) / 2 edge_prob[i][j][6] = (edge_prob[i][j][6] + edge_prob[j][i][5]) / 2 edge_prob[j][i][5] = edge_prob[i][j][6] edge_prob[j][i][6] = edge_prob[i][j][5] prediction = np.argmax(edge_prob, axis=2).tolist() score = np.max(edge_prob, axis=2).tolist() return prediction, score class Decoder(nn.Module): """This class is a wrapper for different decoder architectures, and support multiple decoders.""" def __init__(self, args, tokenizer): super(Decoder, self).__init__() self.args = args self.formats = args.formats self.tokenizer = tokenizer decoder = {} for format_ in args.formats: if format_ == 'edges': decoder['edges'] = GraphPredictor(args.dec_hidden_size, coords=args.continuous_coords) else: decoder[format_] = TransformerDecoderAR(args, tokenizer[format_]) self.decoder = nn.ModuleDict(decoder) self.compute_confidence = args.compute_confidence def forward(self, encoder_out, hiddens, refs): """Training mode. Compute the logits with teacher forcing.""" results = {} refs = to_device(refs, encoder_out.device) for format_ in self.formats: if format_ == 'edges': if 'atomtok_coords' in results: dec_out = results['atomtok_coords'][2] predictions = self.decoder['edges'](dec_out, indices=refs['atom_indices'][0]) elif 'chartok_coords' in results: dec_out = results['chartok_coords'][2] predictions = self.decoder['edges'](dec_out, indices=refs['atom_indices'][0]) else: raise NotImplemented targets = {'edges': refs['edges']} if 'coords' in predictions: targets['coords'] = refs['coords'] results['edges'] = (predictions, targets) else: labels, label_lengths = refs[format_] results[format_] = self.decoder[format_](encoder_out, labels, label_lengths) return results def decode(self, encoder_out, hiddens=None, refs=None, beam_size=1, n_best=1): """Inference mode. Call each decoder's decode method (if required), convert the output format (e.g. token to sequence). Beam search is not supported yet.""" results = {} predictions = [] for format_ in self.formats: if format_ in ['atomtok', 'atomtok_coords', 'chartok_coords']: max_len = FORMAT_INFO[format_]['max_len'] results[format_] = self.decoder[format_].decode(encoder_out, beam_size, n_best, max_length=max_len) outputs, scores, token_scores, *_ = results[format_] beam_preds = [[self.tokenizer[format_].sequence_to_smiles(x.tolist()) for x in pred] for pred in outputs] predictions = [{format_: pred[0]} for pred in beam_preds] if self.compute_confidence: for i in range(len(predictions)): # -1: y score, -2: x score, -3: symbol score indices = np.array(predictions[i][format_]['indices']) - 3 if format_ == 'chartok_coords': atom_scores = [] for symbol, index in zip(predictions[i][format_]['symbols'], indices): atom_score = (np.prod(token_scores[i][0][index - len(symbol) + 1:index + 1]) ** (1 / len(symbol))).item() atom_scores.append(atom_score) else: atom_scores = np.array(token_scores[i][0])[indices].tolist() predictions[i][format_]['atom_scores'] = atom_scores predictions[i][format_]['average_token_score'] = scores[i][0] if format_ == 'edges': if 'atomtok_coords' in results: atom_format = 'atomtok_coords' elif 'chartok_coords' in results: atom_format = 'chartok_coords' else: raise NotImplemented dec_out = results[atom_format][3] # batch x n_best x len x dim for i in range(len(dec_out)): hidden = dec_out[i][0].unsqueeze(0) # 1 * len * dim indices = torch.LongTensor(predictions[i][atom_format]['indices']).unsqueeze(0) # 1 * k pred = self.decoder['edges'](hidden, indices) # k * k prob = F.softmax(pred['edges'].squeeze(0).permute(1, 2, 0), dim=2).tolist() # k * k * 7 edge_pred, edge_score = get_edge_prediction(prob) predictions[i]['edges'] = edge_pred if self.compute_confidence: predictions[i]['edge_scores'] = edge_score predictions[i]['edge_score_product'] = np.sqrt(np.prod(edge_score)).item() predictions[i]['overall_score'] = predictions[i][atom_format]['average_token_score'] * \ predictions[i]['edge_score_product'] predictions[i][atom_format].pop('average_token_score') predictions[i].pop('edge_score_product') return predictions