{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Note: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "# Uncomment if you don't have the following modules\n", "#pip install -qq gradio\n", "#pip install -qq torch\n", "#pip install -qq PIL\n", "#pip install -qq torchvision" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import os\n", "from PIL import Image\n", "import torch\n", "import torchvision\n", "import torchvision.transforms as transforms\n", "from utils import transformer, tensor_to_img\n", "from network import Style_Transfer_Network\n", "import gradio as gr" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "device = \"cpu\"\n", "if torch.cuda.is_available(): device = \"cuda\"" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\VICTUS\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torchvision\\models\\_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n", " warnings.warn(\n", "C:\\Users\\VICTUS\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torchvision\\models\\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG19_Weights.IMAGENET1K_V1`. You can also use `weights=VGG19_Weights.DEFAULT` to get the most up-to-date weights.\n", " warnings.warn(msg)\n", "C:\\Users\\VICTUS\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\\LocalCache\\local-packages\\Python310\\site-packages\\torchvision\\models\\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=None`.\n", " warnings.warn(msg)\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#import gradio as gr\n", "check_point = torch.load('check_point1_0.pth', map_location = device)\n", "transfer_network = Style_Transfer_Network().to(device)\n", "transfer_network.load_state_dict(check_point['state_dict'])" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7860\n", "Running on public URL: https://b4e9024bf7c14725c6.gradio.live\n", "\n", "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def style_transfer(content_img, style_strength, style_img_1 = None, iw_1 = 0, style_img_2 = None, iw_2 = 0, style_img_3 = None, iw_3 = 0, preserve_color = None):\n", " transform = transformer(imsize = 512)\n", "\n", " content = transform(content_img).unsqueeze(0).to(device)\n", "\n", " iw = [iw_1, iw_2, iw_3]\n", " interpolation_weights = [i/ sum(iw) for i in iw]\n", "\n", " style_imgs = [style_img_1, style_img_2, style_img_3]\n", " styles = []\n", " for style_img in style_imgs:\n", " if style_img is not None:\n", " styles.append(transform(style_img).unsqueeze(0).to(device))\n", " if preserve_color == \"None\": preserve_color = None\n", " elif preserve_color == \"Whitening & Coloring\": preserve_color = \"whiten_and_color\"\n", " elif preserve_color == \"Histogram matching\": preserve_color = \"histogram_matching\"\n", " with torch.no_grad():\n", " stylized_img = transfer_network(content, styles, style_strength, interpolation_weights, preserve_color = preserve_color)\n", " return tensor_to_img(stylized_img)\n", "\n", "title = \"Artistic Style Transfer\"\n", "\n", "content_img = gr.components.Image(label=\"Content image\", type = \"pil\")\n", "\n", "style_img_1 = gr.components.Image(label=\"Style images\", type = \"pil\")\n", "iw_1 = gr.components.Slider(0., 1., label = \"Style 1 interpolation\")\n", "style_img_2 = gr.components.Image(label=\"Style images\", type = \"pil\")\n", "iw_2 = gr.components.Slider(0., 1., label = \"Style 2 interpolation\")\n", "style_img_3 = gr.components.Image(label=\"Style images\", type = \"pil\")\n", "iw_3 = gr.components.Slider(0., 1., label = \"Style 3 interpolation\")\n", "style_strength = gr.components.Slider(0., 1., label = \"Adjust style strength\")\n", "preserve_color = gr.components.Dropdown([\"None\", \"Whitening & Coloring\", \"Histogram matching\"], label = \"Choose color preserving mode\")\n", "\n", "interface = gr.Interface(fn = style_transfer,\n", " inputs = [content_img,\n", " style_strength,\n", " style_img_1,\n", " iw_1,\n", " style_img_2,\n", " iw_2,\n", " style_img_3,\n", " iw_3,\n", " preserve_color],\n", " outputs = gr.components.Image(),\n", " title = title,\n", " \n", " )\n", "interface.queue()\n", "interface.launch(share = True)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.11" } }, "nbformat": 4, "nbformat_minor": 2 }