File size: 633 Bytes
3b13f40 24d96ab 3b13f40 24d96ab 3b13f40 24d96ab 3b13f40 24d96ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
import timm
from timm import data
import torch.nn as nn
from torchvision import transforms
from src.config import TinyCLIPVisionConfig
def get_vision_base(
config: TinyCLIPVisionConfig,
) -> tuple[nn.Module, int]:
base = timm.create_model(config.vision_model, num_classes=0, pretrained=True)
num_features = base.num_features
return base, num_features
def get_vision_transform(config: TinyCLIPVisionConfig) -> transforms.Compose:
timm_config = data.resolve_data_config({}, model=config.vision_model)
transform = data.transforms_factory.create_transform(**timm_config)
return transform # type: ignore
|