# modeling_arlow_gpt.py from transformers import PreTrainedModel, PreTrainedModel, CLIPModel, GPT2Model from transformers.modeling_outputs import Seq2SeqLMOutput from typing import Optional, Union, Dict, Tuple from .configuration_arlow_gpt import ArlowGPTConfig class ArlowGPTPreTrainedModel(PreTrainedModel): config_class = ArlowGPTConfig base_model_prefix = "arlow_gpt" supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"clip", r"gpt2"] def _init_weights(self, module): if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=0.02) if module.bias is not None: module.bias.data.zero_() class ArlowGPTModel(ArlowGPTPreTrainedModel): # Same as before class ArlowGPTForImageTextToText(ArlowGPTPreTrainedModel): def __init__(self, config: ArlowGPTConfig): super().__init__(config) self.arlow_gpt = ArlowGPTModel(config) self.lm_head = nn.Linear(config.projection_dim, config.vocab_size) # Initialize weights and apply final processing self.post_init() def save_pretrained(self, save_directory, **kwargs): """Override save_pretrained to save all components""" super().save_pretrained(save_directory, **kwargs) self.arlow_gpt.save_pretrained(save_directory) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): """Override from_pretrained to handle custom loading logic""" config = kwargs.get("config", None) if config is None: config = ArlowGPTConfig.from_pretrained(pretrained_model_name_or_path) kwargs["config"] = config config._name_or_path = pretrained_model_name_or_path return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, pixel_values: torch.Tensor, labels: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: outputs = self.arlow_gpt( input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values, return_dict=True ) hidden_states = outputs["hidden_states"] logits = self.lm_head(hidden_states) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) if return_dict: return Seq2SeqLMOutput( loss=loss, logits=logits, ) return (loss, logits) if loss is not None else logits def prepare_inputs_for_generation( self, input_ids, past=None, attention_mask=None, **kwargs ): # only last token for inputs_ids if past is defined in kwargs if past: input_ids = input_ids[:, -1].unsqueeze(-1) if attention_mask is not None: attention_mask = attention_mask[:, -1].unsqueeze(-1) return { "input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": kwargs.get("pixel_values", None), "past_key_values": past, }