In [1]:
!pip install -qq hub
!pip install -qq flask

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datasets 2.4.0 requires dill<0.3.6, but you have dill 0.3.7 which is incompatible.
awscli 1.25.91 requires botocore==1.27.90, but you have botocore 1.31.17 which is incompatible.[0m[31m
[0m

In [4]:
import torch
import deeplake
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.nn as nn
from network import Style_Transfer_Network, Encoder
from utils import save_img
import torchvision

In [5]:
reshape_size = 512
crop_size = 256
def any_to_rgb(img):
 return img.convert('RGB')
preprocess = transforms.Compose([
 transforms.Lambda(any_to_rgb),
 transforms.ToTensor(),
 transforms.Resize(reshape_size),
 transforms.RandomCrop(crop_size),
 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
 ])
wiki_art_dataset = deeplake.load('hub://activeloop/wiki-art')
coco_dataset = deeplake.load('hub://activeloop/coco-test')

style_data_loader = wiki_art_dataset.pytorch(batch_size = 8, num_workers = 0,
 transform = {'images': preprocess, 'labels': None}, shuffle = True, decode_method = {'images':'pil'})

cnt_data_loader = coco_dataset.pytorch(batch_size = 8, num_workers = 0,
 transform = {'images': preprocess}, shuffle = True, decode_method = {'images': 'pil'})


|

Opening dataset in read-only mode as you don't have write permissions.


-

This dataset can be visualized in Jupyter Notebook by ds.visualize() or at https://app.activeloop.ai/activeloop/wiki-art



-

hub://activeloop/wiki-art loaded successfully.



 

Opening dataset in read-only mode as you don't have write permissions.


\

This dataset can be visualized in Jupyter Notebook by ds.visualize() or at https://app.activeloop.ai/activeloop/coco-test



\

hub://activeloop/coco-test loaded successfully.



 

In [7]:
mse_loss = nn.MSELoss(reduction = 'mean')
def content_loss(source, target):
 cnt_loss = mse_loss(source, target)
 return cnt_loss

def style_loss(features, targets):
 loss = 0
 for feature, target in zip(features, targets):
 B, C, H, W = feature.shape
 feature_std, feature_mean = torch.std_mean(feature.view(B, C, -1), dim = 2)
 target_std, target_mean = torch.std_mean(target.view(B, C, -1), dim = 2)
 loss += mse_loss(feature_std, target_std) + mse_loss(feature_mean, target_mean)
 return loss * 1. / len(features)
"""
def style_loss(features, targets, weights=None):
 if weights is None:
 weights = [1/len(features)] * len(features)
 
 loss = 0
 for feature, target, weight in zip(features, targets, weights):
 b, c, h, w = feature.size()
 feature_std, feature_mean = torch.std_mean(feature.view(b, c, -1), dim=2)
 target_std, target_mean = torch.std_mean(target.view(b, c, -1), dim=2)
 loss += (mse_loss(feature_std, target_std) + mse_loss(feature_mean, target_mean))*weight
 return loss
"""
def total_variational_loss(images):
 loss = 0.0
 B = images.shape[0]
 vertical_up = images[:,:,:-1]
 vertical_down = images[:,:,1:]

 horizontal_up = images[:,:,:,:-1]
 horizontal_down = images[:,:,:,1:]

 loss = ((vertical_up - vertical_down) ** 2).sum() + \
 ((horizontal_up - horizontal_down) ** 2).sum()

 return loss * 1.0 / B

In [8]:
if torch.cuda.is_available():
 device = "cuda"
else: device = "cpu"

In [14]:
style_transfer_network = Style_Transfer_Network().to(device)
check_point = torch.load("/notebooks/Style_transfer_with_ADAin/check_point.pth", map_location = 'cuda')
style_transfer_network.load_state_dict(check_point['state_dict'])



In [15]:
def denormalize():
 # out = (x - mean) / std
 MEAN = [0.485, 0.456, 0.406]
 STD = [0.229, 0.224, 0.225]
 MEAN = [-mean/std for mean, std in zip(MEAN, STD)]
 STD = [1/std for std in STD]
 return transforms.Normalize(mean=MEAN, std=STD)

def save_img(tensor, path):
 denormalizer = denormalize() 
 if tensor.is_cuda:
 tensor = tensor.cpu()
 tensor = torchvision.utils.make_grid(tensor)
 torchvision.utils.save_image(denormalizer(tensor).clamp_(0.0, 1.0), path) 
 return None

In [16]:
def train_network(iteration, loss_weight = [1.0, 100.0, 0.001], check_iter = 1, test_iter = 10):
 for param in style_transfer_network.encoder.parameters():
 # freeze parameter in the encoder network
 param.requires_grad = False
 optimizer = torch.optim.Adam(style_transfer_network.decoder.parameters(), lr = 1e-6)

 encoder_net = Encoder().to(device)
 for param in encoder_net.parameters():
 param.requires_grad = False
 for i in range(iteration):
 content_imgs = next(iter(cnt_data_loader))['images'].to(device)
 style_imgs = next(iter(style_data_loader))['images'].to(device)

 output_imgs, transformed_features = style_transfer_network(content_imgs, style_imgs, train = True)

 output_features = encoder_net(output_imgs)
 style_features = encoder_net(style_imgs)

 cnt_loss = content_loss(transformed_features, output_features[-1])
 st_loss = style_loss(output_features, style_features)
 tv_loss = total_variational_loss(output_imgs)
 cnt_w, style_w, tv_w = loss_weight
 total_loss = cnt_w * tv_loss + style_w * st_loss + tv_w * tv_loss

 optimizer.zero_grad()
 total_loss.backward()
 optimizer.step()

 if i % check_iter == 0:
 print('-' * 80)
 print("Iteration {} loss: {}".format(i, total_loss))

 if i % test_iter == 0:
 #save_img(torch.cat([content_imgs[0], style_imgs[0], output_imgs[0]], dim = 0), "training_image.png")
 torch.save({'iteration':iteration+1,
 'state_dict':style_transfer_network.state_dict()},
 'check_point1.pth')

In [17]:
train_network(iteration = 300)

--------------------------------------------------------------------------------
Iteration 0 loss: 0.8845198750495911
--------------------------------------------------------------------------------
Iteration 1 loss: 1.8098524808883667
--------------------------------------------------------------------------------
Iteration 2 loss: 1.868203043937683
--------------------------------------------------------------------------------
Iteration 3 loss: 1.1070071458816528
--------------------------------------------------------------------------------
Iteration 4 loss: 2.0751609802246094
--------------------------------------------------------------------------------
Iteration 5 loss: 2.7107627391815186
--------------------------------------------------------------------------------
Iteration 6 loss: 1.4618340730667114
--------------------------------------------------------------------------------
Iteration 7 loss: 1.2351319789886475
---------------------------------------------------------



--------------------------------------------------------------------------------
Iteration 37 loss: 0.7055997848510742
--------------------------------------------------------------------------------
Iteration 38 loss: 1.3557121753692627
--------------------------------------------------------------------------------
Iteration 39 loss: 1.0668007135391235
--------------------------------------------------------------------------------
Iteration 40 loss: 1.1934823989868164
--------------------------------------------------------------------------------
Iteration 41 loss: 0.7692145109176636
--------------------------------------------------------------------------------
Iteration 42 loss: 1.141457438468933
--------------------------------------------------------------------------------
Iteration 43 loss: 1.5705242156982422
--------------------------------------------------------------------------------
Iteration 44 loss: 1.7851486206054688
-------------------------------------------------