import timm | |
from timm import data | |
import torch.nn as nn | |
from torchvision import transforms | |
from src.config import TinyCLIPVisionConfig | |
def get_vision_base( | |
config: TinyCLIPVisionConfig, | |
) -> tuple[nn.Module, int]: | |
base = timm.create_model(config.vision_model, num_classes=0, pretrained=True) | |
num_features = base.num_features | |
return base, num_features | |
def get_vision_transform(config: TinyCLIPVisionConfig) -> transforms.Compose: | |
timm_config = data.resolve_data_config({}, model=config.vision_model) | |
transform = data.transforms_factory.create_transform(**timm_config) | |
return transform # type: ignore | |