{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "cdf36725-ec00-4027-95d6-374340c2264e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|█████████████████████████████████████| 4.72G/4.72G [02:04<00:00, 40.7MiB/s]\n", "extracting: ./1.3B/tokenizer/bpe-16k-vocab.json (size:0MB): 100%|██████████| 7/7 [00:59<00:00, 8.51s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "/root/.cache/minDALL-E/1.3B/tokenizer successfully restored..\n", "/root/.cache/minDALL-E/1.3B/stage1_last.ckpt successfully restored..\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0.00/338M [00:00<?, ?iB/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "/root/.cache/minDALL-E/1.3B/stage2_last.ckpt succesfully restored..\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|███████████████████████████████████████| 338M/338M [00:09<00:00, 38.5MiB/s]\n" ] } ], "source": [ "import os\n", "import sys\n", "import math\n", "import argparse\n", "import clip\n", "import numpy as np\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "from PIL import Image\n", "\n", "sys.path.append(os.path.dirname(os.getcwd()))\n", "\n", "from dalle.models import Dalle\n", "from dalle.utils.utils import set_seed, clip_score\n", "\n", "device = 'cuda:0'\n", "model = Dalle.from_pretrained(\"minDALL-E/1.3B\")\n", "model_clip, preprocess_clip = clip.load(\"ViT-B/32\", device=device)\n", "\n", "model_clip.to(device=device)\n", "model.to(device=device)\n", "\n", "def sampling(prompt, top_k, softmax_temperature, seed, num_candidates=96, num_samples_for_display=36):\n", " # Setup\n", " n_row = int(math.sqrt(num_samples_for_display))\n", " n_col = int(math.sqrt(num_samples_for_display))\n", " set_seed(seed)\n", " \n", " # Sampling\n", " images = model.sampling(prompt=prompt,\n", " top_k=top_k,\n", " top_p=None,\n", " softmax_temperature=softmax_temperature,\n", " num_candidates=num_candidates,\n", " device=device).cpu().numpy()\n", " images = np.transpose(images, (0, 2, 3, 1))\n", "\n", " # CLIP Re-ranking\n", " rank = clip_score(prompt=prompt, images=images, model_clip=model_clip, preprocess_clip=preprocess_clip, device=device)\n", " images = images[rank]\n", " \n", " images = images[:num_samples_for_display]\n", " fig = plt.figure(figsize=(8*n_row, 8*n_col))\n", "\n", " for i in range(num_samples_for_display):\n", " ax = fig.add_subplot(n_row, n_col, i+1)\n", " ax.imshow(images[i])\n", " ax.set_axis_off()\n", "\n", " plt.tight_layout()\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 2, "id": "619add15-073e-40f4-9a97-06b89d647c81", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ee477531ea0e4b86b20d997f8cb83767", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntSlider(value=0, description='RND SEED: ', max=1024)" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d63edc4725ef4f4e8a6f03f7693a481d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "FloatSlider(value=1.0, description='SOFTMAX TEMPERATURE:', max=5.0, step=0.2)" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5bb9170e9e8b4686a661799d8aff3901", "version_major": 2, "version_minor": 0 }, "text/plain": [ "IntSlider(value=256, description='TOP-K:', max=512, step=16)" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "6b97b49debfc4f7ab002748e9fd89864", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Text(value='A painting of a monkey with sunglasses in the frame', description='String:', placeholder='Text pro…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a520b10d8c0b4dd0bb6db56dc37b4422", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Button(description='Generate!', style=ButtonStyle())" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5a98437abf964636a467677dc4f816bb", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "90d05006d50e4d88b8fb7c36095b12e7", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import ipywidgets as widgets\n", "from IPython.display import display\n", "from IPython.display import clear_output\n", "\n", "output = widgets.Output()\n", "plot_output = widgets.Output()\n", "\n", "def btn_eventhandler(obj):\n", " output.clear_output()\n", " plot_output.clear_output()\n", " \n", " with output:\n", " print(f'SEED: {slider_seed.value}')\n", " print(f'Softmax Temperature: {slider_temp.value}')\n", " print(f'Top-K: {slider_topk.value}')\n", " print(f'Text prompt: {wd_text.value}')\n", " \n", " with plot_output:\n", " sampling(prompt=wd_text.value, top_k=slider_topk.value, softmax_temperature=slider_temp.value, seed=slider_seed.value)\n", " \n", "slider_seed = widgets.IntSlider(\n", " min=0,\n", " max=1024,\n", " step=1,\n", " description='RND SEED: ',\n", " value=0\n", ")\n", "slider_topk = widgets.IntSlider(\n", " min=0,\n", " max=512,\n", " step=16,\n", " description='TOP-K:',\n", " value=256\n", ")\n", "slider_temp = widgets.FloatSlider(\n", " min=0.0,\n", " max=5.0,\n", " step=0.2,\n", " description='SOFTMAX TEMPERATURE:',\n", " value=1.0\n", ")\n", "wd_text = widgets.Text(\n", " value='A painting of a monkey with sunglasses in the frame',\n", " placeholder='Text prompt',\n", " description='String:',\n", " disabled=False\n", ")\n", "\n", "display(slider_seed)\n", "display(slider_temp)\n", "display(slider_topk)\n", "display(wd_text)\n", "\n", "btn = widgets.Button(description='Generate!')\n", "display(btn)\n", "btn.on_click(btn_eventhandler)\n", "\n", "display(output)\n", "display(plot_output)" ] }, { "cell_type": "code", "execution_count": null, "id": "20571236-3b9a-426e-ab29-96b643c8cbe1", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.7" } }, "nbformat": 4, "nbformat_minor": 5 }