File size: 1,275 Bytes
d661b19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import numpy as np
import torch


def torch_samps_to_imgs(imgs, uncenter=True):
    if uncenter:
        imgs = (imgs + 1) / 2  # [-1, 1] -> [0, 1]
    imgs = (imgs * 255).clamp(0, 255)
    imgs = imgs.to(torch.uint8)
    imgs = imgs.permute(0, 2, 3, 1)
    imgs = imgs.cpu().numpy()
    return imgs


def imgs_to_torch(imgs):
    assert imgs.dtype == np.uint8
    assert len(imgs.shape) == 4 and imgs.shape[-1] == 3, "expect (N, H, W, C)"
    _, H, W, _ = imgs.shape

    imgs = imgs.transpose(0, 3, 1, 2)
    imgs = (imgs / 255).astype(np.float32)
    imgs = (imgs * 2) - 1
    imgs = torch.as_tensor(imgs)
    H, W = [_l - (_l % 32) for _l in (H, W)]
    imgs = torch.nn.functional.interpolate(imgs, (H, W), mode="bilinear")
    return imgs


def test_encode_decode():
    import imageio
    from run_img_sampling import ScoreAdapter, SD
    from vis import _draw

    fname = "~/clean.png"
    raw = imageio.imread(fname)
    raw = imgs_to_torch(raw[np.newaxis, ...])

    model: ScoreAdapter = SD().run()
    raw = raw.to(model.device)
    zs = model.encode(raw)
    img = model.decode(zs)
    img = torch_samps_to_imgs(img)
    _draw(
        [imageio.imread(fname), img.squeeze(0)],
    )


def test():
    test_encode_decode()


if __name__ == "__main__":
    test()