|
import os |
|
|
|
import monai.networks.nets |
|
import torch |
|
from transformers import AutoConfig, AutoModel, PreTrainedModel |
|
from vista3d_config import VISTA3DConfig |
|
|
|
|
|
class VISTA3DModel(PreTrainedModel): |
|
"""VISTA3D model for hugging face""" |
|
|
|
config_class = VISTA3DConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
if config.model_type == "VISTA3D": |
|
self.network = monai.networks.nets.vista3d132( |
|
encoder_embed_dim=config.encoder_embed_dim, |
|
in_channels=config.input_channels, |
|
) |
|
|
|
def forward(self, input): |
|
return self.network(input) |
|
|
|
|
|
def register_my_model(): |
|
"""Utility function to register VISTA3D model so that it can be instantiate by the AutoModel function.""" |
|
AutoConfig.register("VISTA3D", VISTA3DConfig) |
|
AutoModel.register(VISTA3DConfig, VISTA3DModel) |
|
|
|
|
|
if __name__ == "__main__": |
|
FILE_PATH = os.path.dirname(__file__) |
|
MODEL_WEIGHT_PATH = os.path.join(FILE_PATH, "models/model.pt") |
|
MODEL_PATH = os.path.join(FILE_PATH, "vista3d_pretrained_model") |
|
config = VISTA3DConfig() |
|
hugging_face_model = VISTA3DModel(config) |
|
hugging_face_model.network.load_state_dict(torch.load(MODEL_WEIGHT_PATH)) |
|
hugging_face_model.save_pretrained(MODEL_PATH) |
|
|