File size: 4,694 Bytes
6817d94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "OaBIb0WNma-U",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 383
        },
        "outputId": "33f26e87-f2b5-4de8-dfd8-dca7e8973a01"
      },
      "outputs": [
        {
          "output_type": "error",
          "ename": "ModuleNotFoundError",
          "evalue": "No module named 'decoder'",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-1-140de5a90328>\u001b[0m in \u001b[0;36m<cell line: 4>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m  \u001b[0;32mimport\u001b[0m \u001b[0mfunctional\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mdecoder\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mVAE_AttentionBlock\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mVAE_ResidualBlock\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0mVAE_Encoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSequential\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'decoder'",
            "",
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0;32m\nNOTE: If your import is failing due to a missing package, you can\nmanually install dependencies using either !pip or !apt.\n\nTo view examples of installing some common dependencies, click the\n\"Open Examples\" button below.\n\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n"
          ],
          "errorDetails": {
            "actions": [
              {
                "action": "open_url",
                "actionText": "Open Examples",
                "url": "/notebooks/snippets/importing_libraries.ipynb"
              }
            ]
          }
        }
      ],
      "source": [
        "import torch\n",
        "from torch import nn\n",
        "from torch.nn  import functional as F\n",
        "from decoder import VAE_AttentionBlock, VAE_ResidualBlock\n",
        "\n",
        "class VAE_Encoder(nn.Sequential):\n",
        "  def __init__(self):\n",
        "    super().__init__(\n",
        "        nn.Conv2d(3, 128, kernel_size=3, padding=1),\n",
        "        VAE_ResidualBlock(128, 128),\n",
        "        VAE_ResidualBlock(128, 128),\n",
        "        nn.Conv2d(128, 128, kernel_size=3,stride=2, padding=1),\n",
        "        VAE_ResidualBlock(128,256),\n",
        "        VAE_ResidualBlock(256, 256),\n",
        "        nn.Conv2d(3, 128, kernel_size=3,stride=2, padding=1),\n",
        "        VAE_ResidualBlock(256, 512),\n",
        "        VAE_ResidualBlock(512, 512),\n",
        "        VAE_AttentionBlock(512),\n",
        "        VAE_ResidualBlock(512, 512),\n",
        "        nn.GroupNorm(32, 512),\n",
        "        nn.SiLU(),\n",
        "        nn.Conv2d(512, 8, kernel_size=3, padding=1),\n",
        "        nn.Conv2d(8,8,kernel_size=3, padding=0),\n",
        "    )\n",
        "  def forward(self, x , noise):\n",
        "    for module in self:\n",
        "      if getattr(module, \"stride\", None) ==(2,2):\n",
        "        x= F.pad(x, (0,1,0,1))\n",
        "      x= module(x)\n",
        "      mean , log_variance = torch.chunk(x, 2, dim=1)\n",
        "      log_variance= torch.clamp(log_variance, -20, 30)\n",
        "      variance= log_variance.exp()\n",
        "      stdev = variance.sqrt()\n",
        "      x= stdev*noise +mean\n",
        "      x*=0.18215\n",
        "      return x\n"
      ]
    }
  ]
}