|
import torch |
|
import torch.nn as nn |
|
from typing import Dict, List, Union |
|
from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel |
|
import torch.nn.functional as F |
|
import json, os |
|
|
|
|
|
class ConnectorConfig(PretrainedConfig): |
|
model_type = "mm_connector" |
|
def __init__( |
|
self, |
|
vision_hidden_size: List[int] = [], |
|
text_hidden_size: int = 0, |
|
num_patches: int = 24, |
|
rms_norm_eps: float = 1e-4, |
|
token_input_shape: List[int] = [], |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.vision_hidden_size = vision_hidden_size |
|
self.text_hidden_size = text_hidden_size |
|
self.num_patches = num_patches |
|
self.rms_norm_eps=rms_norm_eps |
|
self.token_input_shape = token_input_shape |
|
|
|
@classmethod |
|
def load_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "ConnectorConfig": |
|
cls._set_token_in_kwargs(kwargs) |
|
config_dict, kwargs = cls.get_config_from_json(pretrained_model_name_or_path, **kwargs) |
|
return cls.from_dict(config_dict, **kwargs) |
|
|
|
@classmethod |
|
def get_config_from_json(cls, config_file, **kwargs): |
|
with open(config_file, 'r') as file: |
|
config_data = json.load(file) |
|
return config_data, kwargs |
|
|
|
|