File size: 4,080 Bytes
05cd0b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import os

class ColorNet(nn.Module):
    DEFAULT_CHECKPOINT_PATH = "checkpoint/colornet.pt"

    def __init__(self, checkpoint_path:str=DEFAULT_CHECKPOINT_PATH):
        super(ColorNet, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()  # to scale the output to [0, 1]
        )
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.to(self.device)
        
        if os.path.exists(checkpoint_path):
           self._load_model(checkpoint_path) 

    def _load_model(self, path):
        print("Loading ColorNet model...", end="")
        self.load_state_dict(torch.load(path, map_location=self.device))
        print("done.")

    def forward(self, x):
        x = x.to(self.device)
        x = self.encoder(x)
        x = self.decoder(x)
        return x

    def train_model(self, model, train_loader, criterion, optimizer, num_epochs=10):
        for epoch in range(num_epochs):
            model.train()
            running_loss = 0.0
            for inputs, _ in train_loader:
                gray_images = transforms.Grayscale(num_output_channels=1)(inputs).to(self.device)
                gray_images = gray_images.repeat(1,3,1,1)
                color_images = inputs.to(self.device)

                optimizer.zero_grad()

                outputs = model(gray_images)
                loss = criterion(outputs, color_images)
                loss.backward()
                optimizer.step()

                running_loss += loss.item() * gray_images.size(0)

            epoch_loss = running_loss / len(train_loader.dataset)
            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}')
        
        torch.save(model.state_dict(), self.DEFAULT_CHECKPOINT_PATH)


    def colorize(self, input_path:str, output_path):
        input_image = Image.open(input_path).convert("RGB")
        input_image = transforms.ToTensor()(input_image).unsqueeze(0).to(self.device)

        with torch.inference_mode():
            output_image_tnsr = self(input_image)
            output_image_tnsr = output_image_tnsr.squeeze(0).cpu()
            output_image_tnsr = transforms.ToPILImage()(output_image_tnsr)

            output_image_tnsr.save(output_path)

    def visualize_results(model, test_loader, num_images=5):
        model.eval()
        with torch.no_grad():
            data_iter = iter(test_loader)
            images, _ = data_iter.next()
            
            # Get grayscale and colorized images
            gray_images = images[:num_images]
            colorized_images = model(gray_images)
            
            # Plotting the results
            for i in range(num_images):
                plt.subplot(3, num_images, i+1)
                plt.imshow(gray_images[i].permute(1, 2, 0).squeeze(), cmap="gray")
                plt.axis('off')

                plt.subplot(3, num_images, num_images+i+1)
                plt.imshow(colorized_images[i].permute(1, 2, 0))
                plt.axis('off')

                plt.subplot(3, num_images, 2*num_images+i+1)
                plt.imshow(gray_images[i].permute(1, 2, 0).repeat(3, 1, 1).permute(1, 2, 0))
                plt.axis('off')

            plt.show()