|
import torch |
|
import torch.nn as nn |
|
from transformers import PreTrainedModel, CLIPModel, GPT2Model |
|
from typing import Optional, Union, Dict, Tuple |
|
from .configuration_arlow_gpt import ArlowGPTConfig |
|
|
|
class ArlowGPTPreTrainedModel(PreTrainedModel): |
|
config_class = ArlowGPTConfig |
|
base_model_prefix = "arlow_gpt" |
|
supports_gradient_checkpointing = True |
|
_keys_to_ignore_on_load_missing = [r"clip", r"gpt2"] |
|
|
|
def _init_weights(self, module): |
|
if isinstance(module, nn.Linear): |
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
|
|
class ArlowGPTModel(ArlowGPTPreTrainedModel): |
|
def __init__(self, config: ArlowGPTConfig): |
|
super().__init__(config) |
|
|
|
|
|
if hasattr(config, "_name_or_path"): |
|
self.clip = CLIPModel.from_pretrained( |
|
f"{config._name_or_path}/clip", |
|
local_files_only=True |
|
) |
|
self.gpt2 = GPT2Model.from_pretrained( |
|
f"{config._name_or_path}/gpt2", |
|
local_files_only=True |
|
) |
|
else: |
|
self.clip = CLIPModel.from_pretrained(config.clip_model_name) |
|
self.gpt2 = GPT2Model.from_pretrained(config.gpt2_model_name) |
|
|
|
self.feature_projection = nn.Linear( |
|
self.clip.vision_model.config.hidden_size + self.gpt2.config.hidden_size, |
|
config.projection_dim |
|
) |
|
|
|
|
|
self.post_init() |
|
|
|
def save_pretrained(self, save_directory, **kwargs): |
|
"""Override save_pretrained to save sub-models separately""" |
|
super().save_pretrained(save_directory, **kwargs) |
|
|
|
|
|
self.clip.save_pretrained(f"{save_directory}/clip") |
|
self.gpt2.save_pretrained(f"{save_directory}/gpt2") |
|
|
|
class ArlowGPTForImageTextToText(ArlowGPTPreTrainedModel): |
|
def __init__(self, config: ArlowGPTConfig): |
|
super().__init__(config) |
|
self.arlow_gpt = ArlowGPTModel(config) |
|
self.output_projection = nn.Linear(config.projection_dim, config.vocab_size) |
|
|
|
|
|
self.post_init() |
|
|
|
def save_pretrained(self, save_directory, **kwargs): |
|
"""Override save_pretrained to save all components""" |
|
super().save_pretrained(save_directory, **kwargs) |
|
self.arlow_gpt.save_pretrained(save_directory) |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
|
"""Override from_pretrained to handle custom loading logic""" |
|
config = kwargs.get("config", None) |
|
if config is None: |
|
config = ArlowGPTConfig.from_pretrained(pretrained_model_name_or_path) |
|
kwargs["config"] = config |
|
|
|
|
|
config._name_or_path = pretrained_model_name_or_path |
|
|
|
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
pixel_values: torch.Tensor, |
|
labels: Optional[torch.Tensor] = None, |
|
return_dict: bool = True, |
|
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: |
|
outputs = self.arlow_gpt( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
pixel_values=pixel_values, |
|
return_dict=True |
|
) |
|
|
|
hidden_states = outputs["hidden_states"] |
|
logits = self.output_projection(hidden_states) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss_fct = nn.CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) |
|
|
|
if return_dict: |
|
return { |
|
"loss": loss, |
|
"logits": logits |
|
} |
|
return (loss, logits) if loss is not None else logits |
|
|
|
def prepare_inputs_for_generation( |
|
self, input_ids, past=None, attention_mask=None, **kwargs |
|
): |
|
|
|
if past: |
|
input_ids = input_ids[:, -1].unsqueeze(-1) |
|
if attention_mask is not None: |
|
attention_mask = attention_mask[:, -1].unsqueeze(-1) |
|
|
|
return { |
|
"input_ids": input_ids, |
|
"attention_mask": attention_mask, |
|
"pixel_values": kwargs.get("pixel_values", None), |
|
"past_key_values": past, |
|
} |