import torch import torch.nn as nn from transformers import PreTrainedModel, CLIPModel, GPT2Model from typing import Optional, Union, Dict, Tuple from .configuration_arlow_gpt import ArlowGPTConfig class ArlowGPTPreTrainedModel(PreTrainedModel): config_class = ArlowGPTConfig base_model_prefix = "arlow_gpt" supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"clip", r"gpt2"] def _init_weights(self, module): if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=0.02) if module.bias is not None: module.bias.data.zero_() class ArlowGPTModel(ArlowGPTPreTrainedModel): def __init__(self, config: ArlowGPTConfig): super().__init__(config) # Load models with local weights if available if hasattr(config, "_name_or_path"): self.clip = CLIPModel.from_pretrained( f"{config._name_or_path}/clip", local_files_only=True ) self.gpt2 = GPT2Model.from_pretrained( f"{config._name_or_path}/gpt2", local_files_only=True ) else: self.clip = CLIPModel.from_pretrained(config.clip_model_name) self.gpt2 = GPT2Model.from_pretrained(config.gpt2_model_name) self.feature_projection = nn.Linear( self.clip.vision_model.config.hidden_size + self.gpt2.config.hidden_size, config.projection_dim ) # Initialize weights and apply final processing self.post_init() def save_pretrained(self, save_directory, **kwargs): """Override save_pretrained to save sub-models separately""" super().save_pretrained(save_directory, **kwargs) # Save CLIP and GPT-2 models in subdirectories self.clip.save_pretrained(f"{save_directory}/clip") self.gpt2.save_pretrained(f"{save_directory}/gpt2") class ArlowGPTForImageTextToText(ArlowGPTPreTrainedModel): def __init__(self, config: ArlowGPTConfig): super().__init__(config) self.arlow_gpt = ArlowGPTModel(config) self.output_projection = nn.Linear(config.projection_dim, config.vocab_size) # Initialize weights and apply final processing self.post_init() def save_pretrained(self, save_directory, **kwargs): """Override save_pretrained to save all components""" super().save_pretrained(save_directory, **kwargs) self.arlow_gpt.save_pretrained(save_directory) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): """Override from_pretrained to handle custom loading logic""" config = kwargs.get("config", None) if config is None: config = ArlowGPTConfig.from_pretrained(pretrained_model_name_or_path) kwargs["config"] = config # Set the path for loading sub-models config._name_or_path = pretrained_model_name_or_path return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, pixel_values: torch.Tensor, labels: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: outputs = self.arlow_gpt( input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values, return_dict=True ) hidden_states = outputs["hidden_states"] logits = self.output_projection(hidden_states) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) if return_dict: return { "loss": loss, "logits": logits } return (loss, logits) if loss is not None else logits def prepare_inputs_for_generation( self, input_ids, past=None, attention_mask=None, **kwargs ): # only last token for inputs_ids if past is defined in kwargs if past: input_ids = input_ids[:, -1].unsqueeze(-1) if attention_mask is not None: attention_mask = attention_mask[:, -1].unsqueeze(-1) return { "input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": kwargs.get("pixel_values", None), "past_key_values": past, }