Spaces:
Build error
Build error
File size: 1,652 Bytes
ef01fd7 000a894 ef01fd7 000a894 ef01fd7 c9fd122 ef01fd7 000a894 ef01fd7 c9fd122 1c5ae49 ef01fd7 2f675f6 ef01fd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
from __future__ import print_function
import torch
import process_stylization
from photo_wct import PhotoWCT
import gradio as gr
from datetime import datetime
# Load model
model_path = './models/photo_wct.pth'
p_wct = PhotoWCT()
p_wct.load_state_dict(torch.load(model_path))
def run(content_img, style_img, cuda, post_processing, fast):
now = datetime.now()
dt_string = now.strftime("%d/%m/%Y %H:%M:%S")
print("[TimeStamp] {}".format(dt_string))
if fast == 0:
from photo_gif import GIFSmoothing
p_pro = GIFSmoothing(r=35, eps=0.001)
else:
from photo_smooth import Propagator
p_pro = Propagator()
if cuda:
p_wct.cuda(0)
else:
p_wct.to('cpu')
output_img = process_stylization.stylization_gradio(
stylization_module=p_wct,
smoothing_module=p_pro,
content_image=content_img,
style_image=style_img,
cuda=cuda,
post_processing=post_processing
)
return output_img
if __name__ == '__main__':
style = gr.Interface(
fn=run,
inputs=[
gr.Image(label='Content Image'),
gr.Image(label='Stylize Image'),
gr.Checkbox(value=True, label='Use CUDA'),
gr.Checkbox(value=True, label='Post Processing'),
gr.Radio(choices=["Guided Image Filtering (Fast)", "Photorealisitic Smoothing (Slow)"], value="Guided Image Filtering (Fast)", type="index", label="Algorithm", interactive=False),
],
outputs=[gr.Image(
type="pil",
label="Result"),
]
)
style.queue()
style.launch() |