Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
import os | |
import warnings | |
from argparse import Namespace | |
from typing import Any, Callable, Dict, List | |
import torch | |
from fairseq import metrics, search, tokenizer, utils | |
from fairseq.data import Dictionary, FairseqDataset, data_utils, encoders, iterators | |
from fairseq.dataclass import FairseqDataclass | |
from fairseq.dataclass.utils import gen_parser_from_dataclass | |
from fairseq.optim.amp_optimizer import AMPOptimizer | |
from omegaconf import DictConfig | |
logger = logging.getLogger(__name__) | |
class StatefulContainer(object): | |
def __init__(self): | |
self._state = dict() | |
self._factories = dict() | |
def add_factory(self, name, factory: Callable[[], Any]): | |
self._factories[name] = factory | |
def merge_state_dict(self, state_dict: Dict[str, Any]): | |
self._state.update(state_dict) | |
def state_dict(self) -> Dict[str, Any]: | |
return self._state | |
def __getattr__(self, name): | |
if name not in self._state and name in self._factories: | |
self._state[name] = self._factories[name]() | |
if name in self._state: | |
return self._state[name] | |
raise AttributeError(f"Task state has no factory for attribute {name}") | |
class FairseqTask(object): | |
""" | |
Tasks store dictionaries and provide helpers for loading/iterating over | |
Datasets, initializing the Model/Criterion and calculating the loss. | |
Tasks have limited statefulness. In particular, state that needs to be | |
saved to/loaded from checkpoints needs to be stored in the `self.state` | |
:class:`StatefulContainer` object. For example:: | |
self.state.add_factory("dictionary", self.load_dictionary) | |
print(self.state.dictionary) # calls self.load_dictionary() | |
This is necessary so that when loading checkpoints, we can properly | |
recreate the task state after initializing the task instance. | |
""" | |
def add_args(cls, parser): | |
"""Add task-specific arguments to the parser.""" | |
dc = getattr(cls, "__dataclass", None) | |
if dc is not None: | |
gen_parser_from_dataclass(parser, dc()) | |
def logging_outputs_can_be_summed(criterion) -> bool: | |
""" | |
Whether the logging outputs returned by `train_step` and `valid_step` can | |
be summed across workers prior to calling `aggregate_logging_outputs`. | |
Setting this to True will improves distributed training speed. | |
""" | |
return criterion.logging_outputs_can_be_summed() | |
def __init__(self, cfg: FairseqDataclass, **kwargs): | |
self.cfg = cfg | |
self.datasets = dict() | |
self.dataset_to_epoch_iter = dict() | |
self.state = StatefulContainer() | |
def load_dictionary(cls, filename): | |
"""Load the dictionary from the filename | |
Args: | |
filename (str): the filename | |
""" | |
return Dictionary.load(filename) | |
def build_dictionary( | |
cls, filenames, workers=1, threshold=-1, nwords=-1, padding_factor=8 | |
): | |
"""Build the dictionary | |
Args: | |
filenames (list): list of filenames | |
workers (int): number of concurrent workers | |
threshold (int): defines the minimum word count | |
nwords (int): defines the total number of words in the final dictionary, | |
including special symbols | |
padding_factor (int): can be used to pad the dictionary size to be a | |
multiple of 8, which is important on some hardware (e.g., Nvidia | |
Tensor Cores). | |
""" | |
d = Dictionary() | |
for filename in filenames: | |
Dictionary.add_file_to_dictionary( | |
filename, d, tokenizer.tokenize_line, workers | |
) | |
d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor) | |
return d | |
def setup_task(cls, cfg: DictConfig, **kwargs): | |
"""Setup the task (e.g., load dictionaries). | |
Args: | |
cfg (omegaconf.DictConfig): parsed command-line arguments | |
""" | |
return cls(cfg, **kwargs) | |
def has_sharded_data(self, split): | |
return os.pathsep in getattr(self.cfg, "data", "") | |
def load_dataset( | |
self, | |
split: str, | |
combine: bool = False, | |
task_cfg: FairseqDataclass = None, | |
**kwargs, | |
): | |
"""Load a given dataset split. | |
Args: | |
split (str): name of the split (e.g., train, valid, test) | |
combine (bool): combines a split segmented into pieces into one dataset | |
task_cfg (FairseqDataclass): optional task configuration stored in the checkpoint that can be used | |
to load datasets | |
""" | |
raise NotImplementedError | |
def dataset(self, split): | |
""" | |
Return a loaded dataset split. | |
Args: | |
split (str): name of the split (e.g., train, valid, test) | |
Returns: | |
a :class:`~fairseq.data.FairseqDataset` corresponding to *split* | |
""" | |
from fairseq.data import FairseqDataset | |
if split not in self.datasets: | |
raise KeyError("Dataset not loaded: " + split) | |
if not isinstance(self.datasets[split], FairseqDataset): | |
raise TypeError("Datasets are expected to be of type FairseqDataset") | |
return self.datasets[split] | |
def filter_indices_by_size( | |
self, indices, dataset, max_positions=None, ignore_invalid_inputs=False | |
): | |
""" | |
Filter examples that are too large | |
Args: | |
indices (np.array): original array of sample indices | |
dataset (~fairseq.data.FairseqDataset): dataset to batch | |
max_positions (optional): max sentence length supported by the | |
model (default: None). | |
ignore_invalid_inputs (bool, optional): don't raise Exception for | |
sentences that are too long (default: False). | |
Returns: | |
np.array: array of filtered sample indices | |
""" | |
indices, ignored = dataset.filter_indices_by_size(indices, max_positions) | |
if len(ignored) > 0: | |
if not ignore_invalid_inputs: | |
raise Exception( | |
( | |
"Size of sample #{} is invalid (={}) since max_positions={}, " | |
"skip this example with --skip-invalid-size-inputs-valid-test" | |
).format(ignored[0], dataset.size(ignored[0]), max_positions) | |
) | |
logger.warning( | |
( | |
"{:,} samples have invalid sizes and will be skipped, " | |
"max_positions={}, first few sample ids={}" | |
).format(len(ignored), max_positions, ignored[:10]) | |
) | |
return indices | |
def can_reuse_epoch_itr(self, dataset): | |
# We can reuse the epoch iterator across epochs as long as the dataset | |
# hasn't disabled it. We default to ``False`` here, although in practice | |
# this will be ``True`` for most datasets that inherit from | |
# ``FairseqDataset`` due to the base implementation there. | |
return getattr(dataset, "can_reuse_epoch_itr_across_epochs", False) | |
def get_batch_iterator( | |
self, | |
dataset, | |
max_tokens=None, | |
max_sentences=None, | |
max_positions=None, | |
ignore_invalid_inputs=False, | |
required_batch_size_multiple=1, | |
seed=1, | |
num_shards=1, | |
shard_id=0, | |
num_workers=0, | |
epoch=1, | |
data_buffer_size=0, | |
disable_iterator_cache=False, | |
skip_remainder_batch=False, | |
grouped_shuffling=False, | |
update_epoch_batch_itr=False, | |
): | |
""" | |
Get an iterator that yields batches of data from the given dataset. | |
Args: | |
dataset (~fairseq.data.FairseqDataset): dataset to batch | |
max_tokens (int, optional): max number of tokens in each batch | |
(default: None). | |
max_sentences (int, optional): max number of sentences in each | |
batch (default: None). | |
max_positions (optional): max sentence length supported by the | |
model (default: None). | |
ignore_invalid_inputs (bool, optional): don't raise Exception for | |
sentences that are too long (default: False). | |
required_batch_size_multiple (int, optional): require batch size to | |
be a multiple of N (default: 1). | |
seed (int, optional): seed for random number generator for | |
reproducibility (default: 1). | |
num_shards (int, optional): shard the data iterator into N | |
shards (default: 1). | |
shard_id (int, optional): which shard of the data iterator to | |
return (default: 0). | |
num_workers (int, optional): how many subprocesses to use for data | |
loading. 0 means the data will be loaded in the main process | |
(default: 0). | |
epoch (int, optional): the epoch to start the iterator from | |
(default: 1). | |
data_buffer_size (int, optional): number of batches to | |
preload (default: 0). | |
disable_iterator_cache (bool, optional): don't cache the | |
EpochBatchIterator (ignores `FairseqTask::can_reuse_epoch_itr`) | |
(default: False). | |
skip_remainder_batch (bool, optional): if set, discard the last | |
batch in each training epoch, as the last batch is often smaller than | |
local_batch_size * distributed_word_size (default: ``True``). | |
grouped_shuffling (bool, optional): group batches with each groups | |
containing num_shards batches and shuffle groups. Reduces difference | |
between sequence lengths among workers for batches sorted by length. | |
update_epoch_batch_itr (bool optional): if true then donot use the cached | |
batch iterator for the epoch | |
Returns: | |
~fairseq.iterators.EpochBatchIterator: a batched iterator over the | |
given dataset split | |
""" | |
can_reuse_epoch_itr = ( | |
not disable_iterator_cache | |
and not update_epoch_batch_itr | |
and self.can_reuse_epoch_itr(dataset) | |
) | |
if can_reuse_epoch_itr and dataset in self.dataset_to_epoch_iter: | |
logger.debug("reusing EpochBatchIterator for epoch {}".format(epoch)) | |
return self.dataset_to_epoch_iter[dataset] | |
assert isinstance(dataset, FairseqDataset) | |
# initialize the dataset with the correct starting epoch | |
dataset.set_epoch(epoch) | |
# get indices ordered by example size | |
with data_utils.numpy_seed(seed): | |
indices = dataset.ordered_indices() | |
# filter examples that are too large | |
if max_positions is not None: | |
indices = self.filter_indices_by_size( | |
indices, dataset, max_positions, ignore_invalid_inputs | |
) | |
# create mini-batches with given size constraints | |
batch_sampler = dataset.batch_by_size( | |
indices, | |
max_tokens=max_tokens, | |
max_sentences=max_sentences, | |
required_batch_size_multiple=required_batch_size_multiple, | |
) | |
reuse_dataloader = getattr(self.cfg, "reuse_dataloader", True) | |
persistent_workers = getattr(self.cfg, "persistent_workers", False) | |
# return a reusable, sharded iterator | |
epoch_iter = iterators.EpochBatchIterator( | |
dataset=dataset, | |
collate_fn=dataset.collater, | |
batch_sampler=batch_sampler, | |
seed=seed, | |
num_shards=num_shards, | |
shard_id=shard_id, | |
num_workers=num_workers, | |
epoch=epoch, | |
buffer_size=data_buffer_size, | |
skip_remainder_batch=skip_remainder_batch, | |
grouped_shuffling=grouped_shuffling, | |
reuse_dataloader=reuse_dataloader, | |
persistent_workers=persistent_workers, | |
) | |
if can_reuse_epoch_itr: | |
self.dataset_to_epoch_iter[dataset] = epoch_iter | |
return epoch_iter | |
def build_model(self, cfg: FairseqDataclass, from_checkpoint=False): | |
""" | |
Build the :class:`~fairseq.models.BaseFairseqModel` instance for this | |
task. | |
Args: | |
cfg (FairseqDataclass): configuration object | |
Returns: | |
a :class:`~fairseq.models.BaseFairseqModel` instance | |
""" | |
from fairseq import models, quantization_utils | |
model = models.build_model(cfg, self, from_checkpoint) | |
model = quantization_utils.quantize_model_scalar(model, cfg) | |
return model | |
def build_criterion(self, cfg: DictConfig): | |
""" | |
Build the :class:`~fairseq.criterions.FairseqCriterion` instance for | |
this task. | |
Args: | |
cfg (omegaconf.DictConfig): configration object | |
Returns: | |
a :class:`~fairseq.criterions.FairseqCriterion` instance | |
""" | |
from fairseq import criterions | |
return criterions.build_criterion(cfg, self) | |
def build_generator( | |
self, | |
models, | |
args, | |
seq_gen_cls=None, | |
extra_gen_cls_kwargs=None, | |
prefix_allowed_tokens_fn=None, | |
): | |
""" | |
Build a :class:`~fairseq.SequenceGenerator` instance for this | |
task. | |
Args: | |
models (List[~fairseq.models.FairseqModel]): ensemble of models | |
args (fairseq.dataclass.configs.GenerationConfig): | |
configuration object (dataclass) for generation | |
extra_gen_cls_kwargs (Dict[str, Any]): extra options to pass | |
through to SequenceGenerator | |
prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]]): | |
If provided, this function constrains the beam search to | |
allowed tokens only at each step. The provided function | |
should take 2 arguments: the batch ID (`batch_id: int`) | |
and a unidimensional tensor of token ids (`inputs_ids: | |
torch.Tensor`). It has to return a `List[int]` with the | |
allowed tokens for the next generation step conditioned | |
on the previously generated tokens (`inputs_ids`) and | |
the batch ID (`batch_id`). This argument is useful for | |
constrained generation conditioned on the prefix, as | |
described in "Autoregressive Entity Retrieval" | |
(https://arxiv.org/abs/2010.00904) and | |
https://github.com/facebookresearch/GENRE. | |
""" | |
if getattr(args, "score_reference", False): | |
from fairseq.sequence_scorer import SequenceScorer | |
return SequenceScorer( | |
self.target_dictionary, | |
compute_alignment=getattr(args, "print_alignment", False), | |
) | |
from fairseq.sequence_generator import ( | |
SequenceGenerator, | |
SequenceGeneratorWithAlignment, | |
) | |
# Choose search strategy. Defaults to Beam Search. | |
sampling = getattr(args, "sampling", False) | |
sampling_topk = getattr(args, "sampling_topk", -1) | |
sampling_topp = getattr(args, "sampling_topp", -1.0) | |
diverse_beam_groups = getattr(args, "diverse_beam_groups", -1) | |
diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5) | |
match_source_len = getattr(args, "match_source_len", False) | |
diversity_rate = getattr(args, "diversity_rate", -1) | |
constrained = getattr(args, "constraints", False) | |
if prefix_allowed_tokens_fn is None: | |
prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None) | |
if ( | |
sum( | |
int(cond) | |
for cond in [ | |
sampling, | |
diverse_beam_groups > 0, | |
match_source_len, | |
diversity_rate > 0, | |
] | |
) | |
> 1 | |
): | |
raise ValueError("Provided Search parameters are mutually exclusive.") | |
assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling" | |
assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling" | |
if sampling: | |
search_strategy = search.Sampling( | |
self.target_dictionary, sampling_topk, sampling_topp | |
) | |
elif diverse_beam_groups > 0: | |
search_strategy = search.DiverseBeamSearch( | |
self.target_dictionary, diverse_beam_groups, diverse_beam_strength | |
) | |
elif match_source_len: | |
# this is useful for tagging applications where the output | |
# length should match the input length, so we hardcode the | |
# length constraints for simplicity | |
search_strategy = search.LengthConstrainedBeamSearch( | |
self.target_dictionary, | |
min_len_a=1, | |
min_len_b=0, | |
max_len_a=1, | |
max_len_b=0, | |
) | |
elif diversity_rate > -1: | |
search_strategy = search.DiverseSiblingsSearch( | |
self.target_dictionary, diversity_rate | |
) | |
elif constrained: | |
search_strategy = search.LexicallyConstrainedBeamSearch( | |
self.target_dictionary, args.constraints | |
) | |
elif prefix_allowed_tokens_fn: | |
search_strategy = search.PrefixConstrainedBeamSearch( | |
self.target_dictionary, prefix_allowed_tokens_fn | |
) | |
else: | |
search_strategy = search.BeamSearch(self.target_dictionary) | |
extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} | |
if seq_gen_cls is None: | |
if getattr(args, "print_alignment", False): | |
seq_gen_cls = SequenceGeneratorWithAlignment | |
extra_gen_cls_kwargs["print_alignment"] = args.print_alignment | |
else: | |
seq_gen_cls = SequenceGenerator | |
return seq_gen_cls( | |
models, | |
self.target_dictionary, | |
beam_size=getattr(args, "beam", 5), | |
max_len_a=getattr(args, "max_len_a", 0), | |
max_len_b=getattr(args, "max_len_b", 200), | |
min_len=getattr(args, "min_len", 1), | |
normalize_scores=(not getattr(args, "unnormalized", False)), | |
len_penalty=getattr(args, "lenpen", 1), | |
unk_penalty=getattr(args, "unkpen", 0), | |
temperature=getattr(args, "temperature", 1.0), | |
match_source_len=getattr(args, "match_source_len", False), | |
no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), | |
search_strategy=search_strategy, | |
**extra_gen_cls_kwargs, | |
) | |
def train_step( | |
self, sample, model, criterion, optimizer, update_num, ignore_grad=False | |
): | |
""" | |
Do forward and backward, and return the loss as computed by *criterion* | |
for the given *model* and *sample*. | |
Args: | |
sample (dict): the mini-batch. The format is defined by the | |
:class:`~fairseq.data.FairseqDataset`. | |
model (~fairseq.models.BaseFairseqModel): the model | |
criterion (~fairseq.criterions.FairseqCriterion): the criterion | |
optimizer (~fairseq.optim.FairseqOptimizer): the optimizer | |
update_num (int): the current update | |
ignore_grad (bool): multiply loss by 0 if this is set to True | |
Returns: | |
tuple: | |
- the loss | |
- the sample size, which is used as the denominator for the | |
gradient | |
- logging outputs to display while training | |
""" | |
model.train() | |
model.set_num_updates(update_num) | |
with torch.autograd.profiler.record_function("forward"): | |
with torch.cuda.amp.autocast(enabled=(isinstance(optimizer, AMPOptimizer))): | |
loss, sample_size, logging_output = criterion(model, sample) | |
if ignore_grad: | |
loss *= 0 | |
with torch.autograd.profiler.record_function("backward"): | |
optimizer.backward(loss) | |
return loss, sample_size, logging_output | |
def valid_step(self, sample, model, criterion): | |
model.eval() | |
with torch.no_grad(): | |
loss, sample_size, logging_output = criterion(model, sample) | |
return loss, sample_size, logging_output | |
def optimizer_step(self, optimizer, model, update_num): | |
optimizer.step() | |
def build_dataset_for_inference( | |
self, src_tokens: List[torch.Tensor], src_lengths: List[int], **kwargs | |
) -> torch.utils.data.Dataset: | |
raise NotImplementedError | |
def inference_step( | |
self, generator, models, sample, prefix_tokens=None, constraints=None | |
): | |
with torch.no_grad(): | |
return generator.generate( | |
models, sample, prefix_tokens=prefix_tokens, constraints=constraints | |
) | |
def begin_epoch(self, epoch, model): | |
"""Hook function called before the start of each epoch.""" | |
pass | |
def begin_valid_epoch(self, epoch, model): | |
"""Hook function called before the start of each validation epoch.""" | |
pass | |
def aggregate_logging_outputs(self, logging_outputs, criterion): | |
"""[deprecated] Aggregate logging outputs from data parallel training.""" | |
utils.deprecation_warning( | |
"The aggregate_logging_outputs API is deprecated. " | |
"Please use the reduce_metrics API instead." | |
) | |
with metrics.aggregate() as agg: | |
self.reduce_metrics(logging_outputs, criterion) | |
return agg.get_smoothed_values() | |
def reduce_metrics(self, logging_outputs, criterion): | |
"""Aggregate logging outputs from data parallel training.""" | |
# backward compatibility for tasks that override aggregate_logging_outputs | |
base_func = FairseqTask.aggregate_logging_outputs | |
self_func = getattr(self, "aggregate_logging_outputs").__func__ | |
if self_func is not base_func: | |
utils.deprecation_warning( | |
"Tasks should implement the reduce_metrics API. " | |
"Falling back to deprecated aggregate_logging_outputs API." | |
) | |
agg_logging_outputs = self.aggregate_logging_outputs( | |
logging_outputs, criterion | |
) | |
for k, v in agg_logging_outputs.items(): | |
metrics.log_scalar(k, v) | |
return | |
if not any("ntokens" in log for log in logging_outputs): | |
warnings.warn( | |
"ntokens not found in Criterion logging outputs, cannot log wpb or wps" | |
) | |
else: | |
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) | |
metrics.log_scalar("wpb", ntokens, priority=180, round=1) | |
metrics.log_speed("wps", ntokens, priority=90, round=1) | |
if not any("nsentences" in log for log in logging_outputs): | |
warnings.warn( | |
"nsentences not found in Criterion logging outputs, cannot log bsz" | |
) | |
else: | |
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) | |
metrics.log_scalar("bsz", nsentences, priority=190, round=1) | |
criterion.__class__.reduce_metrics(logging_outputs) | |
def state_dict(self): | |
if self.state is not None: | |
return self.state.state_dict | |
return {} | |
def load_state_dict(self, state_dict: Dict[str, Any]): | |
if self.state is not None: | |
self.state.merge_state_dict(state_dict) | |
def max_positions(self): | |
"""Return the max input length allowed by the task.""" | |
return None | |
def source_dictionary(self): | |
"""Return the source :class:`~fairseq.data.Dictionary` (if applicable | |
for this task).""" | |
raise NotImplementedError | |
def target_dictionary(self): | |
"""Return the target :class:`~fairseq.data.Dictionary` (if applicable | |
for this task).""" | |
raise NotImplementedError | |
def build_tokenizer(self, args): | |
"""Build the pre-tokenizer for this task.""" | |
return encoders.build_tokenizer(args) | |
def build_bpe(self, args): | |
"""Build the tokenizer for this task.""" | |
return encoders.build_bpe(args) | |
def get_interactive_tokens_and_lengths(self, lines, encode_fn): | |
tokens = [ | |
self.source_dictionary.encode_line( | |
encode_fn(src_str), add_if_not_exist=False | |
).long() | |
for src_str in lines | |
] | |
lengths = [t.numel() for t in tokens] | |
return tokens, lengths | |
class LegacyFairseqTask(FairseqTask): | |
def __init__(self, args: Namespace): | |
super().__init__(None) | |
self.args = args | |
self.datasets = {} | |
self.dataset_to_epoch_iter = {} | |
def setup_task(cls, args: Namespace, **kwargs): | |
"""Setup the task (e.g., load dictionaries). | |
Args: | |
args (argparse.Namespace): parsed command-line arguments | |
""" | |
return cls(args, **kwargs) | |
def has_sharded_data(self, split): | |
return os.pathsep in getattr(self.args, "data", "") | |
def build_model(self, args: Namespace, from_checkpoint=False): | |
""" | |
Build the :class:`~fairseq.models.BaseFairseqModel` instance for this | |
task. | |
Args: | |
args (argparse.Namespace): parsed command-line arguments | |
Returns: | |
a :class:`~fairseq.models.BaseFairseqModel` instance | |
""" | |
from fairseq import models, quantization_utils | |
model = models.build_model(args, self, from_checkpoint) | |
model = quantization_utils.quantize_model_scalar(model, args) | |
return model | |
def build_criterion(self, args: Namespace): | |
""" | |
Build the :class:`~fairseq.criterions.FairseqCriterion` instance for | |
this task. | |
Args: | |
args (argparse.Namespace): parsed command-line arguments | |
Returns: | |
a :class:`~fairseq.criterions.FairseqCriterion` instance | |
""" | |
from fairseq import criterions | |
return criterions.build_criterion(args, self) | |