1-13-am's picture
Update app.py
7c189e4
raw
history blame
2.62 kB
import gradio as gr
import torch
from utils import transformer, tensor_to_img
from network import Style_Transfer_Network
check_point = torch.load("check_point1_0.pth", map_location = torch.device('cpu'))
model = Style_Transfer_Network()
model.load_state_dict(check_point['state_dict'])
def style_transfer(content_img, style_strength, style_img_1 = None, iw_1 = 0, style_img_2 = None, iw_2 = 0, style_img_3 = None, iw_3 = 0, preserve_color = None):
transform = transformer(imsize = 512)
content = transform(content_img).unsqueeze(0)
iw = [iw_1, iw_2, iw_3]
interpolation_weights = [i/ sum(iw) for i in iw]
style_imgs = [style_img_1, style_img_2, style_img_3]
styles = []
for style_img in style_imgs:
if style_img is not None:
styles.append(transform(style_img).unsqueeze(0))
if preserve_color == "None": preserve_color = None
elif preserve_color == "Whitening & Coloring": preserve_color = "whitening_and_coloring"
elif preserve_color == "Histogram matching": preserve_color = "histogram_matching"
with torch.no_grad():
stylized_img = model(content, styles, style_strength, interpolation_weights, preserve_color = preserve_color)
return tensor_to_img(stylized_img)
title = "Artistic Style Transfer"
content_img = gr.components.Image(label="Content image", type = "pil")
style_img_1 = gr.components.Image(label="Style images", type = "pil")
iw_1 = gr.components.Slider(0., 1., label = "Style 1 strength")
style_img_2 = gr.components.Image(label="Style images", type = "pil")
iw_2 = gr.components.Slider(0., 1., label = "Style 2 strength")
style_img_3 = gr.components.Image(label="Style images", type = "pil")
iw_3 = gr.components.Slider(0., 1., label = "Style 3 strength")
style_strength = gr.components.Slider(0., 1., label = "Adjust style strength")
preserve_color = gr.components.Dropdown(["None", "Whitening & Coloring", "Histogram matching"], label = "Choose color preserving mode")
interface = gr.Interface(fn = style_transfer,
inputs = [content_img,
style_strength,
style_img_1,
iw_1,
style_img_2,
iw_2,
style_img_3,
iw_3,
preserve_color],
outputs = gr.components.Image(),
title = title
)
interface.queue()
interface.launch(share = True, debug = True)