Spaces:
Paused
Paused
import torch | |
import torch.optim as optim | |
import torch.nn as nn | |
from tqdm import tqdm | |
import numpy as np | |
from PIL import Image | |
import requests | |
import io | |
from unet import Unet, ConditionalUnet | |
from diffusion import GaussianDiffusion, DiffusionImageAPI | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def inference1(): | |
# new image from web page | |
image = requests.get("https://picsum.photos/120/80").content | |
return Image.open(io.BytesIO(image)) | |
def inference(cond, x0=None, gif=False, callback=None): | |
model = Unet( | |
image_channels=3, | |
dropout=0.1, | |
) | |
model = ConditionalUnet( | |
unet=model, | |
num_classes=13, | |
) | |
model.load_state_dict(torch.load("./model_final.pt", map_location=device)) | |
diffusion = GaussianDiffusion( | |
model=model, | |
noise_steps=1000, | |
beta_0=1e-4, | |
beta_T=0.02, | |
image_size=(192, 128), | |
) | |
if x0 is not None: | |
x0 = diffusion.normalize_image(x0) | |
x0 = x0.permute(2, 0, 1) | |
x0 = x0.unsqueeze(0) | |
model.to(device) | |
diffusion.to(device) | |
imageAPI = DiffusionImageAPI(diffusion) | |
new_images, versions = diffusion.sample(1,cond=cond,x0=x0, cb=callback) | |
if gif: | |
images = [] | |
for image in versions: | |
images.append(imageAPI.tensor_to_image(image.squeeze(0))) | |
print(len(images)) | |
print(images[0]) | |
# make gif out of pillow images | |
images[0].save('./gif_output/versions.gif', | |
save_all=True, | |
append_images=images[1:], | |
duration=100, | |
loop=0) | |
return imageAPI.tensor_to_image(new_images.squeeze(0)) | |
if __name__ == "__main__": | |
inference().show() | |