ImgCapt / tasks /ofa_task.py
cagataydag's picture
Duplicate from OFA-Sys/OFA-Image_Caption
733aa30
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass, field
import logging
import os
import math
import torch
from typing import Dict, Optional
from fairseq import search
from fairseq.data import FairseqDataset, iterators
from fairseq.optim.amp_optimizer import AMPOptimizer
from fairseq.dataclass import FairseqDataclass
from fairseq.tasks import FairseqTask, register_task
from omegaconf import DictConfig
logger = logging.getLogger(__name__)
@dataclass
class OFAConfig(FairseqDataclass):
data: Optional[str] = field(
default=None,
metadata={
"help": "colon separated path to data directories list, will be iterated upon during epochs "
"in round-robin manner; however, valid and test data are always in the first directory "
"to avoid the need for repeating them in all directories"
},
)
selected_cols: Optional[str] = field(
default=None,
metadata={"help": "selected cols"},
)
bpe_dir: Optional[str] = field(
default=None,
metadata={"help": "bpe dir"},
)
max_source_positions: int = field(
default=1024, metadata={"help": "max number of tokens in the source sequence"}
)
max_target_positions: int = field(
default=1024, metadata={"help": "max number of tokens in the target sequence"}
)
max_src_length: int = field(
default=128, metadata={"help": "the maximum src sequence length"}
)
max_tgt_length: int = field(
default=30, metadata={"help": "the maximum target sequence length"}
)
code_dict_size: int = field(
default=8192, metadata={"help": "code dict size"}
)
patch_image_size: int = field(
default=480, metadata={"help": "patch image size"}
)
num_bins: int = field(
default=1000, metadata={"help": "number of quantization bins"}
)
imagenet_default_mean_and_std: bool = field(
default=False,
metadata={"help": "imagenet normalize"},
)
constraint_range: Optional[str] = field(
default=None,
metadata={"help": "constraint range"}
)
@register_task("ofa", dataclass=OFAConfig)
class OFATask(FairseqTask):
def __init__(self, cfg: OFAConfig, src_dict, tgt_dict):
super().__init__(cfg)
self.src_dict = src_dict
self.tgt_dict = tgt_dict
@classmethod
def setup_task(cls, cfg: DictConfig, **kwargs):
"""Setup the task."""
# load dictionaries
src_dict = cls.load_dictionary(
os.path.join(cfg.bpe_dir, "dict.txt")
)
tgt_dict = cls.load_dictionary(
os.path.join(cfg.bpe_dir, "dict.txt")
)
src_dict.add_symbol("<mask>")
tgt_dict.add_symbol("<mask>")
for i in range(cfg.code_dict_size):
src_dict.add_symbol("<code_{}>".format(i))
tgt_dict.add_symbol("<code_{}>".format(i))
# quantization
for i in range(cfg.num_bins):
src_dict.add_symbol("<bin_{}>".format(i))
tgt_dict.add_symbol("<bin_{}>".format(i))
logger.info("source dictionary: {} types".format(len(src_dict)))
logger.info("target dictionary: {} types".format(len(tgt_dict)))
return cls(cfg, src_dict, tgt_dict)
def get_batch_iterator(
self,
dataset,
max_tokens=None,
max_sentences=None,
max_positions=None,
ignore_invalid_inputs=False,
required_batch_size_multiple=1,
seed=1,
num_shards=1,
shard_id=0,
num_workers=0,
epoch=1,
data_buffer_size=0,
disable_iterator_cache=False,
):
assert isinstance(dataset, FairseqDataset)
# initialize the dataset with the correct starting epoch
dataset.set_epoch(epoch)
# create mini-batches with given size constraints
batch_sampler = [
[j for j in range(i, min(i + max_sentences, len(dataset)))]
for i in range(0, len(dataset), max_sentences)
]
total_row_count = dataset.dataset.get_total_row_count()
num_batches = math.ceil(math.ceil(total_row_count / num_shards) / max_sentences)
if len(batch_sampler) < num_batches:
batch_sampler.append([])
# return a reusable, sharded iterator
epoch_iter = iterators.EpochBatchIterator(
dataset=dataset,
collate_fn=dataset.collater,
batch_sampler=batch_sampler,
seed=seed,
num_shards=1,
shard_id=0,
num_workers=num_workers,
epoch=epoch,
buffer_size=data_buffer_size
)
return epoch_iter
def build_model(self, cfg: FairseqDataclass):
model = super().build_model(cfg)
bpe_dict = {
"_name": "gpt2",
"gpt2_encoder_json": os.path.join(self.cfg.bpe_dir, "encoder.json"),
"gpt2_vocab_bpe": os.path.join(self.cfg.bpe_dir, "vocab.bpe")
}
bpe_dict = DictConfig(bpe_dict)
self.bpe = self.build_bpe(bpe_dict)
return model
def build_generator(
self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, prefix_allowed_tokens_fn=None,
):
"""
Build a :class:`~fairseq.SequenceGenerator` instance for this
task.
Args:
models (List[~fairseq.models.FairseqModel]): ensemble of models
args (fairseq.dataclass.configs.GenerationConfig):
configuration object (dataclass) for generation
extra_gen_cls_kwargs (Dict[str, Any]): extra options to pass
through to SequenceGenerator
prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]]):
If provided, this function constrains the beam search to
allowed tokens only at each step. The provided function
should take 2 arguments: the batch ID (`batch_id: int`)
and a unidimensional tensor of token ids (`inputs_ids:
torch.Tensor`). It has to return a `List[int]` with the
allowed tokens for the next generation step conditioned
on the previously generated tokens (`inputs_ids`) and
the batch ID (`batch_id`). This argument is useful for
constrained generation conditioned on the prefix, as
described in "Autoregressive Entity Retrieval"
(https://arxiv.org/abs/2010.00904) and
https://github.com/facebookresearch/GENRE.
"""
if getattr(args, "score_reference", False):
from fairseq.sequence_scorer import SequenceScorer
return SequenceScorer(
self.target_dictionary,
compute_alignment=getattr(args, "print_alignment", False),
)
from fairseq.sequence_generator import (
# SequenceGenerator,
SequenceGeneratorWithAlignment,
)
from models.sequence_generator import SequenceGenerator
# Choose search strategy. Defaults to Beam Search.
sampling = getattr(args, "sampling", False)
sampling_topk = getattr(args, "sampling_topk", -1)
sampling_topp = getattr(args, "sampling_topp", -1.0)
diverse_beam_groups = getattr(args, "diverse_beam_groups", -1)
diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5)
match_source_len = getattr(args, "match_source_len", False)
diversity_rate = getattr(args, "diversity_rate", -1)
constrained = getattr(args, "constraints", False)
if prefix_allowed_tokens_fn is None:
prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None)
if (
sum(
int(cond)
for cond in [
sampling,
diverse_beam_groups > 0,
match_source_len,
diversity_rate > 0,
]
)
> 1
):
raise ValueError("Provided Search parameters are mutually exclusive.")
assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling"
assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling"
if sampling:
search_strategy = search.Sampling(
self.target_dictionary, sampling_topk, sampling_topp
)
elif diverse_beam_groups > 0:
search_strategy = search.DiverseBeamSearch(
self.target_dictionary, diverse_beam_groups, diverse_beam_strength
)
elif match_source_len:
# this is useful for tagging applications where the output
# length should match the input length, so we hardcode the
# length constraints for simplicity
search_strategy = search.LengthConstrainedBeamSearch(
self.target_dictionary,
min_len_a=1,
min_len_b=0,
max_len_a=1,
max_len_b=0,
)
elif diversity_rate > -1:
search_strategy = search.DiverseSiblingsSearch(
self.target_dictionary, diversity_rate
)
elif constrained:
search_strategy = search.LexicallyConstrainedBeamSearch(
self.target_dictionary, args.constraints
)
elif prefix_allowed_tokens_fn:
search_strategy = search.PrefixConstrainedBeamSearch(
self.target_dictionary, prefix_allowed_tokens_fn
)
else:
search_strategy = search.BeamSearch(self.target_dictionary)
extra_gen_cls_kwargs = extra_gen_cls_kwargs or {}
if seq_gen_cls is None:
if getattr(args, "print_alignment", False):
seq_gen_cls = SequenceGeneratorWithAlignment
extra_gen_cls_kwargs["print_alignment"] = args.print_alignment
else:
seq_gen_cls = SequenceGenerator
return seq_gen_cls(
models,
self.target_dictionary,
beam_size=getattr(args, "beam", 5),
max_len_a=getattr(args, "max_len_a", 0),
max_len_b=getattr(args, "max_len_b", 200),
min_len=getattr(args, "min_len", 1),
normalize_scores=(not getattr(args, "unnormalized", False)),
len_penalty=getattr(args, "lenpen", 1),
unk_penalty=getattr(args, "unkpen", 0),
temperature=getattr(args, "temperature", 1.0),
match_source_len=getattr(args, "match_source_len", False),
no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0),
search_strategy=search_strategy,
constraint_range=self.cfg.constraint_range,
**extra_gen_cls_kwargs,
)
def train_step(
self, sample, model, criterion, optimizer, update_num, ignore_grad=False, **extra_kwargs
):
"""
Do forward and backward, and return the loss as computed by *criterion*
for the given *model* and *sample*.
Args:
sample (dict): the mini-batch. The format is defined by the
:class:`~fairseq.data.FairseqDataset`.
model (~fairseq.models.BaseFairseqModel): the model
criterion (~fairseq.criterions.FairseqCriterion): the criterion
optimizer (~fairseq.optim.FairseqOptimizer): the optimizer
update_num (int): the current update
ignore_grad (bool): multiply loss by 0 if this is set to True
Returns:
tuple:
- the loss
- the sample size, which is used as the denominator for the
gradient
- logging outputs to display while training
"""
model.train()
model.set_num_updates(update_num)
with torch.autograd.profiler.record_function("forward"):
with torch.cuda.amp.autocast(enabled=(isinstance(optimizer, AMPOptimizer))):
loss, sample_size, logging_output = criterion(model, sample, update_num)
if ignore_grad:
loss *= 0
with torch.autograd.profiler.record_function("backward"):
optimizer.backward(loss)
return loss, sample_size, logging_output
def max_positions(self):
"""Return the max sentence length allowed by the task."""
return (self.cfg.max_source_positions, self.cfg.max_target_positions)
@property
def source_dictionary(self):
"""Return the source :class:`~fairseq.data.Dictionary`."""
return self.src_dict
@property
def target_dictionary(self):
"""Return the target :class:`~fairseq.data.Dictionary`."""
return self.tgt_dict