Spaces:
Sleeping
Sleeping
dependencies = ['torch', 'diffusers'] | |
import torch | |
from diffusers import UNet2DConditionModel | |
# mgd is the name of entrypoint | |
def mgd(dataset: str, pretrained: bool = True, **kwargs) -> UNet2DConditionModel: | |
""" # This docstring shows up in hub.help() | |
MGD model | |
pretrained (bool): kwargs, load pretrained weights into the model | |
""" | |
config = UNet2DConditionModel.load_config("runwayml/stable-diffusion-inpainting", subfolder="unet") | |
config['in_channels'] = 28 | |
unet = UNet2DConditionModel.from_config(config) | |
if pretrained: | |
checkpoint = f"https://github.com/aimagelab/multimodal-garment-designer/releases/download/weights/{dataset}.pth" | |
unet.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=True)) | |
return unet | |