movie-diffusion / inference.py
Anton Forsman
fix device issue
94539f6
raw
history blame
1.39 kB
import torch
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
import numpy as np
from PIL import Image
import requests
import io
from unet import Unet, ConditionalUnet
from diffusion import GaussianDiffusion, DiffusionImageAPI
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def inference1():
# new image from web page
image = requests.get("https://picsum.photos/120/80").content
return Image.open(io.BytesIO(image))
def inference():
model = Unet(
image_channels=3,
dropout=0.1,
)
model = ConditionalUnet(
unet=model,
num_classes=13,
)
model.load_state_dict(torch.load("./model_final.pt", map_location=device))
diffusion = GaussianDiffusion(
model=model,
noise_steps=1000,
beta_0=1e-4,
beta_T=0.02,
image_size=(192, 128),
)
model.to(device)
diffusion.to(device)
imageAPI = DiffusionImageAPI(diffusion)
images, versions = diffusion.sample(1)
images = []
for image in versions:
images.append(imageAPI.tensor_to_image(image.squeeze(0)))
#print(len(images))
#print(images[0])
## make gif out of pillow images
#images[0].save('./gif_output/versions.gif',
# save_all=True,
# append_images=images[1:],
# duration=100,
# loop=0)
return images[-1]
if __name__ == "__main__":
inference().show()