File size: 1,833 Bytes
f01c2b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import numpy as np
import torch
import torch.nn as nn
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import os
import torch.nn.functional as F
from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
from transformers import AutoConfig
from collections import OrderedDict


class HybridTowerConfig(PretrainedConfig):
    model_type = "hybrid_vision_tower"

    def __init__(self, configs=None, **kwargs):
        """
        Initializes the HybridTowerConfig.
        
        Args:
            configs (dict, optional): A dictionary where keys are component names and values are
                                      instances of configurations that have a `to_dict()` method.
            **kwargs: Additional keyword arguments that are passed to the superclass.
        """
        super().__init__(**kwargs)
        self.configs = {}
        
        if configs is not None:
            if not isinstance(configs, dict):
                raise TypeError("configs must be a dictionary where keys are component names and values are configuration objects.")
            
            for component_name, config in configs.items():
                if hasattr(config, 'to_dict'):
                    self.configs[component_name] = config.to_dict()
                else:
                    raise TypeError(f"The configuration for '{component_name}' does not have a to_dict() method and cannot be serialized.")
    
    def to_dict(self):
        """
        Serializes this instance to a Python dictionary.

        Returns:
            dict: A dictionary containing all the keys and values of this configuration instance.
        """
        config_dict = super().to_dict()
        config_dict['configs'] = self.configs
        return config_dict