Spaces:
Paused
Paused
import torch | |
import torch.optim as optim | |
import torch.nn as nn | |
from tqdm import tqdm | |
import numpy as np | |
from PIL import Image | |
import requests | |
import io | |
from unet import Unet, ConditionalUnet | |
from diffusion import GaussianDiffusion, DiffusionImageAPI | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def inference1(): | |
# new image from web page | |
image = requests.get("https://picsum.photos/120/80").content | |
return Image.open(io.BytesIO(image)) | |
def inference(): | |
model = Unet( | |
image_channels=3, | |
dropout=0.1, | |
) | |
model = ConditionalUnet( | |
unet=model, | |
num_classes=13, | |
) | |
model.load_state_dict(torch.load("./model_final.pt", map_location=device)) | |
diffusion = GaussianDiffusion( | |
model=model, | |
noise_steps=1000, | |
beta_0=1e-4, | |
beta_T=0.02, | |
image_size=(192, 128), | |
) | |
model.to(device) | |
diffusion.to(device) | |
imageAPI = DiffusionImageAPI(diffusion) | |
images, versions = diffusion.sample(1) | |
images = [] | |
for image in versions: | |
images.append(imageAPI.tensor_to_image(image.squeeze(0))) | |
#print(len(images)) | |
#print(images[0]) | |
## make gif out of pillow images | |
#images[0].save('./gif_output/versions.gif', | |
# save_all=True, | |
# append_images=images[1:], | |
# duration=100, | |
# loop=0) | |
return images[-1] | |
if __name__ == "__main__": | |
inference().show() | |