File size: 4,677 Bytes
7e2eb5c
53d9a84
 
9e86c3d
2a6dccf
524e81c
 
9e86c3d
524e81c
 
7aff387
e0eae5b
7aff387
 
 
2a6dccf
7aff387
 
 
9e86c3d
524e81c
 
53d9a84
 
 
 
 
 
 
9e86c3d
53d9a84
 
 
 
 
9e86c3d
53d9a84
 
 
 
 
524e81c
7aff387
 
 
e0eae5b
53d9a84
e0eae5b
53d9a84
 
 
 
 
9e86c3d
 
 
 
 
 
 
 
 
 
 
 
 
 
e0eae5b
 
 
 
 
9dacb54
e0eae5b
53d9a84
 
9dacb54
53d9a84
e0eae5b
2a6dccf
9dacb54
 
 
 
 
 
 
 
9e86c3d
9dacb54
 
9e86c3d
 
9dacb54
 
9e86c3d
 
9dacb54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# modeling_arlow_gpt.py
import torch
import torch.nn as nn
from transformers import PreTrainedModel, CLIPModel, GPT2Model
from typing import Optional, Union, Dict, Tuple
from .configuration_arlow_gpt import ArlowGPTConfig

class ArlowGPTPreTrainedModel(PreTrainedModel):
    config_class = ArlowGPTConfig
    base_model_prefix = "arlow_gpt"
    supports_gradient_checkpointing = True
    _keys_to_ignore_on_load_missing = [r"clip", r"gpt2"]

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if module.bias is not None:
                module.bias.data.zero_()

class ArlowGPTModel(ArlowGPTPreTrainedModel):
    def __init__(self, config: ArlowGPTConfig):
        super().__init__(config)
        
        # Load models with local weights if available
        if hasattr(config, "_name_or_path"):
            self.clip = CLIPModel.from_pretrained(
                f"{config._name_or_path}/clip",
                local_files_only=True
            )
            self.gpt2 = GPT2Model.from_pretrained(
                f"{config._name_or_path}/gpt2",
                local_files_only=True
            )
        else:
            self.clip = CLIPModel.from_pretrained(config.clip_model_name)
            self.gpt2 = GPT2Model.from_pretrained(config.gpt2_model_name)

        self.feature_projection = nn.Linear(
            self.clip.vision_model.config.hidden_size + self.gpt2.config.hidden_size,
            config.projection_dim
        )

        # Initialize weights and apply final processing
        self.post_init()

    def save_pretrained(self, save_directory, **kwargs):
        """Override save_pretrained to save sub-models separately"""
        super().save_pretrained(save_directory, **kwargs)
        
        # Save CLIP and GPT-2 models in subdirectories
        self.clip.save_pretrained(f"{save_directory}/clip")
        self.gpt2.save_pretrained(f"{save_directory}/gpt2")

class ArlowGPTForImageTextToText(ArlowGPTPreTrainedModel):
    def __init__(self, config: ArlowGPTConfig):
        super().__init__(config)
        self.arlow_gpt = ArlowGPTModel(config)
        self.output_projection = nn.Linear(config.projection_dim, config.vocab_size)
        
        # Initialize weights and apply final processing
        self.post_init()

    def save_pretrained(self, save_directory, **kwargs):
        """Override save_pretrained to save all components"""
        super().save_pretrained(save_directory, **kwargs)
        self.arlow_gpt.save_pretrained(save_directory)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        """Override from_pretrained to handle custom loading logic"""
        config = kwargs.get("config", None)
        if config is None:
            config = ArlowGPTConfig.from_pretrained(pretrained_model_name_or_path)
            kwargs["config"] = config
        
        # Set the path for loading sub-models
        config._name_or_path = pretrained_model_name_or_path
        
        return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        pixel_values: torch.Tensor,
        labels: Optional[torch.Tensor] = None,
        return_dict: bool = True,
    ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
        outputs = self.arlow_gpt(
            input_ids=input_ids,
            attention_mask=attention_mask,
            pixel_values=pixel_values,
            return_dict=True
        )

        hidden_states = outputs["hidden_states"]
        logits = self.output_projection(hidden_states)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))

        if return_dict:
            return {
                "loss": loss,
                "logits": logits
            }
        return (loss, logits) if loss is not None else logits

    def prepare_inputs_for_generation(
        self, input_ids, past=None, attention_mask=None, **kwargs
    ):
        # only last token for inputs_ids if past is defined in kwargs
        if past:
            input_ids = input_ids[:, -1].unsqueeze(-1)
            if attention_mask is not None:
                attention_mask = attention_mask[:, -1].unsqueeze(-1)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "pixel_values": kwargs.get("pixel_values", None),
            "past_key_values": past,
        }