|
--- |
|
license: apache-2.0 |
|
language: |
|
- en |
|
base_model: |
|
- yuchenxie/GPT-2 |
|
- yuchenxie/CLiP |
|
library_name: transformers |
|
--- |
|
# Check config.json file. |
|
|
|
|
|
# Merging script: |
|
```python |
|
import os |
|
import shutil |
|
from pathlib import Path |
|
from typing import Optional, Dict, Union |
|
|
|
import torch |
|
from torch import nn |
|
from transformers import ( |
|
CLIPModel, |
|
GPT2LMHeadModel, |
|
GPT2Tokenizer, |
|
CLIPProcessor, |
|
PretrainedConfig, |
|
AutoConfig, |
|
) |
|
from safetensors.torch import save_file, load_file |
|
|
|
|
|
class ArlowGPTConfig(PretrainedConfig): |
|
model_type = "ArlowGPT" # Use the desired architecture name |
|
|
|
def __init__( |
|
self, |
|
clip_model_name: str = "yuchenxie/CLiP", |
|
gpt2_model_name: str = "yuchenxie/GPT-2", |
|
clip_config: Optional[Dict] = None, |
|
gpt2_config: Optional[Dict] = None, |
|
projection_dim: int = 768, |
|
vocab_size: int = 50257, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.clip_model_name = clip_model_name |
|
self.gpt2_model_name = gpt2_model_name |
|
self.clip_config = clip_config |
|
self.gpt2_config = gpt2_config |
|
self.projection_dim = projection_dim |
|
self.vocab_size = vocab_size |
|
|
|
|
|
class ArlowGPT(nn.Module): |
|
def __init__(self, config: ArlowGPTConfig): |
|
super().__init__() |
|
print("Initializing ArlowGPT model...") |
|
|
|
# Load the CLIP model |
|
self.clip = CLIPModel.from_pretrained(config.clip_model_name) |
|
|
|
# Extract the CLIP vision model hidden size |
|
clip_hidden_size = self.clip.config.vision_config.hidden_size # Vision model hidden size (1024) |
|
gpt2_hidden_size = config.projection_dim # Target hidden size (768) |
|
|
|
# Add a projection layer to align dimensions |
|
self.clip_projection = nn.Linear(clip_hidden_size, gpt2_hidden_size) |
|
|
|
# Load GPT-2 with cross-attention enabled |
|
self.gpt2_config = AutoConfig.from_pretrained(config.gpt2_model_name) |
|
self.gpt2_config.add_cross_attention = True |
|
self.gpt2 = GPT2LMHeadModel.from_pretrained( |
|
config.gpt2_model_name, config=self.gpt2_config |
|
) |
|
|
|
# Update vocabulary size |
|
self.config = config |
|
self.config.vocab_size = self.gpt2.config.vocab_size |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
attention_mask: torch.Tensor, |
|
pixel_values: torch.Tensor, |
|
labels: Optional[torch.Tensor] = None, |
|
return_dict: bool = True, |
|
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: |
|
# Process vision inputs through CLIP |
|
vision_outputs = self.clip.vision_model(pixel_values=pixel_values) |
|
encoder_hidden_states = vision_outputs.last_hidden_state |
|
|
|
# Apply projection to align dimensions |
|
encoder_hidden_states = self.clip_projection(encoder_hidden_states) |
|
|
|
# Create attention mask for CLIP embeddings |
|
encoder_attention_mask = torch.ones( |
|
encoder_hidden_states.size()[:-1], dtype=torch.long |
|
).to(encoder_hidden_states.device) |
|
|
|
# Process text inputs through GPT-2 with cross-attention |
|
outputs = self.gpt2( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
) |
|
|
|
logits = outputs.logits |
|
loss = None |
|
|
|
# Calculate loss if labels are provided |
|
if labels is not None: |
|
loss_fct = nn.CrossEntropyLoss() |
|
loss = loss_fct( |
|
logits.view(-1, self.config.vocab_size), labels.view(-1) |
|
) |
|
|
|
if return_dict: |
|
return {"loss": loss, "logits": logits} |
|
return logits |
|
|
|
def save_merged_safetensor(self, output_dir: str) -> None: |
|
state_dict = self.state_dict() |
|
|
|
# Rename mismatched keys |
|
if "clip_vision_model.weight" in state_dict: |
|
state_dict["clip.vision_model.weight"] = state_dict.pop("clip_vision_model.weight") |
|
if "clip_vision_model.bias" in state_dict: |
|
state_dict["clip.vision_model.bias"] = state_dict.pop("clip_vision_model.bias") |
|
if "gpt2.weight" in state_dict: |
|
state_dict["gpt2.transformer.wte.weight"] = state_dict.pop("gpt2.weight") |
|
if "gpt2.bias" in state_dict: |
|
state_dict["gpt2.transformer.wpe.bias"] = state_dict.pop("gpt2.bias") |
|
|
|
# Clone shared weights to avoid shared memory issues |
|
if "gpt2.lm_head.weight" in state_dict and "gpt2.transformer.wte.weight" in state_dict: |
|
state_dict["gpt2.lm_head.weight"] = state_dict["gpt2.lm_head.weight"].clone() |
|
state_dict["gpt2.transformer.wte.weight"] = state_dict["gpt2.transformer.wte.weight"].clone() |
|
|
|
# Save the state dictionary as a safetensor |
|
save_path = Path(output_dir) / "model.safetensors" |
|
save_file(state_dict, save_path) |
|
|
|
@classmethod |
|
def from_merged_safetensor(cls, config_path: str, safetensor_path: str): |
|
config = ArlowGPTConfig.from_pretrained(config_path) |
|
model = cls(config) |
|
state_dict = load_file(safetensor_path) |
|
|
|
# Rename mismatched keys in loaded state dict |
|
if "clip.vision_model.weight" in state_dict: |
|
state_dict["clip_vision_model.weight"] = state_dict.pop("clip.vision_model.weight") |
|
if "clip.vision_model.bias" in state_dict: |
|
state_dict["clip_vision_model.bias"] = state_dict.pop("clip.vision_model.bias") |
|
if "gpt2.transformer.wte.weight" in state_dict: |
|
state_dict["gpt2.weight"] = state_dict.pop("gpt2.transformer.wte.weight") |
|
if "gpt2.transformer.wpe.bias" in state_dict: |
|
state_dict["gpt2.bias"] = state_dict.pop("gpt2.transformer.wpe.bias") |
|
|
|
model.load_state_dict(state_dict) |
|
return model |
|
|
|
|
|
def save_merged_model(model: ArlowGPT, output_dir: str) -> None: |
|
output_path = Path(output_dir) |
|
if output_path.exists(): |
|
shutil.rmtree(output_path) |
|
output_path.mkdir(parents=True) |
|
|
|
# Save the model configuration and weights |
|
model.config.save_pretrained(output_path) |
|
model.save_merged_safetensor(output_path) |
|
|
|
# Save the tokenizer and processor |
|
tokenizer = GPT2Tokenizer.from_pretrained(model.config.gpt2_model_name) |
|
processor = CLIPProcessor.from_pretrained(model.config.clip_model_name) |
|
tokenizer.save_pretrained(output_path) |
|
processor.save_pretrained(output_path) |
|
|
|
|
|
def main(): |
|
clip_model = "yuchenxie/CLiP" |
|
gpt2_model = "yuchenxie/GPT-2" |
|
output_dir = "merged_model" |
|
|
|
print("Merging ArlowGPT model...") |
|
config = ArlowGPTConfig( |
|
clip_model_name=clip_model, |
|
gpt2_model_name=gpt2_model |
|
) |
|
model = ArlowGPT(config) |
|
|
|
print("Saving merged ArlowGPT model...") |
|
save_merged_model(model, output_dir) |
|
|
|
print(f"Merged model saved to {output_dir}") |
|
print("Saved files:") |
|
for file in os.listdir(output_dir): |
|
print(f"- {file}") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
``` |
|
|
|
|