yuchenxie commited on
Commit
2a6dccf
1 Parent(s): c4056ea

Update modeling_arlow_gpt.py

Browse files
Files changed (1) hide show
  1. modeling_arlow_gpt.py +54 -17
modeling_arlow_gpt.py CHANGED
@@ -1,42 +1,34 @@
 
1
  import torch
2
  import torch.nn as nn
3
  from transformers import PreTrainedModel, CLIPModel, GPT2Model
4
- from typing import Optional, Union, Dict
5
  from .configuration_arlow_gpt import ArlowGPTConfig
6
 
7
  class ArlowGPTPreTrainedModel(PreTrainedModel):
8
- """
9
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
10
- """
11
  config_class = ArlowGPTConfig
12
  base_model_prefix = "arlow_gpt"
13
  supports_gradient_checkpointing = True
14
 
15
  def _init_weights(self, module):
16
  if isinstance(module, nn.Linear):
17
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range if hasattr(self.config, "initializer_range") else 0.02)
18
  if module.bias is not None:
19
  module.bias.data.zero_()
20
 
21
- class ArlowGPT(ArlowGPTPreTrainedModel):
22
  def __init__(self, config: ArlowGPTConfig):
23
  super().__init__(config)
24
 
25
- # Load the models
26
  self.clip = CLIPModel.from_pretrained(config.clip_model_name)
27
  self.gpt2 = GPT2Model.from_pretrained(config.gpt2_model_name)
28
 
29
- # Projection layers
30
  self.feature_projection = nn.Linear(
31
  self.clip.vision_model.config.hidden_size + self.gpt2.config.hidden_size,
32
  config.projection_dim
33
  )
34
 
35
- self.output_projection = nn.Linear(
36
- config.projection_dim,
37
- config.vocab_size
38
- )
39
-
40
  # Initialize weights and apply final processing
41
  self.post_init()
42
 
@@ -45,7 +37,6 @@ class ArlowGPT(ArlowGPTPreTrainedModel):
45
  input_ids: torch.Tensor,
46
  attention_mask: torch.Tensor,
47
  pixel_values: torch.Tensor,
48
- labels: Optional[torch.Tensor] = None,
49
  return_dict: bool = True,
50
  ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
51
  vision_outputs = self.clip.get_image_features(pixel_values=pixel_values)
@@ -66,8 +57,38 @@ class ArlowGPT(ArlowGPTPreTrainedModel):
66
  dim=-1
67
  )
68
 
69
- projected_features = self.feature_projection(combined_features)
70
- logits = self.output_projection(projected_features)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  loss = None
73
  if labels is not None:
@@ -79,4 +100,20 @@ class ArlowGPT(ArlowGPTPreTrainedModel):
79
  "loss": loss,
80
  "logits": logits
81
  }
82
- return logits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modeling_arlow_gpt.py
2
  import torch
3
  import torch.nn as nn
4
  from transformers import PreTrainedModel, CLIPModel, GPT2Model
5
+ from typing import Optional, Union, Dict, Tuple
6
  from .configuration_arlow_gpt import ArlowGPTConfig
7
 
8
  class ArlowGPTPreTrainedModel(PreTrainedModel):
9
+ """Base class for ArlowGPT model."""
 
 
10
  config_class = ArlowGPTConfig
11
  base_model_prefix = "arlow_gpt"
12
  supports_gradient_checkpointing = True
13
 
14
  def _init_weights(self, module):
15
  if isinstance(module, nn.Linear):
16
+ module.weight.data.normal_(mean=0.0, std=0.02)
17
  if module.bias is not None:
18
  module.bias.data.zero_()
19
 
20
+ class ArlowGPTModel(ArlowGPTPreTrainedModel):
21
  def __init__(self, config: ArlowGPTConfig):
22
  super().__init__(config)
23
 
 
24
  self.clip = CLIPModel.from_pretrained(config.clip_model_name)
25
  self.gpt2 = GPT2Model.from_pretrained(config.gpt2_model_name)
26
 
 
27
  self.feature_projection = nn.Linear(
28
  self.clip.vision_model.config.hidden_size + self.gpt2.config.hidden_size,
29
  config.projection_dim
30
  )
31
 
 
 
 
 
 
32
  # Initialize weights and apply final processing
33
  self.post_init()
34
 
 
37
  input_ids: torch.Tensor,
38
  attention_mask: torch.Tensor,
39
  pixel_values: torch.Tensor,
 
40
  return_dict: bool = True,
41
  ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
42
  vision_outputs = self.clip.get_image_features(pixel_values=pixel_values)
 
57
  dim=-1
58
  )
59
 
60
+ hidden_states = self.feature_projection(combined_features)
61
+
62
+ if return_dict:
63
+ return {"hidden_states": hidden_states}
64
+ return hidden_states
65
+
66
+ class ArlowGPTForCausalLM(ArlowGPTPreTrainedModel):
67
+ def __init__(self, config: ArlowGPTConfig):
68
+ super().__init__(config)
69
+ self.arlow_gpt = ArlowGPTModel(config)
70
+ self.output_projection = nn.Linear(config.projection_dim, config.vocab_size)
71
+
72
+ # Initialize weights and apply final processing
73
+ self.post_init()
74
+
75
+ def forward(
76
+ self,
77
+ input_ids: torch.Tensor,
78
+ attention_mask: torch.Tensor,
79
+ pixel_values: torch.Tensor,
80
+ labels: Optional[torch.Tensor] = None,
81
+ return_dict: bool = True,
82
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
83
+ outputs = self.arlow_gpt(
84
+ input_ids=input_ids,
85
+ attention_mask=attention_mask,
86
+ pixel_values=pixel_values,
87
+ return_dict=True
88
+ )
89
+
90
+ hidden_states = outputs["hidden_states"]
91
+ logits = self.output_projection(hidden_states)
92
 
93
  loss = None
94
  if labels is not None:
 
100
  "loss": loss,
101
  "logits": logits
102
  }
103
+ return (loss, logits) if loss is not None else logits
104
+
105
+ def prepare_inputs_for_generation(
106
+ self, input_ids, past=None, attention_mask=None, **kwargs
107
+ ):
108
+ # only last token for inputs_ids if past is defined in kwargs
109
+ if past:
110
+ input_ids = input_ids[:, -1].unsqueeze(-1)
111
+ if attention_mask is not None:
112
+ attention_mask = attention_mask[:, -1].unsqueeze(-1)
113
+
114
+ return {
115
+ "input_ids": input_ids,
116
+ "attention_mask": attention_mask,
117
+ "pixel_values": kwargs.get("pixel_values", None),
118
+ "past_key_values": past,
119
+ }