import torch import torch.optim as optim import torch.nn as nn from tqdm import tqdm import numpy as np from PIL import Image import requests import io from model import Unet, ConditionalUnet, GaussianDiffusion, DiffusionImageAPI def inference1(): # new image from web page image = requests.get("https://picsum.photos/120/80").content return Image.open(io.BytesIO(image)) def inference(): model = Unet( image_channels=3, ) model.load_state_dict(torch.load("./model_final.pt")) diffusion = GaussianDiffusion( model=model, noise_steps=1000, beta_0=1e-4, beta_T=0.02, image_size=(120, 80), ) imageAPI = DiffusionImageAPI(diffusion) images, versions = diffusion.sample(1) images = [] for image in versions: images.append(imageAPI.tensor_to_image(image.squeeze(0))) #print(len(images)) #print(images[0]) ## make gif out of pillow images #images[0].save('./gif_output/versions.gif', # save_all=True, # append_images=images[1:], # duration=100, # loop=0) return images[-1] if __name__ == "__main__": inference().show()