GPT-2V / modeling_arlow_gpt.py
yuchenxie's picture
Update modeling_arlow_gpt.py
9e86c3d verified
raw
history blame
4.65 kB
import torch
import torch.nn as nn
from transformers import PreTrainedModel, CLIPModel, GPT2Model
from typing import Optional, Union, Dict, Tuple
from .configuration_arlow_gpt import ArlowGPTConfig
class ArlowGPTPreTrainedModel(PreTrainedModel):
config_class = ArlowGPTConfig
base_model_prefix = "arlow_gpt"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"clip", r"gpt2"]
def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
class ArlowGPTModel(ArlowGPTPreTrainedModel):
def __init__(self, config: ArlowGPTConfig):
super().__init__(config)
# Load models with local weights if available
if hasattr(config, "_name_or_path"):
self.clip = CLIPModel.from_pretrained(
f"{config._name_or_path}/clip",
local_files_only=True
)
self.gpt2 = GPT2Model.from_pretrained(
f"{config._name_or_path}/gpt2",
local_files_only=True
)
else:
self.clip = CLIPModel.from_pretrained(config.clip_model_name)
self.gpt2 = GPT2Model.from_pretrained(config.gpt2_model_name)
self.feature_projection = nn.Linear(
self.clip.vision_model.config.hidden_size + self.gpt2.config.hidden_size,
config.projection_dim
)
# Initialize weights and apply final processing
self.post_init()
def save_pretrained(self, save_directory, **kwargs):
"""Override save_pretrained to save sub-models separately"""
super().save_pretrained(save_directory, **kwargs)
# Save CLIP and GPT-2 models in subdirectories
self.clip.save_pretrained(f"{save_directory}/clip")
self.gpt2.save_pretrained(f"{save_directory}/gpt2")
class ArlowGPTForImageTextToText(ArlowGPTPreTrainedModel):
def __init__(self, config: ArlowGPTConfig):
super().__init__(config)
self.arlow_gpt = ArlowGPTModel(config)
self.output_projection = nn.Linear(config.projection_dim, config.vocab_size)
# Initialize weights and apply final processing
self.post_init()
def save_pretrained(self, save_directory, **kwargs):
"""Override save_pretrained to save all components"""
super().save_pretrained(save_directory, **kwargs)
self.arlow_gpt.save_pretrained(save_directory)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"""Override from_pretrained to handle custom loading logic"""
config = kwargs.get("config", None)
if config is None:
config = ArlowGPTConfig.from_pretrained(pretrained_model_name_or_path)
kwargs["config"] = config
# Set the path for loading sub-models
config._name_or_path = pretrained_model_name_or_path
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
pixel_values: torch.Tensor,
labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
outputs = self.arlow_gpt(
input_ids=input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values,
return_dict=True
)
hidden_states = outputs["hidden_states"]
logits = self.output_projection(hidden_states)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if return_dict:
return {
"loss": loss,
"logits": logits
}
return (loss, logits) if loss is not None else logits
def prepare_inputs_for_generation(
self, input_ids, past=None, attention_mask=None, **kwargs
):
# only last token for inputs_ids if past is defined in kwargs
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
if attention_mask is not None:
attention_mask = attention_mask[:, -1].unsqueeze(-1)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"pixel_values": kwargs.get("pixel_values", None),
"past_key_values": past,
}