yuchenxie commited on
Commit
a7e8050
·
verified ·
1 Parent(s): 6753a20

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +206 -0
README.md ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ base_model:
6
+ - yuchenxie/GPT-2
7
+ - yuchenxie/CLiP
8
+ library_name: transformers
9
+ ---
10
+ # Check config.json file.
11
+
12
+
13
+ # Merging script:
14
+ ```python
15
+ import os
16
+ import shutil
17
+ from pathlib import Path
18
+ from typing import Optional, Dict, Union
19
+
20
+ import torch
21
+ from torch import nn
22
+ from transformers import (
23
+ CLIPModel,
24
+ GPT2LMHeadModel,
25
+ GPT2Tokenizer,
26
+ CLIPProcessor,
27
+ PretrainedConfig,
28
+ AutoConfig,
29
+ )
30
+ from safetensors.torch import save_file, load_file
31
+
32
+
33
+ class ArlowGPTConfig(PretrainedConfig):
34
+ model_type = "ArlowGPT" # Use the desired architecture name
35
+
36
+ def __init__(
37
+ self,
38
+ clip_model_name: str = "yuchenxie/CLiP",
39
+ gpt2_model_name: str = "yuchenxie/GPT-2",
40
+ clip_config: Optional[Dict] = None,
41
+ gpt2_config: Optional[Dict] = None,
42
+ projection_dim: int = 768,
43
+ vocab_size: int = 50257,
44
+ **kwargs
45
+ ):
46
+ super().__init__(**kwargs)
47
+ self.clip_model_name = clip_model_name
48
+ self.gpt2_model_name = gpt2_model_name
49
+ self.clip_config = clip_config
50
+ self.gpt2_config = gpt2_config
51
+ self.projection_dim = projection_dim
52
+ self.vocab_size = vocab_size
53
+
54
+
55
+ class ArlowGPT(nn.Module):
56
+ def __init__(self, config: ArlowGPTConfig):
57
+ super().__init__()
58
+ print("Initializing ArlowGPT model...")
59
+
60
+ # Load the CLIP model
61
+ self.clip = CLIPModel.from_pretrained(config.clip_model_name)
62
+
63
+ # Extract the CLIP vision model hidden size
64
+ clip_hidden_size = self.clip.config.vision_config.hidden_size # Vision model hidden size (1024)
65
+ gpt2_hidden_size = config.projection_dim # Target hidden size (768)
66
+
67
+ # Add a projection layer to align dimensions
68
+ self.clip_projection = nn.Linear(clip_hidden_size, gpt2_hidden_size)
69
+
70
+ # Load GPT-2 with cross-attention enabled
71
+ self.gpt2_config = AutoConfig.from_pretrained(config.gpt2_model_name)
72
+ self.gpt2_config.add_cross_attention = True
73
+ self.gpt2 = GPT2LMHeadModel.from_pretrained(
74
+ config.gpt2_model_name, config=self.gpt2_config
75
+ )
76
+
77
+ # Update vocabulary size
78
+ self.config = config
79
+ self.config.vocab_size = self.gpt2.config.vocab_size
80
+
81
+ def forward(
82
+ self,
83
+ input_ids: torch.Tensor,
84
+ attention_mask: torch.Tensor,
85
+ pixel_values: torch.Tensor,
86
+ labels: Optional[torch.Tensor] = None,
87
+ return_dict: bool = True,
88
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
89
+ # Process vision inputs through CLIP
90
+ vision_outputs = self.clip.vision_model(pixel_values=pixel_values)
91
+ encoder_hidden_states = vision_outputs.last_hidden_state
92
+
93
+ # Apply projection to align dimensions
94
+ encoder_hidden_states = self.clip_projection(encoder_hidden_states)
95
+
96
+ # Create attention mask for CLIP embeddings
97
+ encoder_attention_mask = torch.ones(
98
+ encoder_hidden_states.size()[:-1], dtype=torch.long
99
+ ).to(encoder_hidden_states.device)
100
+
101
+ # Process text inputs through GPT-2 with cross-attention
102
+ outputs = self.gpt2(
103
+ input_ids=input_ids,
104
+ attention_mask=attention_mask,
105
+ encoder_hidden_states=encoder_hidden_states,
106
+ encoder_attention_mask=encoder_attention_mask,
107
+ )
108
+
109
+ logits = outputs.logits
110
+ loss = None
111
+
112
+ # Calculate loss if labels are provided
113
+ if labels is not None:
114
+ loss_fct = nn.CrossEntropyLoss()
115
+ loss = loss_fct(
116
+ logits.view(-1, self.config.vocab_size), labels.view(-1)
117
+ )
118
+
119
+ if return_dict:
120
+ return {"loss": loss, "logits": logits}
121
+ return logits
122
+
123
+ def save_merged_safetensor(self, output_dir: str) -> None:
124
+ state_dict = self.state_dict()
125
+
126
+ # Rename mismatched keys
127
+ if "clip_vision_model.weight" in state_dict:
128
+ state_dict["clip.vision_model.weight"] = state_dict.pop("clip_vision_model.weight")
129
+ if "clip_vision_model.bias" in state_dict:
130
+ state_dict["clip.vision_model.bias"] = state_dict.pop("clip_vision_model.bias")
131
+ if "gpt2.weight" in state_dict:
132
+ state_dict["gpt2.transformer.wte.weight"] = state_dict.pop("gpt2.weight")
133
+ if "gpt2.bias" in state_dict:
134
+ state_dict["gpt2.transformer.wpe.bias"] = state_dict.pop("gpt2.bias")
135
+
136
+ # Clone shared weights to avoid shared memory issues
137
+ if "gpt2.lm_head.weight" in state_dict and "gpt2.transformer.wte.weight" in state_dict:
138
+ state_dict["gpt2.lm_head.weight"] = state_dict["gpt2.lm_head.weight"].clone()
139
+ state_dict["gpt2.transformer.wte.weight"] = state_dict["gpt2.transformer.wte.weight"].clone()
140
+
141
+ # Save the state dictionary as a safetensor
142
+ save_path = Path(output_dir) / "model.safetensors"
143
+ save_file(state_dict, save_path)
144
+
145
+ @classmethod
146
+ def from_merged_safetensor(cls, config_path: str, safetensor_path: str):
147
+ config = ArlowGPTConfig.from_pretrained(config_path)
148
+ model = cls(config)
149
+ state_dict = load_file(safetensor_path)
150
+
151
+ # Rename mismatched keys in loaded state dict
152
+ if "clip.vision_model.weight" in state_dict:
153
+ state_dict["clip_vision_model.weight"] = state_dict.pop("clip.vision_model.weight")
154
+ if "clip.vision_model.bias" in state_dict:
155
+ state_dict["clip_vision_model.bias"] = state_dict.pop("clip.vision_model.bias")
156
+ if "gpt2.transformer.wte.weight" in state_dict:
157
+ state_dict["gpt2.weight"] = state_dict.pop("gpt2.transformer.wte.weight")
158
+ if "gpt2.transformer.wpe.bias" in state_dict:
159
+ state_dict["gpt2.bias"] = state_dict.pop("gpt2.transformer.wpe.bias")
160
+
161
+ model.load_state_dict(state_dict)
162
+ return model
163
+
164
+
165
+ def save_merged_model(model: ArlowGPT, output_dir: str) -> None:
166
+ output_path = Path(output_dir)
167
+ if output_path.exists():
168
+ shutil.rmtree(output_path)
169
+ output_path.mkdir(parents=True)
170
+
171
+ # Save the model configuration and weights
172
+ model.config.save_pretrained(output_path)
173
+ model.save_merged_safetensor(output_path)
174
+
175
+ # Save the tokenizer and processor
176
+ tokenizer = GPT2Tokenizer.from_pretrained(model.config.gpt2_model_name)
177
+ processor = CLIPProcessor.from_pretrained(model.config.clip_model_name)
178
+ tokenizer.save_pretrained(output_path)
179
+ processor.save_pretrained(output_path)
180
+
181
+
182
+ def main():
183
+ clip_model = "yuchenxie/CLiP"
184
+ gpt2_model = "yuchenxie/GPT-2"
185
+ output_dir = "merged_model"
186
+
187
+ print("Merging ArlowGPT model...")
188
+ config = ArlowGPTConfig(
189
+ clip_model_name=clip_model,
190
+ gpt2_model_name=gpt2_model
191
+ )
192
+ model = ArlowGPT(config)
193
+
194
+ print("Saving merged ArlowGPT model...")
195
+ save_merged_model(model, output_dir)
196
+
197
+ print(f"Merged model saved to {output_dir}")
198
+ print("Saved files:")
199
+ for file in os.listdir(output_dir):
200
+ print(f"- {file}")
201
+
202
+
203
+ if __name__ == "__main__":
204
+ main()
205
+ ```
206
+