File size: 4,677 Bytes
7e2eb5c 53d9a84 9e86c3d 2a6dccf 524e81c 9e86c3d 524e81c 7aff387 e0eae5b 7aff387 2a6dccf 7aff387 9e86c3d 524e81c 53d9a84 9e86c3d 53d9a84 9e86c3d 53d9a84 524e81c 7aff387 e0eae5b 53d9a84 e0eae5b 53d9a84 9e86c3d e0eae5b 9dacb54 e0eae5b 53d9a84 9dacb54 53d9a84 e0eae5b 2a6dccf 9dacb54 9e86c3d 9dacb54 9e86c3d 9dacb54 9e86c3d 9dacb54 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
# modeling_arlow_gpt.py
import torch
import torch.nn as nn
from transformers import PreTrainedModel, CLIPModel, GPT2Model
from typing import Optional, Union, Dict, Tuple
from .configuration_arlow_gpt import ArlowGPTConfig
class ArlowGPTPreTrainedModel(PreTrainedModel):
config_class = ArlowGPTConfig
base_model_prefix = "arlow_gpt"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"clip", r"gpt2"]
def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
class ArlowGPTModel(ArlowGPTPreTrainedModel):
def __init__(self, config: ArlowGPTConfig):
super().__init__(config)
# Load models with local weights if available
if hasattr(config, "_name_or_path"):
self.clip = CLIPModel.from_pretrained(
f"{config._name_or_path}/clip",
local_files_only=True
)
self.gpt2 = GPT2Model.from_pretrained(
f"{config._name_or_path}/gpt2",
local_files_only=True
)
else:
self.clip = CLIPModel.from_pretrained(config.clip_model_name)
self.gpt2 = GPT2Model.from_pretrained(config.gpt2_model_name)
self.feature_projection = nn.Linear(
self.clip.vision_model.config.hidden_size + self.gpt2.config.hidden_size,
config.projection_dim
)
# Initialize weights and apply final processing
self.post_init()
def save_pretrained(self, save_directory, **kwargs):
"""Override save_pretrained to save sub-models separately"""
super().save_pretrained(save_directory, **kwargs)
# Save CLIP and GPT-2 models in subdirectories
self.clip.save_pretrained(f"{save_directory}/clip")
self.gpt2.save_pretrained(f"{save_directory}/gpt2")
class ArlowGPTForImageTextToText(ArlowGPTPreTrainedModel):
def __init__(self, config: ArlowGPTConfig):
super().__init__(config)
self.arlow_gpt = ArlowGPTModel(config)
self.output_projection = nn.Linear(config.projection_dim, config.vocab_size)
# Initialize weights and apply final processing
self.post_init()
def save_pretrained(self, save_directory, **kwargs):
"""Override save_pretrained to save all components"""
super().save_pretrained(save_directory, **kwargs)
self.arlow_gpt.save_pretrained(save_directory)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"""Override from_pretrained to handle custom loading logic"""
config = kwargs.get("config", None)
if config is None:
config = ArlowGPTConfig.from_pretrained(pretrained_model_name_or_path)
kwargs["config"] = config
# Set the path for loading sub-models
config._name_or_path = pretrained_model_name_or_path
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
pixel_values: torch.Tensor,
labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
outputs = self.arlow_gpt(
input_ids=input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values,
return_dict=True
)
hidden_states = outputs["hidden_states"]
logits = self.output_projection(hidden_states)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if return_dict:
return {
"loss": loss,
"logits": logits
}
return (loss, logits) if loss is not None else logits
def prepare_inputs_for_generation(
self, input_ids, past=None, attention_mask=None, **kwargs
):
# only last token for inputs_ids if past is defined in kwargs
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
if attention_mask is not None:
attention_mask = attention_mask[:, -1].unsqueeze(-1)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"pixel_values": kwargs.get("pixel_values", None),
"past_key_values": past,
} |