File size: 1,045 Bytes
5f9e8f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from transformers import SamConfig, SamModel, SamProcessor, SamImageProcessor
from transformers.models.sam.convert_sam_original_to_hf_format import replace_keys

from segment_anything import sam_model_registry  # pip install git+https://github.com/facebookresearch/segment-anything.git

# load the MedSAM ViT-B model
checkpoint = 'medsam_vit_b.pth'  # https://drive.google.com/file/d/1UAmWL88roYR7wKlnApw5Bcuzf2iQgk6_/view?usp=drive_link
pt_model = sam_model_registry['vit_b'](checkpoint)
pt_state_dict = pt_model.state_dict()

# tweak the model's weights to transformers design
hf_state_dict = replace_keys(pt_state_dict)

# save the model
hf_model = SamModel(config=SamConfig())
hf_model.load_state_dict(hf_state_dict)
hf_model.save_pretrained('./')

# update the processor, inputs are min-max scaled instead of normalized
hf_processor = SamProcessor(
    image_processor=SamImageProcessor(
        do_normalize=False,
        image_mean=[0, 0, 0],
        image_std=[1, 1, 1],
    )
)

# save the processor
hf_processor.save_pretrained('./')