Spaces:
Configuration error
Configuration error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
# Code adapted from https://github.com/pytorch/fairseq | |
# Copyright (c) 2017-present, Facebook, Inc. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the LICENSE file in | |
# https://github.com/pytorch/fairseq. An additional grant of patent rights | |
# can be found in the PATENTS file in the same directory. | |
from collections import defaultdict, OrderedDict | |
import logging | |
import os | |
import re | |
import torch | |
import traceback | |
from torch.serialization import default_restore_location | |
def torch_persistent_save(*args, **kwargs): | |
for i in range(3): | |
try: | |
return torch.save(*args, **kwargs) | |
except Exception: | |
if i == 2: | |
logging.error(traceback.format_exc()) | |
def convert_state_dict_type(state_dict, ttype=torch.FloatTensor): | |
if isinstance(state_dict, dict): | |
cpu_dict = OrderedDict() | |
for k, v in state_dict.items(): | |
cpu_dict[k] = convert_state_dict_type(v) | |
return cpu_dict | |
elif isinstance(state_dict, list): | |
return [convert_state_dict_type(v) for v in state_dict] | |
elif torch.is_tensor(state_dict): | |
return state_dict.type(ttype) | |
else: | |
return state_dict | |
def save_state(filename, args, model, criterion, optimizer, lr_scheduler, | |
num_updates, optim_history=None, extra_state=None): | |
if optim_history is None: | |
optim_history = [] | |
if extra_state is None: | |
extra_state = {} | |
state_dict = { | |
'args': args, | |
'model': convert_state_dict_type(model.state_dict()), | |
'optimizer_history': optim_history + [ | |
{ | |
'criterion_name': criterion.__class__.__name__, | |
'optimizer_name': optimizer.__class__.__name__, | |
'lr_scheduler_state': lr_scheduler.state_dict(), | |
'num_updates': num_updates, | |
} | |
], | |
'last_optimizer_state': convert_state_dict_type(optimizer.state_dict()), | |
'extra_state': extra_state, | |
} | |
torch_persistent_save(state_dict, filename) | |
def load_model_state(filename, model): | |
if not os.path.exists(filename): | |
return None, [], None | |
state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu')) | |
state = _upgrade_state_dict(state) | |
model.upgrade_state_dict(state['model']) | |
# load model parameters | |
try: | |
model.load_state_dict(state['model'], strict=True) | |
except Exception: | |
raise Exception('Cannot load model parameters from checkpoint, ' | |
'please ensure that the architectures match') | |
return state['extra_state'], state['optimizer_history'], state['last_optimizer_state'] | |
def _upgrade_state_dict(state): | |
"""Helper for upgrading old model checkpoints.""" | |
# add optimizer_history | |
if 'optimizer_history' not in state: | |
state['optimizer_history'] = [ | |
{ | |
'criterion_name': 'CrossEntropyCriterion', | |
'best_loss': state['best_loss'], | |
}, | |
] | |
state['last_optimizer_state'] = state['optimizer'] | |
del state['optimizer'] | |
del state['best_loss'] | |
# move extra_state into sub-dictionary | |
if 'epoch' in state and 'extra_state' not in state: | |
state['extra_state'] = { | |
'epoch': state['epoch'], | |
'batch_offset': state['batch_offset'], | |
'val_loss': state['val_loss'], | |
} | |
del state['epoch'] | |
del state['batch_offset'] | |
del state['val_loss'] | |
# reduce optimizer history's memory usage (only keep the last state) | |
if 'optimizer' in state['optimizer_history'][-1]: | |
state['last_optimizer_state'] = state['optimizer_history'][-1]['optimizer'] | |
for optim_hist in state['optimizer_history']: | |
del optim_hist['optimizer'] | |
# record the optimizer class name | |
if 'optimizer_name' not in state['optimizer_history'][-1]: | |
state['optimizer_history'][-1]['optimizer_name'] = 'FairseqNAG' | |
# move best_loss into lr_scheduler_state | |
if 'lr_scheduler_state' not in state['optimizer_history'][-1]: | |
state['optimizer_history'][-1]['lr_scheduler_state'] = { | |
'best': state['optimizer_history'][-1]['best_loss'], | |
} | |
del state['optimizer_history'][-1]['best_loss'] | |
# keep track of number of updates | |
if 'num_updates' not in state['optimizer_history'][-1]: | |
state['optimizer_history'][-1]['num_updates'] = 0 | |
# old model checkpoints may not have separate source/target positions | |
if hasattr(state['args'], 'max_positions') and not hasattr(state['args'], 'max_source_positions'): | |
state['args'].max_source_positions = state['args'].max_positions | |
state['args'].max_target_positions = state['args'].max_positions | |
# use stateful training data iterator | |
if 'train_iterator' not in state['extra_state']: | |
state['extra_state']['train_iterator'] = { | |
'epoch': state['extra_state']['epoch'], | |
'iterations_in_epoch': 0, | |
} | |
return state | |
def load_ensemble_for_inference(filenames, task, model_arg_overrides=None): | |
"""Load an ensemble of models for inference. | |
model_arg_overrides allows you to pass a dictionary model_arg_overrides -- | |
{'arg_name': arg} -- to override model args that were used during model | |
training | |
""" | |
# load model architectures and weights | |
states = [] | |
for filename in filenames: | |
if not os.path.exists(filename): | |
raise IOError('Model file not found: {}'.format(filename)) | |
state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu')) | |
state = _upgrade_state_dict(state) | |
states.append(state) | |
args = states[0]['args'] | |
if model_arg_overrides is not None: | |
args = _override_model_args(args, model_arg_overrides) | |
# build ensemble | |
ensemble = [] | |
for state in states: | |
model = task.build_model(args) | |
model.upgrade_state_dict(state['model']) | |
model.load_state_dict(state['model'], strict=True) | |
ensemble.append(model) | |
return ensemble, args | |
def _override_model_args(args, model_arg_overrides): | |
# Uses model_arg_overrides {'arg_name': arg} to override model args | |
for arg_name, arg_val in model_arg_overrides.items(): | |
setattr(args, arg_name, arg_val) | |
return args | |
def move_to_cuda(sample): | |
if len(sample) == 0: | |
return {} | |
def _move_to_cuda(maybe_tensor): | |
if torch.is_tensor(maybe_tensor): | |
return maybe_tensor.cuda() | |
elif isinstance(maybe_tensor, dict): | |
return { | |
key: _move_to_cuda(value) | |
for key, value in maybe_tensor.items() | |
} | |
elif isinstance(maybe_tensor, list): | |
return [_move_to_cuda(x) for x in maybe_tensor] | |
else: | |
return maybe_tensor | |
return _move_to_cuda(sample) | |
INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0) | |
def _get_full_incremental_state_key(module_instance, key): | |
module_name = module_instance.__class__.__name__ | |
# assign a unique ID to each module instance, so that incremental state is | |
# not shared across module instances | |
if not hasattr(module_instance, '_fairseq_instance_id'): | |
INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1 | |
module_instance._fairseq_instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name] | |
return '{}.{}.{}'.format(module_name, module_instance._fairseq_instance_id, key) | |
def get_incremental_state(module, incremental_state, key): | |
"""Helper for getting incremental state for an nn.Module.""" | |
full_key = _get_full_incremental_state_key(module, key) | |
if incremental_state is None or full_key not in incremental_state: | |
return None | |
return incremental_state[full_key] | |
def set_incremental_state(module, incremental_state, key, value): | |
"""Helper for setting incremental state for an nn.Module.""" | |
if incremental_state is not None: | |
full_key = _get_full_incremental_state_key(module, key) | |
incremental_state[full_key] = value | |
def load_align_dict(replace_unk): | |
if replace_unk is None: | |
align_dict = None | |
elif isinstance(replace_unk, str): | |
# Load alignment dictionary for unknown word replacement if it was passed as an argument. | |
align_dict = {} | |
with open(replace_unk, 'r') as f: | |
for line in f: | |
cols = line.split() | |
align_dict[cols[0]] = cols[1] | |
else: | |
# No alignment dictionary provided but we still want to perform unknown word replacement by copying the | |
# original source word. | |
align_dict = {} | |
return align_dict | |
def print_embed_overlap(embed_dict, vocab_dict): | |
embed_keys = set(embed_dict.keys()) | |
vocab_keys = set(vocab_dict.symbols) | |
overlap = len(embed_keys & vocab_keys) | |
print("| Found {}/{} types in embedding file.".format(overlap, len(vocab_dict))) | |
def parse_embedding(embed_path): | |
"""Parse embedding text file into a dictionary of word and embedding tensors. | |
The first line can have vocabulary size and dimension. The following lines | |
should contain word and embedding separated by spaces. | |
Example: | |
2 5 | |
the -0.0230 -0.0264 0.0287 0.0171 0.1403 | |
at -0.0395 -0.1286 0.0275 0.0254 -0.0932 | |
""" | |
embed_dict = {} | |
with open(embed_path) as f_embed: | |
next(f_embed) # skip header | |
for line in f_embed: | |
pieces = line.rstrip().split(" ") | |
embed_dict[pieces[0]] = torch.Tensor([float(weight) for weight in pieces[1:]]) | |
return embed_dict | |
def load_embedding(embed_dict, vocab, embedding): | |
for idx in range(len(vocab)): | |
token = vocab[idx] | |
if token in embed_dict: | |
embedding.weight.data[idx] = embed_dict[token] | |
return embedding | |
def replace_unk(hypo_str, src_str, alignment, align_dict, unk): | |
from fairseq import tokenizer | |
# Tokens are strings here | |
hypo_tokens = tokenizer.tokenize_line(hypo_str) | |
# TODO: Very rare cases where the replacement is '<eos>' should be handled gracefully | |
src_tokens = tokenizer.tokenize_line(src_str) + ['<eos>'] | |
for i, ht in enumerate(hypo_tokens): | |
if ht == unk: | |
src_token = src_tokens[alignment[i]] | |
# Either take the corresponding value in the aligned dictionary or just copy the original value. | |
hypo_tokens[i] = align_dict.get(src_token, src_token) | |
return ' '.join(hypo_tokens) | |
def post_process_prediction(hypo_tokens, src_str, alignment, align_dict, tgt_dict, remove_bpe): | |
from fairseq import tokenizer | |
hypo_str = tgt_dict.string(hypo_tokens, remove_bpe) | |
if align_dict is not None: | |
hypo_str = replace_unk(hypo_str, src_str, alignment, align_dict, tgt_dict.unk_string()) | |
if align_dict is not None or remove_bpe is not None: | |
# Convert back to tokens for evaluating with unk replacement or without BPE | |
# Note that the dictionary can be modified inside the method. | |
hypo_tokens = tokenizer.Tokenizer.tokenize(hypo_str, tgt_dict, add_if_not_exist=True) | |
return hypo_tokens, hypo_str, alignment | |
def make_positions(tensor, padding_idx, left_pad): | |
"""Replace non-padding symbols with their position numbers. | |
Position numbers begin at padding_idx+1. | |
Padding symbols are ignored, but it is necessary to specify whether padding | |
is added on the left side (left_pad=True) or right side (left_pad=False). | |
""" | |
max_pos = padding_idx + 1 + tensor.size(1) | |
if not hasattr(make_positions, 'range_buf'): | |
make_positions.range_buf = tensor.new() | |
make_positions.range_buf = make_positions.range_buf.type_as(tensor) | |
if make_positions.range_buf.numel() < max_pos: | |
torch.arange(padding_idx + 1, max_pos, out=make_positions.range_buf) | |
mask = tensor.ne(padding_idx) | |
positions = make_positions.range_buf[:tensor.size(1)].expand_as(tensor) | |
if left_pad: | |
positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1) | |
return tensor.clone().masked_scatter_(mask, positions[mask]) | |
def strip_pad(tensor, pad): | |
return tensor[tensor.ne(pad)] | |
def buffered_arange(max): | |
if not hasattr(buffered_arange, 'buf'): | |
buffered_arange.buf = torch.LongTensor() | |
if max > buffered_arange.buf.numel(): | |
torch.arange(max, out=buffered_arange.buf) | |
return buffered_arange.buf[:max] | |
def convert_padding_direction(src_tokens, padding_idx, right_to_left=False, left_to_right=False): | |
assert right_to_left ^ left_to_right | |
pad_mask = src_tokens.eq(padding_idx) | |
if not pad_mask.any(): | |
# no padding, return early | |
return src_tokens | |
if left_to_right and not pad_mask[:, 0].any(): | |
# already right padded | |
return src_tokens | |
if right_to_left and not pad_mask[:, -1].any(): | |
# already left padded | |
return src_tokens | |
max_len = src_tokens.size(1) | |
range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens) | |
num_pads = pad_mask.long().sum(dim=1, keepdim=True) | |
if right_to_left: | |
index = torch.remainder(range - num_pads, max_len) | |
else: | |
index = torch.remainder(range + num_pads, max_len) | |
return src_tokens.gather(1, index) | |
def item(tensor): | |
if hasattr(tensor, 'item'): | |
return tensor.item() | |
if hasattr(tensor, '__getitem__'): | |
return tensor[0] | |
return tensor | |
def clip_grad_norm_(tensor, max_norm): | |
grad_norm = item(torch.norm(tensor)) | |
if grad_norm > max_norm > 0: | |
clip_coef = max_norm / (grad_norm + 1e-6) | |
tensor.mul_(clip_coef) | |
return grad_norm | |
def fill_with_neg_inf(t): | |
"""FP16-compatible function that fills a tensor with -inf.""" | |
return t.float().fill_(float('-inf')).type_as(t) | |
def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'): | |
"""Retrieves all checkpoints found in `path` directory. | |
Checkpoints are identified by matching filename to the specified pattern. If | |
the pattern contains groups, the result will be sorted by the first group in | |
descending order. | |
""" | |
pt_regexp = re.compile(pattern) | |
files = os.listdir(path) | |
entries = [] | |
for i, f in enumerate(files): | |
m = pt_regexp.fullmatch(f) | |
if m is not None: | |
idx = int(m.group(1)) if len(m.groups()) > 0 else i | |
entries.append((idx, m.group(0))) | |
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] | |