File size: 1,946 Bytes
108b1ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
# International Conference on Computer Vision (ICCV), 2023

from efficientvit.models.efficientvit import (EfficientViTSam,
                                              efficientvit_sam_l0,
                                              efficientvit_sam_l1,
                                              efficientvit_sam_l2,
                                              efficientvit_sam_xl0,
                                              efficientvit_sam_xl1)
from efficientvit.models.nn.norm import set_norm_eps
from efficientvit.models.utils import load_state_dict_from_file

__all__ = ["create_sam_model"]


REGISTERED_SAM_MODEL: dict[str, str] = {
    "l0": "assets/checkpoints/sam/l0.pt",
    "l1": "assets/checkpoints/sam/l1.pt",
    "l2": "assets/checkpoints/sam/l2.pt",
    "xl0": "assets/checkpoints/sam/xl0.pt",
    "xl1": "assets/checkpoints/sam/xl1.pt",
}


def create_sam_model(
    name: str, pretrained=True, weight_url: str or None = None, **kwargs
) -> EfficientViTSam:
    model_dict = {
        "l0": efficientvit_sam_l0,
        "l1": efficientvit_sam_l1,
        "l2": efficientvit_sam_l2,
        "xl0": efficientvit_sam_xl0,
        "xl1": efficientvit_sam_xl1,
    }

    model_id = name.split("-")[0]
    if model_id not in model_dict:
        raise ValueError(
            f"Do not find {name} in the model zoo. List of models: {list(model_dict.keys())}"
        )
    else:
        model = model_dict[model_id](**kwargs)
    set_norm_eps(model, 1e-6)

    if pretrained:
        weight_url = weight_url or REGISTERED_SAM_MODEL.get(name, None)
        if weight_url is None:
            raise ValueError(f"Do not find the pretrained weight of {name}.")
        else:
            weight = load_state_dict_from_file(weight_url)
            model.load_state_dict(weight)
    return model