metadata
license: apache-2.0
language:
- en
base_model:
- yuchenxie/GPT-2
- yuchenxie/CLiP
library_name: transformers
Check config.json file.
Merging script:
import os
import shutil
from pathlib import Path
from typing import Optional, Dict, Union
import torch
from torch import nn
from transformers import (
CLIPModel,
GPT2LMHeadModel,
GPT2Tokenizer,
CLIPProcessor,
PretrainedConfig,
AutoConfig,
)
from safetensors.torch import save_file, load_file
class ArlowGPTConfig(PretrainedConfig):
model_type = "ArlowGPT" # Use the desired architecture name
def __init__(
self,
clip_model_name: str = "yuchenxie/CLiP",
gpt2_model_name: str = "yuchenxie/GPT-2",
clip_config: Optional[Dict] = None,
gpt2_config: Optional[Dict] = None,
projection_dim: int = 768,
vocab_size: int = 50257,
**kwargs
):
super().__init__(**kwargs)
self.clip_model_name = clip_model_name
self.gpt2_model_name = gpt2_model_name
self.clip_config = clip_config
self.gpt2_config = gpt2_config
self.projection_dim = projection_dim
self.vocab_size = vocab_size
class ArlowGPT(nn.Module):
def __init__(self, config: ArlowGPTConfig):
super().__init__()
print("Initializing ArlowGPT model...")
# Load the CLIP model
self.clip = CLIPModel.from_pretrained(config.clip_model_name)
# Extract the CLIP vision model hidden size
clip_hidden_size = self.clip.config.vision_config.hidden_size # Vision model hidden size (1024)
gpt2_hidden_size = config.projection_dim # Target hidden size (768)
# Add a projection layer to align dimensions
self.clip_projection = nn.Linear(clip_hidden_size, gpt2_hidden_size)
# Load GPT-2 with cross-attention enabled
self.gpt2_config = AutoConfig.from_pretrained(config.gpt2_model_name)
self.gpt2_config.add_cross_attention = True
self.gpt2 = GPT2LMHeadModel.from_pretrained(
config.gpt2_model_name, config=self.gpt2_config
)
# Update vocabulary size
self.config = config
self.config.vocab_size = self.gpt2.config.vocab_size
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
pixel_values: torch.Tensor,
labels: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
# Process vision inputs through CLIP
vision_outputs = self.clip.vision_model(pixel_values=pixel_values)
encoder_hidden_states = vision_outputs.last_hidden_state
# Apply projection to align dimensions
encoder_hidden_states = self.clip_projection(encoder_hidden_states)
# Create attention mask for CLIP embeddings
encoder_attention_mask = torch.ones(
encoder_hidden_states.size()[:-1], dtype=torch.long
).to(encoder_hidden_states.device)
# Process text inputs through GPT-2 with cross-attention
outputs = self.gpt2(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
logits = outputs.logits
loss = None
# Calculate loss if labels are provided
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
logits.view(-1, self.config.vocab_size), labels.view(-1)
)
if return_dict:
return {"loss": loss, "logits": logits}
return logits
def save_merged_safetensor(self, output_dir: str) -> None:
state_dict = self.state_dict()
# Rename mismatched keys
if "clip_vision_model.weight" in state_dict:
state_dict["clip.vision_model.weight"] = state_dict.pop("clip_vision_model.weight")
if "clip_vision_model.bias" in state_dict:
state_dict["clip.vision_model.bias"] = state_dict.pop("clip_vision_model.bias")
if "gpt2.weight" in state_dict:
state_dict["gpt2.transformer.wte.weight"] = state_dict.pop("gpt2.weight")
if "gpt2.bias" in state_dict:
state_dict["gpt2.transformer.wpe.bias"] = state_dict.pop("gpt2.bias")
# Clone shared weights to avoid shared memory issues
if "gpt2.lm_head.weight" in state_dict and "gpt2.transformer.wte.weight" in state_dict:
state_dict["gpt2.lm_head.weight"] = state_dict["gpt2.lm_head.weight"].clone()
state_dict["gpt2.transformer.wte.weight"] = state_dict["gpt2.transformer.wte.weight"].clone()
# Save the state dictionary as a safetensor
save_path = Path(output_dir) / "model.safetensors"
save_file(state_dict, save_path)
@classmethod
def from_merged_safetensor(cls, config_path: str, safetensor_path: str):
config = ArlowGPTConfig.from_pretrained(config_path)
model = cls(config)
state_dict = load_file(safetensor_path)
# Rename mismatched keys in loaded state dict
if "clip.vision_model.weight" in state_dict:
state_dict["clip_vision_model.weight"] = state_dict.pop("clip.vision_model.weight")
if "clip.vision_model.bias" in state_dict:
state_dict["clip_vision_model.bias"] = state_dict.pop("clip.vision_model.bias")
if "gpt2.transformer.wte.weight" in state_dict:
state_dict["gpt2.weight"] = state_dict.pop("gpt2.transformer.wte.weight")
if "gpt2.transformer.wpe.bias" in state_dict:
state_dict["gpt2.bias"] = state_dict.pop("gpt2.transformer.wpe.bias")
model.load_state_dict(state_dict)
return model
def save_merged_model(model: ArlowGPT, output_dir: str) -> None:
output_path = Path(output_dir)
if output_path.exists():
shutil.rmtree(output_path)
output_path.mkdir(parents=True)
# Save the model configuration and weights
model.config.save_pretrained(output_path)
model.save_merged_safetensor(output_path)
# Save the tokenizer and processor
tokenizer = GPT2Tokenizer.from_pretrained(model.config.gpt2_model_name)
processor = CLIPProcessor.from_pretrained(model.config.clip_model_name)
tokenizer.save_pretrained(output_path)
processor.save_pretrained(output_path)
def main():
clip_model = "yuchenxie/CLiP"
gpt2_model = "yuchenxie/GPT-2"
output_dir = "merged_model"
print("Merging ArlowGPT model...")
config = ArlowGPTConfig(
clip_model_name=clip_model,
gpt2_model_name=gpt2_model
)
model = ArlowGPT(config)
print("Saving merged ArlowGPT model...")
save_merged_model(model, output_dir)
print(f"Merged model saved to {output_dir}")
print("Saved files:")
for file in os.listdir(output_dir):
print(f"- {file}")
if __name__ == "__main__":
main()