yuchenxie's picture
Create README.md
a7e8050 verified
|
raw
history blame
7.01 kB
metadata
license: apache-2.0
language:
  - en
base_model:
  - yuchenxie/GPT-2
  - yuchenxie/CLiP
library_name: transformers

Check config.json file.

Merging script:

import os
import shutil
from pathlib import Path
from typing import Optional, Dict, Union

import torch
from torch import nn
from transformers import (
    CLIPModel,
    GPT2LMHeadModel,
    GPT2Tokenizer,
    CLIPProcessor,
    PretrainedConfig,
    AutoConfig,
)
from safetensors.torch import save_file, load_file


class ArlowGPTConfig(PretrainedConfig):
    model_type = "ArlowGPT"  # Use the desired architecture name

    def __init__(
        self,
        clip_model_name: str = "yuchenxie/CLiP",
        gpt2_model_name: str = "yuchenxie/GPT-2",
        clip_config: Optional[Dict] = None,
        gpt2_config: Optional[Dict] = None,
        projection_dim: int = 768,
        vocab_size: int = 50257,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.clip_model_name = clip_model_name
        self.gpt2_model_name = gpt2_model_name
        self.clip_config = clip_config
        self.gpt2_config = gpt2_config
        self.projection_dim = projection_dim
        self.vocab_size = vocab_size


class ArlowGPT(nn.Module):
    def __init__(self, config: ArlowGPTConfig):
        super().__init__()
        print("Initializing ArlowGPT model...")

        # Load the CLIP model
        self.clip = CLIPModel.from_pretrained(config.clip_model_name)

        # Extract the CLIP vision model hidden size
        clip_hidden_size = self.clip.config.vision_config.hidden_size  # Vision model hidden size (1024)
        gpt2_hidden_size = config.projection_dim  # Target hidden size (768)

        # Add a projection layer to align dimensions
        self.clip_projection = nn.Linear(clip_hidden_size, gpt2_hidden_size)

        # Load GPT-2 with cross-attention enabled
        self.gpt2_config = AutoConfig.from_pretrained(config.gpt2_model_name)
        self.gpt2_config.add_cross_attention = True
        self.gpt2 = GPT2LMHeadModel.from_pretrained(
            config.gpt2_model_name, config=self.gpt2_config
        )

        # Update vocabulary size
        self.config = config
        self.config.vocab_size = self.gpt2.config.vocab_size

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        pixel_values: torch.Tensor,
        labels: Optional[torch.Tensor] = None,
        return_dict: bool = True,
    ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
        # Process vision inputs through CLIP
        vision_outputs = self.clip.vision_model(pixel_values=pixel_values)
        encoder_hidden_states = vision_outputs.last_hidden_state

        # Apply projection to align dimensions
        encoder_hidden_states = self.clip_projection(encoder_hidden_states)

        # Create attention mask for CLIP embeddings
        encoder_attention_mask = torch.ones(
            encoder_hidden_states.size()[:-1], dtype=torch.long
        ).to(encoder_hidden_states.device)

        # Process text inputs through GPT-2 with cross-attention
        outputs = self.gpt2(
            input_ids=input_ids,
            attention_mask=attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
        )

        logits = outputs.logits
        loss = None

        # Calculate loss if labels are provided
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                logits.view(-1, self.config.vocab_size), labels.view(-1)
            )

        if return_dict:
            return {"loss": loss, "logits": logits}
        return logits

    def save_merged_safetensor(self, output_dir: str) -> None:
        state_dict = self.state_dict()

        # Rename mismatched keys
        if "clip_vision_model.weight" in state_dict:
            state_dict["clip.vision_model.weight"] = state_dict.pop("clip_vision_model.weight")
        if "clip_vision_model.bias" in state_dict:
            state_dict["clip.vision_model.bias"] = state_dict.pop("clip_vision_model.bias")
        if "gpt2.weight" in state_dict:
            state_dict["gpt2.transformer.wte.weight"] = state_dict.pop("gpt2.weight")
        if "gpt2.bias" in state_dict:
            state_dict["gpt2.transformer.wpe.bias"] = state_dict.pop("gpt2.bias")

        # Clone shared weights to avoid shared memory issues
        if "gpt2.lm_head.weight" in state_dict and "gpt2.transformer.wte.weight" in state_dict:
            state_dict["gpt2.lm_head.weight"] = state_dict["gpt2.lm_head.weight"].clone()
            state_dict["gpt2.transformer.wte.weight"] = state_dict["gpt2.transformer.wte.weight"].clone()

        # Save the state dictionary as a safetensor
        save_path = Path(output_dir) / "model.safetensors"
        save_file(state_dict, save_path)

    @classmethod
    def from_merged_safetensor(cls, config_path: str, safetensor_path: str):
        config = ArlowGPTConfig.from_pretrained(config_path)
        model = cls(config)
        state_dict = load_file(safetensor_path)

        # Rename mismatched keys in loaded state dict
        if "clip.vision_model.weight" in state_dict:
            state_dict["clip_vision_model.weight"] = state_dict.pop("clip.vision_model.weight")
        if "clip.vision_model.bias" in state_dict:
            state_dict["clip_vision_model.bias"] = state_dict.pop("clip.vision_model.bias")
        if "gpt2.transformer.wte.weight" in state_dict:
            state_dict["gpt2.weight"] = state_dict.pop("gpt2.transformer.wte.weight")
        if "gpt2.transformer.wpe.bias" in state_dict:
            state_dict["gpt2.bias"] = state_dict.pop("gpt2.transformer.wpe.bias")

        model.load_state_dict(state_dict)
        return model


def save_merged_model(model: ArlowGPT, output_dir: str) -> None:
    output_path = Path(output_dir)
    if output_path.exists():
        shutil.rmtree(output_path)
    output_path.mkdir(parents=True)

    # Save the model configuration and weights
    model.config.save_pretrained(output_path)
    model.save_merged_safetensor(output_path)

    # Save the tokenizer and processor
    tokenizer = GPT2Tokenizer.from_pretrained(model.config.gpt2_model_name)
    processor = CLIPProcessor.from_pretrained(model.config.clip_model_name)
    tokenizer.save_pretrained(output_path)
    processor.save_pretrained(output_path)


def main():
    clip_model = "yuchenxie/CLiP"
    gpt2_model = "yuchenxie/GPT-2"
    output_dir = "merged_model"

    print("Merging ArlowGPT model...")
    config = ArlowGPTConfig(
        clip_model_name=clip_model,
        gpt2_model_name=gpt2_model
    )
    model = ArlowGPT(config)

    print("Saving merged ArlowGPT model...")
    save_merged_model(model, output_dir)

    print(f"Merged model saved to {output_dir}")
    print("Saved files:")
    for file in os.listdir(output_dir):
        print(f"- {file}")


if __name__ == "__main__":
    main()