dummy_m4 / m4 /training /setup_vision_model.py
ysharma's picture
ysharma HF staff
Duplicate from HuggingFaceM4/m4-dialogue
e7d3e35
import re
from transformers import AutoModel
# map to check the supported cv archs and also how to extract the model - in some arch, we want to
# go through a specific prefix to get to the model as in `model.vision_model` for clip
vision_model_name2model = {
r"clip": lambda model: model.vision_model,
r"vit": lambda model: model,
}
def vision_model_name_to_model(model_name_or_path, model):
"""returns the model if supported, asserts otherwise"""
model_name_lowcase = model_name_or_path.lower()
for rx, lookup in vision_model_name2model.items():
if re.search(rx, model_name_lowcase):
return lookup(model)
else:
raise ValueError(
f"Unknown type of backbone vision model. Got {model_name_or_path}, supported regexes:"
f" {list(vision_model_name2model.keys())}."
)
def get_vision_model(config):
vision_model_name = config.vision_model_name
vision_model_params = eval(config.vision_model_params)
model = AutoModel.from_pretrained(vision_model_name, **vision_model_params)
return vision_model_name_to_model(vision_model_name, model)