import torch from .ema import ExponentialMovingAverage def load_model_weights(model, ckpt_path, use_ema=True, device='cuda:0'): """ Load weights of a model from a checkpoint file. Args: model (torch.nn.Module): The model to load weights into. ckpt_path (str): Path to the checkpoint file. use_ema (bool): Whether to use Exponential Moving Average (EMA) weights if available. """ checkpoint = torch.load(ckpt_path,map_location={'cuda:0': str(device)}) total_iter = checkpoint.get('total_it', 0) if "model_ema" in checkpoint and use_ema: ema_key = next(iter(checkpoint["model_ema"])) if ('module' in ema_key) or ('n_averaged' in ema_key): model = ExponentialMovingAverage(model, decay=1.0) model.load_state_dict(checkpoint["model_ema"], strict=True) if ('module' in ema_key) or ('n_averaged' in ema_key): model = model.module print(f'\nLoading EMA module model from {ckpt_path} with {total_iter} iterations') else: print(f'\nLoading EMA model from {ckpt_path} with {total_iter} iterations') else: model.load_state_dict(checkpoint['encoder'], strict=True) print(f'\nLoading model from {ckpt_path} with {total_iter} iterations') return total_iter