File size: 5,337 Bytes
7a37c2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1d7f45
7a37c2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import gradio as gr
import json
import torch
from torch import nn
from diffusers import UNet2DModel, DDPMScheduler
import safetensors
from huggingface_hub import hf_hub_download

### GPU SETUP
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## LOAD THE UNET MODEL AND DDPM SCHEDULER FROM HUGGINGFACE HUB
class ClassConditionedUnet(nn.Module):
  def __init__(self, num_classes=10, class_emb_size=10):
    super().__init__()

    # The embedding layer will map the class label to a vector of size class_emb_size
    self.class_emb = nn.Embedding(num_classes, class_emb_size)

    # Self.model is an unconditional UNet with extra input channels
    # to accept the conditioning information (the class embedding)
    self.model = UNet2DModel(
        sample_size=28,           # output image resolution. Equal to input resolution
        in_channels=1 + class_emb_size, # Additional input channels for class cond
        out_channels=1,           # the number of output channels. Equal to input
        layers_per_block=3,       # three residual connections (ResNet) per block
        block_out_channels=(128, 256, 512), # N of output channels for each block. Inverse for upsampling
        down_block_types=(
            "DownBlock2D",  # a regular ResNet downsampling block
            "AttnDownBlock2D",
            "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        ),
        up_block_types=(
            "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
            "AttnUpBlock2D",
            "UpBlock2D",  # a regular ResNet upsampling block
        ),
        dropout = 0.1,  # Dropout prob between Conv1 and Conv2 in a block. From Improved DDPM paper
    )

  # Forward method takes the class labels as an additional argument
  def forward(self, x, t, class_labels):
    bs, ch, w, h = x.shape # x is shape (bs, 1, 28, 28)

    # class conditioning embedding to add as additional input channels
    class_cond = self.class_emb(class_labels) # Map to embedding dimension
    class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)
    # class_cond final shape (bs, 4, 28, 28)

    # Model input is now x and class cond concatenated together along dimension 1
    # We need provide additional information (the class label)
    # to every spatial location (pixel) in the image. Not changing the original
    # pixels of the images, but adding new channels.
    net_input = torch.cat((x, class_cond), 1) # (bs, 5, 28, 28)

    # Feed this to the UNet alongside the timestep and return the prediction
    # with image output size
    return self.model(net_input, t).sample # (bs, 1, 28, 28)
  
# Define paths to download the model and scheduler
repo_name = "Huertas97/conditioned-unet-fashion-mnist-non-ema"

### UNET MODEL
# Download the safetensors model file
model_file_path = hf_hub_download(repo_id=repo_name, filename="fashion_class_cond_unet_model_best.safetensors")

# Load the Class Conditioned UNet model state dictionary
state_dict = safetensors.torch.load_file(model_file_path)
model_classcond_native =  ClassConditionedUnet()
model_classcond_native.load_state_dict(state_dict)
model_classcond_native.to(device)

### DDPM SCHEDULER
# Download and load the scheduler configuration file
scheduler_file_path = hf_hub_download(repo_id=repo_name, filename="scheduler_config.json")

with open(scheduler_file_path, 'r') as f:
    scheduler_config = json.load(f)

noise_scheduler = DDPMScheduler.from_config(scheduler_config)




# Define the classes
class_labels = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]


def generate_images(selected_class, num_images, progress=gr.Progress()):
    """
    Generate images using the trained model.
    
    Parameters:
    - selected_class: The class label as a string.
    - num_images: Number of images to generate.
    
    Returns:
    - A list of generated images.
    """
    # Convert class label to corresponding index
    class_idx = class_labels.index(selected_class)
    
    # Prepare random x to start from
    x = torch.randn(num_images, 1, 28, 28).to(device)
    y = torch.tensor([class_idx] * num_images).to(device)
    
    for t in progress.tqdm(noise_scheduler.timesteps, desc="Generating image", total=noise_scheduler.config.num_train_timesteps): # 
        with torch.no_grad():
            residual = model_classcond_native(x, t, y)
        x = noise_scheduler.step(residual, t, x).prev_sample

    # Post-process the generated images
    # Clamp the values to [0, 1] and convert to [0, 255] uint8
    # Also move the tensor to CPU and convert to numpy for plotting
    x = (x.clamp(-1, 1) + 1) / 2
    x = (x * 255).type(torch.uint8).cpu()
    
    # Convert to list of images
    images = [img.squeeze(0).numpy() for img in x]
    return images

# Create the Gradio interface
demo = gr.Interface(
    fn=generate_images,
    inputs=[
        gr.Dropdown(class_labels, label="Select Class", value="T-shirt/top"),
        gr.Slider(minimum=1, maximum=8, step=1, value=1, label="Number of Images")
    ],
    outputs=gr.Gallery(type="numpy", label="Generated Images"),
    live=False,
    description="Generate images using a class-conditioned UNet model."
)

demo.launch(share=True)