|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections import namedtuple |
|
from typing import Callable, Optional, List, Union |
|
|
|
from timm.models import VisionTransformer |
|
import torch |
|
from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
|
|
from .common import RESOURCE_MAP, DEFAULT_VERSION |
|
|
|
|
|
from .eradio_model import eradio |
|
from .radio_model import create_model_from_args |
|
from .radio_model import RADIOModel as RADIOModelBase, Resolution |
|
from .input_conditioner import get_default_conditioner, InputConditioner |
|
|
|
|
|
|
|
from .extra_timm_models import * |
|
|
|
|
|
class RADIOConfig(PretrainedConfig): |
|
"""Pretrained Hugging Face configuration for RADIO models.""" |
|
|
|
def __init__( |
|
self, |
|
args: Optional[dict] = None, |
|
version: Optional[str] = DEFAULT_VERSION, |
|
patch_size: Optional[int] = None, |
|
max_resolution: Optional[int] = None, |
|
preferred_resolution: Optional[Resolution] = None, |
|
adaptor_names: Union[str, List[str]] = None, |
|
vitdet_window_size: Optional[int] = None, |
|
**kwargs, |
|
): |
|
self.args = args |
|
for field in ["dtype", "amp_dtype"]: |
|
if self.args is not None and field in self.args: |
|
|
|
|
|
|
|
self.args[field] = str(args[field]).split(".")[-1] |
|
self.version = version |
|
resource = RESOURCE_MAP[version] |
|
self.patch_size = patch_size or resource.patch_size |
|
self.max_resolution = max_resolution or resource.max_resolution |
|
self.preferred_resolution = ( |
|
preferred_resolution or resource.preferred_resolution |
|
) |
|
self.adaptor_names = adaptor_names |
|
self.vitdet_window_size = vitdet_window_size |
|
super().__init__(**kwargs) |
|
|
|
|
|
class RADIOModel(PreTrainedModel): |
|
"""Pretrained Hugging Face model for RADIO. |
|
|
|
This class inherits from PreTrainedModel, which provides |
|
HuggingFace's functionality for loading and saving models. |
|
""" |
|
|
|
config_class = RADIOConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
RADIOArgs = namedtuple("RADIOArgs", config.args.keys()) |
|
args = RADIOArgs(**config.args) |
|
self.config = config |
|
|
|
model = create_model_from_args(args) |
|
input_conditioner: InputConditioner = get_default_conditioner() |
|
|
|
dtype = getattr(args, "dtype", torch.float32) |
|
if isinstance(dtype, str): |
|
|
|
dtype = getattr(torch, dtype) |
|
model.to(dtype=dtype) |
|
input_conditioner.dtype = dtype |
|
|
|
summary_idxs = torch.tensor( |
|
[i for i, t in enumerate(args.teachers) if t.get("use_summary", True)], |
|
dtype=torch.int64, |
|
) |
|
|
|
adaptor_names = config.adaptor_names |
|
if adaptor_names is not None: |
|
raise NotImplementedError( |
|
f"Adaptors are not yet supported in Hugging Face models. Adaptor names: {adaptor_names}" |
|
) |
|
|
|
adaptors = dict() |
|
|
|
self.radio_model = RADIOModelBase( |
|
model, |
|
input_conditioner, |
|
summary_idxs=summary_idxs, |
|
patch_size=config.patch_size, |
|
max_resolution=config.max_resolution, |
|
window_size=config.vitdet_window_size, |
|
preferred_resolution=config.preferred_resolution, |
|
adaptors=adaptors, |
|
) |
|
|
|
@property |
|
def model(self) -> VisionTransformer: |
|
return self.radio_model.model |
|
|
|
@property |
|
def input_conditioner(self) -> InputConditioner: |
|
return self.radio_model.input_conditioner |
|
|
|
@input_conditioner.setter |
|
def input_conditioner(self, v: InputConditioner): |
|
self.radio_model.input_conditioner = v |
|
|
|
def make_preprocessor_external(self) -> Callable[[torch.Tensor], torch.Tensor]: |
|
ret = self.input_conditioner |
|
self.input_conditioner = nn.Identity() |
|
return ret |
|
|
|
def forward(self, x: torch.Tensor): |
|
return self.radio_model.forward(x) |
|
|