mranzinger commited on
Commit
d1f280d
1 Parent(s): aa1a477

Fix double conditioning (#6)

Browse files

- Fix double conditioning (19f901633dea829072dc888c2286a14f44f5f4e4)

Files changed (3) hide show
  1. config.json +1 -0
  2. hf_model.py +4 -1
  3. radio_model.py +8 -1
config.json CHANGED
@@ -347,6 +347,7 @@
347
  "AutoConfig": "hf_model.RADIOConfig",
348
  "AutoModel": "hf_model.RADIOModel"
349
  },
 
350
  "max_resolution": 2048,
351
  "patch_size": 16,
352
  "preferred_resolution": [
 
347
  "AutoConfig": "hf_model.RADIOConfig",
348
  "AutoModel": "hf_model.RADIOModel"
349
  },
350
+ "external_conditioner": false,
351
  "max_resolution": 2048,
352
  "patch_size": 16,
353
  "preferred_resolution": [
hf_model.py CHANGED
@@ -45,6 +45,7 @@ class RADIOConfig(PretrainedConfig):
45
  preferred_resolution: Optional[Resolution] = None,
46
  adaptor_names: Union[str, List[str]] = None,
47
  vitdet_window_size: Optional[int] = None,
 
48
  **kwargs,
49
  ):
50
  self.args = args
@@ -63,6 +64,7 @@ class RADIOConfig(PretrainedConfig):
63
  )
64
  self.adaptor_names = adaptor_names
65
  self.vitdet_window_size = vitdet_window_size
 
66
  super().__init__(**kwargs)
67
 
68
 
@@ -75,7 +77,7 @@ class RADIOModel(PreTrainedModel):
75
 
76
  config_class = RADIOConfig
77
 
78
- def __init__(self, config):
79
  super().__init__(config)
80
 
81
  RADIOArgs = namedtuple("RADIOArgs", config.args.keys())
@@ -115,6 +117,7 @@ class RADIOModel(PreTrainedModel):
115
  preferred_resolution=config.preferred_resolution,
116
  adaptors=adaptors,
117
  )
 
118
 
119
  @property
120
  def adaptors(self) -> nn.ModuleDict:
 
45
  preferred_resolution: Optional[Resolution] = None,
46
  adaptor_names: Union[str, List[str]] = None,
47
  vitdet_window_size: Optional[int] = None,
48
+ external_conditioner: Optional[bool] = False,
49
  **kwargs,
50
  ):
51
  self.args = args
 
64
  )
65
  self.adaptor_names = adaptor_names
66
  self.vitdet_window_size = vitdet_window_size
67
+ self.external_conditioner = external_conditioner
68
  super().__init__(**kwargs)
69
 
70
 
 
77
 
78
  config_class = RADIOConfig
79
 
80
+ def __init__(self, config: RADIOConfig):
81
  super().__init__(config)
82
 
83
  RADIOArgs = namedtuple("RADIOArgs", config.args.keys())
 
117
  preferred_resolution=config.preferred_resolution,
118
  adaptors=adaptors,
119
  )
120
+ self.radio_model._external_conditioner = config.external_conditioner
121
 
122
  @property
123
  def adaptors(self) -> nn.ModuleDict:
radio_model.py CHANGED
@@ -51,6 +51,12 @@ class RADIOModel(nn.Module):
51
  self._patch_size = patch_size
52
  self._max_resolution = max_resolution
53
  self._window_size = window_size
 
 
 
 
 
 
54
 
55
  adaptors = adaptors or dict()
56
  self.adaptors = nn.ModuleDict(adaptors)
@@ -113,7 +119,8 @@ class RADIOModel(nn.Module):
113
  '`self.get_nearest_supported_resolution(<height>, <width>) is provided as a convenience API. '
114
  f'Input: {x.shape[-2:]}, Nearest: {self.get_nearest_supported_resolution(*x.shape[-2:])}')
115
 
116
- x = self.input_conditioner(x)
 
117
  y = self.model.forward_features(x)
118
 
119
  if isinstance(self.model, VisionTransformer):
 
51
  self._patch_size = patch_size
52
  self._max_resolution = max_resolution
53
  self._window_size = window_size
54
+ # This is a hack workaround for huggingface, since their
55
+ # data prep is annoying and complicated. If set to true,
56
+ # then will not call `self.input_conditioner` on the
57
+ # input tensor. This will be set in `hf_model.RADIOModel`
58
+ # where appropriate.
59
+ self._external_conditioner = False
60
 
61
  adaptors = adaptors or dict()
62
  self.adaptors = nn.ModuleDict(adaptors)
 
119
  '`self.get_nearest_supported_resolution(<height>, <width>) is provided as a convenience API. '
120
  f'Input: {x.shape[-2:]}, Nearest: {self.get_nearest_supported_resolution(*x.shape[-2:])}')
121
 
122
+ if not self._external_conditioner:
123
+ x = self.input_conditioner(x)
124
  y = self.model.forward_features(x)
125
 
126
  if isinstance(self.model, VisionTransformer):