update model type
Browse filesSigned-off-by: binliu <binliu@nvidia.com>
- vista3d_model.py +5 -4
vista3d_model.py
CHANGED
@@ -13,10 +13,11 @@ class VISTA3DModel(PreTrainedModel):
|
|
13 |
|
14 |
def __init__(self, config):
|
15 |
super().__init__(config)
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
20 |
|
21 |
def forward(self, input):
|
22 |
return self.network(input)
|
|
|
13 |
|
14 |
def __init__(self, config):
|
15 |
super().__init__(config)
|
16 |
+
if config.model_type == "VISTA3D":
|
17 |
+
self.network = monai.networks.nets.vista3d132(
|
18 |
+
encoder_embed_dim=config.encoder_embed_dim,
|
19 |
+
in_channels=config.input_channels,
|
20 |
+
)
|
21 |
|
22 |
def forward(self, input):
|
23 |
return self.network(input)
|