File size: 20,507 Bytes
6a62ffb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Base classes for various fairseq models.
"""
import logging
from argparse import Namespace
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.data import Dictionary
from fairseq.dataclass.utils import (
convert_namespace_to_omegaconf,
gen_parser_from_dataclass,
)
from fairseq.models import FairseqDecoder, FairseqEncoder
from omegaconf import DictConfig
from torch import Tensor
logger = logging.getLogger(__name__)
def check_type(module, expected_type):
if hasattr(module, "unwrapped_module"):
assert isinstance(
module.unwrapped_module, expected_type
), f"{type(module.unwrapped_module)} != {expected_type}"
else:
assert isinstance(module, expected_type), f"{type(module)} != {expected_type}"
class BaseFairseqModel(nn.Module):
"""Base class for fairseq models."""
def __init__(self):
super().__init__()
self._is_generation_fast = False
@classmethod
def add_args(cls, parser):
"""Add model-specific arguments to the parser."""
dc = getattr(cls, "__dataclass", None)
if dc is not None:
# do not set defaults so that settings defaults from various architectures still works
gen_parser_from_dataclass(parser, dc(), delete_default=True)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
raise NotImplementedError("Model must implement the build_model method")
def get_targets(self, sample, net_output):
"""Get targets from either the sample or the net's output."""
return sample["target"]
def get_normalized_probs(
self,
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):
"""Get normalized probabilities (or log probs) from a net's output."""
return self.get_normalized_probs_scriptable(net_output, log_probs, sample)
# TorchScript doesn't support super() method so that the scriptable Subclass
# can't access the base class model in Torchscript.
# Current workaround is to add a helper function with different name and
# call the helper function from scriptable Subclass.
def get_normalized_probs_scriptable(
self,
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]],
log_probs: bool,
sample: Optional[Dict[str, Tensor]] = None,
):
"""Scriptable helper function for get_normalized_probs in ~BaseFairseqModel"""
if hasattr(self, "decoder"):
return self.decoder.get_normalized_probs(net_output, log_probs, sample)
elif torch.is_tensor(net_output):
# syntactic sugar for simple models which don't have a decoder
# (e.g., the classification tutorial)
logits = net_output.float()
if log_probs:
return F.log_softmax(logits, dim=-1)
else:
return F.softmax(logits, dim=-1)
raise NotImplementedError
def extract_features(self, *args, **kwargs):
"""Similar to *forward* but only return features."""
return self(*args, **kwargs)
def max_positions(self):
"""Maximum length supported by the model."""
return None
def load_state_dict(
self,
state_dict,
strict=True,
model_cfg: Optional[DictConfig] = None,
args: Optional[Namespace] = None,
):
"""Copies parameters and buffers from *state_dict* into this module and
its descendants.
Overrides the method in :class:`nn.Module`. Compared with that method
this additionally "upgrades" *state_dicts* from old checkpoints.
"""
if model_cfg is None and args is not None:
logger.warn(
"using 'args' is deprecated, please update your code to use dataclass config"
)
model_cfg = convert_namespace_to_omegaconf(args).model
self.upgrade_state_dict(state_dict)
from fairseq.checkpoint_utils import prune_state_dict
new_state_dict = prune_state_dict(state_dict, model_cfg)
return super().load_state_dict(new_state_dict, strict)
def upgrade_state_dict(self, state_dict):
"""Upgrade old state dicts to work with newer code."""
self.upgrade_state_dict_named(state_dict, "")
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade old state dicts to work with newer code.
Args:
state_dict (dict): state dictionary to upgrade, in place
name (str): the state dict key corresponding to the current module
"""
assert state_dict is not None
def do_upgrade(m, prefix):
if len(prefix) > 0:
prefix += "."
for n, c in m.named_children():
name = prefix + n
if hasattr(c, "upgrade_state_dict_named"):
c.upgrade_state_dict_named(state_dict, name)
elif hasattr(c, "upgrade_state_dict"):
c.upgrade_state_dict(state_dict)
do_upgrade(c, name)
do_upgrade(self, name)
def set_num_updates(self, num_updates):
"""State from trainer to pass along to model at every update."""
for m in self.modules():
if hasattr(m, "set_num_updates") and m != self:
m.set_num_updates(num_updates)
def prepare_for_inference_(self, cfg: DictConfig):
"""Prepare model for inference."""
kwargs = {}
kwargs["beamable_mm_beam_size"] = (
None
if getattr(cfg.generation, "no_beamable_mm", False)
else getattr(cfg.generation, "beam", 5)
)
kwargs["need_attn"] = getattr(cfg.generation, "print_alignment", False)
if getattr(cfg.generation, "retain_dropout", False):
kwargs["retain_dropout"] = cfg.generation.retain_dropout
kwargs["retain_dropout_modules"] = cfg.generation.retain_dropout_modules
self.make_generation_fast_(**kwargs)
def make_generation_fast_(self, **kwargs):
"""
Legacy entry point to optimize model for faster generation.
Prefer prepare_for_inference_.
"""
if self._is_generation_fast:
return # only apply once
self._is_generation_fast = True
# remove weight norm from all modules in the network
def apply_remove_weight_norm(module):
try:
nn.utils.remove_weight_norm(module)
except (AttributeError, ValueError): # this module didn't have weight norm
return
self.apply(apply_remove_weight_norm)
def apply_make_generation_fast_(module, prefix):
if len(prefix) > 0:
prefix += "."
base_func = BaseFairseqModel.make_generation_fast_
for n, m in module.named_modules():
if (
m != self
and hasattr(m, "make_generation_fast_")
# don't call this implementation again, e.g., if
# children modules also inherit from BaseFairseqModel
and m.make_generation_fast_.__func__ is not base_func
):
name = prefix + n
m.make_generation_fast_(name=name, **kwargs)
apply_make_generation_fast_(self, "")
def train(mode=True):
if mode:
raise RuntimeError("cannot train after make_generation_fast")
# this model should no longer be used for training
self.eval()
self.train = train
def prepare_for_onnx_export_(self, **kwargs):
"""Make model exportable via ONNX trace."""
seen = set()
def apply_prepare_for_onnx_export_(module):
if (
module != self
and hasattr(module, "prepare_for_onnx_export_")
and module not in seen
):
seen.add(module)
module.prepare_for_onnx_export_(**kwargs)
self.apply(apply_prepare_for_onnx_export_)
@classmethod
def from_pretrained(
cls,
model_name_or_path,
checkpoint_file="model.pt",
data_name_or_path=".",
**kwargs,
):
"""
Load a :class:`~fairseq.models.FairseqModel` from a pre-trained model
file. Downloads and caches the pre-trained model file if needed.
The base implementation returns a
:class:`~fairseq.hub_utils.GeneratorHubInterface`, which can be used to
generate translations or sample from language models. The underlying
:class:`~fairseq.models.FairseqModel` can be accessed via the
*generator.models* attribute.
Other models may override this to implement custom hub interfaces.
Args:
model_name_or_path (str): either the name of a pre-trained model to
load or a path/URL to a pre-trained model state dict
checkpoint_file (str, optional): colon-separated list of checkpoint
files in the model archive to ensemble (default: 'model.pt')
data_name_or_path (str, optional): point args.data to the archive
at the given path/URL. Can start with '.' or './' to reuse the
model archive path.
"""
from fairseq import hub_utils
x = hub_utils.from_pretrained(
model_name_or_path,
checkpoint_file,
data_name_or_path,
archive_map=cls.hub_models(),
**kwargs,
)
logger.info(x["args"])
return hub_utils.GeneratorHubInterface(x["args"], x["task"], x["models"])
@classmethod
def hub_models(cls):
return {}
class FairseqEncoderDecoderModel(BaseFairseqModel):
"""Base class for encoder-decoder models.
Args:
encoder (FairseqEncoder): the encoder
decoder (FairseqDecoder): the decoder
"""
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
check_type(self.encoder, FairseqEncoder)
check_type(self.decoder, FairseqDecoder)
def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
"""
Run the forward pass for an encoder-decoder model.
First feed a batch of source tokens through the encoder. Then, feed the
encoder output and previous decoder outputs (i.e., teacher forcing) to
the decoder to produce the next outputs::
encoder_out = self.encoder(src_tokens, src_lengths)
return self.decoder(prev_output_tokens, encoder_out)
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (LongTensor): source sentence lengths of shape `(batch)`
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
Returns:
tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
decoder_out = self.decoder(
prev_output_tokens, encoder_out=encoder_out, **kwargs
)
return decoder_out
def forward_decoder(self, prev_output_tokens, **kwargs):
return self.decoder(prev_output_tokens, **kwargs)
def extract_features(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
"""
Similar to *forward* but only return features.
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs)
features = self.decoder.extract_features(
prev_output_tokens, encoder_out=encoder_out, **kwargs
)
return features
def output_layer(self, features, **kwargs):
"""Project features to the default output size (typically vocabulary size)."""
return self.decoder.output_layer(features, **kwargs)
def max_positions(self):
"""Maximum length supported by the model."""
return (self.encoder.max_positions(), self.decoder.max_positions())
def max_decoder_positions(self):
"""Maximum length supported by the decoder."""
return self.decoder.max_positions()
class FairseqModel(FairseqEncoderDecoderModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
utils.deprecation_warning(
"FairseqModel is deprecated, please use FairseqEncoderDecoderModel "
"or BaseFairseqModel instead",
stacklevel=4,
)
class FairseqMultiModel(BaseFairseqModel):
"""Base class for combining multiple encoder-decoder models."""
def __init__(self, encoders, decoders):
super().__init__()
assert encoders.keys() == decoders.keys()
self.keys = list(encoders.keys())
for key in self.keys:
check_type(encoders[key], FairseqEncoder)
check_type(decoders[key], FairseqDecoder)
self.models = nn.ModuleDict(
{
key: FairseqEncoderDecoderModel(encoders[key], decoders[key])
for key in self.keys
}
)
@staticmethod
def build_shared_embeddings(
dicts: Dict[str, Dictionary],
langs: List[str],
embed_dim: int,
build_embedding: callable,
pretrained_embed_path: Optional[str] = None,
):
"""
Helper function to build shared embeddings for a set of languages after
checking that all dicts corresponding to those languages are equivalent.
Args:
dicts: Dict of lang_id to its corresponding Dictionary
langs: languages that we want to share embeddings for
embed_dim: embedding dimension
build_embedding: callable function to actually build the embedding
pretrained_embed_path: Optional path to load pretrained embeddings
"""
shared_dict = dicts[langs[0]]
if any(dicts[lang] != shared_dict for lang in langs):
raise ValueError(
"--share-*-embeddings requires a joined dictionary: "
"--share-encoder-embeddings requires a joined source "
"dictionary, --share-decoder-embeddings requires a joined "
"target dictionary, and --share-all-embeddings requires a "
"joint source + target dictionary."
)
return build_embedding(shared_dict, embed_dim, pretrained_embed_path)
def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs):
raise NotImplementedError
def max_positions(self):
"""Maximum length supported by the model."""
return {
key: (
self.models[key].encoder.max_positions(),
self.models[key].decoder.max_positions(),
)
for key in self.keys
}
def max_decoder_positions(self):
"""Maximum length supported by the decoder."""
return min(model.decoder.max_positions() for model in self.models.values())
@property
def encoder(self):
return self.models[self.keys[0]].encoder
@property
def decoder(self):
return self.models[self.keys[0]].decoder
def forward_decoder(self, prev_output_tokens, **kwargs):
return self.decoder(prev_output_tokens, **kwargs)
def load_state_dict(
self,
state_dict,
strict=True,
model_cfg=None,
args: Optional[Namespace] = None,
):
"""Copies parameters and buffers from *state_dict* into this module and
its descendants.
Overrides the method in :class:`nn.Module`. Compared with that method
this additionally "upgrades" *state_dicts* from old checkpoints.
"""
if model_cfg is None and args is not None:
logger.warn(
"using 'args' is deprecated, please update your code to use dataclass config"
)
model_cfg = convert_namespace_to_omegaconf(args).model
self.upgrade_state_dict(state_dict)
from fairseq.checkpoint_utils import prune_state_dict
new_state_dict = prune_state_dict(state_dict, model_cfg)
return super().load_state_dict(new_state_dict, strict)
class FairseqLanguageModel(BaseFairseqModel):
"""Base class for decoder-only models.
Args:
decoder (FairseqDecoder): the decoder
"""
def __init__(self, decoder):
super().__init__()
self.decoder = decoder
check_type(self.decoder, FairseqDecoder)
def forward(self, src_tokens, **kwargs):
"""
Run the forward pass for a decoder-only model.
Feeds a batch of tokens through the decoder to predict the next tokens.
Args:
src_tokens (LongTensor): tokens on which to condition the decoder,
of shape `(batch, tgt_len)`
src_lengths (LongTensor): source sentence lengths of shape `(batch)`
Returns:
tuple:
- the decoder's output of shape `(batch, seq_len, vocab)`
- a dictionary with any model-specific outputs
"""
return self.decoder(src_tokens, **kwargs)
def forward_decoder(self, prev_output_tokens, **kwargs):
return self.decoder(prev_output_tokens, **kwargs)
def extract_features(self, src_tokens, **kwargs):
"""
Similar to *forward* but only return features.
Returns:
tuple:
- the decoder's features of shape `(batch, seq_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
return self.decoder.extract_features(src_tokens, **kwargs)
def output_layer(self, features, **kwargs):
"""Project features to the default output size (typically vocabulary size)."""
return self.decoder.output_layer(features, **kwargs)
def max_positions(self):
"""Maximum length supported by the model."""
return self.decoder.max_positions()
def max_decoder_positions(self):
"""Maximum length supported by the decoder."""
return self.decoder.max_positions()
@property
def supported_targets(self):
return {"future"}
class FairseqEncoderModel(BaseFairseqModel):
"""Base class for encoder-only models.
Args:
encoder (FairseqEncoder): the encoder
"""
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
check_type(self.encoder, FairseqEncoder)
def forward(self, src_tokens, src_lengths, **kwargs):
"""
Run the forward pass for a encoder-only model.
Feeds a batch of tokens through the encoder to generate features.
Args:
src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
src_lengths (LongTensor): source sentence lengths of shape `(batch)`
Returns:
the encoder's output, typically of shape `(batch, src_len, features)`
"""
return self.encoder(src_tokens, src_lengths, **kwargs)
def get_normalized_probs(self, net_output, log_probs, sample=None):
"""Get normalized probabilities (or log probs) from a net's output."""
encoder_out = net_output["encoder_out"]
if torch.is_tensor(encoder_out):
logits = encoder_out.float()
if log_probs:
return F.log_softmax(logits, dim=-1)
else:
return F.softmax(logits, dim=-1)
raise NotImplementedError
def max_positions(self):
"""Maximum length supported by the model."""
return self.encoder.max_positions()
|