File size: 952 Bytes
063d3d8 d573157 232b8ae 063d3d8 232b8ae d573157 063d3d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
# configuration_arlow_gpt.py
from transformers import PretrainedConfig
from typing import Dict, Optional
ARLOW_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"yuchenxie/GPT-2V": "https://huggingface.co/yuchenxie/GPT-2V/resolve/main/config.json",
}
class ArlowGPTConfig(PretrainedConfig):
model_type = "arlow_gpt"
def __init__(
self,
clip_model_name: str = "yuchenxie/CLiP",
gpt2_model_name: str = "yuchenxie/GPT-2",
clip_config: Optional[Dict] = None,
gpt2_config: Optional[Dict] = None,
projection_dim: int = 768,
vocab_size: int = 50257,
**kwargs
):
super().__init__(**kwargs)
self.clip_model_name = clip_model_name
self.gpt2_model_name = gpt2_model_name
self.clip_config = clip_config
self.gpt2_config = gpt2_config
self.projection_dim = projection_dim
self.vocab_size = vocab_size
self._name_or_path = None |