--- license: apache-2.0 language: - en base_model: - yuchenxie/GPT-2 - yuchenxie/CLiP library_name: transformers --- # Check config.json file. # Merging script: ```python import os import shutil from pathlib import Path from typing import Optional, Dict, Union import torch from torch import nn from transformers import ( CLIPModel, GPT2LMHeadModel, GPT2Tokenizer, CLIPProcessor, PretrainedConfig, AutoConfig, ) from safetensors.torch import save_file, load_file class ArlowGPTConfig(PretrainedConfig): model_type = "ArlowGPT" # Use the desired architecture name def __init__( self, clip_model_name: str = "yuchenxie/CLiP", gpt2_model_name: str = "yuchenxie/GPT-2", clip_config: Optional[Dict] = None, gpt2_config: Optional[Dict] = None, projection_dim: int = 768, vocab_size: int = 50257, **kwargs ): super().__init__(**kwargs) self.clip_model_name = clip_model_name self.gpt2_model_name = gpt2_model_name self.clip_config = clip_config self.gpt2_config = gpt2_config self.projection_dim = projection_dim self.vocab_size = vocab_size class ArlowGPT(nn.Module): def __init__(self, config: ArlowGPTConfig): super().__init__() print("Initializing ArlowGPT model...") # Load the CLIP model self.clip = CLIPModel.from_pretrained(config.clip_model_name) # Extract the CLIP vision model hidden size clip_hidden_size = self.clip.config.vision_config.hidden_size # Vision model hidden size (1024) gpt2_hidden_size = config.projection_dim # Target hidden size (768) # Add a projection layer to align dimensions self.clip_projection = nn.Linear(clip_hidden_size, gpt2_hidden_size) # Load GPT-2 with cross-attention enabled self.gpt2_config = AutoConfig.from_pretrained(config.gpt2_model_name) self.gpt2_config.add_cross_attention = True self.gpt2 = GPT2LMHeadModel.from_pretrained( config.gpt2_model_name, config=self.gpt2_config ) # Update vocabulary size self.config = config self.config.vocab_size = self.gpt2.config.vocab_size def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, pixel_values: torch.Tensor, labels: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: # Process vision inputs through CLIP vision_outputs = self.clip.vision_model(pixel_values=pixel_values) encoder_hidden_states = vision_outputs.last_hidden_state # Apply projection to align dimensions encoder_hidden_states = self.clip_projection(encoder_hidden_states) # Create attention mask for CLIP embeddings encoder_attention_mask = torch.ones( encoder_hidden_states.size()[:-1], dtype=torch.long ).to(encoder_hidden_states.device) # Process text inputs through GPT-2 with cross-attention outputs = self.gpt2( input_ids=input_ids, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, ) logits = outputs.logits loss = None # Calculate loss if labels are provided if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct( logits.view(-1, self.config.vocab_size), labels.view(-1) ) if return_dict: return {"loss": loss, "logits": logits} return logits def save_merged_safetensor(self, output_dir: str) -> None: state_dict = self.state_dict() # Rename mismatched keys if "clip_vision_model.weight" in state_dict: state_dict["clip.vision_model.weight"] = state_dict.pop("clip_vision_model.weight") if "clip_vision_model.bias" in state_dict: state_dict["clip.vision_model.bias"] = state_dict.pop("clip_vision_model.bias") if "gpt2.weight" in state_dict: state_dict["gpt2.transformer.wte.weight"] = state_dict.pop("gpt2.weight") if "gpt2.bias" in state_dict: state_dict["gpt2.transformer.wpe.bias"] = state_dict.pop("gpt2.bias") # Clone shared weights to avoid shared memory issues if "gpt2.lm_head.weight" in state_dict and "gpt2.transformer.wte.weight" in state_dict: state_dict["gpt2.lm_head.weight"] = state_dict["gpt2.lm_head.weight"].clone() state_dict["gpt2.transformer.wte.weight"] = state_dict["gpt2.transformer.wte.weight"].clone() # Save the state dictionary as a safetensor save_path = Path(output_dir) / "model.safetensors" save_file(state_dict, save_path) @classmethod def from_merged_safetensor(cls, config_path: str, safetensor_path: str): config = ArlowGPTConfig.from_pretrained(config_path) model = cls(config) state_dict = load_file(safetensor_path) # Rename mismatched keys in loaded state dict if "clip.vision_model.weight" in state_dict: state_dict["clip_vision_model.weight"] = state_dict.pop("clip.vision_model.weight") if "clip.vision_model.bias" in state_dict: state_dict["clip_vision_model.bias"] = state_dict.pop("clip.vision_model.bias") if "gpt2.transformer.wte.weight" in state_dict: state_dict["gpt2.weight"] = state_dict.pop("gpt2.transformer.wte.weight") if "gpt2.transformer.wpe.bias" in state_dict: state_dict["gpt2.bias"] = state_dict.pop("gpt2.transformer.wpe.bias") model.load_state_dict(state_dict) return model def save_merged_model(model: ArlowGPT, output_dir: str) -> None: output_path = Path(output_dir) if output_path.exists(): shutil.rmtree(output_path) output_path.mkdir(parents=True) # Save the model configuration and weights model.config.save_pretrained(output_path) model.save_merged_safetensor(output_path) # Save the tokenizer and processor tokenizer = GPT2Tokenizer.from_pretrained(model.config.gpt2_model_name) processor = CLIPProcessor.from_pretrained(model.config.clip_model_name) tokenizer.save_pretrained(output_path) processor.save_pretrained(output_path) def main(): clip_model = "yuchenxie/CLiP" gpt2_model = "yuchenxie/GPT-2" output_dir = "merged_model" print("Merging ArlowGPT model...") config = ArlowGPTConfig( clip_model_name=clip_model, gpt2_model_name=gpt2_model ) model = ArlowGPT(config) print("Saving merged ArlowGPT model...") save_merged_model(model, output_dir) print(f"Merged model saved to {output_dir}") print("Saved files:") for file in os.listdir(output_dir): print(f"- {file}") if __name__ == "__main__": main() ```