Spaces:
Runtime error
Runtime error
aliabd
commited on
Commit
•
05fb2e7
1
Parent(s):
101eb22
full demo working
Browse files
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2021 Xiaoyu Xiang
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
app.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
from data import get_image_list
|
4 |
+
from model import create_model
|
5 |
+
from data import read_img_path, tensor_to_img, save_image
|
6 |
+
import gradio as gr
|
7 |
+
import torchtext
|
8 |
+
from PIL import Image
|
9 |
+
import torch
|
10 |
+
|
11 |
+
torch.hub.download_url_to_file('https://upload.wikimedia.org/wikipedia/commons/thumb/a/a5/Tsunami_by_hokusai_19th_century.jpg/1920px-Tsunami_by_hokusai_19th_century.jpg', 'wave.jpg')
|
12 |
+
torch.hub.download_url_to_file('https://cdn.pixabay.com/photo/2020/10/02/13/49/bridge-5621201_1280.jpg', 'building.jpg')
|
13 |
+
|
14 |
+
torchtext.utils.download_from_url("https://drive.google.com/uc?id=1RILKwUdjjBBngB17JHwhZNBEaW4Mr-Ml", root="./weights/")
|
15 |
+
gpu_ids=[]
|
16 |
+
model = create_model(gpu_ids)
|
17 |
+
# model.eval()
|
18 |
+
|
19 |
+
def sketch2anime(img, load_size=512):
|
20 |
+
img, aus_resize = read_img_path(img.name, load_size)
|
21 |
+
aus_tensor = model(img)
|
22 |
+
aus_img = tensor_to_img(aus_tensor)
|
23 |
+
image_pil = Image.fromarray(aus_img)
|
24 |
+
image_pil = image_pil.resize(aus_resize, Image.BICUBIC)
|
25 |
+
return image_pil
|
26 |
+
|
27 |
+
|
28 |
+
title = "Anime2Sketch"
|
29 |
+
description = "A sketch extractor for illustration, anime art and manga. Read more at the links below."
|
30 |
+
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2104.05703'>Adversarial Open Domain Adaption for Sketch-to-Photo Synthesis</a> | <a href='https://github.com/Mukosame/Anime2Sketch'>Github Repo</a></p>"
|
31 |
+
|
32 |
+
gr.Interface(
|
33 |
+
sketch2anime,
|
34 |
+
[gr.inputs.Image(type="file", label="Input")],
|
35 |
+
gr.outputs.Image(type="pil", label="Output"),
|
36 |
+
title=title,
|
37 |
+
description=description,
|
38 |
+
article=article,
|
39 |
+
examples=[
|
40 |
+
["test_samples/madoka.jpg"],
|
41 |
+
["building.jpg"],
|
42 |
+
["wave.jpg"]
|
43 |
+
]).launch(debug=True)
|
data.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']
|
8 |
+
|
9 |
+
def is_image_file(filename):
|
10 |
+
"""if a given filename is a valid image
|
11 |
+
Parameters:
|
12 |
+
filename (str) -- image filename
|
13 |
+
"""
|
14 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
15 |
+
|
16 |
+
def get_image_list(path):
|
17 |
+
"""read the paths of valid images from the given directory path
|
18 |
+
Parameters:
|
19 |
+
path (str) -- input directory path
|
20 |
+
"""
|
21 |
+
assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
|
22 |
+
images = []
|
23 |
+
for dirpath, _, fnames in sorted(os.walk(path)):
|
24 |
+
for fname in sorted(fnames):
|
25 |
+
if is_image_file(fname):
|
26 |
+
img_path = os.path.join(dirpath, fname)
|
27 |
+
images.append(img_path)
|
28 |
+
assert images, '{:s} has no valid image file'.format(path)
|
29 |
+
return images
|
30 |
+
|
31 |
+
def get_transform(load_size=0, grayscale=False, method=Image.BICUBIC, convert=True):
|
32 |
+
transform_list = []
|
33 |
+
if grayscale:
|
34 |
+
transform_list.append(transforms.Grayscale(1))
|
35 |
+
if load_size > 0:
|
36 |
+
osize = [load_size, load_size]
|
37 |
+
transform_list.append(transforms.Resize(osize, method))
|
38 |
+
if convert:
|
39 |
+
transform_list += [transforms.ToTensor()]
|
40 |
+
if grayscale:
|
41 |
+
transform_list += [transforms.Normalize((0.5,), (0.5,))]
|
42 |
+
else:
|
43 |
+
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
|
44 |
+
return transforms.Compose(transform_list)
|
45 |
+
|
46 |
+
def read_img_path(path, load_size):
|
47 |
+
"""read tensors from a given image path
|
48 |
+
Parameters:
|
49 |
+
path (str) -- input image path
|
50 |
+
load_size(int) -- the input size. If <= 0, don't resize
|
51 |
+
"""
|
52 |
+
img = Image.open(path).convert('RGB')
|
53 |
+
aus_resize = None
|
54 |
+
if load_size > 0:
|
55 |
+
aus_resize = img.size
|
56 |
+
transform = get_transform(load_size=load_size)
|
57 |
+
image = transform(img)
|
58 |
+
return image.unsqueeze(0), aus_resize
|
59 |
+
|
60 |
+
def tensor_to_img(input_image, imtype=np.uint8):
|
61 |
+
""""Converts a Tensor array into a numpy image array.
|
62 |
+
Parameters:
|
63 |
+
input_image (tensor) -- the input image tensor array
|
64 |
+
imtype (type) -- the desired type of the converted numpy array
|
65 |
+
"""
|
66 |
+
|
67 |
+
if not isinstance(input_image, np.ndarray):
|
68 |
+
if isinstance(input_image, torch.Tensor): # get the data from a variable
|
69 |
+
image_tensor = input_image.data
|
70 |
+
else:
|
71 |
+
return input_image
|
72 |
+
image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
|
73 |
+
if image_numpy.shape[0] == 1: # grayscale to RGB
|
74 |
+
image_numpy = np.tile(image_numpy, (3, 1, 1))
|
75 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
|
76 |
+
else: # if it is a numpy array, do nothing
|
77 |
+
image_numpy = input_image
|
78 |
+
return image_numpy.astype(imtype)
|
79 |
+
|
80 |
+
def save_image(image_numpy, image_path, output_resize=None):
|
81 |
+
"""Save a numpy image to the disk
|
82 |
+
Parameters:
|
83 |
+
image_numpy (numpy array) -- input numpy array
|
84 |
+
image_path (str) -- the path of the image
|
85 |
+
output_resize(None or tuple) -- the output size. If None, don't resize
|
86 |
+
"""
|
87 |
+
|
88 |
+
image_pil = Image.fromarray(image_numpy)
|
89 |
+
if output_resize:
|
90 |
+
image_pil = image_pil.resize(output_resize, Image.BICUBIC)
|
91 |
+
image_pil.save(image_path)
|
model.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import functools
|
4 |
+
|
5 |
+
|
6 |
+
class UnetGenerator(nn.Module):
|
7 |
+
"""Create a Unet-based generator"""
|
8 |
+
|
9 |
+
def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
10 |
+
"""Construct a Unet generator
|
11 |
+
Parameters:
|
12 |
+
input_nc (int) -- the number of channels in input images
|
13 |
+
output_nc (int) -- the number of channels in output images
|
14 |
+
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
|
15 |
+
image of size 128x128 will become of size 1x1 # at the bottleneck
|
16 |
+
ngf (int) -- the number of filters in the last conv layer
|
17 |
+
norm_layer -- normalization layer
|
18 |
+
We construct the U-Net from the innermost layer to the outermost layer.
|
19 |
+
It is a recursive process.
|
20 |
+
"""
|
21 |
+
super(UnetGenerator, self).__init__()
|
22 |
+
# construct unet structure
|
23 |
+
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
|
24 |
+
for _ in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
|
25 |
+
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
|
26 |
+
# gradually reduce the number of filters from ngf * 8 to ngf
|
27 |
+
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
28 |
+
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
29 |
+
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
30 |
+
self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
|
31 |
+
|
32 |
+
def forward(self, input):
|
33 |
+
"""Standard forward"""
|
34 |
+
return self.model(input)
|
35 |
+
|
36 |
+
class UnetSkipConnectionBlock(nn.Module):
|
37 |
+
"""Defines the Unet submodule with skip connection.
|
38 |
+
X -------------------identity----------------------
|
39 |
+
|-- downsampling -- |submodule| -- upsampling --|
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(self, outer_nc, inner_nc, input_nc=None,
|
43 |
+
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
44 |
+
"""Construct a Unet submodule with skip connections.
|
45 |
+
Parameters:
|
46 |
+
outer_nc (int) -- the number of filters in the outer conv layer
|
47 |
+
inner_nc (int) -- the number of filters in the inner conv layer
|
48 |
+
input_nc (int) -- the number of channels in input images/features
|
49 |
+
submodule (UnetSkipConnectionBlock) -- previously defined submodules
|
50 |
+
outermost (bool) -- if this module is the outermost module
|
51 |
+
innermost (bool) -- if this module is the innermost module
|
52 |
+
norm_layer -- normalization layer
|
53 |
+
use_dropout (bool) -- if use dropout layers.
|
54 |
+
"""
|
55 |
+
super(UnetSkipConnectionBlock, self).__init__()
|
56 |
+
self.outermost = outermost
|
57 |
+
if type(norm_layer) == functools.partial:
|
58 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
59 |
+
else:
|
60 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
61 |
+
if input_nc is None:
|
62 |
+
input_nc = outer_nc
|
63 |
+
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
|
64 |
+
stride=2, padding=1, bias=use_bias)
|
65 |
+
downrelu = nn.LeakyReLU(0.2, True)
|
66 |
+
downnorm = norm_layer(inner_nc)
|
67 |
+
uprelu = nn.ReLU(True)
|
68 |
+
upnorm = norm_layer(outer_nc)
|
69 |
+
|
70 |
+
if outermost:
|
71 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
72 |
+
kernel_size=4, stride=2,
|
73 |
+
padding=1)
|
74 |
+
down = [downconv]
|
75 |
+
up = [uprelu, upconv, nn.Tanh()]
|
76 |
+
model = down + [submodule] + up
|
77 |
+
elif innermost:
|
78 |
+
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
|
79 |
+
kernel_size=4, stride=2,
|
80 |
+
padding=1, bias=use_bias)
|
81 |
+
down = [downrelu, downconv]
|
82 |
+
up = [uprelu, upconv, upnorm]
|
83 |
+
model = down + up
|
84 |
+
else:
|
85 |
+
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
86 |
+
kernel_size=4, stride=2,
|
87 |
+
padding=1, bias=use_bias)
|
88 |
+
down = [downrelu, downconv, downnorm]
|
89 |
+
up = [uprelu, upconv, upnorm]
|
90 |
+
|
91 |
+
if use_dropout:
|
92 |
+
model = down + [submodule] + up + [nn.Dropout(0.5)]
|
93 |
+
else:
|
94 |
+
model = down + [submodule] + up
|
95 |
+
|
96 |
+
self.model = nn.Sequential(*model)
|
97 |
+
|
98 |
+
def forward(self, x):
|
99 |
+
if self.outermost:
|
100 |
+
return self.model(x)
|
101 |
+
else: # add skip connections
|
102 |
+
return torch.cat([x, self.model(x)], 1)
|
103 |
+
|
104 |
+
|
105 |
+
def create_model(gpu_ids=[]):
|
106 |
+
"""Create a model for anime2sketch
|
107 |
+
hardcoding the options for simplicity
|
108 |
+
"""
|
109 |
+
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
110 |
+
net = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False)
|
111 |
+
ckpt = torch.load('weights/netG.pth')
|
112 |
+
for key in list(ckpt.keys()):
|
113 |
+
if 'module.' in key:
|
114 |
+
ckpt[key.replace('module.', '')] = ckpt[key]
|
115 |
+
del ckpt[key]
|
116 |
+
net.load_state_dict(ckpt)
|
117 |
+
if len(gpu_ids) > 0:
|
118 |
+
assert(torch.cuda.is_available())
|
119 |
+
net.to(gpu_ids[0])
|
120 |
+
net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
|
121 |
+
return net
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
Pillow
|
4 |
+
gradio
|
5 |
+
torchtext
|
test.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Test script for anime-to-sketch translation
|
2 |
+
Example:
|
3 |
+
python3 test.py --dataroot /your_path/dir --load_size 512
|
4 |
+
python3 test.py --dataroot /your_path/img.jpg --load_size 512
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
from data import get_image_list
|
9 |
+
from model import create_model
|
10 |
+
from data import read_img_path, tensor_to_img, save_image
|
11 |
+
import argparse
|
12 |
+
|
13 |
+
|
14 |
+
if __name__ == '__main__':
|
15 |
+
parser = argparse.ArgumentParser(description='Anime-to-sketch test options.')
|
16 |
+
parser.add_argument('--dataroot','-i', default='test_samples/', type=str)
|
17 |
+
parser.add_argument('--load_size','-s', default=512, type=int)
|
18 |
+
parser.add_argument('--output_dir','-o', default='results/', type=str)
|
19 |
+
parser.add_argument('--gpu_ids', '-g', default=[], help="gpu ids: e.g. 0 0,1,2 0,2.")
|
20 |
+
opt = parser.parse_args()
|
21 |
+
|
22 |
+
# create model
|
23 |
+
model = create_model(opt.gpu_ids) # create a model given opt.model and other options
|
24 |
+
model.eval()
|
25 |
+
# get input data
|
26 |
+
if os.path.isdir(opt.dataroot):
|
27 |
+
test_list = get_image_list(opt.dataroot)
|
28 |
+
elif os.path.isfile(opt.dataroot):
|
29 |
+
test_list = [opt.dataroot]
|
30 |
+
else:
|
31 |
+
raise Exception("{} is not a valid directory or image file.".format(opt.dataroot))
|
32 |
+
# save outputs
|
33 |
+
save_dir = opt.output_dir
|
34 |
+
os.makedirs(save_dir, exist_ok=True)
|
35 |
+
|
36 |
+
for test_path in test_list:
|
37 |
+
basename = os.path.basename(test_path)
|
38 |
+
aus_path = os.path.join(save_dir, basename)
|
39 |
+
img, aus_resize = read_img_path(test_path, opt.load_size)
|
40 |
+
aus_tensor = model(img)
|
41 |
+
aus_img = tensor_to_img(aus_tensor)
|
42 |
+
save_image(aus_img, aus_path, aus_resize)
|