File size: 952 Bytes
063d3d8
d573157
 
 
232b8ae
063d3d8
232b8ae
 
d573157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
063d3d8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# configuration_arlow_gpt.py
from transformers import PretrainedConfig
from typing import Dict, Optional

ARLOW_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
    "yuchenxie/GPT-2V": "https://huggingface.co/yuchenxie/GPT-2V/resolve/main/config.json",
}

class ArlowGPTConfig(PretrainedConfig):
    model_type = "arlow_gpt"
    
    def __init__(
        self,
        clip_model_name: str = "yuchenxie/CLiP",
        gpt2_model_name: str = "yuchenxie/GPT-2",
        clip_config: Optional[Dict] = None,
        gpt2_config: Optional[Dict] = None,
        projection_dim: int = 768,
        vocab_size: int = 50257,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.clip_model_name = clip_model_name
        self.gpt2_model_name = gpt2_model_name
        self.clip_config = clip_config
        self.gpt2_config = gpt2_config
        self.projection_dim = projection_dim
        self.vocab_size = vocab_size
        self._name_or_path = None