File size: 1,045 Bytes
5f9e8f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
from transformers import SamConfig, SamModel, SamProcessor, SamImageProcessor
from transformers.models.sam.convert_sam_original_to_hf_format import replace_keys
from segment_anything import sam_model_registry # pip install git+https://github.com/facebookresearch/segment-anything.git
# load the MedSAM ViT-B model
checkpoint = 'medsam_vit_b.pth' # https://drive.google.com/file/d/1UAmWL88roYR7wKlnApw5Bcuzf2iQgk6_/view?usp=drive_link
pt_model = sam_model_registry['vit_b'](checkpoint)
pt_state_dict = pt_model.state_dict()
# tweak the model's weights to transformers design
hf_state_dict = replace_keys(pt_state_dict)
# save the model
hf_model = SamModel(config=SamConfig())
hf_model.load_state_dict(hf_state_dict)
hf_model.save_pretrained('./')
# update the processor, inputs are min-max scaled instead of normalized
hf_processor = SamProcessor(
image_processor=SamImageProcessor(
do_normalize=False,
image_mean=[0, 0, 0],
image_std=[1, 1, 1],
)
)
# save the processor
hf_processor.save_pretrained('./')
|