File size: 1,138 Bytes
e7d3e35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import re

from transformers import AutoModel


# map to check the supported cv archs and also how to extract the model - in some arch, we want to
# go through a specific prefix to get to the model as in `model.vision_model` for clip
vision_model_name2model = {
    r"clip": lambda model: model.vision_model,
    r"vit": lambda model: model,
}


def vision_model_name_to_model(model_name_or_path, model):
    """returns the model if supported, asserts otherwise"""

    model_name_lowcase = model_name_or_path.lower()
    for rx, lookup in vision_model_name2model.items():
        if re.search(rx, model_name_lowcase):
            return lookup(model)
    else:
        raise ValueError(
            f"Unknown type of backbone vision model. Got {model_name_or_path}, supported regexes:"
            f" {list(vision_model_name2model.keys())}."
        )


def get_vision_model(config):
    vision_model_name = config.vision_model_name
    vision_model_params = eval(config.vision_model_params)

    model = AutoModel.from_pretrained(vision_model_name, **vision_model_params)
    return vision_model_name_to_model(vision_model_name, model)