VISTA3D-HF / vista3d_model.py
BinLiunls's picture
update model type
1c8621b
import os
import monai.networks.nets
import torch
from transformers import AutoConfig, AutoModel, PreTrainedModel
from vista3d_config import VISTA3DConfig
class VISTA3DModel(PreTrainedModel):
"""VISTA3D model for hugging face"""
config_class = VISTA3DConfig
def __init__(self, config):
super().__init__(config)
if config.model_type == "VISTA3D":
self.network = monai.networks.nets.vista3d132(
encoder_embed_dim=config.encoder_embed_dim,
in_channels=config.input_channels,
)
def forward(self, input):
return self.network(input)
def register_my_model():
"""Utility function to register VISTA3D model so that it can be instantiate by the AutoModel function."""
AutoConfig.register("VISTA3D", VISTA3DConfig)
AutoModel.register(VISTA3DConfig, VISTA3DModel)
if __name__ == "__main__":
FILE_PATH = os.path.dirname(__file__)
MODEL_WEIGHT_PATH = os.path.join(FILE_PATH, "models/model.pt")
MODEL_PATH = os.path.join(FILE_PATH, "vista3d_pretrained_model")
config = VISTA3DConfig()
hugging_face_model = VISTA3DModel(config)
hugging_face_model.network.load_state_dict(torch.load(MODEL_WEIGHT_PATH))
hugging_face_model.save_pretrained(MODEL_PATH)