In [1]:
# Uncomment if you don't have the following modules
#pip install -qq gradio
#pip install -qq torch
#pip install -qq PIL
#pip install -qq torchvision

Note: you may need to restart the kernel to use updated packages.


In [2]:
import os
from PIL import Image
import torch
import torchvision
import torchvision.transforms as transforms
from utils import transformer, tensor_to_img
from network import Style_Transfer_Network
import gradio as gr

In [3]:
device = "cpu"
if torch.cuda.is_available(): device = "cuda"

In [5]:
#import gradio as gr
check_point = torch.load('check_point1_0.pth', map_location = device)
transfer_network = Style_Transfer_Network().to(device)
transfer_network.load_state_dict(check_point['state_dict'])



<All keys matched successfully>

In [6]:
def style_transfer(content_img, style_strength, style_img_1 = None, iw_1 = 0, style_img_2 = None, iw_2 = 0, style_img_3 = None, iw_3 = 0, preserve_color = None):
    transform = transformer(imsize = 512)

    content = transform(content_img).unsqueeze(0).to(device)

    iw = [iw_1, iw_2, iw_3]
    interpolation_weights = [i/ sum(iw) for i in iw]

    style_imgs = [style_img_1, style_img_2, style_img_3]
    styles = []
    for style_img in style_imgs:
      if style_img is not None:
        styles.append(transform(style_img).unsqueeze(0).to(device))
    if preserve_color == "None": preserve_color = None
    elif preserve_color == "Whitening & Coloring": preserve_color = "whiten_and_color"
    elif preserve_color == "Histogram matching": preserve_color = "histogram_matching"
    with torch.no_grad():
        stylized_img = transfer_network(content, styles, style_strength, interpolation_weights, preserve_color = preserve_color)
    return tensor_to_img(stylized_img)

title = "Artistic Style Transfer"

content_img = gr.components.Image(label="Content image", type = "pil")

style_img_1 = gr.components.Image(label="Style images", type = "pil")
iw_1 = gr.components.Slider(0., 1., label = "Style 1 interpolation")
style_img_2 = gr.components.Image(label="Style images", type = "pil")
iw_2 = gr.components.Slider(0., 1., label = "Style 2 interpolation")
style_img_3 = gr.components.Image(label="Style images", type = "pil")
iw_3 = gr.components.Slider(0., 1., label = "Style 3 interpolation")
style_strength =  gr.components.Slider(0., 1., label = "Adjust style strength")
preserve_color = gr.components.Dropdown(["None", "Whitening & Coloring", "Histogram matching"], label = "Choose color preserving mode")

interface = gr.Interface(fn = style_transfer,
                         inputs = [content_img,
                                   style_strength,
                                   style_img_1,
                                   iw_1,
                                   style_img_2,
                                   iw_2,
                                   style_img_3,
                                   iw_3,
                                   preserve_color],
                         outputs = gr.components.Image(),
                         title = title,
                         
                         )
interface.queue()
interface.launch(share = True)

Running on local URL:  http://127.0.0.1:7860
Running on public URL: https://b4e9024bf7c14725c6.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


