import torch import torch.nn as nn from typing import Dict, List, Union from transformers import PreTrainedModel, PretrainedConfig, AutoConfig, AutoModel import torch.nn.functional as F import json, os class ConnectorConfig(PretrainedConfig): model_type = "mm_connector" def __init__( self, vision_hidden_size: List[int] = [], text_hidden_size: int = 0, num_patches: int = 24, rms_norm_eps: float = 1e-4, token_input_shape: List[int] = [], **kwargs, ): super().__init__(**kwargs) self.vision_hidden_size = vision_hidden_size self.text_hidden_size = text_hidden_size self.num_patches = num_patches self.rms_norm_eps=rms_norm_eps self.token_input_shape = token_input_shape @classmethod def load_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "ConnectorConfig": cls._set_token_in_kwargs(kwargs) config_dict, kwargs = cls.get_config_from_json(pretrained_model_name_or_path, **kwargs) return cls.from_dict(config_dict, **kwargs) @classmethod def get_config_from_json(cls, config_file, **kwargs): with open(config_file, 'r') as file: config_data = json.load(file) return config_data, kwargs