|
import numpy as np |
|
import torch |
|
|
|
|
|
def torch_samps_to_imgs(imgs, uncenter=True): |
|
if uncenter: |
|
imgs = (imgs + 1) / 2 |
|
imgs = (imgs * 255).clamp(0, 255) |
|
imgs = imgs.to(torch.uint8) |
|
imgs = imgs.permute(0, 2, 3, 1) |
|
imgs = imgs.cpu().numpy() |
|
return imgs |
|
|
|
|
|
def imgs_to_torch(imgs): |
|
assert imgs.dtype == np.uint8 |
|
assert len(imgs.shape) == 4 and imgs.shape[-1] == 3, "expect (N, H, W, C)" |
|
_, H, W, _ = imgs.shape |
|
|
|
imgs = imgs.transpose(0, 3, 1, 2) |
|
imgs = (imgs / 255).astype(np.float32) |
|
imgs = (imgs * 2) - 1 |
|
imgs = torch.as_tensor(imgs) |
|
H, W = [_l - (_l % 32) for _l in (H, W)] |
|
imgs = torch.nn.functional.interpolate(imgs, (H, W), mode="bilinear") |
|
return imgs |
|
|
|
|
|
def test_encode_decode(): |
|
import imageio |
|
from run_img_sampling import ScoreAdapter, SD |
|
from vis import _draw |
|
|
|
fname = "~/clean.png" |
|
raw = imageio.imread(fname) |
|
raw = imgs_to_torch(raw[np.newaxis, ...]) |
|
|
|
model: ScoreAdapter = SD().run() |
|
raw = raw.to(model.device) |
|
zs = model.encode(raw) |
|
img = model.decode(zs) |
|
img = torch_samps_to_imgs(img) |
|
_draw( |
|
[imageio.imread(fname), img.squeeze(0)], |
|
) |
|
|
|
|
|
def test(): |
|
test_encode_decode() |
|
|
|
|
|
if __name__ == "__main__": |
|
test() |
|
|