G2PTL / G2PTL_utils.py
jinyan218's picture
G2PTL Init
9c0f93c
raw
history blame contribute delete
No virus
54.6 kB
#! python3
# -*- encoding: utf-8 -*-
import torch.utils.checkpoint
from torch import nn
from transformers.utils import logging
import inspect
from typing import Set, Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union
import re
import math
from typing import Optional, Tuple
from transformers.models.ernie.modeling_ernie import *
import torch
from fairseq import utils
from fairseq.modules.fairseq_dropout import FairseqDropout
from fairseq.modules.quant_noise import quant_noise
from torch import Tensor, nn
from torch.hub import load_state_dict_from_url
import torch.distributed as dist
from torch.hub import load_state_dict_from_url
import torch.distributed as dist
PRETRAINED_MODEL_URLS = {
"pcqm4mv1_graphormer_base":"https://ml2md.blob.core.windows.net/graphormer-ckpts/checkpoint_best_pcqm4mv1.pt",
"pcqm4mv2_graphormer_base":"https://ml2md.blob.core.windows.net/graphormer-ckpts/checkpoint_best_pcqm4mv2.pt",
"oc20is2re_graphormer3d_base":"https://szheng.blob.core.windows.net/graphormer/modelzoo/oc20is2re/checkpoint_last_oc20_is2re.pt", # this pretrained model is temporarily unavailable
"pcqm4mv1_graphormer_base_for_molhiv":"https://ml2md.blob.core.windows.net/graphormer-ckpts/checkpoint_base_preln_pcqm4mv1_for_hiv.pt",
}
def load_pretrained_model(pretrained_model_name):
if pretrained_model_name not in PRETRAINED_MODEL_URLS:
raise ValueError("Unknown pretrained model name %s", pretrained_model_name)
if not dist.is_initialized():
return load_state_dict_from_url(PRETRAINED_MODEL_URLS[pretrained_model_name], progress=True)["model"]
else:
pretrained_model = load_state_dict_from_url(PRETRAINED_MODEL_URLS[pretrained_model_name], progress=True, file_name=f"{pretrained_model_name}_{dist.get_rank()}")["model"]
dist.barrier()
return pretrained_model
class MultiheadAttention(nn.Module):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def __init__(
self,
embed_dim,
num_heads,
kdim=None,
vdim=None,
dropout=0.0,
bias=True,
self_attention=False,
q_noise=0.0,
qn_block_size=8,
):
super().__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
self.dropout_module = FairseqDropout(
dropout, module_name=self.__class__.__name__
)
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.self_attention = self_attention
assert self.self_attention, "Only support self attention"
assert not self.self_attention or self.qkv_same_dim, (
"Self-attention requires query, key and " "value to be of the same size"
)
self.k_proj = quant_noise(
nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
)
self.v_proj = quant_noise(
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
)
self.q_proj = quant_noise(
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
)
self.out_proj = quant_noise(
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
)
self.reset_parameters()
self.onnx_trace = False
def prepare_for_onnx_export_(self):
raise NotImplementedError
def reset_parameters(self):
if self.qkv_same_dim:
# Empirically observed the convergence to be much better with
# the scaled initialization
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
else:
nn.init.xavier_uniform_(self.k_proj.weight)
nn.init.xavier_uniform_(self.v_proj.weight)
nn.init.xavier_uniform_(self.q_proj.weight)
nn.init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None:
nn.init.constant_(self.out_proj.bias, 0.0)
def forward(
self,
query,
key: Optional[Tensor],
value: Optional[Tensor],
attn_bias: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
before_softmax: bool = False,
need_head_weights: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
"""Input shape: Time x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
need_weights (bool, optional): return the attention weights,
averaged over heads (default: False).
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
before_softmax (bool, optional): return the raw attention
weights and values before the attention softmax.
need_head_weights (bool, optional): return the attention
weights for each head. Implies *need_weights*. Default:
return the average attention weights over all heads.
"""
if need_head_weights:
need_weights = True
tgt_len, bsz, embed_dim = query.size()
src_len = tgt_len
assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
assert list(query.size()) == [tgt_len, bsz, embed_dim]
if key is not None:
src_len, key_bsz, _ = key.size()
if not torch.jit.is_scripting():
assert key_bsz == bsz
assert value is not None
assert src_len, bsz == value.shape[:2]
q = self.q_proj(query)
k = self.k_proj(query)
v = self.v_proj(query)
q *= self.scaling
q = (
q.contiguous()
.view(tgt_len, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
if k is not None:
k = (
k.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
if v is not None:
v = (
v.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
assert k is not None
assert k.size(1) == src_len
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if attn_bias is not None:
attn_weights += attn_bias.view(bsz * self.num_heads, tgt_len, src_len)
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
attn_weights += attn_mask
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf"),
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if before_softmax:
return attn_weights, v
attn_weights_float = utils.softmax(
attn_weights, dim=-1, onnx_trace=self.onnx_trace
)
attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = self.dropout_module(attn_weights)
assert v is not None
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
attn_weights: Optional[Tensor] = None
if need_weights:
attn_weights = attn_weights_float.view(
bsz, self.num_heads, tgt_len, src_len
).transpose(1, 0)
if not need_head_weights:
# average attention weights over heads
attn_weights = attn_weights.mean(dim=0)
return attn, attn_weights
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
return attn_weights
def upgrade_state_dict_named(self, state_dict, name):
prefix = name + "." if name != "" else ""
items_to_add = {}
keys_to_remove = []
for k in state_dict.keys():
if k.endswith(prefix + "in_proj_weight"):
# in_proj_weight used to be q + k + v with same dimensions
dim = int(state_dict[k].shape[0] / 3)
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
keys_to_remove.append(k)
k_bias = prefix + "in_proj_bias"
if k_bias in state_dict.keys():
dim = int(state_dict[k].shape[0] / 3)
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
dim : 2 * dim
]
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
keys_to_remove.append(prefix + "in_proj_bias")
for k in keys_to_remove:
del state_dict[k]
for key, value in items_to_add.items():
state_dict[key] = value
def init_graphormer_params(module):
"""
Initialize the weights specific to the Graphormer Model.
"""
def normal_(data):
# with FSDP, module params will be on CUDA, so we cast them back to CPU
# so that the RNG is consistent with and without FSDP
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
if isinstance(module, nn.Linear):
normal_(module.weight.data)
if module.bias is not None:
module.bias.data.zero_()
if isinstance(module, nn.Embedding):
normal_(module.weight.data)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if isinstance(module, MultiheadAttention):
normal_(module.q_proj.weight.data)
normal_(module.k_proj.weight.data)
normal_(module.v_proj.weight.data)
def add_start_docstrings(*docstr):
def docstring_decorator(fn):
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
return fn
return docstring_decorator
def add_start_docstrings_to_model_forward(*docstr):
def docstring_decorator(fn):
docstring = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
class_name = f"[`{fn.__qualname__.split('.')[0]}`]"
intro = f" The {class_name} forward method, overrides the `__call__` special method."
note = r"""
<Tip>
Although the recipe for forward pass needs to be defined within this function, one should call the [`Module`]
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.
</Tip>
"""
fn.__doc__ = intro + note + docstring
return fn
return docstring_decorator
def add_end_docstrings(*docstr):
def docstring_decorator(fn):
fn.__doc__ = (fn.__doc__ if fn.__doc__ is not None else "") + "".join(docstr)
return fn
return docstring_decorator
PT_RETURN_INTRODUCTION = r"""
Returns:
[`{full_output_type}`] or `tuple(torch.FloatTensor)`: A [`{full_output_type}`] or a tuple of
`torch.FloatTensor` (if `return_dict=False` is passed or when `config.return_dict=False`) comprising various
elements depending on the configuration ([`{config_class}`]) and inputs.
"""
TF_RETURN_INTRODUCTION = r"""
Returns:
[`{full_output_type}`] or `tuple(tf.Tensor)`: A [`{full_output_type}`] or a tuple of `tf.Tensor` (if
`return_dict=False` is passed or when `config.return_dict=False`) comprising various elements depending on the
configuration ([`{config_class}`]) and inputs.
"""
def _get_indent(t):
"""Returns the indentation in the first line of t"""
search = re.search(r"^(\s*)\S", t)
return "" if search is None else search.groups()[0]
def _convert_output_args_doc(output_args_doc):
"""Convert output_args_doc to display properly."""
# Split output_arg_doc in blocks argument/description
indent = _get_indent(output_args_doc)
blocks = []
current_block = ""
for line in output_args_doc.split("\n"):
# If the indent is the same as the beginning, the line is the name of new arg.
if _get_indent(line) == indent:
if len(current_block) > 0:
blocks.append(current_block[:-1])
current_block = f"{line}\n"
else:
# Otherwise it's part of the description of the current arg.
# We need to remove 2 spaces to the indentation.
current_block += f"{line[2:]}\n"
blocks.append(current_block[:-1])
# Format each block for proper rendering
for i in range(len(blocks)):
blocks[i] = re.sub(r"^(\s+)(\S+)(\s+)", r"\1- **\2**\3", blocks[i])
blocks[i] = re.sub(r":\s*\n\s*(\S)", r" -- \1", blocks[i])
return "\n".join(blocks)
def _prepare_output_docstrings(output_type, config_class, min_indent=None):
"""
Prepares the return part of the docstring using `output_type`.
"""
output_docstring = output_type.__doc__
# Remove the head of the docstring to keep the list of args only
lines = output_docstring.split("\n")
i = 0
while i < len(lines) and re.search(r"^\s*(Args|Parameters):\s*$", lines[i]) is None:
i += 1
if i < len(lines):
params_docstring = "\n".join(lines[(i + 1):])
params_docstring = _convert_output_args_doc(params_docstring)
# Add the return introduction
full_output_type = f"{output_type.__module__}.{output_type.__name__}"
intro = TF_RETURN_INTRODUCTION if output_type.__name__.startswith("TF") else PT_RETURN_INTRODUCTION
intro = intro.format(full_output_type=full_output_type, config_class=config_class)
result = intro + params_docstring
# Apply minimum indent if necessary
if min_indent is not None:
lines = result.split("\n")
# Find the indent of the first nonempty line
i = 0
while len(lines[i]) == 0:
i += 1
indent = len(_get_indent(lines[i]))
# If too small, add indentation to all nonempty lines
if indent < min_indent:
to_add = " " * (min_indent - indent)
lines = [(f"{to_add}{line}" if len(line) > 0 else line) for line in lines]
result = "\n".join(lines)
return result
PT_TOKEN_CLASSIFICATION_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> import torch
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = tokenizer(
... "HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="pt"
... )
>>> with torch.no_grad():
... logits = model(**inputs).logits
>>> predicted_token_class_ids = logits.argmax(-1)
>>> # Note that tokens are classified rather then input words which means that
>>> # there might be more predicted token classes than words.
>>> # Multiple token classes might account for the same word
>>> predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids[0]]
>>> predicted_tokens_classes
{expected_output}
```
```python
>>> labels = predicted_token_class_ids
>>> loss = model(**inputs, labels=labels).loss
>>> round(loss.item(), 2)
{expected_loss}
```
"""
PT_QUESTION_ANSWERING_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> import torch
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
>>> inputs = tokenizer(question, text, return_tensors="pt")
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> answer_start_index = outputs.start_logits.argmax()
>>> answer_end_index = outputs.end_logits.argmax()
>>> predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
>>> tokenizer.decode(predict_answer_tokens)
{expected_output}
```
```python
>>> # target is "nice puppet"
>>> target_start_index = torch.tensor([{qa_target_start_index}])
>>> target_end_index = torch.tensor([{qa_target_end_index}])
>>> outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index)
>>> loss = outputs.loss
>>> round(loss.item(), 2)
{expected_loss}
```
"""
PT_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
Example of single-label classification:
```python
>>> import torch
>>> from transformers import {processor_class}, {model_class}
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> with torch.no_grad():
... logits = model(**inputs).logits
>>> predicted_class_id = logits.argmax().item()
>>> model.config.id2label[predicted_class_id]
{expected_output}
```
```python
>>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
>>> num_labels = len(model.config.id2label)
>>> model = {model_class}.from_pretrained("{checkpoint}", num_labels=num_labels)
>>> labels = torch.tensor([1])
>>> loss = model(**inputs, labels=labels).loss
>>> round(loss.item(), 2)
{expected_loss}
```
Example of multi-label classification:
```python
>>> import torch
>>> from transformers import {processor_class}, {model_class}
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}", problem_type="multi_label_classification")
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> with torch.no_grad():
... logits = model(**inputs).logits
>>> predicted_class_id = logits.argmax().item()
>>> model.config.id2label[predicted_class_id]
{expected_output}
```
```python
>>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
>>> num_labels = len(model.config.id2label)
>>> model = {model_class}.from_pretrained(
... "{checkpoint}", num_labels=num_labels, problem_type="multi_label_classification"
... )
>>> labels = torch.nn.functional.one_hot(torch.tensor([predicted_class_id]), num_classes=num_labels).to(
... torch.float
... )
>>> loss = model(**inputs, labels=labels).loss
>>> loss.backward() # doctest: +IGNORE_RESULT
```
"""
PT_MASKED_LM_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> import torch
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="pt")
>>> with torch.no_grad():
... logits = model(**inputs).logits
>>> # retrieve index of {mask}
>>> mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
>>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
>>> tokenizer.decode(predicted_token_id)
{expected_output}
```
```python
>>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"]
>>> # mask labels of non-{mask} tokens
>>> labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
>>> outputs = model(**inputs, labels=labels)
>>> round(outputs.loss.item(), 2)
{expected_loss}
```
"""
PT_BASE_MODEL_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> import torch
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
```
"""
PT_MULTIPLE_CHOICE_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> import torch
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
>>> choice0 = "It is eaten with a fork and a knife."
>>> choice1 = "It is eaten while held in the hand."
>>> labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1
>>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True)
>>> outputs = model(**{{k: v.unsqueeze(0) for k, v in encoding.items()}}, labels=labels) # batch size is 1
>>> # the linear classifier still needs to be trained
>>> loss = outputs.loss
>>> logits = outputs.logits
```
"""
PT_CAUSAL_LM_SAMPLE = r"""
Example:
```python
>>> import torch
>>> from transformers import {processor_class}, {model_class}
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs, labels=inputs["input_ids"])
>>> loss = outputs.loss
>>> logits = outputs.logits
```
"""
PT_SPEECH_BASE_MODEL_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> import torch
>>> from datasets import load_dataset
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
>>> dataset = dataset.sort("id")
>>> sampling_rate = dataset.features["audio"].sampling_rate
>>> processor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> # audio file is decoded on the fly
>>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
>>> list(last_hidden_states.shape)
{expected_output}
```
"""
PT_SPEECH_CTC_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> from datasets import load_dataset
>>> import torch
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
>>> dataset = dataset.sort("id")
>>> sampling_rate = dataset.features["audio"].sampling_rate
>>> processor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> # audio file is decoded on the fly
>>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
>>> with torch.no_grad():
... logits = model(**inputs).logits
>>> predicted_ids = torch.argmax(logits, dim=-1)
>>> # transcribe speech
>>> transcription = processor.batch_decode(predicted_ids)
>>> transcription[0]
{expected_output}
```
```python
>>> inputs["labels"] = processor(text=dataset[0]["text"], return_tensors="pt").input_ids
>>> # compute loss
>>> loss = model(**inputs).loss
>>> round(loss.item(), 2)
{expected_loss}
```
"""
PT_SPEECH_SEQ_CLASS_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> from datasets import load_dataset
>>> import torch
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
>>> dataset = dataset.sort("id")
>>> sampling_rate = dataset.features["audio"].sampling_rate
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> # audio file is decoded on the fly
>>> inputs = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
>>> with torch.no_grad():
... logits = model(**inputs).logits
>>> predicted_class_ids = torch.argmax(logits, dim=-1).item()
>>> predicted_label = model.config.id2label[predicted_class_ids]
>>> predicted_label
{expected_output}
```
```python
>>> # compute loss - target_label is e.g. "down"
>>> target_label = model.config.id2label[0]
>>> inputs["labels"] = torch.tensor([model.config.label2id[target_label]])
>>> loss = model(**inputs).loss
>>> round(loss.item(), 2)
{expected_loss}
```
"""
PT_SPEECH_FRAME_CLASS_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> from datasets import load_dataset
>>> import torch
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
>>> dataset = dataset.sort("id")
>>> sampling_rate = dataset.features["audio"].sampling_rate
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> # audio file is decoded on the fly
>>> inputs = feature_extractor(dataset[0]["audio"]["array"], return_tensors="pt", sampling_rate=sampling_rate)
>>> with torch.no_grad():
... logits = model(**inputs).logits
>>> probabilities = torch.sigmoid(logits[0])
>>> # labels is a one-hot array of shape (num_frames, num_speakers)
>>> labels = (probabilities > 0.5).long()
>>> labels[0].tolist()
{expected_output}
```
"""
PT_SPEECH_XVECTOR_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> from datasets import load_dataset
>>> import torch
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
>>> dataset = dataset.sort("id")
>>> sampling_rate = dataset.features["audio"].sampling_rate
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> # audio file is decoded on the fly
>>> inputs = feature_extractor(
... [d["array"] for d in dataset[:2]["audio"]], sampling_rate=sampling_rate, return_tensors="pt", padding=True
... )
>>> with torch.no_grad():
... embeddings = model(**inputs).embeddings
>>> embeddings = torch.nn.functional.normalize(embeddings, dim=-1).cpu()
>>> # the resulting embeddings can be used for cosine similarity-based retrieval
>>> cosine_sim = torch.nn.CosineSimilarity(dim=-1)
>>> similarity = cosine_sim(embeddings[0], embeddings[1])
>>> threshold = 0.7 # the optimal threshold is dataset-dependent
>>> if similarity < threshold:
... print("Speakers are not the same!")
>>> round(similarity.item(), 2)
{expected_output}
```
"""
PT_VISION_BASE_MODEL_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> import torch
>>> from datasets import load_dataset
>>> dataset = load_dataset("huggingface/cats-image")
>>> image = dataset["test"]["image"][0]
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = feature_extractor(image, return_tensors="pt")
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
>>> list(last_hidden_states.shape)
{expected_output}
```
"""
PT_VISION_SEQ_CLASS_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> import torch
>>> from datasets import load_dataset
>>> dataset = load_dataset("huggingface/cats-image")
>>> image = dataset["test"]["image"][0]
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = feature_extractor(image, return_tensors="pt")
>>> with torch.no_grad():
... logits = model(**inputs).logits
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_label = logits.argmax(-1).item()
>>> print(model.config.id2label[predicted_label])
{expected_output}
```
"""
PT_SAMPLE_DOCSTRINGS = {
"SequenceClassification": PT_SEQUENCE_CLASSIFICATION_SAMPLE,
"QuestionAnswering": PT_QUESTION_ANSWERING_SAMPLE,
"TokenClassification": PT_TOKEN_CLASSIFICATION_SAMPLE,
"MultipleChoice": PT_MULTIPLE_CHOICE_SAMPLE,
"MaskedLM": PT_MASKED_LM_SAMPLE,
"LMHead": PT_CAUSAL_LM_SAMPLE,
"BaseModel": PT_BASE_MODEL_SAMPLE,
"SpeechBaseModel": PT_SPEECH_BASE_MODEL_SAMPLE,
"CTC": PT_SPEECH_CTC_SAMPLE,
"AudioClassification": PT_SPEECH_SEQ_CLASS_SAMPLE,
"AudioFrameClassification": PT_SPEECH_FRAME_CLASS_SAMPLE,
"AudioXVector": PT_SPEECH_XVECTOR_SAMPLE,
"VisionBaseModel": PT_VISION_BASE_MODEL_SAMPLE,
"ImageClassification": PT_VISION_SEQ_CLASS_SAMPLE,
}
TF_TOKEN_CLASSIFICATION_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> import tensorflow as tf
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = tokenizer(
... "HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="tf"
... )
>>> logits = model(**inputs).logits
>>> predicted_token_class_ids = tf.math.argmax(logits, axis=-1)
>>> # Note that tokens are classified rather then input words which means that
>>> # there might be more predicted token classes than words.
>>> # Multiple token classes might account for the same word
>>> predicted_tokens_classes = [model.config.id2label[t] for t in predicted_token_class_ids[0].numpy().tolist()]
>>> predicted_tokens_classes
{expected_output}
```
```python
>>> labels = predicted_token_class_ids
>>> loss = tf.math.reduce_mean(model(**inputs, labels=labels).loss)
>>> round(float(loss), 2)
{expected_loss}
```
"""
TF_QUESTION_ANSWERING_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> import tensorflow as tf
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
>>> inputs = tokenizer(question, text, return_tensors="tf")
>>> outputs = model(**inputs)
>>> answer_start_index = int(tf.math.argmax(outputs.start_logits, axis=-1)[0])
>>> answer_end_index = int(tf.math.argmax(outputs.end_logits, axis=-1)[0])
>>> predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
>>> tokenizer.decode(predict_answer_tokens)
{expected_output}
```
```python
>>> # target is "nice puppet"
>>> target_start_index = tf.constant([{qa_target_start_index}])
>>> target_end_index = tf.constant([{qa_target_end_index}])
>>> outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index)
>>> loss = tf.math.reduce_mean(outputs.loss)
>>> round(float(loss), 2)
{expected_loss}
```
"""
TF_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> import tensorflow as tf
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
>>> logits = model(**inputs).logits
>>> predicted_class_id = int(tf.math.argmax(logits, axis=-1)[0])
>>> model.config.id2label[predicted_class_id]
{expected_output}
```
```python
>>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
>>> num_labels = len(model.config.id2label)
>>> model = {model_class}.from_pretrained("{checkpoint}", num_labels=num_labels)
>>> labels = tf.constant(1)
>>> loss = model(**inputs, labels=labels).loss
>>> round(float(loss), 2)
{expected_loss}
```
"""
TF_MASKED_LM_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> import tensorflow as tf
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="tf")
>>> logits = model(**inputs).logits
>>> # retrieve index of {mask}
>>> mask_token_index = tf.where((inputs.input_ids == tokenizer.mask_token_id)[0])
>>> selected_logits = tf.gather_nd(logits[0], indices=mask_token_index)
>>> predicted_token_id = tf.math.argmax(selected_logits, axis=-1)
>>> tokenizer.decode(predicted_token_id)
{expected_output}
```
```python
>>> labels = tokenizer("The capital of France is Paris.", return_tensors="tf")["input_ids"]
>>> # mask labels of non-{mask} tokens
>>> labels = tf.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
>>> outputs = model(**inputs, labels=labels)
>>> round(float(outputs.loss), 2)
{expected_loss}
```
"""
TF_BASE_MODEL_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> import tensorflow as tf
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
>>> outputs = model(inputs)
>>> last_hidden_states = outputs.last_hidden_state
```
"""
TF_MULTIPLE_CHOICE_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> import tensorflow as tf
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
>>> choice0 = "It is eaten with a fork and a knife."
>>> choice1 = "It is eaten while held in the hand."
>>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="tf", padding=True)
>>> inputs = {{k: tf.expand_dims(v, 0) for k, v in encoding.items()}}
>>> outputs = model(inputs) # batch size is 1
>>> # the linear classifier still needs to be trained
>>> logits = outputs.logits
```
"""
TF_CAUSAL_LM_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> import tensorflow as tf
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
>>> outputs = model(inputs)
>>> logits = outputs.logits
```
"""
TF_SPEECH_BASE_MODEL_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> from datasets import load_dataset
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
>>> dataset = dataset.sort("id")
>>> sampling_rate = dataset.features["audio"].sampling_rate
>>> processor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> # audio file is decoded on the fly
>>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="tf")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
>>> list(last_hidden_states.shape)
{expected_output}
```
"""
TF_SPEECH_CTC_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> from datasets import load_dataset
>>> import tensorflow as tf
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
>>> dataset = dataset.sort("id")
>>> sampling_rate = dataset.features["audio"].sampling_rate
>>> processor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> # audio file is decoded on the fly
>>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="tf")
>>> logits = model(**inputs).logits
>>> predicted_ids = tf.math.argmax(logits, axis=-1)
>>> # transcribe speech
>>> transcription = processor.batch_decode(predicted_ids)
>>> transcription[0]
{expected_output}
```
```python
>>> inputs["labels"] = processor(text=dataset[0]["text"], return_tensors="tf").input_ids
>>> # compute loss
>>> loss = model(**inputs).loss
>>> round(float(loss), 2)
{expected_loss}
```
"""
TF_VISION_BASE_MODEL_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> from datasets import load_dataset
>>> dataset = load_dataset("huggingface/cats-image")
>>> image = dataset["test"]["image"][0]
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = feature_extractor(image, return_tensors="tf")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
>>> list(last_hidden_states.shape)
{expected_output}
```
"""
TF_VISION_SEQ_CLASS_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> import tensorflow as tf
>>> from datasets import load_dataset
>>> dataset = load_dataset("huggingface/cats-image")
>>> image = dataset["test"]["image"][0]
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = feature_extractor(image, return_tensors="tf")
>>> logits = model(**inputs).logits
>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_label = int(tf.math.argmax(logits, axis=-1))
>>> print(model.config.id2label[predicted_label])
{expected_output}
```
"""
TF_SAMPLE_DOCSTRINGS = {
"SequenceClassification": TF_SEQUENCE_CLASSIFICATION_SAMPLE,
"QuestionAnswering": TF_QUESTION_ANSWERING_SAMPLE,
"TokenClassification": TF_TOKEN_CLASSIFICATION_SAMPLE,
"MultipleChoice": TF_MULTIPLE_CHOICE_SAMPLE,
"MaskedLM": TF_MASKED_LM_SAMPLE,
"LMHead": TF_CAUSAL_LM_SAMPLE,
"BaseModel": TF_BASE_MODEL_SAMPLE,
"SpeechBaseModel": TF_SPEECH_BASE_MODEL_SAMPLE,
"CTC": TF_SPEECH_CTC_SAMPLE,
"VisionBaseModel": TF_VISION_BASE_MODEL_SAMPLE,
"ImageClassification": TF_VISION_SEQ_CLASS_SAMPLE,
}
FLAX_TOKEN_CLASSIFICATION_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
```
"""
FLAX_QUESTION_ANSWERING_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
>>> inputs = tokenizer(question, text, return_tensors="jax")
>>> outputs = model(**inputs)
>>> start_scores = outputs.start_logits
>>> end_scores = outputs.end_logits
```
"""
FLAX_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
```
"""
FLAX_MASKED_LM_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="jax")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
```
"""
FLAX_BASE_MODEL_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
```
"""
FLAX_MULTIPLE_CHOICE_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
>>> choice0 = "It is eaten with a fork and a knife."
>>> choice1 = "It is eaten while held in the hand."
>>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="jax", padding=True)
>>> outputs = model(**{{k: v[None, :] for k, v in encoding.items()}})
>>> logits = outputs.logits
```
"""
FLAX_CAUSAL_LM_SAMPLE = r"""
Example:
```python
>>> from transformers import {processor_class}, {model_class}
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
>>> outputs = model(**inputs)
>>> # retrieve logts for next token
>>> next_token_logits = outputs.logits[:, -1]
```
"""
FLAX_SAMPLE_DOCSTRINGS = {
"SequenceClassification": FLAX_SEQUENCE_CLASSIFICATION_SAMPLE,
"QuestionAnswering": FLAX_QUESTION_ANSWERING_SAMPLE,
"TokenClassification": FLAX_TOKEN_CLASSIFICATION_SAMPLE,
"MultipleChoice": FLAX_MULTIPLE_CHOICE_SAMPLE,
"MaskedLM": FLAX_MASKED_LM_SAMPLE,
"BaseModel": FLAX_BASE_MODEL_SAMPLE,
"LMHead": FLAX_CAUSAL_LM_SAMPLE,
}
def add_code_sample_docstrings(
*docstr,
processor_class=None,
checkpoint=None,
output_type=None,
config_class=None,
mask="[MASK]",
qa_target_start_index=14,
qa_target_end_index=15,
model_cls=None,
modality=None,
expected_output="",
expected_loss="",
):
def docstring_decorator(fn):
# model_class defaults to function's class if not specified otherwise
model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls
if model_class[:2] == "TF":
sample_docstrings = TF_SAMPLE_DOCSTRINGS
elif model_class[:4] == "Flax":
sample_docstrings = FLAX_SAMPLE_DOCSTRINGS
else:
sample_docstrings = PT_SAMPLE_DOCSTRINGS
# putting all kwargs for docstrings in a dict to be used
# with the `.format(**doc_kwargs)`. Note that string might
# be formatted with non-existing keys, which is fine.
doc_kwargs = dict(
model_class=model_class,
processor_class=processor_class,
checkpoint=checkpoint,
mask=mask,
qa_target_start_index=qa_target_start_index,
qa_target_end_index=qa_target_end_index,
expected_output=expected_output,
expected_loss=expected_loss,
)
if "SequenceClassification" in model_class and modality == "audio":
code_sample = sample_docstrings["AudioClassification"]
elif "SequenceClassification" in model_class:
code_sample = sample_docstrings["SequenceClassification"]
elif "QuestionAnswering" in model_class:
code_sample = sample_docstrings["QuestionAnswering"]
elif "TokenClassification" in model_class:
code_sample = sample_docstrings["TokenClassification"]
elif "MultipleChoice" in model_class:
code_sample = sample_docstrings["MultipleChoice"]
elif "MaskedLM" in model_class or model_class in ["FlaubertWithLMHeadModel", "XLMWithLMHeadModel"]:
code_sample = sample_docstrings["MaskedLM"]
elif "LMHead" in model_class or "CausalLM" in model_class:
code_sample = sample_docstrings["LMHead"]
elif "CTC" in model_class:
code_sample = sample_docstrings["CTC"]
elif "AudioFrameClassification" in model_class:
code_sample = sample_docstrings["AudioFrameClassification"]
elif "XVector" in model_class and modality == "audio":
code_sample = sample_docstrings["AudioXVector"]
elif "Model" in model_class and modality == "audio":
code_sample = sample_docstrings["SpeechBaseModel"]
elif "Model" in model_class and modality == "vision":
code_sample = sample_docstrings["VisionBaseModel"]
elif "Model" in model_class or "Encoder" in model_class:
code_sample = sample_docstrings["BaseModel"]
elif "ImageClassification" in model_class:
code_sample = sample_docstrings["ImageClassification"]
else:
raise ValueError(f"Docstring can't be built for model {model_class}")
func_doc = (fn.__doc__ or "") + "".join(docstr)
output_doc = "" if output_type is None else _prepare_output_docstrings(output_type, config_class)
built_doc = code_sample.format(**doc_kwargs)
fn.__doc__ = func_doc + output_doc + built_doc
return fn
return docstring_decorator
def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = 0) -> nn.Linear:
"""
Prune a linear layer to keep only entries in index.
Used to remove heads.
Args:
layer (`torch.nn.Linear`): The layer to prune.
index (`torch.LongTensor`): The indices to keep in the layer.
dim (`int`, *optional*, defaults to 0): The dimension on which to keep the indices.
Returns:
`torch.nn.Linear`: The pruned layer as a new layer with `requires_grad=True`.
"""
index = index.to(layer.weight.device)
W = layer.weight.index_select(dim, index).clone().detach()
if layer.bias is not None:
if dim == 1:
b = layer.bias.clone().detach()
else:
b = layer.bias[index].clone().detach()
new_size = list(layer.weight.size())
new_size[dim] = len(index)
new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
new_layer.weight.requires_grad = False
new_layer.weight.copy_(W.contiguous())
new_layer.weight.requires_grad = True
if layer.bias is not None:
new_layer.bias.requires_grad = False
new_layer.bias.copy_(b.contiguous())
new_layer.bias.requires_grad = True
return new_layer
def apply_chunking_to_forward(
forward_fn: Callable[..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors
) -> torch.Tensor:
"""
This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension
`chunk_dim`. It then applies a layer `forward_fn` to each chunk independently to save memory.
If the `forward_fn` is independent across the `chunk_dim` this function will yield the same result as directly
applying `forward_fn` to `input_tensors`.
Args:
forward_fn (`Callable[..., torch.Tensor]`):
The forward function of the model.
chunk_size (`int`):
The chunk size of a chunked tensor: `num_chunks = len(input_tensors[0]) / chunk_size`.
chunk_dim (`int`):
The dimension over which the `input_tensors` should be chunked.
input_tensors (`Tuple[torch.Tensor]`):
The input tensors of `forward_fn` which will be chunked
Returns:
`torch.Tensor`: A tensor with the same shape as the `forward_fn` would have given if applied`.
Examples:
```python
# rename the usual forward() fn to forward_chunk()
def forward_chunk(self, hidden_states):
hidden_states = self.decoder(hidden_states)
return hidden_states
# implement a chunked forward function
def forward(self, hidden_states):
return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
```"""
assert len(input_tensors) > 0, f"{input_tensors} has to be a tuple/list of tensors"
# inspect.signature exist since python 3.5 and is a python method -> no problem with backward compatibility
num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
if num_args_in_forward_chunk_fn != len(input_tensors):
raise ValueError(
f"forward_chunk_fn expects {num_args_in_forward_chunk_fn} arguments, but only {len(input_tensors)} input "
"tensors are given"
)
if chunk_size > 0:
tensor_shape = input_tensors[0].shape[chunk_dim]
for input_tensor in input_tensors:
if input_tensor.shape[chunk_dim] != tensor_shape:
raise ValueError(
f"All input tenors have to be of the same shape: {tensor_shape}, "
f"found shape {input_tensor.shape[chunk_dim]}"
)
if input_tensors[0].shape[chunk_dim] % chunk_size != 0:
raise ValueError(
f"The dimension to be chunked {input_tensors[0].shape[chunk_dim]} has to be a multiple of the chunk "
f"size {chunk_size}"
)
num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
# chunk input tensor into tuples
input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
# apply forward fn to every tuple
output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
# concatenate output at same dimension
return torch.cat(output_chunks, dim=chunk_dim)
return forward_fn(*input_tensors)
def find_pruneable_heads_and_indices(
heads: List[int], n_heads: int, head_size: int, already_pruned_heads: Set[int]
) -> Tuple[Set[int], torch.LongTensor]:
"""
Finds the heads and their indices taking `already_pruned_heads` into account.
Args:
heads (`List[int]`): List of the indices of heads to prune.
n_heads (`int`): The number of heads in the model.
head_size (`int`): The size of each head.
already_pruned_heads (`Set[int]`): A set of already pruned heads.
Returns:
`Tuple[Set[int], torch.LongTensor]`: A tuple with the remaining heads and their corresponding indices.
"""
mask = torch.ones(n_heads, head_size)
heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads
for head in heads:
# Compute how many pruned heads are before the head and move the index accordingly
head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index: torch.LongTensor = torch.arange(len(mask))[mask].long()
return heads, index