Spaces:
Running
Running
File size: 4,991 Bytes
265ae36 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
from enum import Enum
from typing import Union
import torch
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
class Weights(Enum):
LVD142M = "LVD142M"
def _make_dinov2_model(
*,
arch_name: str = "vit_large",
img_size: int = 518,
patch_size: int = 14,
init_values: float = 1.0,
ffn_layer: str = "mlp",
block_chunks: int = 0,
num_register_tokens: int = 0,
interpolate_antialias: bool = False,
interpolate_offset: float = 0.1,
pretrained: bool = True,
weights: Union[Weights, str] = Weights.LVD142M,
**kwargs,
):
from ..models import vision_transformer as vits
if isinstance(weights, str):
try:
weights = Weights[weights]
except KeyError:
raise AssertionError(f"Unsupported weights: {weights}")
model_base_name = _make_dinov2_model_name(arch_name, patch_size)
vit_kwargs = dict(
img_size=img_size,
patch_size=patch_size,
init_values=init_values,
ffn_layer=ffn_layer,
block_chunks=block_chunks,
num_register_tokens=num_register_tokens,
interpolate_antialias=interpolate_antialias,
interpolate_offset=interpolate_offset,
)
vit_kwargs.update(**kwargs)
model = vits.__dict__[arch_name](**vit_kwargs)
if pretrained:
model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
model.load_state_dict(state_dict, strict=True)
return model
def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
"""
DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
"""
return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
"""
DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
"""
return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
"""
DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
"""
return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
"""
DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
"""
return _make_dinov2_model(
arch_name="vit_giant2",
ffn_layer="swiglufused",
weights=weights,
pretrained=pretrained,
**kwargs,
)
def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
"""
DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
"""
return _make_dinov2_model(
arch_name="vit_small",
pretrained=pretrained,
weights=weights,
num_register_tokens=4,
interpolate_antialias=True,
interpolate_offset=0.0,
**kwargs,
)
def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
"""
DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
"""
return _make_dinov2_model(
arch_name="vit_base",
pretrained=pretrained,
weights=weights,
num_register_tokens=4,
interpolate_antialias=True,
interpolate_offset=0.0,
**kwargs,
)
def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
"""
DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
"""
return _make_dinov2_model(
arch_name="vit_large",
pretrained=pretrained,
weights=weights,
num_register_tokens=4,
interpolate_antialias=True,
interpolate_offset=0.0,
**kwargs,
)
def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
"""
DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
"""
return _make_dinov2_model(
arch_name="vit_giant2",
ffn_layer="swiglufused",
weights=weights,
pretrained=pretrained,
num_register_tokens=4,
interpolate_antialias=True,
interpolate_offset=0.0,
**kwargs,
)
|