Create README.md
Browse files
README.md
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
language:
|
4 |
+
- en
|
5 |
+
base_model:
|
6 |
+
- yuchenxie/GPT-2
|
7 |
+
- yuchenxie/CLiP
|
8 |
+
library_name: transformers
|
9 |
+
---
|
10 |
+
# Check config.json file.
|
11 |
+
|
12 |
+
|
13 |
+
# Merging script:
|
14 |
+
```python
|
15 |
+
import os
|
16 |
+
import shutil
|
17 |
+
from pathlib import Path
|
18 |
+
from typing import Optional, Dict, Union
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from torch import nn
|
22 |
+
from transformers import (
|
23 |
+
CLIPModel,
|
24 |
+
GPT2LMHeadModel,
|
25 |
+
GPT2Tokenizer,
|
26 |
+
CLIPProcessor,
|
27 |
+
PretrainedConfig,
|
28 |
+
AutoConfig,
|
29 |
+
)
|
30 |
+
from safetensors.torch import save_file, load_file
|
31 |
+
|
32 |
+
|
33 |
+
class ArlowGPTConfig(PretrainedConfig):
|
34 |
+
model_type = "ArlowGPT" # Use the desired architecture name
|
35 |
+
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
clip_model_name: str = "yuchenxie/CLiP",
|
39 |
+
gpt2_model_name: str = "yuchenxie/GPT-2",
|
40 |
+
clip_config: Optional[Dict] = None,
|
41 |
+
gpt2_config: Optional[Dict] = None,
|
42 |
+
projection_dim: int = 768,
|
43 |
+
vocab_size: int = 50257,
|
44 |
+
**kwargs
|
45 |
+
):
|
46 |
+
super().__init__(**kwargs)
|
47 |
+
self.clip_model_name = clip_model_name
|
48 |
+
self.gpt2_model_name = gpt2_model_name
|
49 |
+
self.clip_config = clip_config
|
50 |
+
self.gpt2_config = gpt2_config
|
51 |
+
self.projection_dim = projection_dim
|
52 |
+
self.vocab_size = vocab_size
|
53 |
+
|
54 |
+
|
55 |
+
class ArlowGPT(nn.Module):
|
56 |
+
def __init__(self, config: ArlowGPTConfig):
|
57 |
+
super().__init__()
|
58 |
+
print("Initializing ArlowGPT model...")
|
59 |
+
|
60 |
+
# Load the CLIP model
|
61 |
+
self.clip = CLIPModel.from_pretrained(config.clip_model_name)
|
62 |
+
|
63 |
+
# Extract the CLIP vision model hidden size
|
64 |
+
clip_hidden_size = self.clip.config.vision_config.hidden_size # Vision model hidden size (1024)
|
65 |
+
gpt2_hidden_size = config.projection_dim # Target hidden size (768)
|
66 |
+
|
67 |
+
# Add a projection layer to align dimensions
|
68 |
+
self.clip_projection = nn.Linear(clip_hidden_size, gpt2_hidden_size)
|
69 |
+
|
70 |
+
# Load GPT-2 with cross-attention enabled
|
71 |
+
self.gpt2_config = AutoConfig.from_pretrained(config.gpt2_model_name)
|
72 |
+
self.gpt2_config.add_cross_attention = True
|
73 |
+
self.gpt2 = GPT2LMHeadModel.from_pretrained(
|
74 |
+
config.gpt2_model_name, config=self.gpt2_config
|
75 |
+
)
|
76 |
+
|
77 |
+
# Update vocabulary size
|
78 |
+
self.config = config
|
79 |
+
self.config.vocab_size = self.gpt2.config.vocab_size
|
80 |
+
|
81 |
+
def forward(
|
82 |
+
self,
|
83 |
+
input_ids: torch.Tensor,
|
84 |
+
attention_mask: torch.Tensor,
|
85 |
+
pixel_values: torch.Tensor,
|
86 |
+
labels: Optional[torch.Tensor] = None,
|
87 |
+
return_dict: bool = True,
|
88 |
+
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
89 |
+
# Process vision inputs through CLIP
|
90 |
+
vision_outputs = self.clip.vision_model(pixel_values=pixel_values)
|
91 |
+
encoder_hidden_states = vision_outputs.last_hidden_state
|
92 |
+
|
93 |
+
# Apply projection to align dimensions
|
94 |
+
encoder_hidden_states = self.clip_projection(encoder_hidden_states)
|
95 |
+
|
96 |
+
# Create attention mask for CLIP embeddings
|
97 |
+
encoder_attention_mask = torch.ones(
|
98 |
+
encoder_hidden_states.size()[:-1], dtype=torch.long
|
99 |
+
).to(encoder_hidden_states.device)
|
100 |
+
|
101 |
+
# Process text inputs through GPT-2 with cross-attention
|
102 |
+
outputs = self.gpt2(
|
103 |
+
input_ids=input_ids,
|
104 |
+
attention_mask=attention_mask,
|
105 |
+
encoder_hidden_states=encoder_hidden_states,
|
106 |
+
encoder_attention_mask=encoder_attention_mask,
|
107 |
+
)
|
108 |
+
|
109 |
+
logits = outputs.logits
|
110 |
+
loss = None
|
111 |
+
|
112 |
+
# Calculate loss if labels are provided
|
113 |
+
if labels is not None:
|
114 |
+
loss_fct = nn.CrossEntropyLoss()
|
115 |
+
loss = loss_fct(
|
116 |
+
logits.view(-1, self.config.vocab_size), labels.view(-1)
|
117 |
+
)
|
118 |
+
|
119 |
+
if return_dict:
|
120 |
+
return {"loss": loss, "logits": logits}
|
121 |
+
return logits
|
122 |
+
|
123 |
+
def save_merged_safetensor(self, output_dir: str) -> None:
|
124 |
+
state_dict = self.state_dict()
|
125 |
+
|
126 |
+
# Rename mismatched keys
|
127 |
+
if "clip_vision_model.weight" in state_dict:
|
128 |
+
state_dict["clip.vision_model.weight"] = state_dict.pop("clip_vision_model.weight")
|
129 |
+
if "clip_vision_model.bias" in state_dict:
|
130 |
+
state_dict["clip.vision_model.bias"] = state_dict.pop("clip_vision_model.bias")
|
131 |
+
if "gpt2.weight" in state_dict:
|
132 |
+
state_dict["gpt2.transformer.wte.weight"] = state_dict.pop("gpt2.weight")
|
133 |
+
if "gpt2.bias" in state_dict:
|
134 |
+
state_dict["gpt2.transformer.wpe.bias"] = state_dict.pop("gpt2.bias")
|
135 |
+
|
136 |
+
# Clone shared weights to avoid shared memory issues
|
137 |
+
if "gpt2.lm_head.weight" in state_dict and "gpt2.transformer.wte.weight" in state_dict:
|
138 |
+
state_dict["gpt2.lm_head.weight"] = state_dict["gpt2.lm_head.weight"].clone()
|
139 |
+
state_dict["gpt2.transformer.wte.weight"] = state_dict["gpt2.transformer.wte.weight"].clone()
|
140 |
+
|
141 |
+
# Save the state dictionary as a safetensor
|
142 |
+
save_path = Path(output_dir) / "model.safetensors"
|
143 |
+
save_file(state_dict, save_path)
|
144 |
+
|
145 |
+
@classmethod
|
146 |
+
def from_merged_safetensor(cls, config_path: str, safetensor_path: str):
|
147 |
+
config = ArlowGPTConfig.from_pretrained(config_path)
|
148 |
+
model = cls(config)
|
149 |
+
state_dict = load_file(safetensor_path)
|
150 |
+
|
151 |
+
# Rename mismatched keys in loaded state dict
|
152 |
+
if "clip.vision_model.weight" in state_dict:
|
153 |
+
state_dict["clip_vision_model.weight"] = state_dict.pop("clip.vision_model.weight")
|
154 |
+
if "clip.vision_model.bias" in state_dict:
|
155 |
+
state_dict["clip_vision_model.bias"] = state_dict.pop("clip.vision_model.bias")
|
156 |
+
if "gpt2.transformer.wte.weight" in state_dict:
|
157 |
+
state_dict["gpt2.weight"] = state_dict.pop("gpt2.transformer.wte.weight")
|
158 |
+
if "gpt2.transformer.wpe.bias" in state_dict:
|
159 |
+
state_dict["gpt2.bias"] = state_dict.pop("gpt2.transformer.wpe.bias")
|
160 |
+
|
161 |
+
model.load_state_dict(state_dict)
|
162 |
+
return model
|
163 |
+
|
164 |
+
|
165 |
+
def save_merged_model(model: ArlowGPT, output_dir: str) -> None:
|
166 |
+
output_path = Path(output_dir)
|
167 |
+
if output_path.exists():
|
168 |
+
shutil.rmtree(output_path)
|
169 |
+
output_path.mkdir(parents=True)
|
170 |
+
|
171 |
+
# Save the model configuration and weights
|
172 |
+
model.config.save_pretrained(output_path)
|
173 |
+
model.save_merged_safetensor(output_path)
|
174 |
+
|
175 |
+
# Save the tokenizer and processor
|
176 |
+
tokenizer = GPT2Tokenizer.from_pretrained(model.config.gpt2_model_name)
|
177 |
+
processor = CLIPProcessor.from_pretrained(model.config.clip_model_name)
|
178 |
+
tokenizer.save_pretrained(output_path)
|
179 |
+
processor.save_pretrained(output_path)
|
180 |
+
|
181 |
+
|
182 |
+
def main():
|
183 |
+
clip_model = "yuchenxie/CLiP"
|
184 |
+
gpt2_model = "yuchenxie/GPT-2"
|
185 |
+
output_dir = "merged_model"
|
186 |
+
|
187 |
+
print("Merging ArlowGPT model...")
|
188 |
+
config = ArlowGPTConfig(
|
189 |
+
clip_model_name=clip_model,
|
190 |
+
gpt2_model_name=gpt2_model
|
191 |
+
)
|
192 |
+
model = ArlowGPT(config)
|
193 |
+
|
194 |
+
print("Saving merged ArlowGPT model...")
|
195 |
+
save_merged_model(model, output_dir)
|
196 |
+
|
197 |
+
print(f"Merged model saved to {output_dir}")
|
198 |
+
print("Saved files:")
|
199 |
+
for file in os.listdir(output_dir):
|
200 |
+
print(f"- {file}")
|
201 |
+
|
202 |
+
|
203 |
+
if __name__ == "__main__":
|
204 |
+
main()
|
205 |
+
```
|
206 |
+
|