BinLiunls commited on
Commit
1c8621b
·
1 Parent(s): 45d6413

update model type

Browse files

Signed-off-by: binliu <binliu@nvidia.com>

Files changed (1) hide show
  1. vista3d_model.py +5 -4
vista3d_model.py CHANGED
@@ -13,10 +13,11 @@ class VISTA3DModel(PreTrainedModel):
13
 
14
  def __init__(self, config):
15
  super().__init__(config)
16
- self.network = monai.networks.nets.vista3d132(
17
- encoder_embed_dim=config.encoder_embed_dim,
18
- in_channels=config.input_channels,
19
- )
 
20
 
21
  def forward(self, input):
22
  return self.network(input)
 
13
 
14
  def __init__(self, config):
15
  super().__init__(config)
16
+ if config.model_type == "VISTA3D":
17
+ self.network = monai.networks.nets.vista3d132(
18
+ encoder_embed_dim=config.encoder_embed_dim,
19
+ in_channels=config.input_channels,
20
+ )
21
 
22
  def forward(self, input):
23
  return self.network(input)