flamingo-2024 / modeling_flamingo.py
Chengxu Zhuang
minor fix for causal mask
272b3c6
import random
import pdb
from einops import rearrange
from typing import List, Optional, Tuple, Union
import os
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss, MSELoss
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
import transformers.models.opt.modeling_opt as modeling_opt
from transformers.models.opt.modeling_opt\
import OPTDecoderLayer, OPTPreTrainedModel, OPTConfig
from transformers import ViTModel
try:
from transformers.models.opt.modeling_opt import _prepare_4d_causal_attention_mask
except:
_prepare_4d_causal_attention_mask = None
from .utils import exists, freeze_all_layers_, unfreeze_all_layers_
from .flamingo_pytorch import GatedCrossAttentionBlock, PerceiverResampler
from .configuration_flamingo import FlamingoConfig
class OPTLearnedPositionalEmbedding(nn.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, num_embeddings: int, embedding_dim: int):
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)
def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
attention_mask = attention_mask.long()
# create positions depending on attention_mask
positions = torch.cumsum(attention_mask, dim=1)
positions = (positions.type_as(attention_mask) * attention_mask).long() - 1
# cut positions if `past_key_values_length` is > 0
positions = positions[:, past_key_values_length:]
return super().forward(positions + self.offset)
class OPTDecoder(modeling_opt.OPTDecoder):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`OPTDecoderLayer`]
Args:
config: OPTConfig
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self, config: OPTConfig):
OPTPreTrainedModel.__init__(self, config)
self.dropout = config.dropout
self.layerdrop = config.layerdrop
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings
self.vocab_size = config.vocab_size
self.media_token_id = config.media_token_id
self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx)
self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size)
if config.word_embed_proj_dim != config.hidden_size:
self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False)
else:
self.project_out = None
if config.word_embed_proj_dim != config.hidden_size:
self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False)
else:
self.project_in = None
# Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility
# with checkpoints that have been fine-tuned before transformers v4.20.1
# see https://github.com/facebookresearch/metaseq/pull/164
if config.do_layer_norm_before and not config._remove_final_layer_norm:
self.final_layer_norm = nn.LayerNorm(config.hidden_size)
else:
self.final_layer_norm = None
dim_head = config.hidden_size // config.num_attention_heads
if not config.id_perceiver:
self.perceiver_resampler = PerceiverResampler(
dim=config.hidden_size,
depth=config.perceiver_depth,
dim_head=dim_head,
heads=config.num_attention_heads,
num_latents=config.perceiver_num_latents,
inp_dim=config.inp_dim,
)
else:
if config.inp_dim is None:
self.perceiver_resampler = nn.Identity()
else:
self.perceiver_resampler = nn.Linear(
config.inp_dim, config.hidden_size,
bias=False)
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gated_attn_layers = nn.ModuleList(
[GatedCrossAttentionBlock(
dim=config.hidden_size, dim_head=dim_head, heads=config.num_attention_heads,
only_attend_immediate_media=config.only_attend_immediate_media)\
if not (ind % config.cross_attn_every) else None \
for ind in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
# in flamingo mode, freeze everything but perceiver and gated cross attention
if not config.finetune_LM:
freeze_all_layers_(self)
unfreeze_all_layers_(self.perceiver_resampler)
[unfreeze_all_layers_(cross_attn) for cross_attn in self.gated_attn_layers if exists(cross_attn)]
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values=None,
image_embeds=None
) -> Union[Tuple, BaseModelOutputWithPast]:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
batch, device = input_ids.shape[0], input_ids.device
flamingo_mode = exists(pixel_values) or exists(image_embeds)
# derive the media token ids (as a boolean tensor), for calculating the masked cross attention
if flamingo_mode:
media_locations = input_ids == self.media_token_id
assert not (exists(pixel_values) and exists(image_embeds))
# encode images into embeddings
# with the img_encoder passed in at init
# it can also accept precomputed image embeddings
if exists(pixel_values):
assert exists(self.img_encoder), 'img_encoder must be passed in for automatic image encoding'
if len(pixel_values.shape) == 4:
pixel_values = torch.unsqueeze(pixel_values, 1)
pixel_values = rearrange(pixel_values, 'b t ... -> (b t) ...')
with torch.no_grad():
if getattr(self.img_encoder, 'vision_model', None) is not None:
image_outputs = self.img_encoder.vision_model(
pixel_values=pixel_values,
output_hidden_states=True, return_dict=True)
else:
image_outputs = self.img_encoder(
pixel_values=pixel_values,
output_hidden_states=True, return_dict=True)
image_embeds = image_outputs['last_hidden_state']
image_embeds = rearrange(image_embeds, '(b t) ... -> b t ...', b = batch)
if exists(image_embeds):
image_embeds = self.perceiver_resampler(image_embeds)
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device)
pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
if _prepare_4d_causal_attention_mask is None:
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
else:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)
if self.project_in is not None:
inputs_embeds = self.project_in(inputs_embeds)
hidden_states = inputs_embeds + pos_embeds
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask], ["head_mask"]):
if attn_mask is not None:
if attn_mask.size()[0] != (len(self.layers)):
raise ValueError(
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
f" {head_mask.size()[0]}."
)
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop):
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None
flamingo_cross_attn = self.gated_attn_layers[idx]
if exists(flamingo_cross_attn) and exists(image_embeds):
hidden_states = flamingo_cross_attn(
hidden_states,
image_embeds,
media_locations = media_locations
)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states)
if self.project_out is not None:
hidden_states = self.project_out(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class OPTModel(modeling_opt.OPTModel):
def __init__(self, config: OPTConfig):
OPTPreTrainedModel.__init__(self, config)
self.decoder = OPTDecoder(config)
# Initialize weights and apply final processing
self.post_init()
class OPTForCausalLM(modeling_opt.OPTForCausalLM):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
def __init__(self, config):
OPTPreTrainedModel.__init__(self, config)
self.model = OPTModel(config)
# the lm_head weight is automatically tied to the embed tokens weight
self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def set_default_if_nonexist(config, key, value):
if getattr(config, key, None) is None:
setattr(config, key, value)
return config
def setup_default_flamingo_configs(config):
set_default_if_nonexist(config, 'perceiver_depth', 2)
set_default_if_nonexist(config, 'perceiver_num_latents', 64)
set_default_if_nonexist(config, 'cross_attn_every', 3)
set_default_if_nonexist(config, 'only_attend_immediate_media', True)
set_default_if_nonexist(config, 'media_token_id', 50265)
set_default_if_nonexist(config, 'inp_dim', 768)
set_default_if_nonexist(config, 'finetune_LM', True)
set_default_if_nonexist(config, 'id_perceiver', False)
return config
class FlamingoForCausalLM(modeling_opt.OPTForCausalLM):
_keys_to_ignore_on_load_missing = [
r"lm_head.weight",
]
config_class = FlamingoConfig
def __init__(self, config):
OPTPreTrainedModel.__init__(self, config)
config = setup_default_flamingo_configs(config)
self.model = OPTModel(config)
# the lm_head weight is automatically tied to the embed tokens weight
self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
self.model.decoder.img_encoder = None
self.loss_fct = CrossEntropyLoss()
dino_model = ViTModel.from_pretrained("facebook/dino-vitb16")
self.setup_vis_encoder(dino_model)
def setup_vis_encoder(self, img_encoder):
self.model.decoder.img_encoder = img_encoder
freeze_all_layers_(img_encoder)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
*args, **kwargs) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using [`OPTTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
Returns:
Example:
```python
>>> from transformers import GPT2Tokenizer, OPTForCausalLM
>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
>>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
>>> prompt = "Hey, are you consciours? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
*args, **kwargs)
logits = self.lm_head(outputs[0]).contiguous()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss = self.loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class FlamingoForSequenceClassification(OPTPreTrainedModel):
_keys_to_ignore_on_load_missing = [
r"score.weight",
]
def __init__(self, config: OPTConfig):
OPTPreTrainedModel.__init__(self, config)
config = setup_default_flamingo_configs(config)
self.num_labels = config.num_labels
self.model = OPTModel(config)
# the lm_head weight is automatically tied to the embed tokens weight
self.score = nn.Linear(config.word_embed_proj_dim, self.num_labels, bias=False)
# Initialize weights and apply final processing
self.post_init()
self.model.decoder.img_encoder = None
self.loss_fct = CrossEntropyLoss()
dino_model = ViTModel.from_pretrained("facebook/dino-vitb16")
self.setup_vis_encoder(dino_model)
def setup_vis_encoder(self, img_encoder):
self.model.decoder.img_encoder = img_encoder
freeze_all_layers_(img_encoder)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
*args, **kwargs) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
head_mask=head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
*args, **kwargs)
hidden_states = outputs[0]
logits = self.score(hidden_states)
if input_ids is not None:
batch_size, sequence_length = input_ids.shape[:2]
else:
batch_size, sequence_length = inputs_embeds.shape[:2]
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
# logger.warning(
# f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
# "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
# )
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def get_input_embeddings(self):
return self.model.decoder.embed_tokens
def set_input_embeddings(self, value):
self.model.decoder.embed_tokens = value