{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "id": "_qsogBHiKtzF", "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "datasets 2.4.0 requires dill<0.3.6, but you have dill 0.3.7 which is incompatible.\n", "awscli 1.25.91 requires botocore==1.27.90, but you have botocore 1.31.17 which is incompatible.\u001b[0m\u001b[31m\n", "\u001b[0m" ] } ], "source": [ "!pip install -qq hub\n", "!pip install -qq flask" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "E8nHybN3KDIq", "tags": [] }, "outputs": [], "source": [ "import torch\n", "import deeplake\n", "from torch.utils.data import DataLoader\n", "from torchvision import transforms\n", "import torch.nn as nn\n", "from network import Style_Transfer_Network, Encoder\n", "from utils import save_img\n", "import torchvision" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "rnAFLCiIKqkM", "outputId": "81b8f1c3-3974-4ee3-a284-99186c1502c7", "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "|" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Opening dataset in read-only mode as you don't have write permissions.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "-" ] }, { "name": "stdout", "output_type": "stream", "text": [ "This dataset can be visualized in Jupyter Notebook by ds.visualize() or at https://app.activeloop.ai/activeloop/wiki-art\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "-" ] }, { "name": "stdout", "output_type": "stream", "text": [ "hub://activeloop/wiki-art loaded successfully.\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Opening dataset in read-only mode as you don't have write permissions.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\\" ] }, { "name": "stdout", "output_type": "stream", "text": [ "This dataset can be visualized in Jupyter Notebook by ds.visualize() or at https://app.activeloop.ai/activeloop/coco-test\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\\" ] }, { "name": "stdout", "output_type": "stream", "text": [ "hub://activeloop/coco-test loaded successfully.\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] } ], "source": [ "reshape_size = 512\n", "crop_size = 256\n", "def any_to_rgb(img):\n", " return img.convert('RGB')\n", "preprocess = transforms.Compose([\n", " transforms.Lambda(any_to_rgb),\n", " transforms.ToTensor(),\n", " transforms.Resize(reshape_size),\n", " transforms.RandomCrop(crop_size),\n", " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", " ])\n", "wiki_art_dataset = deeplake.load('hub://activeloop/wiki-art')\n", "coco_dataset = deeplake.load('hub://activeloop/coco-test')\n", "\n", "style_data_loader = wiki_art_dataset.pytorch(batch_size = 8, num_workers = 0,\n", " transform = {'images': preprocess, 'labels': None}, shuffle = True, decode_method = {'images':'pil'})\n", "\n", "cnt_data_loader = coco_dataset.pytorch(batch_size = 8, num_workers = 0,\n", " transform = {'images': preprocess}, shuffle = True, decode_method = {'images': 'pil'})\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "XKqi9mMyoNUy", "tags": [] }, "outputs": [], "source": [ "mse_loss = nn.MSELoss(reduction = 'mean')\n", "def content_loss(source, target):\n", " cnt_loss = mse_loss(source, target)\n", " return cnt_loss\n", "\n", "def style_loss(features, targets):\n", " loss = 0\n", " for feature, target in zip(features, targets):\n", " B, C, H, W = feature.shape\n", " feature_std, feature_mean = torch.std_mean(feature.view(B, C, -1), dim = 2)\n", " target_std, target_mean = torch.std_mean(target.view(B, C, -1), dim = 2)\n", " loss += mse_loss(feature_std, target_std) + mse_loss(feature_mean, target_mean)\n", " return loss * 1. / len(features)\n", "\"\"\"\n", "def style_loss(features, targets, weights=None):\n", " if weights is None:\n", " weights = [1/len(features)] * len(features)\n", " \n", " loss = 0\n", " for feature, target, weight in zip(features, targets, weights):\n", " b, c, h, w = feature.size()\n", " feature_std, feature_mean = torch.std_mean(feature.view(b, c, -1), dim=2)\n", " target_std, target_mean = torch.std_mean(target.view(b, c, -1), dim=2)\n", " loss += (mse_loss(feature_std, target_std) + mse_loss(feature_mean, target_mean))*weight\n", " return loss\n", "\"\"\"\n", "def total_variational_loss(images):\n", " loss = 0.0\n", " B = images.shape[0]\n", " vertical_up = images[:,:,:-1]\n", " vertical_down = images[:,:,1:]\n", "\n", " horizontal_up = images[:,:,:,:-1]\n", " horizontal_down = images[:,:,:,1:]\n", "\n", " loss = ((vertical_up - vertical_down) ** 2).sum() + \\\n", " ((horizontal_up - horizontal_down) ** 2).sum()\n", "\n", " return loss * 1.0 / B" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "JAeuZ2Sq6E-0", "tags": [] }, "outputs": [], "source": [ "if torch.cuda.is_available():\n", " device = \"cuda\"\n", "else: device = \"cpu\"" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "style_transfer_network = Style_Transfer_Network().to(device)\n", "check_point = torch.load(\"/notebooks/Style_transfer_with_ADAin/check_point.pth\", map_location = 'cuda')\n", "style_transfer_network.load_state_dict(check_point['state_dict'])" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "def denormalize():\n", " # out = (x - mean) / std\n", " MEAN = [0.485, 0.456, 0.406]\n", " STD = [0.229, 0.224, 0.225]\n", " MEAN = [-mean/std for mean, std in zip(MEAN, STD)]\n", " STD = [1/std for std in STD]\n", " return transforms.Normalize(mean=MEAN, std=STD)\n", "\n", "def save_img(tensor, path):\n", " denormalizer = denormalize() \n", " if tensor.is_cuda:\n", " tensor = tensor.cpu()\n", " tensor = torchvision.utils.make_grid(tensor)\n", " torchvision.utils.save_image(denormalizer(tensor).clamp_(0.0, 1.0), path) \n", " return None" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "1Y-JrlNquBwn", "outputId": "31d5fe14-5315-40cd-8946-99c34ff41726", "tags": [] }, "outputs": [], "source": [ "def train_network(iteration, loss_weight = [1.0, 100.0, 0.001], check_iter = 1, test_iter = 10):\n", " for param in style_transfer_network.encoder.parameters():\n", " # freeze parameter in the encoder network\n", " param.requires_grad = False\n", " optimizer = torch.optim.Adam(style_transfer_network.decoder.parameters(), lr = 1e-6)\n", "\n", " encoder_net = Encoder().to(device)\n", " for param in encoder_net.parameters():\n", " param.requires_grad = False\n", " for i in range(iteration):\n", " content_imgs = next(iter(cnt_data_loader))['images'].to(device)\n", " style_imgs = next(iter(style_data_loader))['images'].to(device)\n", "\n", " output_imgs, transformed_features = style_transfer_network(content_imgs, style_imgs, train = True)\n", "\n", " output_features = encoder_net(output_imgs)\n", " style_features = encoder_net(style_imgs)\n", "\n", " cnt_loss = content_loss(transformed_features, output_features[-1])\n", " st_loss = style_loss(output_features, style_features)\n", " tv_loss = total_variational_loss(output_imgs)\n", " cnt_w, style_w, tv_w = loss_weight\n", " total_loss = cnt_w * tv_loss + style_w * st_loss + tv_w * tv_loss\n", "\n", " optimizer.zero_grad()\n", " total_loss.backward()\n", " optimizer.step()\n", "\n", " if i % check_iter == 0:\n", " print('-' * 80)\n", " print(\"Iteration {} loss: {}\".format(i, total_loss))\n", "\n", " if i % test_iter == 0:\n", " #save_img(torch.cat([content_imgs[0], style_imgs[0], output_imgs[0]], dim = 0), \"training_image.png\")\n", " torch.save({'iteration':iteration+1,\n", " 'state_dict':style_transfer_network.state_dict()},\n", " 'check_point1.pth')" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------------------------------------------------------------------------------\n", "Iteration 0 loss: 0.8845198750495911\n", "--------------------------------------------------------------------------------\n", "Iteration 1 loss: 1.8098524808883667\n", "--------------------------------------------------------------------------------\n", "Iteration 2 loss: 1.868203043937683\n", "--------------------------------------------------------------------------------\n", "Iteration 3 loss: 1.1070071458816528\n", "--------------------------------------------------------------------------------\n", "Iteration 4 loss: 2.0751609802246094\n", "--------------------------------------------------------------------------------\n", "Iteration 5 loss: 2.7107627391815186\n", "--------------------------------------------------------------------------------\n", "Iteration 6 loss: 1.4618340730667114\n", "--------------------------------------------------------------------------------\n", "Iteration 7 loss: 1.2351319789886475\n", "--------------------------------------------------------------------------------\n", "Iteration 8 loss: 1.3090686798095703\n", "--------------------------------------------------------------------------------\n", "Iteration 9 loss: 1.7165802717208862\n", "--------------------------------------------------------------------------------\n", "Iteration 10 loss: 1.9655226469039917\n", "--------------------------------------------------------------------------------\n", "Iteration 11 loss: 1.8032971620559692\n", "--------------------------------------------------------------------------------\n", "Iteration 12 loss: 1.757157802581787\n", "--------------------------------------------------------------------------------\n", "Iteration 13 loss: 1.2641586065292358\n", "--------------------------------------------------------------------------------\n", "Iteration 14 loss: 1.230526328086853\n", "--------------------------------------------------------------------------------\n", "Iteration 15 loss: 1.8332327604293823\n", "--------------------------------------------------------------------------------\n", "Iteration 16 loss: 2.347355365753174\n", "--------------------------------------------------------------------------------\n", "Iteration 17 loss: 0.8620480298995972\n", "--------------------------------------------------------------------------------\n", "Iteration 18 loss: 1.572771668434143\n", "--------------------------------------------------------------------------------\n", "Iteration 19 loss: 2.281660795211792\n", "--------------------------------------------------------------------------------\n", "Iteration 20 loss: 1.417534589767456\n", "--------------------------------------------------------------------------------\n", "Iteration 21 loss: 1.848774790763855\n", "--------------------------------------------------------------------------------\n", "Iteration 22 loss: 1.1456807851791382\n", "--------------------------------------------------------------------------------\n", "Iteration 23 loss: 1.2357560396194458\n", "--------------------------------------------------------------------------------\n", "Iteration 24 loss: 0.6565238833427429\n", "--------------------------------------------------------------------------------\n", "Iteration 25 loss: 1.2375402450561523\n", "--------------------------------------------------------------------------------\n", "Iteration 26 loss: 2.1140313148498535\n", "--------------------------------------------------------------------------------\n", "Iteration 27 loss: 1.0238616466522217\n", "--------------------------------------------------------------------------------\n", "Iteration 28 loss: 2.618056058883667\n", "--------------------------------------------------------------------------------\n", "Iteration 29 loss: 1.1616159677505493\n", "--------------------------------------------------------------------------------\n", "Iteration 30 loss: 1.919601559638977\n", "--------------------------------------------------------------------------------\n", "Iteration 31 loss: 1.0250651836395264\n", "--------------------------------------------------------------------------------\n", "Iteration 32 loss: 1.1823596954345703\n", "--------------------------------------------------------------------------------\n", "Iteration 33 loss: 0.8185012936592102\n", "--------------------------------------------------------------------------------\n", "Iteration 34 loss: 1.1374247074127197\n", "--------------------------------------------------------------------------------\n", "Iteration 35 loss: 1.9250235557556152\n", "--------------------------------------------------------------------------------\n", "Iteration 36 loss: 1.466286540031433\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.9/dist-packages/PIL/Image.py:3035: DecompressionBombWarning: Image size (99962094 pixels) exceeds limit of 89478485 pixels, could be decompression bomb DOS attack.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "--------------------------------------------------------------------------------\n", "Iteration 37 loss: 0.7055997848510742\n", "--------------------------------------------------------------------------------\n", "Iteration 38 loss: 1.3557121753692627\n", "--------------------------------------------------------------------------------\n", "Iteration 39 loss: 1.0668007135391235\n", "--------------------------------------------------------------------------------\n", "Iteration 40 loss: 1.1934823989868164\n", "--------------------------------------------------------------------------------\n", "Iteration 41 loss: 0.7692145109176636\n", "--------------------------------------------------------------------------------\n", "Iteration 42 loss: 1.141457438468933\n", "--------------------------------------------------------------------------------\n", "Iteration 43 loss: 1.5705242156982422\n", "--------------------------------------------------------------------------------\n", "Iteration 44 loss: 1.7851486206054688\n", "--------------------------------------------------------------------------------\n", "Iteration 45 loss: 0.7252503633499146\n", "--------------------------------------------------------------------------------\n", "Iteration 46 loss: 1.1291860342025757\n", "--------------------------------------------------------------------------------\n", "Iteration 47 loss: 1.3588659763336182\n", "--------------------------------------------------------------------------------\n", "Iteration 48 loss: 0.9960977435112\n", "--------------------------------------------------------------------------------\n", "Iteration 49 loss: 0.9272828102111816\n", "--------------------------------------------------------------------------------\n", "Iteration 50 loss: 2.4692296981811523\n" ] } ], "source": [ "train_network(iteration = 300)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 4 }