GPT-2V / README.md
yuchenxie's picture
Update README.md
26587a3 verified
metadata
license: mit
language:
  - en
base_model:
  - yuchenxie/CLiP
  - yuchenxie/GPT-2
library_name: transformers
inference: true

GPT-2V Model Card

Model Overview

GPT-2V is a multimodal transformer model that combines the CLIP model (vision) and GPT-2 (text generation) to generate responses based on both textual and visual inputs. This model leverages the strengths of CLIP for image understanding and GPT-2 for language generation, allowing for creative and context-aware outputs based on images and text. The model is designed to extend GPT-2's capabilities by incorporating image features through learned projection layers.

Model Architecture

  • Model Type: arlow_gpt
  • Base Vision Model: CLIP (yuchenxie/CLiP)
  • Base Text Model: GPT-2 (yuchenxie/GPT-2)
  • Config: Custom configuration for merging vision and text modalities.
  • Tokenizer: GPT-2 Tokenizer

Key Features

  • Multimodal Input: Takes both text and image as inputs.
  • Text Generation: Produces creative and context-specific language outputs.
  • Vision-Text Fusion: Combines features from both vision and text for enhanced generation quality.

Merging Script

The following script merges CLIP and GPT-2 (safetensor model variants made by Yuchen under yuchenxie/CLiP and yuchenxie/GPT-2) models into a single multimodal model, GPT-2V. This script saves the combined model along with the necessary configuration and tokenizer files for easy loading.

import os
import json
import shutil
from pathlib import Path
from typing import Dict, Any, Optional, Union

import torch
import torch.nn as nn
from transformers import (
    CLIPModel,
    GPT2Model,
    CLIPProcessor,
    GPT2Tokenizer,
    PretrainedConfig,
    PreTrainedModel,
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer
)
from safetensors.torch import save_file

class ArlowGPTConfig(PretrainedConfig):
    model_type = "arlow_gpt"
    
    def __init__(
        self,
        clip_model_name: str = "yuchenxie/CLiP",
        gpt2_model_name: str = "yuchenxie/GPT-2",
        clip_config: Optional[Dict] = None,
        gpt2_config: Optional[Dict] = None,
        projection_dim: int = 768,
        vocab_size: int = 50257,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.clip_model_name = clip_model_name
        self.gpt2_model_name = gpt2_model_name
        self.clip_config = clip_config
        self.gpt2_config = gpt2_config
        self.projection_dim = projection_dim
        self.vocab_size = vocab_size

class ArlowGPT(PreTrainedModel):
    config_class = ArlowGPTConfig
    
    def __init__(self, config: ArlowGPTConfig):
        super().__init__(config)
        
        # Load the models
        self.clip = CLIPModel.from_pretrained(config.clip_model_name)
        self.gpt2 = GPT2Model.from_pretrained(config.gpt2_model_name)

        # Projection layers
        self.feature_projection = nn.Linear(
            self.clip.vision_model.config.hidden_size + self.gpt2.config.hidden_size,
            config.projection_dim
        )

        self.output_projection = nn.Linear(
            config.projection_dim,
            config.vocab_size
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        pixel_values: torch.Tensor,
        labels: Optional[torch.Tensor] = None,
        return_dict: bool = True,
    ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
        vision_outputs = self.clip.get_image_features(pixel_values=pixel_values)
        text_outputs = self.gpt2(
            input_ids=input_ids,
            attention_mask=attention_mask
        ).last_hidden_state

        batch_size = text_outputs.shape[0]
        seq_length = text_outputs.shape[1]

        vision_features = vision_outputs.unsqueeze(1).expand(
            batch_size, seq_length, -1
        )

        combined_features = torch.cat(
            [vision_features, text_outputs],
            dim=-1
        )

        projected_features = self.feature_projection(combined_features)
        logits = self.output_projection(projected_features)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))

        if return_dict:
            return {
                "loss": loss,
                "logits": logits
            }
        return logits

    @staticmethod
    def register_auto_classes():
        """Register the model with Auto* classes."""
        try:
            AutoConfig.register("arlow_gpt", ArlowGPTConfig)
            AutoModelForCausalLM.register(ArlowGPTConfig, ArlowGPT)
        except ValueError:
            # Already registered
            pass

def save_merged_model(
    model: ArlowGPT,
    output_dir: str,
    model_name: str = "merged_model"
) -> None:
    """Save the merged model with all necessary components in standard format."""
    output_path = Path(output_dir)

    # Remove existing directory if it exists
    if output_path.exists():
        shutil.rmtree(output_path)

    # Create new directory
    output_path.mkdir(parents=True)

    # Register auto classes
    model.register_auto_classes()

    # Save the model
    model.save_pretrained(output_path)

    # Save tokenizer and processor
    tokenizer = GPT2Tokenizer.from_pretrained(model.config.gpt2_model_name)
    processor = CLIPProcessor.from_pretrained(model.config.clip_model_name)

    tokenizer.save_pretrained(output_path)
    processor.save_pretrained(output_path)

def main():
    clip_model = "yuchenxie/CLiP"
    gpt2_model = "yuchenxie/GPT-2"
    output_dir = "merged_model"

    print("Initializing merged model...")
    config = ArlowGPTConfig(
        clip_model_name=clip_model,
        gpt2_model_name=gpt2_model
    )
    model = ArlowGPT(config)

    print("Saving merged model...")
    save_merged_model(model, output_dir)
    print(f"Merged model saved to {output_dir}")
    print("Saved files:")
    for file in os.listdir(output_dir):
        print(f"- {file}")

if __name__ == "__main__":
    main()

License

The usage of this model is subject to the same licensing as the original CLIP and GPT-2 models used for merging. Please refer to the license agreements provided by OpenAI and the respective contributors for further details.

Citation

If you use GPT-2V in your research or application, please cite the original works of CLIP and GPT-2, along with this model card.