{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "cdf36725-ec00-4027-95d6-374340c2264e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████| 4.72G/4.72G [02:04<00:00, 40.7MiB/s]\n",
      "extracting: ./1.3B/tokenizer/bpe-16k-vocab.json (size:0MB): 100%|██████████| 7/7 [00:59<00:00,  8.51s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/root/.cache/minDALL-E/1.3B/tokenizer successfully restored..\n",
      "/root/.cache/minDALL-E/1.3B/stage1_last.ckpt successfully restored..\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|                                               | 0.00/338M [00:00<?, ?iB/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/root/.cache/minDALL-E/1.3B/stage2_last.ckpt succesfully restored..\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████| 338M/338M [00:09<00:00, 38.5MiB/s]\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import sys\n",
    "import math\n",
    "import argparse\n",
    "import clip\n",
    "import numpy as np\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "\n",
    "sys.path.append(os.path.dirname(os.getcwd()))\n",
    "\n",
    "from dalle.models import Dalle\n",
    "from dalle.utils.utils import set_seed, clip_score\n",
    "\n",
    "device = 'cuda:0'\n",
    "model = Dalle.from_pretrained(\"minDALL-E/1.3B\")\n",
    "model_clip, preprocess_clip = clip.load(\"ViT-B/32\", device=device)\n",
    "\n",
    "model_clip.to(device=device)\n",
    "model.to(device=device)\n",
    "\n",
    "def sampling(prompt, top_k, softmax_temperature, seed, num_candidates=96, num_samples_for_display=36):\n",
    "    # Setup\n",
    "    n_row = int(math.sqrt(num_samples_for_display))\n",
    "    n_col = int(math.sqrt(num_samples_for_display))\n",
    "    set_seed(seed)\n",
    "    \n",
    "    # Sampling\n",
    "    images = model.sampling(prompt=prompt,\n",
    "                            top_k=top_k,\n",
    "                            top_p=None,\n",
    "                            softmax_temperature=softmax_temperature,\n",
    "                            num_candidates=num_candidates,\n",
    "                            device=device).cpu().numpy()\n",
    "    images = np.transpose(images, (0, 2, 3, 1))\n",
    "\n",
    "    # CLIP Re-ranking\n",
    "    rank = clip_score(prompt=prompt, images=images, model_clip=model_clip, preprocess_clip=preprocess_clip, device=device)\n",
    "    images = images[rank]\n",
    "    \n",
    "    images = images[:num_samples_for_display]\n",
    "    fig = plt.figure(figsize=(8*n_row, 8*n_col))\n",
    "\n",
    "    for i in range(num_samples_for_display):\n",
    "        ax = fig.add_subplot(n_row, n_col, i+1)\n",
    "        ax.imshow(images[i])\n",
    "        ax.set_axis_off()\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "619add15-073e-40f4-9a97-06b89d647c81",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ee477531ea0e4b86b20d997f8cb83767",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "IntSlider(value=0, description='RND SEED: ', max=1024)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d63edc4725ef4f4e8a6f03f7693a481d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "FloatSlider(value=1.0, description='SOFTMAX TEMPERATURE:', max=5.0, step=0.2)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5bb9170e9e8b4686a661799d8aff3901",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "IntSlider(value=256, description='TOP-K:', max=512, step=16)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6b97b49debfc4f7ab002748e9fd89864",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Text(value='A painting of a monkey with sunglasses in the frame', description='String:', placeholder='Text pro…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a520b10d8c0b4dd0bb6db56dc37b4422",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Button(description='Generate!', style=ButtonStyle())"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5a98437abf964636a467677dc4f816bb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "90d05006d50e4d88b8fb7c36095b12e7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import ipywidgets as widgets\n",
    "from IPython.display import display\n",
    "from IPython.display import clear_output\n",
    "\n",
    "output = widgets.Output()\n",
    "plot_output = widgets.Output()\n",
    "\n",
    "def btn_eventhandler(obj):\n",
    "    output.clear_output()\n",
    "    plot_output.clear_output()\n",
    "    \n",
    "    with output:\n",
    "        print(f'SEED: {slider_seed.value}')\n",
    "        print(f'Softmax Temperature: {slider_temp.value}')\n",
    "        print(f'Top-K: {slider_topk.value}')\n",
    "        print(f'Text prompt: {wd_text.value}')\n",
    "        \n",
    "    with plot_output:\n",
    "        sampling(prompt=wd_text.value, top_k=slider_topk.value, softmax_temperature=slider_temp.value, seed=slider_seed.value)\n",
    "    \n",
    "slider_seed = widgets.IntSlider(\n",
    "    min=0,\n",
    "    max=1024,\n",
    "    step=1,\n",
    "    description='RND SEED: ',\n",
    "    value=0\n",
    ")\n",
    "slider_topk = widgets.IntSlider(\n",
    "    min=0,\n",
    "    max=512,\n",
    "    step=16,\n",
    "    description='TOP-K:',\n",
    "    value=256\n",
    ")\n",
    "slider_temp = widgets.FloatSlider(\n",
    "    min=0.0,\n",
    "    max=5.0,\n",
    "    step=0.2,\n",
    "    description='SOFTMAX TEMPERATURE:',\n",
    "    value=1.0\n",
    ")\n",
    "wd_text = widgets.Text(\n",
    "    value='A painting of a monkey with sunglasses in the frame',\n",
    "    placeholder='Text prompt',\n",
    "    description='String:',\n",
    "    disabled=False\n",
    ")\n",
    "\n",
    "display(slider_seed)\n",
    "display(slider_temp)\n",
    "display(slider_topk)\n",
    "display(wd_text)\n",
    "\n",
    "btn = widgets.Button(description='Generate!')\n",
    "display(btn)\n",
    "btn.on_click(btn_eventhandler)\n",
    "\n",
    "display(output)\n",
    "display(plot_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20571236-3b9a-426e-ab29-96b643c8cbe1",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}