rhfeiyang commited on
Commit
262b155
1 Parent(s): d02b40b

Upload folder using huggingface_hub

Browse files
Files changed (50) hide show
  1. .ipynb_checkpoints/hf_demo_test-checkpoint.ipynb +336 -0
  2. README.md +3 -10
  3. __pycache__/inference.cpython-39.pyc +0 -0
  4. custom_datasets/__init__.py +141 -0
  5. custom_datasets/__pycache__/__init__.cpython-39.pyc +0 -0
  6. custom_datasets/__pycache__/coco.cpython-39.pyc +0 -0
  7. custom_datasets/__pycache__/imagepair.cpython-39.pyc +0 -0
  8. custom_datasets/__pycache__/mypath.cpython-39.pyc +0 -0
  9. custom_datasets/coco.py +307 -0
  10. custom_datasets/custom_caption.py +113 -0
  11. custom_datasets/filt/coco/filt.py +186 -0
  12. custom_datasets/filt/sam_filt.py +299 -0
  13. custom_datasets/imagepair.py +240 -0
  14. custom_datasets/lhq.py +127 -0
  15. custom_datasets/mypath.py +29 -0
  16. custom_datasets/sam.py +160 -0
  17. data/Art_adapters/albert-gleizes_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
  18. data/Art_adapters/andre-derain_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
  19. data/Art_adapters/andy_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
  20. data/Art_adapters/camille-corot_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
  21. data/Art_adapters/gerhard-richter_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
  22. data/Art_adapters/henri-matisse_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
  23. data/Art_adapters/jackson-pollock_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
  24. data/Art_adapters/joan-miro_subset2/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
  25. data/Art_adapters/kandinsky_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
  26. data/Art_adapters/katsushika-hokusai_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
  27. data/Art_adapters/klimt_subset3/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
  28. data/Art_adapters/m.c.-escher_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
  29. data/Art_adapters/monet_subset2/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
  30. data/Art_adapters/picasso_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
  31. data/Art_adapters/roy-lichtenstein_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
  32. data/Art_adapters/van_gogh_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
  33. data/Art_adapters/walter-battiss_subset2/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
  34. data/unsafe.png +0 -0
  35. hf_demo.py +147 -0
  36. hf_demo_test.ipynb +336 -0
  37. inference.py +657 -0
  38. utils/__init__.py +1 -0
  39. utils/__pycache__/__init__.cpython-39.pyc +0 -0
  40. utils/__pycache__/lora.cpython-39.pyc +0 -0
  41. utils/__pycache__/metrics.cpython-39.pyc +0 -0
  42. utils/__pycache__/train_util.cpython-39.pyc +0 -0
  43. utils/art_filter.py +210 -0
  44. utils/config_util.py +105 -0
  45. utils/debug_util.py +16 -0
  46. utils/lora.py +282 -0
  47. utils/metrics.py +577 -0
  48. utils/model_util.py +291 -0
  49. utils/prompt_util.py +174 -0
  50. utils/train_util.py +526 -0
.ipynb_checkpoints/hf_demo_test-checkpoint.ipynb ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "initial_id",
7
+ "metadata": {
8
+ "ExecuteTime": {
9
+ "end_time": "2024-12-09T09:44:30.641366Z",
10
+ "start_time": "2024-12-09T09:44:11.789050Z"
11
+ }
12
+ },
13
+ "outputs": [],
14
+ "source": [
15
+ "import os\n",
16
+ "\n",
17
+ "import gradio as gr\n",
18
+ "from diffusers import DiffusionPipeline\n",
19
+ "import matplotlib.pyplot as plt\n",
20
+ "import torch\n",
21
+ "from PIL import Image\n"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": 2,
27
+ "id": "ddf33e0d3abacc2c",
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "import sys\n",
32
+ "#append current path\n",
33
+ "sys.path.extend(\"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/release/hf_demo\")"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": 3,
39
+ "id": "643e49fd601daf8f",
40
+ "metadata": {
41
+ "ExecuteTime": {
42
+ "end_time": "2024-12-09T09:44:35.790962Z",
43
+ "start_time": "2024-12-09T09:44:35.779496Z"
44
+ }
45
+ },
46
+ "outputs": [],
47
+ "source": [
48
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\""
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": 4,
54
+ "id": "e03aae2a4e5676dd",
55
+ "metadata": {
56
+ "ExecuteTime": {
57
+ "end_time": "2024-12-09T09:44:44.157412Z",
58
+ "start_time": "2024-12-09T09:44:37.138452Z"
59
+ }
60
+ },
61
+ "outputs": [
62
+ {
63
+ "name": "stderr",
64
+ "output_type": "stream",
65
+ "text": [
66
+ "/data/vision/torralba/selfmanaged/torralba/scratch/jomat/sam_dataset/miniforge3/envs/diffusion/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
67
+ " warnings.warn(\n"
68
+ ]
69
+ },
70
+ {
71
+ "data": {
72
+ "application/vnd.jupyter.widget-view+json": {
73
+ "model_id": "9df8347307674ba8afb0250e23109aa1",
74
+ "version_major": 2,
75
+ "version_minor": 0
76
+ },
77
+ "text/plain": [
78
+ "Loading pipeline components...: 0%| | 0/7 [00:00<?, ?it/s]"
79
+ ]
80
+ },
81
+ "metadata": {},
82
+ "output_type": "display_data"
83
+ }
84
+ ],
85
+ "source": [
86
+ "pipe = DiffusionPipeline.from_pretrained(\"rhfeiyang/art-free-diffusion-v1\",).to(\"cuda\")\n",
87
+ "device = \"cuda\""
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": 5,
93
+ "id": "83916bc68ff5d914",
94
+ "metadata": {
95
+ "ExecuteTime": {
96
+ "end_time": "2024-12-09T09:44:52.694399Z",
97
+ "start_time": "2024-12-09T09:44:44.210695Z"
98
+ }
99
+ },
100
+ "outputs": [],
101
+ "source": [
102
+ "from inference import get_lora_network, inference, get_validation_dataloader\n",
103
+ "lora_map = {\n",
104
+ " \"None\": \"None\",\n",
105
+ " \"Andre Derain\": \"andre-derain_subset1\",\n",
106
+ " \"Vincent van Gogh\": \"van_gogh_subset1\",\n",
107
+ " \"Andy Warhol\": \"andy_subset1\",\n",
108
+ " \"Walter Battiss\": \"walter-battiss_subset2\",\n",
109
+ " \"Camille Corot\": \"camille-corot_subset1\",\n",
110
+ " \"Claude Monet\": \"monet_subset2\",\n",
111
+ " \"Pablo Picasso\": \"picasso_subset1\",\n",
112
+ " \"Jackson Pollock\": \"jackson-pollock_subset1\",\n",
113
+ " \"Gerhard Richter\": \"gerhard-richter_subset1\",\n",
114
+ " \"M.C. Escher\": \"m.c.-escher_subset1\",\n",
115
+ " \"Albert Gleizes\": \"albert-gleizes_subset1\",\n",
116
+ " \"Hokusai\": \"katsushika-hokusai_subset1\",\n",
117
+ " \"Wassily Kandinsky\": \"kandinsky_subset1\",\n",
118
+ " \"Gustav Klimt\": \"klimt_subset3\",\n",
119
+ " \"Roy Lichtenstein\": \"roy-lichtenstein_subset1\",\n",
120
+ " \"Henri Matisse\": \"henri-matisse_subset1\",\n",
121
+ " \"Joan Miro\": \"joan-miro_subset2\",\n",
122
+ "}\n",
123
+ "\n",
124
+ "def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0, steps=50, guidance_scale=7.5):\n",
125
+ " adapter_path = lora_map[adapter_choice]\n",
126
+ " if adapter_path not in [None, \"None\"]:\n",
127
+ " adapter_path = f\"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt\"\n",
128
+ "\n",
129
+ " prompts = [prompt]*samples\n",
130
+ " infer_loader = get_validation_dataloader(prompts)\n",
131
+ " network = get_lora_network(pipe.unet, adapter_path)[\"network\"]\n",
132
+ " pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n",
133
+ " height=512, width=512, scales=[1.0],\n",
134
+ " save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,\n",
135
+ " start_noise=-1, show=False, style_prompt=\"sks art\", no_load=True,\n",
136
+ " from_scratch=True)[0][1.0]\n",
137
+ " return pred_images\n",
138
+ "\n",
139
+ "def demo_inference_stylization(adapter_path:str, prompts:list, image:list, start_noise=800,seed:int=0):\n",
140
+ " infer_loader = get_validation_dataloader(prompts, image)\n",
141
+ " network = get_lora_network(pipe.unet, adapter_path,\"all_up\")[\"network\"]\n",
142
+ " pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n",
143
+ " height=512, width=512, scales=[0.,1.],\n",
144
+ " save_dir=None, seed=seed,steps=20, guidance_scale=7.5,\n",
145
+ " start_noise=start_noise, show=True, style_prompt=\"sks art\", no_load=True,\n",
146
+ " from_scratch=False)\n",
147
+ " return pred_images\n",
148
+ "\n",
149
+ "# def infer(prompt, samples, steps, scale, seed):\n",
150
+ "# generator = torch.Generator(device=device).manual_seed(seed)\n",
151
+ "# images_list = pipe( # type: ignore\n",
152
+ "# [prompt] * samples,\n",
153
+ "# num_inference_steps=steps,\n",
154
+ "# guidance_scale=scale,\n",
155
+ "# generator=generator,\n",
156
+ "# )\n",
157
+ "# images = []\n",
158
+ "# safe_image = Image.open(r\"data/unsafe.png\")\n",
159
+ "# print(images_list)\n",
160
+ "# for i, image in enumerate(images_list[\"images\"]): # type: ignore\n",
161
+ "# if images_list[\"nsfw_content_detected\"][i]: # type: ignore\n",
162
+ "# images.append(safe_image)\n",
163
+ "# else:\n",
164
+ "# images.append(image)\n",
165
+ "# return images\n"
166
+ ]
167
+ },
168
+ {
169
+ "cell_type": "code",
170
+ "execution_count": 6,
171
+ "id": "aa33e9d104023847",
172
+ "metadata": {
173
+ "ExecuteTime": {
174
+ "end_time": "2024-12-09T12:09:39.339583Z",
175
+ "start_time": "2024-12-09T12:09:38.953936Z"
176
+ }
177
+ },
178
+ "outputs": [
179
+ {
180
+ "name": "stdout",
181
+ "output_type": "stream",
182
+ "text": [
183
+ "<gradio.components.slider.Slider object at 0x7fa12d3a5280>\n",
184
+ "Running on local URL: http://127.0.0.1:7876\n",
185
+ "Running on public URL: https://be7cce8fec75395c82.gradio.live\n",
186
+ "\n",
187
+ "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
188
+ ]
189
+ },
190
+ {
191
+ "data": {
192
+ "text/html": [
193
+ "<div><iframe src=\"https://be7cce8fec75395c82.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
194
+ ],
195
+ "text/plain": [
196
+ "<IPython.core.display.HTML object>"
197
+ ]
198
+ },
199
+ "metadata": {},
200
+ "output_type": "display_data"
201
+ },
202
+ {
203
+ "data": {
204
+ "text/plain": []
205
+ },
206
+ "execution_count": 6,
207
+ "metadata": {},
208
+ "output_type": "execute_result"
209
+ },
210
+ {
211
+ "name": "stdout",
212
+ "output_type": "stream",
213
+ "text": [
214
+ "Train method: None\n",
215
+ "Rank: 1, Alpha: 1\n",
216
+ "create LoRA for U-Net: 0 modules.\n",
217
+ "save dir: None\n",
218
+ "['Park with cherry blossom trees, picnicker’s and a clear blue pond in the style of sks art'], seed=949192390\n"
219
+ ]
220
+ },
221
+ {
222
+ "name": "stderr",
223
+ "output_type": "stream",
224
+ "text": [
225
+ "/data/vision/torralba/selfmanaged/torralba/scratch/jomat/sam_dataset/miniforge3/envs/diffusion/lib/python3.9/site-packages/torch/nn/modules/conv.py:456: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at /opt/conda/conda-bld/pytorch_1712608883701/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)\n",
226
+ " return F.conv2d(input, weight, bias, self.stride,\n",
227
+ "\n",
228
+ "00%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [00:03<00:00, 6.90it/s]"
229
+ ]
230
+ },
231
+ {
232
+ "name": "stdout",
233
+ "output_type": "stream",
234
+ "text": [
235
+ "Time taken for one batch, Art Adapter scale=1.0: 3.2747044563293457\n"
236
+ ]
237
+ }
238
+ ],
239
+ "source": [
240
+ "block = gr.Blocks()\n",
241
+ "# Direct infer\n",
242
+ "with block:\n",
243
+ " with gr.Group():\n",
244
+ " with gr.Row():\n",
245
+ " text = gr.Textbox(\n",
246
+ " label=\"Enter your prompt\",\n",
247
+ " max_lines=2,\n",
248
+ " placeholder=\"Enter your prompt\",\n",
249
+ " container=False,\n",
250
+ " value=\"Park with cherry blossom trees, picnicker’s and a clear blue pond.\",\n",
251
+ " )\n",
252
+ " \n",
253
+ "\n",
254
+ " \n",
255
+ " btn = gr.Button(\"Run\", scale=0)\n",
256
+ " gallery = gr.Gallery(\n",
257
+ " label=\"Generated images\",\n",
258
+ " show_label=False,\n",
259
+ " elem_id=\"gallery\",\n",
260
+ " columns=[2],\n",
261
+ " )\n",
262
+ "\n",
263
+ " advanced_button = gr.Button(\"Advanced options\", elem_id=\"advanced-btn\")\n",
264
+ "\n",
265
+ " with gr.Row(elem_id=\"advanced-options\"):\n",
266
+ " adapter_choice = gr.Dropdown(\n",
267
+ " label=\"Choose adapter\",\n",
268
+ " choices=[\"None\", \"Andre Derain\",\"Vincent van Gogh\",\"Andy Warhol\", \"Walter Battiss\",\n",
269
+ " \"Camille Corot\", \"Claude Monet\", \"Pablo Picasso\",\n",
270
+ " \"Jackson Pollock\", \"Gerhard Richter\", \"M.C. Escher\",\n",
271
+ " \"Albert Gleizes\", \"Hokusai\", \"Wassily Kandinsky\", \"Gustav Klimt\", \"Roy Lichtenstein\",\n",
272
+ " \"Henri Matisse\", \"Joan Miro\"\n",
273
+ " ],\n",
274
+ " value=\"None\"\n",
275
+ " )\n",
276
+ " # print(adapter_choice[0])\n",
277
+ " # lora_path = lora_map[adapter_choice.value]\n",
278
+ " # if lora_path is not None:\n",
279
+ " # lora_path = f\"data/Art_adapters/{lora_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt\"\n",
280
+ "\n",
281
+ " samples = gr.Slider(label=\"Images\", minimum=1, maximum=4, value=1, step=1)\n",
282
+ " steps = gr.Slider(label=\"Steps\", minimum=1, maximum=50, value=20, step=1)\n",
283
+ " scale = gr.Slider(\n",
284
+ " label=\"Guidance Scale\", minimum=0, maximum=50, value=7.5, step=0.1\n",
285
+ " )\n",
286
+ " print(scale)\n",
287
+ " seed = gr.Slider(\n",
288
+ " label=\"Seed\",\n",
289
+ " minimum=0,\n",
290
+ " maximum=2147483647,\n",
291
+ " step=1,\n",
292
+ " randomize=True,\n",
293
+ " )\n",
294
+ "\n",
295
+ " gr.on([text.submit, btn.click], demo_inference_gen, inputs=[adapter_choice, text, samples, seed, steps, scale], outputs=gallery)\n",
296
+ " advanced_button.click(\n",
297
+ " None,\n",
298
+ " [],\n",
299
+ " text,\n",
300
+ " )\n",
301
+ "\n",
302
+ "\n",
303
+ "block.launch(share=True)"
304
+ ]
305
+ },
306
+ {
307
+ "cell_type": "code",
308
+ "execution_count": null,
309
+ "id": "3239c12167a5f2cd",
310
+ "metadata": {},
311
+ "outputs": [],
312
+ "source": []
313
+ }
314
+ ],
315
+ "metadata": {
316
+ "kernelspec": {
317
+ "display_name": "Python 3 (ipykernel)",
318
+ "language": "python",
319
+ "name": "python3"
320
+ },
321
+ "language_info": {
322
+ "codemirror_mode": {
323
+ "name": "ipython",
324
+ "version": 3
325
+ },
326
+ "file_extension": ".py",
327
+ "mimetype": "text/x-python",
328
+ "name": "python",
329
+ "nbconvert_exporter": "python",
330
+ "pygments_lexer": "ipython3",
331
+ "version": "3.9.18"
332
+ }
333
+ },
334
+ "nbformat": 4,
335
+ "nbformat_minor": 5
336
+ }
README.md CHANGED
@@ -1,13 +1,6 @@
1
  ---
2
- title: Art Free Diffusion
3
- emoji: 🖼
4
- colorFrom: purple
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 5.0.1
8
- app_file: app.py
9
- pinned: false
10
- short_description: Demo for Art Free Diffusion
11
  ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Art-Free-Diffusion
3
+ app_file: hf_demo.py
 
 
4
  sdk: gradio
5
+ sdk_version: 4.44.1
 
 
 
6
  ---
 
 
__pycache__/inference.cpython-39.pyc ADDED
Binary file (19.8 kB). View file
 
custom_datasets/__init__.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .mypath import MyPath
2
+ from copy import deepcopy
3
+ from datasets import load_dataset
4
+ from torch.utils.data import Dataset
5
+ import numpy as np
6
+
7
+ def get_dataset(dataset_name, transformation=None , train_subsample:int =None, val_subsample:int = 10000, get_val=True):
8
+ if train_subsample is not None and train_subsample<val_subsample and train_subsample!=-1:
9
+ print(f"Warning: train_subsample is smaller than val_subsample. val_subsample will be set to train_subsample: {train_subsample}")
10
+ val_subsample = train_subsample
11
+
12
+ if dataset_name == "imagenet":
13
+ from .imagenet import Imagenet1k
14
+ train_set = Imagenet1k(data_dir = MyPath.db_root_dir(dataset_name), transform = transformation, split="train", prompt_transform=Label_prompt_transform(real=True))
15
+ elif dataset_name == "coco_train":
16
+ # raise NotImplementedError("Use coco_filtered instead")
17
+ from .coco import CocoCaptions
18
+ train_set = CocoCaptions(root=MyPath.db_root_dir("coco_train"), annFile=MyPath.db_root_dir("coco_caption_train"))
19
+ elif dataset_name == "coco_val":
20
+ from .coco import CocoCaptions
21
+ train_set = CocoCaptions(root=MyPath.db_root_dir("coco_val"), annFile=MyPath.db_root_dir("coco_caption_val"))
22
+ return {"val": train_set}
23
+
24
+ elif dataset_name == "coco_clip_filtered":
25
+ from .coco import CocoCaptions_clip_filtered
26
+ train_set = CocoCaptions_clip_filtered(root=MyPath.db_root_dir("coco_train"), annFile=MyPath.db_root_dir("coco_caption_train"))
27
+ elif dataset_name == "coco_filtered_sub100":
28
+ from .coco import CocoCaptions_clip_filtered
29
+ train_set = CocoCaptions_clip_filtered(root=MyPath.db_root_dir("coco_train"), annFile=MyPath.db_root_dir("coco_caption_train"), id_file=MyPath.db_root_dir("coco_clip_filtered_ids_sub100"),)
30
+ elif dataset_name == "cifar10":
31
+ from .cifar import CIFAR10
32
+ train_set = CIFAR10(root=MyPath.db_root_dir("cifar10"), train=True, transform=transformation, prompt_transform=Label_prompt_transform(real=True))
33
+ elif dataset_name == "cifar100":
34
+ from .cifar import CIFAR100
35
+ train_set = CIFAR100(root=MyPath.db_root_dir("cifar100"), train=True, transform=transformation, prompt_transform=Label_prompt_transform(real=True))
36
+ elif "wikiart" in dataset_name and "/" not in dataset_name:
37
+ from .wikiart.wikiart import Wikiart_caption
38
+ dataset = Wikiart_caption(data_path=MyPath.db_root_dir(dataset_name))
39
+ return {"train": dataset.subsample(train_subsample).get_dataset(), "val": deepcopy(dataset).subsample(val_subsample).get_dataset() if get_val else None}
40
+ elif "imagepair" in dataset_name:
41
+ from .imagepair import ImagePair
42
+ train_set = ImagePair(folder1=MyPath.db_root_dir(dataset_name)[0], folder2=MyPath.db_root_dir(dataset_name)[1], transform=transformation).subsample(train_subsample)
43
+ # elif dataset_name == "sam_clip_filtered":
44
+ # from .sam import SamDataset
45
+ # train_set = SamDataset(image_folder_path=MyPath.db_root_dir("sam_images"), caption_folder_path=MyPath.db_root_dir("sam_captions"), id_file=MyPath.db_root_dir("sam_ids"), transforms=transformation).subsample(train_subsample)
46
+ elif dataset_name == "sam_whole_filtered":
47
+ from .sam import SamDataset
48
+ train_set = SamDataset(image_folder_path=MyPath.db_root_dir("sam_images"), caption_folder_path=MyPath.db_root_dir("sam_captions"), id_file=MyPath.db_root_dir("sam_whole_filtered_ids_train"), id_dict_file=MyPath.db_root_dir("sam_id_dict"), transforms=transformation).subsample(train_subsample)
49
+ elif dataset_name == "sam_whole_filtered_val":
50
+ from .sam import SamDataset
51
+ train_set = SamDataset(image_folder_path=MyPath.db_root_dir("sam_images"), caption_folder_path=MyPath.db_root_dir("sam_captions"), id_file=MyPath.db_root_dir("sam_whole_filtered_ids_val"), id_dict_file=MyPath.db_root_dir("sam_id_dict"), transforms=transformation).subsample(train_subsample)
52
+ return {"val": train_set}
53
+ elif dataset_name == "lhq_sub100":
54
+ from .lhq import LhqDataset
55
+ train_set = LhqDataset(image_folder_path=MyPath.db_root_dir("lhq_images"), caption_folder_path=MyPath.db_root_dir("lhq_captions"), id_file=MyPath.db_root_dir("lhq_ids_sub100"), transforms=transformation)
56
+ elif dataset_name == "lhq_sub500":
57
+ from .lhq import LhqDataset
58
+ train_set = LhqDataset(image_folder_path=MyPath.db_root_dir("lhq_images"), caption_folder_path=MyPath.db_root_dir("lhq_captions"), id_file=MyPath.db_root_dir("lhq_ids_sub500"), transforms=transformation)
59
+ elif dataset_name == "lhq_sub9":
60
+ from .lhq import LhqDataset
61
+ train_set = LhqDataset(image_folder_path=MyPath.db_root_dir("lhq_images"), caption_folder_path=MyPath.db_root_dir("lhq_captions"), id_file=MyPath.db_root_dir("lhq_ids_sub9"), transforms=transformation)
62
+
63
+ elif dataset_name == "custom_coco100":
64
+ from .coco import CustomCocoCaptions
65
+ train_set = CustomCocoCaptions(root=MyPath.db_root_dir("coco_val"), annFile=MyPath.db_root_dir("coco_caption_val"),
66
+ custom_file=MyPath.db_root_dir("custom_coco100_captions"), transforms=transformation)
67
+ elif dataset_name == "custom_coco500":
68
+ from .coco import CustomCocoCaptions
69
+ train_set = CustomCocoCaptions(root=MyPath.db_root_dir("coco_val"), annFile=MyPath.db_root_dir("coco_caption_val"),
70
+ custom_file=MyPath.db_root_dir("custom_coco500_captions"), transforms=transformation)
71
+ elif dataset_name == "laion_pop500":
72
+ from .custom_caption import Laion_pop
73
+ train_set = Laion_pop(anno_file=MyPath.db_root_dir("laion_pop500"), image_root=MyPath.db_root_dir("laion_images"), transform=transformation)
74
+
75
+ elif dataset_name == "laion_pop500_first_sentence":
76
+ from .custom_caption import Laion_pop
77
+ train_set = Laion_pop(anno_file=MyPath.db_root_dir("laion_pop500_first_sentence"), image_root=MyPath.db_root_dir("laion_images"), transform=transformation)
78
+
79
+
80
+ else:
81
+ try:
82
+ train_set = load_dataset('imagefolder', data_dir = dataset_name, split="train")
83
+ val_set = deepcopy(train_set)
84
+ if val_subsample is not None and val_subsample != -1:
85
+ val_set = val_set.shuffle(seed=0).select(range(val_subsample))
86
+ return {"train": train_set, "val": val_set if get_val else None}
87
+ except:
88
+ raise ValueError(f"dataset_name {dataset_name} not found.")
89
+ return {"train": train_set, "val": deepcopy(train_set).subsample(val_subsample) if get_val else None}
90
+
91
+
92
+ class MergeDataset(Dataset):
93
+ @staticmethod
94
+ def get_merged_dataset(dataset_names:list, transformation=None, train_subsample:int =None, val_subsample:int = 10000):
95
+ train_datasets = []
96
+ val_datasets = []
97
+ for dataset_name in dataset_names:
98
+ datasets = get_dataset(dataset_name, transformation, train_subsample, val_subsample)
99
+ train_datasets.append(datasets["train"])
100
+ val_datasets.append(datasets["val"])
101
+ train_datasets = MergeDataset(train_datasets).subsample(train_subsample)
102
+ val_datasets = MergeDataset(val_datasets).subsample(val_subsample)
103
+ return {"train": train_datasets, "val": val_datasets}
104
+
105
+ def __init__(self, datasets:list):
106
+ self.datasets = datasets
107
+ self.column_names = self.datasets[0].column_names
108
+ # self.ids = []
109
+ # start = 0
110
+ # for dataset in self.datasets:
111
+ # self.ids += [i+start for i in dataset.ids]
112
+ def define_resolution(self, resolution: int):
113
+ for dataset in self.datasets:
114
+ dataset.define_resolution(resolution)
115
+
116
+ def __len__(self):
117
+ return sum([len(dataset) for dataset in self.datasets])
118
+ def __getitem__(self, index):
119
+ for i,dataset in enumerate(self.datasets):
120
+ if index < len(dataset):
121
+ ret = dataset[index]
122
+ ret["id"] = index
123
+ ret["dataset"] = i
124
+ return ret
125
+ index -= len(dataset)
126
+ raise IndexError
127
+
128
+ def subsample(self, num:int):
129
+ if num is None:
130
+ return self
131
+ dataset_ratio = np.array([len(dataset) for dataset in self.datasets]) / len(self)
132
+ new_datasets = []
133
+ for i, dataset in enumerate(self.datasets):
134
+ new_datasets.append(dataset.subsample(int(num*dataset_ratio[i])))
135
+ return MergeDataset(new_datasets)
136
+
137
+ def with_transform(self, transform):
138
+ for dataset in self.datasets:
139
+ dataset.with_transform(transform)
140
+ return self
141
+
custom_datasets/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (5.8 kB). View file
 
custom_datasets/__pycache__/coco.cpython-39.pyc ADDED
Binary file (10.4 kB). View file
 
custom_datasets/__pycache__/imagepair.cpython-39.pyc ADDED
Binary file (8.93 kB). View file
 
custom_datasets/__pycache__/mypath.cpython-39.pyc ADDED
Binary file (1.49 kB). View file
 
custom_datasets/coco.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ from typing import Any, Callable, List, Optional, Tuple
3
+
4
+ from PIL import Image
5
+
6
+ from torchvision.datasets.vision import VisionDataset
7
+ import pickle
8
+ import csv
9
+ import pandas as pd
10
+ import torch
11
+ import torchvision
12
+ import re
13
+ # from torchvision.datasets import CocoDetection
14
+ # from utils.clip_filter import Clip_filter
15
+ from tqdm import tqdm
16
+ from .mypath import MyPath
17
+
18
+ class CocoDetection(VisionDataset):
19
+ """`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset.
20
+
21
+ It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
22
+
23
+ Args:
24
+ root (string): Root directory where images are downloaded to.
25
+ annFile (string): Path to json annotation file.
26
+ transform (callable, optional): A function/transform that takes in an PIL image
27
+ and returns a transformed version. E.g, ``transforms.PILToTensor``
28
+ target_transform (callable, optional): A function/transform that takes in the
29
+ target and transforms it.
30
+ transforms (callable, optional): A function/transform that takes input sample and its target as entry
31
+ and returns a transformed version.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ root: str ,
37
+ annFile: str,
38
+ transform: Optional[Callable] = None,
39
+ target_transform: Optional[Callable] = None,
40
+ transforms: Optional[Callable] = None,
41
+ get_img=True,
42
+ get_cap=True
43
+ ) -> None:
44
+ super().__init__(root, transforms, transform, target_transform)
45
+ from pycocotools.coco import COCO
46
+
47
+ self.coco = COCO(annFile)
48
+ self.ids = list(sorted(self.coco.imgs.keys()))
49
+ self.column_names = ["image", "text"]
50
+ self.get_img = get_img
51
+ self.get_cap = get_cap
52
+
53
+ def _load_image(self, id: int) -> Image.Image:
54
+ path = self.coco.loadImgs(id)[0]["file_name"]
55
+ with open(os.path.join(self.root, path), 'rb') as f:
56
+ img = Image.open(f).convert("RGB")
57
+
58
+ return img
59
+
60
+ def _load_target(self, id: int) -> List[Any]:
61
+ return self.coco.loadAnns(self.coco.getAnnIds(id))
62
+
63
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
64
+ id = self.ids[index]
65
+ ret={"id":id}
66
+ if self.get_img:
67
+ image = self._load_image(id)
68
+ ret["image"] = image
69
+ if self.get_cap:
70
+ target = self._load_target(id)
71
+ ret["caption"] = [target]
72
+
73
+ if self.transforms is not None:
74
+ ret = self.transforms(ret)
75
+
76
+ return ret
77
+
78
+ def subsample(self, n: int = 10000):
79
+ if n is None or n == -1:
80
+ return self
81
+ ori_len = len(self)
82
+ assert n <= ori_len
83
+ # equal interval subsample
84
+ ids = self.ids[::ori_len // n][:n]
85
+ self.ids = ids
86
+ print(f"COCO dataset subsampled from {ori_len} to {len(self)}")
87
+ return self
88
+
89
+
90
+ def with_transform(self, transform):
91
+ self.transforms = transform
92
+ return self
93
+
94
+ def __len__(self) -> int:
95
+ # return 100
96
+ return len(self.ids)
97
+
98
+
99
+ class CocoCaptions(CocoDetection):
100
+ """`MS Coco Captions <https://cocodataset.org/#captions-2015>`_ Dataset.
101
+
102
+ It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
103
+
104
+ Args:
105
+ root (string): Root directory where images are downloaded to.
106
+ annFile (string): Path to json annotation file.
107
+ transform (callable, optional): A function/transform that takes in an PIL image
108
+ and returns a transformed version. E.g, ``transforms.PILToTensor``
109
+ target_transform (callable, optional): A function/transform that takes in the
110
+ target and transforms it.
111
+ transforms (callable, optional): A function/transform that takes input sample and its target as entry
112
+ and returns a transformed version.
113
+
114
+ Example:
115
+
116
+ .. code:: python
117
+
118
+ import torchvision.datasets as dset
119
+ import torchvision.transforms as transforms
120
+ cap = dset.CocoCaptions(root = 'dir where images are',
121
+ annFile = 'json annotation file',
122
+ transform=transforms.PILToTensor())
123
+
124
+ print('Number of samples: ', len(cap))
125
+ img, target = cap[3] # load 4th sample
126
+
127
+ print("Image Size: ", img.size())
128
+ print(target)
129
+
130
+ Output: ::
131
+
132
+ Number of samples: 82783
133
+ Image Size: (3L, 427L, 640L)
134
+ [u'A plane emitting smoke stream flying over a mountain.',
135
+ u'A plane darts across a bright blue sky behind a mountain covered in snow',
136
+ u'A plane leaves a contrail above the snowy mountain top.',
137
+ u'A mountain that has a plane flying overheard in the distance.',
138
+ u'A mountain view with a plume of smoke in the background']
139
+
140
+ """
141
+
142
+ def _load_target(self, id: int) -> List[str]:
143
+ return [ann["caption"] for ann in super()._load_target(id)]
144
+
145
+
146
+ class CocoCaptions_clip_filtered(CocoCaptions):
147
+ positive_prompt=["painting", "drawing", "graffiti",]
148
+ def __init__(
149
+ self,
150
+ root: str ,
151
+ annFile: str,
152
+ transform: Optional[Callable] = None,
153
+ target_transform: Optional[Callable] = None,
154
+ transforms: Optional[Callable] = None,
155
+ regenerate: bool = False,
156
+ id_file: Optional[str] = "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/data/coco/coco_clip_filtered_ids.pickle"
157
+ ) -> None:
158
+ super().__init__(root, annFile, transform, target_transform, transforms)
159
+ os.makedirs(os.path.dirname(id_file), exist_ok=True)
160
+ if os.path.exists(id_file) and not regenerate:
161
+ with open(id_file, "rb") as f:
162
+ self.ids = pickle.load(f)
163
+ else:
164
+ self.ids, naive_filtered_num = self.naive_filter()
165
+ self.ids, clip_filtered_num = self.clip_filter(0.7)
166
+
167
+ print(f"naive Filtered {naive_filtered_num} images")
168
+ print(f"Clip Filtered {clip_filtered_num} images")
169
+
170
+ with open(id_file, "wb") as f:
171
+ pickle.dump(self.ids, f)
172
+ print(f"Filtered ids saved to {id_file}")
173
+ print(f"COCO filtered dataset size: {len(self)}")
174
+
175
+ def naive_filter(self, filter_prompt="painting"):
176
+ new_ids = []
177
+ naive_filtered_num = 0
178
+ for id in self.ids:
179
+ target = self._load_target(id)
180
+ filtered = False
181
+ for prompt in target:
182
+ if filter_prompt in prompt.lower():
183
+ filtered = True
184
+ naive_filtered_num += 1
185
+ break
186
+ # if "artwork" in prompt.lower():
187
+ # pass
188
+ if not filtered:
189
+ new_ids.append(id)
190
+ return new_ids, naive_filtered_num
191
+
192
+ # def clip_filter(self, threshold=0.7):
193
+ #
194
+ # def collate_fn(examples):
195
+ # # {"image": image, "text": [target], "id":id}
196
+ # pixel_values = [example["image"] for example in examples]
197
+ # prompts = [example["text"] for example in examples]
198
+ # id = [example["id"] for example in examples]
199
+ # return {"images": pixel_values, "prompts": prompts, "ids": id}
200
+ #
201
+ #
202
+ # clip_filtered_num = 0
203
+ # clip_filter = Clip_filter(positive_prompt=self.positive_prompt)
204
+ # clip_logs={"positive_prompt":clip_filter.positive_prompt, "negative_prompt":clip_filter.negative_prompt,
205
+ # "ids":torch.Tensor([]),"logits":torch.Tensor([])}
206
+ # clip_log_file = "data/coco/clip_logs.pth"
207
+ # new_ids = []
208
+ # batch_size = 128
209
+ # dataloader = torch.utils.data.DataLoader(self, batch_size=batch_size, num_workers=10, shuffle=False,
210
+ # collate_fn=collate_fn)
211
+ # for i, batch in enumerate(tqdm(dataloader)):
212
+ # images = batch["images"]
213
+ # filter_result, logits = clip_filter.filter(images, threshold=threshold)
214
+ # ids = torch.IntTensor(batch["ids"])
215
+ # clip_logs["ids"] = torch.cat([clip_logs["ids"], ids])
216
+ # clip_logs["logits"] = torch.cat([clip_logs["logits"], logits])
217
+ #
218
+ # new_ids.extend(ids[~filter_result].tolist())
219
+ # clip_filtered_num += filter_result.sum().item()
220
+ # if i % 50 == 0:
221
+ # torch.save(clip_logs, clip_log_file)
222
+ # torch.save(clip_logs, clip_log_file)
223
+ #
224
+ # return new_ids, clip_filtered_num
225
+
226
+
227
+ class CustomCocoCaptions(CocoCaptions):
228
+ def __init__(self, root: str=MyPath.db_root_dir("coco_val"), annFile: str=MyPath.db_root_dir("coco_caption_val"), custom_file:str="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/jomat-code/filtering/ms_coco_captions_testset100.txt",transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, transforms: Optional[Callable] = None) -> None:
229
+
230
+ super().__init__(root, annFile, transform, target_transform, transforms)
231
+ self.column_names = ["image", "text"]
232
+ self.custom_file = custom_file
233
+ self.load_custom_data(custom_file)
234
+ self.transforms = transforms
235
+
236
+ def load_custom_data(self, custom_file):
237
+ self.custom_data = []
238
+ with open(custom_file, "r") as f:
239
+ data = f.readlines()
240
+ head = data[0].strip().split(",")
241
+ self.head = head
242
+ for line in data[1:]:
243
+ sub_data = line.strip().split(",")
244
+ if len(sub_data) > len(head):
245
+ sub_data_new = [sub_data[0]]
246
+ sub_data_new+=[",".join(sub_data[1:-1])]
247
+ sub_data_new.append(sub_data[-1])
248
+ sub_data = sub_data_new
249
+ assert len(sub_data) == len(head)
250
+ self.custom_data.append(sub_data)
251
+ # to pd
252
+ self.custom_data = pd.DataFrame(self.custom_data, columns=head)
253
+
254
+ def __len__(self) -> int:
255
+ return len(self.custom_data)
256
+
257
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
258
+ data = self.custom_data.iloc[index]
259
+ id = int(data["image_id"])
260
+ ret={"id":id}
261
+ if self.get_img:
262
+ image = self._load_image(id)
263
+ ret["image"] = image
264
+ if self.get_cap:
265
+ caption = data["caption"]
266
+ ret["caption"] = [caption]
267
+ ret["seed"] = int(data["random_seed"])
268
+
269
+ if self.transforms is not None:
270
+ ret = self.transforms(ret)
271
+
272
+ return ret
273
+
274
+
275
+
276
+ def get_validation_set():
277
+ coco_instance = CocoDetection(root="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/.datasets/coco_2017/train2017/", annFile="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/.datasets/coco_2017/annotations/instances_train2017.json")
278
+ discard_cat_id = coco_instance.coco.getCatIds(supNms=["person", "animal"])
279
+ discard_img_id = []
280
+ for cat_id in discard_cat_id:
281
+ discard_img_id += coco_instance.coco.catToImgs[cat_id]
282
+
283
+ coco_clip_filtered = CocoCaptions_clip_filtered(root=MyPath.db_root_dir("coco_train"), annFile=MyPath.db_root_dir("coco_caption_train"),
284
+ regenerate=False)
285
+ coco_clip_filtered_ids = coco_clip_filtered.ids
286
+ new_ids = set(coco_clip_filtered_ids) - set(discard_img_id)
287
+ new_ids = list(new_ids)
288
+ new_ids = random.sample(new_ids, 100)
289
+ with open("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/data/coco/coco_clip_filtered_subset100.pickle", "wb") as f:
290
+ pickle.dump(new_ids, f)
291
+
292
+ if __name__ == "__main__":
293
+ from mypath import MyPath
294
+ import random
295
+ # get_validation_set()
296
+ # coco_filtered_remian_id = pickle.load(open("data/coco/coco_clip_filtered_ids.pickle", "rb"))
297
+ #
298
+ # coco_filtered_subset100 = random.sample(coco_filtered_remian_id, 100)
299
+ # save_path = "data/coco/coco_clip_filtered_subset100.pickle"
300
+ # with open(save_path, "wb") as f:
301
+ # pickle.dump(coco_filtered_subset100, f)
302
+
303
+ # dataset = CocoCaptions_clip_filtered(root=MyPath.db_root_dir("coco_train"), annFile=MyPath.db_root_dir("coco_caption_train"),
304
+ # regenerate=False)
305
+ dataset = CustomCocoCaptions(root=MyPath.db_root_dir("coco_val"), annFile=MyPath.db_root_dir("coco_caption_val"),
306
+ custom_file="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/jomat-code/filtering/ms_coco_captions_testset100.txt")
307
+ dataset[0]
custom_datasets/custom_caption.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Authors: Hui Ren (rhfeiyang.github.io)
2
+ import torch
3
+ import pandas as pd
4
+ import numpy as np
5
+ import os
6
+ from PIL import Image
7
+
8
+ class Caption_set(torch.utils.data.Dataset):
9
+
10
+ style_set_names=[
11
+ "andre-derain_subset1",
12
+ "andy_subset1",
13
+ "camille-corot_subset1",
14
+ "gerhard-richter_subset1",
15
+ "henri-matisse_subset1",
16
+ "katsushika-hokusai_subset1",
17
+ "klimt_subset3",
18
+ "monet_subset2",
19
+ "picasso_subset1",
20
+ "van_gogh_subset1",
21
+ ]
22
+ style_set_map={f"{name}":f"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/Style_captions/{name}/style_captions.csv" for name in style_set_names}
23
+
24
+ def __init__(self, prompts_path=None, set_name=None, transform=None):
25
+ assert prompts_path is not None or set_name is not None, "Either prompts_path or set_name should be provided"
26
+ if prompts_path is None:
27
+ prompts_path = self.style_set_map[set_name]
28
+
29
+ self.prompts = pd.read_csv(prompts_path, delimiter=';')
30
+ self.transform = transform
31
+ def __len__(self):
32
+ return len(self.prompts)
33
+ def __getitem__(self, idx):
34
+ ret={}
35
+ ret["id"] = idx
36
+ info = self.prompts.iloc[idx]
37
+ ret.update(info)
38
+ for k,v in ret.items():
39
+ if isinstance(v,np.int64):
40
+ ret[k] = int(v)
41
+ ret["caption"] = [ret["caption"]]
42
+ if self.transform:
43
+ ret = self.transform(ret)
44
+ return ret
45
+
46
+ def with_transform(self, transform):
47
+ self.transform = transform
48
+ return self
49
+
50
+
51
+ class HRS_caption(Caption_set):
52
+ def __init__(self, prompts_path="/vision-nfs/torralba/projects/jomat/hui/stable_diffusion/clip_dissection/Style_captions/andre-derain_subset1/style_captions.csv", transform=None, delimiter=','):
53
+ self.prompts = pd.read_csv(prompts_path, delimiter=delimiter)
54
+ self.transform = transform
55
+ self.caption_key = "original_prompts"
56
+
57
+ def __getitem__(self, idx):
58
+ ret={}
59
+ ret["id"] = idx
60
+ info = self.prompts.iloc[idx]
61
+ ret["caption"] = [info[self.caption_key]]
62
+ ret["seed"] = idx
63
+ if self.transform:
64
+ ret = self.transform(ret)
65
+ return ret
66
+
67
+ class Laion_pop(torch.utils.data.Dataset):
68
+ def __init__(self, anno_file="/vision-nfs/torralba/projects/jomat/hui/stable_diffusion/custom_datasets/laion_pop500.csv",image_root="/vision-nfs/torralba/scratch/jomat/sam_dataset/laion_pop",transform=None):
69
+ self.transform = transform
70
+ self.info = pd.read_csv(anno_file, delimiter=";")
71
+ self.caption_key = "caption"
72
+ self.image_root = image_root
73
+ self.get_img=True
74
+ self.get_caption=True
75
+ def __len__(self):
76
+ return len(self.info)
77
+
78
+ # def subsample(self, num:int):
79
+ # self.data = self.data.select(range(num))
80
+ # return self
81
+
82
+ def load_image(self, key):
83
+ image_path = os.path.join(self.image_root, f"{key:09}.jpg")
84
+ with open(image_path, "rb") as f:
85
+ image = Image.open(f).convert("RGB")
86
+ return image
87
+
88
+ def __getitem__(self, idx):
89
+ info = self.info.iloc[idx]
90
+ ret = {}
91
+ key = info["key"]
92
+ ret["id"] = key
93
+ if self.get_caption:
94
+ ret["caption"] = [info[self.caption_key]]
95
+ ret["seed"] = int(key)
96
+ if self.get_img:
97
+ ret["image"] = self.load_image(key)
98
+
99
+ if self.transform:
100
+ ret = self.transform(ret)
101
+ return ret
102
+
103
+ def with_transform(self, transform):
104
+ self.transform = transform
105
+ return self
106
+
107
+ def subset(self, ids:list):
108
+ self.info = self.info[self.info["key"].isin(ids)]
109
+ return self
110
+
111
+ if __name__ == "__main__":
112
+ dataset = Caption_set("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/Style_captions/andre-derain_subset1/style_captions.csv")
113
+ dataset[0]
custom_datasets/filt/coco/filt.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Authors: Hui Ren (rhfeiyang.github.io)
2
+ import os
3
+ import sys
4
+ import numpy as np
5
+ from PIL import Image
6
+ import pickle
7
+ sys.path.append(os.path.join(os.path.dirname(__file__), "../../../"))
8
+ from custom_datasets import get_dataset
9
+ from utils.art_filter import Art_filter
10
+ import torch
11
+ from matplotlib import pyplot as plt
12
+ import math
13
+ import argparse
14
+ import socket
15
+ import time
16
+ from tqdm import tqdm
17
+ import torch
18
+ def parse_args():
19
+ parser = argparse.ArgumentParser(description="Filter the coco dataset")
20
+ parser.add_argument("--check", action="store_true", help="Check the complete")
21
+ parser.add_argument("--mode", default="clip_logit", help="Filter mode: clip_logit, clip_filt, caption_filt")
22
+ parser.add_argument("--split" , default="val", help="Dataset split, val/train")
23
+ # parser.add_argument("--start_idx", default=0, type=int, help="Start index")
24
+ args = parser.parse_args()
25
+ return args
26
+
27
+ def get_feat(save_path, dataloader, filter):
28
+ clip_feat_file = save_path
29
+ # compute_new = False
30
+ clip_feat={}
31
+ if os.path.exists(clip_feat_file):
32
+ with open(clip_feat_file, 'rb') as f:
33
+ clip_feat = pickle.load(f)
34
+ else:
35
+ print(f"computing clip feat",flush=True)
36
+ clip_feature_ret = filter.clip_feature(dataloader)
37
+ clip_feat["image_features"] = clip_feature_ret["clip_features"]
38
+ clip_feat["ids"] = clip_feature_ret["ids"]
39
+
40
+ with open(clip_feat_file, 'wb') as f:
41
+ pickle.dump(clip_feat, f)
42
+ print(f"clip_feat_result saved to {clip_feat_file}",flush=True)
43
+ return clip_feat
44
+
45
+ def get_clip_logit(save_root, dataloader, filter):
46
+ feat_path = os.path.join(save_root, "clip_feat.pickle")
47
+ clip_feat = get_feat(feat_path, dataloader, filter)
48
+ clip_logits_file = os.path.join(save_root, "clip_logits.pickle")
49
+ # if clip_logit:
50
+ if os.path.exists(clip_logits_file):
51
+ with open(clip_logits_file, 'rb') as f:
52
+ clip_logits = pickle.load(f)
53
+ else:
54
+ clip_logits = filter.clip_logit_by_feat(clip_feat["image_features"])
55
+ clip_logits["ids"] = clip_feat["ids"]
56
+ with open(clip_logits_file, 'wb') as f:
57
+ pickle.dump(clip_logits, f)
58
+ print(f"clip_logits_result saved to {clip_logits_file}",flush=True)
59
+ return clip_logits
60
+
61
+ def clip_filt(save_root, dataloader, filter):
62
+ clip_filt_file = os.path.join(save_root, "clip_filt_result.pickle")
63
+ if os.path.exists(clip_filt_file):
64
+ with open(clip_filt_file, 'rb') as f:
65
+ clip_filt_result = pickle.load(f)
66
+ else:
67
+ clip_logits = get_clip_logit(save_root, dataloader, filter)
68
+ clip_filt_result = filter.clip_filt(clip_logits)
69
+ with open(clip_filt_file, 'wb') as f:
70
+ pickle.dump(clip_filt_result, f)
71
+ print(f"clip_filt_result saved to {clip_filt_file}",flush=True)
72
+ return clip_filt_result
73
+
74
+ def caption_filt(save_root, dataloader, filter):
75
+ caption_filt_file = os.path.join(save_root, "caption_filt_result.pickle")
76
+ if os.path.exists(caption_filt_file):
77
+ with open(caption_filt_file, 'rb') as f:
78
+ caption_filt_result = pickle.load(f)
79
+ else:
80
+ caption_filt_result = filter.caption_filt(dataloader)
81
+ with open(caption_filt_file, 'wb') as f:
82
+ pickle.dump(caption_filt_result, f)
83
+ print(f"caption_filt_result saved to {caption_filt_file}",flush=True)
84
+ return caption_filt_result
85
+
86
+ def gather_result(save_dir, dataloader, filter):
87
+ all_remain_ids=[]
88
+ all_remain_ids_train=[]
89
+ all_remain_ids_val=[]
90
+ all_filtered_id_num = 0
91
+
92
+ clip_filt_result = clip_filt(save_dir, dataloader, filter)
93
+ caption_filt_result = caption_filt(save_dir, dataloader, filter)
94
+
95
+ caption_filtered_ids = [i[0] for i in caption_filt_result["filtered_ids"]]
96
+ all_filtered_id_num += len(set(clip_filt_result["filtered_ids"]) | set(caption_filtered_ids) )
97
+ remain_ids = set(clip_filt_result["remain_ids"]) & set(caption_filt_result["remain_ids"])
98
+ remain_ids = list(remain_ids)
99
+ remain_ids.sort()
100
+ with open(os.path.join(save_dir, "remain_ids.pickle"), 'wb') as f:
101
+ pickle.dump(remain_ids, f)
102
+ print(f"remain_ids saved to {save_dir}/remain_ids.pickle",flush=True)
103
+ return remain_ids
104
+
105
+ @torch.no_grad()
106
+ def main(args):
107
+ filter = Art_filter()
108
+ if args.mode == "caption_filt" or args.mode == "gather_result":
109
+ filter.clip_filter = None
110
+ torch.cuda.empty_cache()
111
+
112
+ # caption_folder_path = "/vision-nfs/torralba/scratch/jomat/sam_dataset/PixArt-alpha/captions"
113
+ # image_folder_path = "/vision-nfs/torralba/scratch/jomat/sam_dataset/images"
114
+ # id_dict_dir = "/vision-nfs/torralba/scratch/jomat/sam_dataset/images/id_dict"
115
+ # filt_dir = "/vision-nfs/torralba/scratch/jomat/sam_dataset/filt_result"
116
+
117
+ def collate_fn(examples):
118
+ # {"image": image, "id":id}
119
+ ret = {}
120
+ if "image" in examples[0]:
121
+ pixel_values = [example["image"] for example in examples]
122
+ ret["images"] = pixel_values
123
+ if "caption" in examples[0]:
124
+ # prompts = [example["caption"] for example in examples]
125
+ prompts = []
126
+ for example in examples:
127
+ if isinstance(example["caption"][0], list):
128
+ prompts.append([" ".join(example["caption"][0])])
129
+ else:
130
+ prompts.append(example["caption"])
131
+ ret["text"] = prompts
132
+ id = [example["id"] for example in examples]
133
+ ret["ids"] = id
134
+ return ret
135
+ if args.split == "val":
136
+ dataset = get_dataset("coco_val")["val"]
137
+ elif args.split == "train":
138
+ dataset = get_dataset("coco_train", get_val=False)["train"]
139
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=False, num_workers=8, collate_fn=collate_fn)
140
+
141
+ error_files=[]
142
+
143
+
144
+
145
+ save_root = f"/vision-nfs/torralba/scratch/jomat/sam_dataset/coco/filt/{args.split}"
146
+ os.makedirs(save_root, exist_ok=True)
147
+
148
+ if args.mode == "clip_feat":
149
+ feat_path = os.path.join(save_root, "clip_feat.pickle")
150
+ clip_feat = get_feat(feat_path, dataloader, filter)
151
+
152
+ if args.mode == "clip_logit":
153
+ clip_logit = get_clip_logit(save_root, dataloader, filter)
154
+
155
+ if args.mode == "clip_filt":
156
+ # if os.path.exists(clip_filt_file):
157
+ # with open(clip_filt_file, 'rb') as f:
158
+ # ret = pickle.load(f)
159
+ # else:
160
+ clip_filt_result = clip_filt(save_root, dataloader, filter)
161
+
162
+ if args.mode == "caption_filt":
163
+ caption_filt_result = caption_filt(save_root, dataloader, filter)
164
+
165
+ if args.mode == "gather_result":
166
+ filtered_result = gather_result(save_root, dataloader, filter)
167
+
168
+ print("finished",flush=True)
169
+ for file in error_files:
170
+ # os.remove(file)
171
+ print(file,flush=True)
172
+
173
+ if __name__ == "__main__":
174
+ args = parse_args()
175
+
176
+ log_file = "sam_filt"
177
+ idx=0
178
+ hostname = socket.gethostname()
179
+ now_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
180
+ while os.path.exists(f"{log_file}_{hostname}_check{args.check}_{now_time}_{idx}.log"):
181
+ idx+=1
182
+
183
+ main(args)
184
+ # clip_logits_analysis()
185
+
186
+
custom_datasets/filt/sam_filt.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Authors: Hui Ren (rhfeiyang.github.io)
2
+ import os
3
+ import sys
4
+ import numpy as np
5
+ from PIL import Image
6
+ import pickle
7
+ sys.path.append(os.path.join(os.path.dirname(__file__), "../../"))
8
+ from custom_datasets.sam import SamDataset
9
+ from utils.art_filter import Art_filter
10
+ import torch
11
+ from matplotlib import pyplot as plt
12
+ import math
13
+ import argparse
14
+ import socket
15
+ import time
16
+ from tqdm import tqdm
17
+
18
+ def parse_args():
19
+ parser = argparse.ArgumentParser(description="Filter the sam dataset")
20
+ parser.add_argument("--check", action="store_true", help="Check the complete")
21
+ parser.add_argument("--mode", default="clip_logit", choices=["clip_logit_update","clip_logit", "clip_filt", "caption_filt", "gather_result","caption_flit_append"])
22
+ parser.add_argument("--start_idx", default=0, type=int, help="Start index")
23
+ parser.add_argument("--end_idx", default=9e10, type=int, help="Start index")
24
+ args = parser.parse_args()
25
+ return args
26
+ @torch.no_grad()
27
+ def main(args):
28
+ filter = Art_filter()
29
+ if args.mode == "caption_filt" or args.mode == "gather_result":
30
+ filter.clip_filter = None
31
+ torch.cuda.empty_cache()
32
+
33
+ caption_folder_path = "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/SAM/subset/captions"
34
+ image_folder_path = "/vision-nfs/torralba/scratch/jomat/sam_dataset/nfs-data/sam/images"
35
+ id_dict_dir = "/vision-nfs/torralba/scratch/jomat/sam_dataset/sam_ids/8.16/id_dict"
36
+ filt_dir = "/vision-nfs/torralba/scratch/jomat/sam_dataset/filt_result"
37
+ def collate_fn(examples):
38
+ # {"image": image, "id":id}
39
+ ret = {}
40
+ if "image" in examples[0]:
41
+ pixel_values = [example["image"] for example in examples]
42
+ ret["images"] = pixel_values
43
+ if "text" in examples[0]:
44
+ prompts = [example["text"] for example in examples]
45
+ ret["text"] = prompts
46
+ id = [example["id"] for example in examples]
47
+ ret["ids"] = id
48
+ return ret
49
+ error_files=[]
50
+ val_set = ["sa_000000"]
51
+ result_check_set = ["sa_000020"]
52
+ all_remain_ids=[]
53
+ all_remain_ids_train=[]
54
+ all_remain_ids_val=[]
55
+ all_filtered_id_num = 0
56
+ remain_feat_num = 0
57
+ remain_caption_num = 0
58
+ filter_feat_num = 0
59
+ filter_caption_num = 0
60
+ for idx,file in tqdm(enumerate(sorted(os.listdir(id_dict_dir)))):
61
+ if idx < args.start_idx or idx >= args.end_idx:
62
+ continue
63
+ if file.endswith(".pickle") and not file.startswith("all"):
64
+ print("=====================================")
65
+ print(file,flush=True)
66
+ save_dir = os.path.join(filt_dir, file.replace("_id_dict.pickle", ""))
67
+ if not os.path.exists(save_dir):
68
+ os.makedirs(save_dir, exist_ok=True)
69
+ id_dict_file = os.path.join(id_dict_dir, file)
70
+ with open(id_dict_file, 'rb') as f:
71
+ id_dict = pickle.load(f)
72
+ ids = list(id_dict.keys())
73
+ dataset = SamDataset(image_folder_path, caption_folder_path, id_file=ids, id_dict_file=id_dict_file)
74
+ # dataset = SamDataset(image_folder_path, caption_folder_path, id_file=[10061410, 10076945, 10310013,1042012, 4487809, 4541052], id_dict_file="/vision-nfs/torralba/scratch/jomat/sam_dataset/images/id_dict/all_id_dict.pickle")
75
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=False, num_workers=8, collate_fn=collate_fn)
76
+ clip_logits = None
77
+ clip_logits_file = os.path.join(save_dir, "clip_logits_result.pickle")
78
+ clip_filt_file = os.path.join(save_dir, "clip_filt_result.pickle")
79
+ caption_filt_file = os.path.join(save_dir, "caption_filt_result.pickle")
80
+
81
+ if args.mode == "clip_feat":
82
+ compute_new = False
83
+ clip_logits = {}
84
+ if os.path.exists(clip_logits_file):
85
+ with open(clip_logits_file, 'rb') as f:
86
+ clip_logits = pickle.load(f)
87
+ if "image_features" not in clip_logits:
88
+ compute_new = True
89
+ else:
90
+ compute_new=True
91
+ if compute_new:
92
+ if clip_logits == '':
93
+ clip_logits = {}
94
+ print(f"compute clip_feat {file}",flush=True)
95
+ clip_feature_ret = filter.clip_feature(dataloader)
96
+ clip_logits["image_features"] = clip_feature_ret["clip_features"]
97
+ if "ids" in clip_logits:
98
+ assert clip_feature_ret["ids"] == clip_logits["ids"]
99
+ else:
100
+ clip_logits["ids"] = clip_feature_ret["ids"]
101
+
102
+ with open(clip_logits_file, 'wb') as f:
103
+ pickle.dump(clip_logits, f)
104
+ print(f"clip_feat_result saved to {clip_logits_file}",flush=True)
105
+ else:
106
+ print(f"skip {clip_logits_file}",flush=True)
107
+
108
+ if args.mode == "clip_logit":
109
+ # if clip_logit:
110
+ if os.path.exists(clip_logits_file):
111
+ try:
112
+ with open(clip_logits_file, 'rb') as f:
113
+ clip_logits = pickle.load(f)
114
+ except:
115
+ continue
116
+ skip = True
117
+ if args.check and clip_logits=="":
118
+ skip = False
119
+
120
+ else:
121
+ skip = False
122
+ # skip = False
123
+ if not skip:
124
+ # os.makedirs(os.path.join(save_dir, "tmp"), exist_ok=True)
125
+ with open(clip_logits_file, 'wb') as f:
126
+ pickle.dump("", f)
127
+ try:
128
+ clip_logits = filter.clip_logit(dataloader)
129
+ except:
130
+ print(f"Error in clip_logit {file}",flush=True)
131
+ continue
132
+ with open(clip_logits_file, 'wb') as f:
133
+ pickle.dump(clip_logits, f)
134
+ print(f"clip_logits_result saved to {clip_logits_file}",flush=True)
135
+ else:
136
+ print(f"skip {clip_logits_file}",flush=True)
137
+
138
+ if args.mode == "clip_logit_update":
139
+ if os.path.exists(clip_logits_file):
140
+ with open(clip_logits_file, 'rb') as f:
141
+ clip_logits = pickle.load(f)
142
+ else:
143
+ print(f"{clip_logits_file} not exist",flush=True)
144
+ continue
145
+ if clip_logits == "":
146
+ print(f"skip {clip_logits_file}",flush=True)
147
+ continue
148
+ ret = filter.clip_logit_by_feat(clip_logits["clip_features"])
149
+ # assert (clip_logits["clip_logits"] - ret["clip_logits"]).abs().max() < 0.01
150
+ clip_logits["clip_logits"] = ret["clip_logits"]
151
+ clip_logits["text"] = ret["text"]
152
+ with open(clip_logits_file, 'wb') as f:
153
+ pickle.dump(clip_logits, f)
154
+
155
+
156
+ if args.mode == "clip_filt":
157
+ # if os.path.exists(clip_filt_file):
158
+ # with open(clip_filt_file, 'rb') as f:
159
+ # ret = pickle.load(f)
160
+ # else:
161
+
162
+ if clip_logits is None:
163
+ try:
164
+ with open(clip_logits_file, 'rb') as f:
165
+ clip_logits = pickle.load(f)
166
+ except:
167
+ print(f"Error in loading {clip_logits_file}",flush=True)
168
+ error_files.append(clip_logits_file)
169
+ continue
170
+ if clip_logits == "":
171
+ print(f"skip {clip_logits_file}",flush=True)
172
+ error_files.append(clip_logits_file)
173
+ continue
174
+ clip_filt_result = filter.clip_filt(clip_logits)
175
+ with open(clip_filt_file, 'wb') as f:
176
+ pickle.dump(clip_filt_result, f)
177
+ print(f"clip_filt_result saved to {clip_filt_file}",flush=True)
178
+
179
+ if args.mode == "caption_filt":
180
+ if os.path.exists(caption_filt_file):
181
+ try:
182
+ with open(caption_filt_file, 'rb') as f:
183
+ ret = pickle.load(f)
184
+ except:
185
+ continue
186
+ skip = True
187
+ if args.check and ret=="":
188
+ skip = False
189
+ # os.remove(caption_filt_file)
190
+ print(f"empty {caption_filt_file}",flush=True)
191
+ # skip = True
192
+ else:
193
+ skip = False
194
+ if not skip:
195
+ with open(caption_filt_file, 'wb') as f:
196
+ pickle.dump("", f)
197
+ # try:
198
+ ret = filter.caption_filt(dataloader)
199
+ # except:
200
+ # print(f"Error in filtering {file}",flush=True)
201
+ # continue
202
+ with open(caption_filt_file, 'wb') as f:
203
+ pickle.dump(ret, f)
204
+ print(f"caption_filt_result saved to {caption_filt_file}",flush=True)
205
+ else:
206
+ print(f"skip {caption_filt_file}",flush=True)
207
+
208
+ if args.mode == "caption_flit_append":
209
+ if not os.path.exists(caption_filt_file):
210
+ print(f"{caption_filt_file} not exist",flush=True)
211
+ continue
212
+ with open(caption_filt_file, 'rb') as f:
213
+ old_caption_filt_result = pickle.load(f)
214
+ skip = True
215
+ for i in filter.caption_filter.filter_prompts:
216
+ if i not in old_caption_filt_result["filter_prompts"]:
217
+ skip = False
218
+ break
219
+ if skip:
220
+ print(f"skip {caption_filt_file}",flush=True)
221
+ continue
222
+ old_remain_ids = old_caption_filt_result["remain_ids"]
223
+ new_dataset = SamDataset(image_folder_path, caption_folder_path, id_file=old_remain_ids, id_dict_file=id_dict_file)
224
+ new_dataloader = torch.utils.data.DataLoader(new_dataset, batch_size=64, shuffle=False, num_workers=8, collate_fn=collate_fn)
225
+ ret = filter.caption_filt(new_dataloader)
226
+ old_caption_filt_result["remain_ids"] = ret["remain_ids"]
227
+ old_caption_filt_result["filtered_ids"].extend(ret["filtered_ids"])
228
+ new_filter_count = ret["filter_count"].copy()
229
+ for i in range(len(old_caption_filt_result["filter_count"])):
230
+ new_filter_count[i] += old_caption_filt_result["filter_count"][i]
231
+
232
+ old_caption_filt_result["filter_count"] = new_filter_count
233
+ old_caption_filt_result["filter_prompts"] = ret["filter_prompts"]
234
+ with open(caption_filt_file, 'wb') as f:
235
+ pickle.dump(old_caption_filt_result, f)
236
+
237
+
238
+
239
+ if args.mode == "gather_result":
240
+ with open(clip_filt_file, 'rb') as f:
241
+ clip_filt_result = pickle.load(f)
242
+ with open(caption_filt_file, 'rb') as f:
243
+ caption_filt_result = pickle.load(f)
244
+ caption_filtered_ids = [i[0] for i in caption_filt_result["filtered_ids"]]
245
+ all_filtered_id_num += len(set(clip_filt_result["filtered_ids"]) | set(caption_filtered_ids) )
246
+
247
+ remain_feat_num += len(clip_filt_result["remain_ids"])
248
+ remain_caption_num += len(caption_filt_result["remain_ids"])
249
+ filter_feat_num += len(clip_filt_result["filtered_ids"])
250
+ filter_caption_num += len(caption_filtered_ids)
251
+
252
+ remain_ids = set(clip_filt_result["remain_ids"]) & set(caption_filt_result["remain_ids"])
253
+ remain_ids = list(remain_ids)
254
+ remain_ids.sort()
255
+ # with open(os.path.join(save_dir, "remain_ids.pickle"), 'wb') as f:
256
+ # pickle.dump(remain_ids, f)
257
+ # print(f"remain_ids saved to {save_dir}/remain_ids.pickle",flush=True)
258
+ all_remain_ids.extend(remain_ids)
259
+ if file.replace("_id_dict.pickle","") in val_set:
260
+ all_remain_ids_val.extend(remain_ids)
261
+ else:
262
+ all_remain_ids_train.extend(remain_ids)
263
+ if args.mode == "gather_result":
264
+ print(f"filtered ids: {all_filtered_id_num}",flush=True)
265
+ print(f"remain feat num: {remain_feat_num}",flush=True)
266
+ print(f"remain caption num: {remain_caption_num}",flush=True)
267
+ print(f"filter feat num: {filter_feat_num}",flush=True)
268
+ print(f"filter caption num: {filter_caption_num}",flush=True)
269
+ all_remain_ids.sort()
270
+ with open(os.path.join(filt_dir, "all_remain_ids.pickle"), 'wb') as f:
271
+ pickle.dump(all_remain_ids, f)
272
+ with open(os.path.join(filt_dir, "all_remain_ids_train.pickle"), 'wb') as f:
273
+ pickle.dump(all_remain_ids_train, f)
274
+ with open(os.path.join(filt_dir, "all_remain_ids_val.pickle"), 'wb') as f:
275
+ pickle.dump(all_remain_ids_val, f)
276
+
277
+ print(f"all_remain_ids saved to {filt_dir}/all_remain_ids.pickle",flush=True)
278
+ print(f"all_remain_ids_train saved to {filt_dir}/all_remain_ids_train.pickle",flush=True)
279
+ print(f"all_remain_ids_val saved to {filt_dir}/all_remain_ids_val.pickle",flush=True)
280
+
281
+ print("finished",flush=True)
282
+ for file in error_files:
283
+ # os.remove(file)
284
+ print(file,flush=True)
285
+
286
+ if __name__ == "__main__":
287
+ args = parse_args()
288
+
289
+ log_file = "sam_filt"
290
+ idx=0
291
+ hostname = socket.gethostname()
292
+ now_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
293
+ while os.path.exists(f"{log_file}_{hostname}_check{args.check}_{now_time}_{idx}.log"):
294
+ idx+=1
295
+
296
+ main(args)
297
+ # clip_logits_analysis()
298
+
299
+
custom_datasets/imagepair.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Authors: Hui Ren (rhfeiyang.github.io)
2
+ import random
3
+
4
+ import torch.utils.data as data
5
+ from PIL import Image
6
+ import os
7
+ import torch
8
+ # from tqdm import tqdm
9
+ class ImageSet(data.Dataset):
10
+ def __init__(self, folder , transform=None, keep_in_mem=True, caption=None):
11
+ self.path = folder
12
+ self.transform = transform
13
+ self.caption_path = None
14
+ self.images = []
15
+ self.captions = []
16
+ self.keep_in_mem = keep_in_mem
17
+
18
+ if not isinstance(folder, list):
19
+ self.image_files = [file for file in os.listdir(folder) if file.endswith((".png",".jpg"))]
20
+ self.image_files.sort()
21
+ else:
22
+ self.images = folder
23
+
24
+ if not isinstance(caption, list):
25
+ if caption not in [None, "", "None"]:
26
+ self.caption_path = caption
27
+ self.caption_files = [os.path.join(caption, file.replace(".png", ".txt").replace(".jpg", ".txt")) for file in self.image_files]
28
+ self.caption_files.sort()
29
+ else:
30
+ self.caption_path = True
31
+ self.captions = caption
32
+ # get all the image files png/jpg
33
+
34
+
35
+ if keep_in_mem:
36
+ if len(self.images) == 0:
37
+ for file in self.image_files:
38
+ img = self.load_image(os.path.join(self.path, file))
39
+ self.images.append(img)
40
+ if len(self.captions) == 0:
41
+ if self.caption_path is not None:
42
+ self.captions = []
43
+ for file in self.caption_files:
44
+ caption = self.load_caption(file)
45
+ self.captions.append(caption)
46
+ else:
47
+ self.images = None
48
+
49
+ def limit_num(self, n):
50
+ raise NotImplementedError
51
+ assert n <= len(self), f"n should be less than the length of the dataset {len(self)}"
52
+ self.image_files = self.image_files[:n]
53
+ self.caption_files = self.caption_files[:n]
54
+ if self.keep_in_mem:
55
+ self.images = self.images[:n]
56
+ self.captions = self.captions[:n]
57
+ print(f"Dataset limited to {n}")
58
+
59
+ def __len__(self):
60
+ if len(self.images) != 0:
61
+ return len(self.images)
62
+ else:
63
+ return len(self.image_files)
64
+
65
+ def load_image(self, path):
66
+ with open(path, 'rb') as f:
67
+ img = Image.open(f).convert('RGB')
68
+ return img
69
+
70
+ def load_caption(self, path):
71
+ with open(path, 'r') as f:
72
+ caption = f.readlines()
73
+ caption = [line.strip() for line in caption if len(line.strip()) > 0]
74
+ return caption
75
+
76
+ def __getitem__(self, index):
77
+ if len(self.images) != 0:
78
+ img = self.images[index]
79
+ else:
80
+ img = self.load_image(os.path.join(self.path, self.image_files[index]))
81
+
82
+ # if self.transform is not None:
83
+ # img = self.transform(img)
84
+
85
+ if self.caption_path is not None or len(self.captions) != 0:
86
+ if len(self.captions) != 0:
87
+ caption = self.captions[index]
88
+ else:
89
+ caption = self.load_caption(self.caption_files[index])
90
+ ret= {"image": img, "caption": caption, "id": index}
91
+ else:
92
+ ret= {"image": img, "id": index}
93
+ if self.transform is not None:
94
+ ret = self.transform(ret)
95
+ return ret
96
+
97
+ def subsample(self, n: int = 10):
98
+ if n is None or n == -1:
99
+ return self
100
+ ori_len = len(self)
101
+ assert n <= ori_len
102
+ # equal interval subsample
103
+ ids = self.image_files[::ori_len // n][:n]
104
+ self.image_files = ids
105
+ if self.keep_in_mem:
106
+ self.images = self.images[::ori_len // n][:n]
107
+ print(f"Dataset subsampled from {ori_len} to {len(self)}")
108
+ return self
109
+
110
+ def with_transform(self, transform):
111
+ self.transform = transform
112
+ return self
113
+ @staticmethod
114
+ def collate_fn(examples):
115
+ images = [example["image"] for example in examples]
116
+ ids = [example["id"] for example in examples]
117
+ if "caption" in examples[0]:
118
+ captions = [random.choice(example["caption"]) for example in examples]
119
+ return {"images": images, "captions": captions, "id": ids}
120
+ else:
121
+ return {"images": images, "id": ids}
122
+
123
+
124
+ class ImagePair(ImageSet):
125
+ def __init__(self, folder1, folder2, transform=None, keep_in_mem=True):
126
+ self.path1 = folder1
127
+ self.path2 = folder2
128
+ self.transform = transform
129
+ # get all the image files png/jpg
130
+ self.image_files = [file for file in os.listdir(folder1) if file.endswith(".png") or file.endswith(".jpg")]
131
+ self.image_files.sort()
132
+ self.keep_in_mem = keep_in_mem
133
+ if keep_in_mem:
134
+ self.images = []
135
+ for file in self.image_files:
136
+ img1 = self.load_image(os.path.join(self.path1, file))
137
+ img2 = self.load_image(os.path.join(self.path2, file))
138
+ self.images.append((img1, img2))
139
+ else:
140
+ self.images = None
141
+
142
+ def __getitem__(self, index):
143
+ if self.keep_in_mem:
144
+ img1, img2 = self.images[index]
145
+ else:
146
+ img1 = self.load_image(os.path.join(self.path1, self.image_files[index]))
147
+ img2 = self.load_image(os.path.join(self.path2, self.image_files[index]))
148
+
149
+ if self.transform is not None:
150
+ img1 = self.transform(img1)
151
+ img2 = self.transform(img2)
152
+ return {"image1": img1, "image2": img2, "id": index}
153
+
154
+
155
+
156
+ @staticmethod
157
+ def collate_fn(examples):
158
+ images1 = [example["image1"] for example in examples]
159
+ images2 = [example["image2"] for example in examples]
160
+ # images1 = torch.stack(images1)
161
+ # images2 = torch.stack(images2)
162
+ ids = [example["id"] for example in examples]
163
+ return {"image1": images1, "image2": images2, "id": ids}
164
+
165
+ def push_to_huggingface(self, hug_folder):
166
+ from datasets import Dataset
167
+ from datasets import Image as HugImage
168
+ photo_path = [os.path.join(self.path1, file) for file in self.image_files]
169
+ sketch_path = [os.path.join(self.path2, file) for file in self.image_files]
170
+ dataset = Dataset.from_dict({"photo": photo_path, "sketch": sketch_path, "file_name": self.image_files})
171
+ dataset = dataset.cast_column("photo", HugImage())
172
+ dataset = dataset.cast_column("sketch", HugImage())
173
+ dataset.push_to_hub(hug_folder, private=True)
174
+
175
+ class ImageClass(ImageSet):
176
+ def __init__(self, folders: list, transform=None, keep_in_mem=True):
177
+ self.paths = folders
178
+ self.transform = transform
179
+ # get all the image files png/jpg
180
+ self.image_files = []
181
+ self.keep_in_mem = keep_in_mem
182
+ for i, folder in enumerate(folders):
183
+ self.image_files+=[(os.path.join(folder, file), i) for file in os.listdir(folder) if file.endswith(".png") or file.endswith(".jpg")]
184
+ if keep_in_mem:
185
+ self.images = []
186
+ print("Loading images to memory")
187
+ for file in self.image_files:
188
+ img = self.load_image(file[0])
189
+ self.images.append((img, file[1]))
190
+ print("Loading images to memory done")
191
+ else:
192
+ self.images = None
193
+
194
+ def __getitem__(self, index):
195
+ if self.keep_in_mem:
196
+ img, label = self.images[index]
197
+ else:
198
+ img_path, label = self.image_files[index]
199
+ img = self.load_image(img_path)
200
+
201
+ if self.transform is not None:
202
+ img = self.transform(img)
203
+ return {"image": img, "label": label, "id": index}
204
+
205
+ @staticmethod
206
+ def collate_fn(examples):
207
+ images = [example["image"] for example in examples]
208
+ labels = [example["label"] for example in examples]
209
+ ids = [example["id"] for example in examples]
210
+ return {"images": images, "labels":labels, "id": ids}
211
+
212
+
213
+ if __name__ == "__main__":
214
+ # dataset = ImagePair("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_50",
215
+ # "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/sketch_50",keep_in_mem=False)
216
+ # dataset.push_to_huggingface("rhfeiyang/photo-sketch-pair-50")
217
+
218
+
219
+
220
+ dataset = ImagePair("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_500",
221
+ "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/sketch_500",
222
+ keep_in_mem=True)
223
+ # dataset.push_to_huggingface("rhfeiyang/photo-sketch-pair-500")
224
+ # ret = dataset[0]
225
+ # print(len(dataset))
226
+ import torch
227
+ from torchvision import transforms
228
+ train_transforms = transforms.Compose(
229
+ [
230
+ transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR),
231
+ transforms.CenterCrop(256),
232
+ transforms.RandomHorizontalFlip(),
233
+ transforms.ToTensor(),
234
+ transforms.Normalize([0.5], [0.5]),
235
+ ]
236
+ )
237
+ dataset = dataset.with_transform(train_transforms)
238
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, collate_fn=ImagePair.collate_fn)
239
+ ret = dataloader.__iter__().__next__()
240
+ pass
custom_datasets/lhq.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Authors: Hui Ren (rhfeiyang.github.io)
2
+ import os
3
+ import pickle
4
+ import random
5
+ import shutil
6
+ from torch.utils.data import Dataset
7
+ from torchvision import transforms
8
+ from PIL import Image
9
+
10
+ class LhqDataset(Dataset):
11
+ def __init__(self, image_folder_path:str, caption_folder_path:str, id_file:str = "clip_dissection/lhq/idx/subsample_100.pickle", transforms: transforms = None,
12
+ get_img=True,
13
+ get_cap=True,):
14
+
15
+ if isinstance(id_file, list):
16
+ self.ids = id_file
17
+ elif isinstance(id_file, str):
18
+ with open(id_file, 'rb') as f:
19
+ print(f"Loading ids from {id_file}", flush=True)
20
+ self.ids = pickle.load(f)
21
+ print(f"Loaded ids from {id_file}", flush=True)
22
+ self.image_folder_path = image_folder_path
23
+ self.caption_folder_path = caption_folder_path
24
+ self.transforms = transforms
25
+ self.column_names = ["image", "text"]
26
+ self.get_img = get_img
27
+ self.get_cap = get_cap
28
+
29
+ def __len__(self):
30
+ return len(self.ids)
31
+
32
+ def __getitem__(self, index: int):
33
+ id = self.ids[index]
34
+ ret={"id":id}
35
+ if self.get_img:
36
+ image = self._load_image(id)
37
+ ret["image"]=image
38
+ if self.get_cap:
39
+ target = self._load_caption(id)
40
+ ret["caption"]=[target]
41
+ if self.transforms is not None:
42
+ ret = self.transforms(ret)
43
+ return ret
44
+
45
+ def _load_image(self, id: int):
46
+ image_path = f"{self.image_folder_path}/{id}.jpg"
47
+ with open(image_path, 'rb') as f:
48
+ img = Image.open(f).convert("RGB")
49
+ return img
50
+
51
+ def _load_caption(self, id: int):
52
+ caption_path = f"{self.caption_folder_path}/{id}.txt"
53
+ with open(caption_path, 'r') as f:
54
+ caption_file = f.read()
55
+ caption = []
56
+ for line in caption_file.split("\n"):
57
+ line = line.strip()
58
+ if len(line) > 0:
59
+ caption.append(line)
60
+ return caption
61
+
62
+ def subsample(self, n: int = 10000):
63
+ if n is None or n == -1:
64
+ return self
65
+ ori_len = len(self)
66
+ assert n <= ori_len
67
+ # equal interval subsample
68
+ ids = self.ids[::ori_len // n][:n]
69
+ self.ids = ids
70
+ print(f"LHQ dataset subsampled from {ori_len} to {len(self)}")
71
+ return self
72
+
73
+ def with_transform(self, transform):
74
+ self.transforms = transform
75
+ return self
76
+
77
+
78
+ def generate_idx(data_folder = "/data/vision/torralba/clip_dissection/huiren/lhq/lhq_1024_jpg/lhq_1024_jpg/", save_path = "/data/vision/torralba/clip_dissection/huiren/lhq/idx/all_ids.pickle"):
79
+ all_ids = os.listdir(data_folder)
80
+ all_ids = [i.split(".")[0] for i in all_ids if i.endswith(".jpg") or i.endswith(".png")]
81
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
82
+ pickle.dump(all_ids, open(f"{save_path}", "wb"))
83
+ print("all_ids generated")
84
+ return all_ids
85
+
86
+ def random_sample(all_ids, sample_num = 110, save_root = "/data/vision/torralba/clip_dissection/huiren/lhq/subsample"):
87
+ chosen_id = random.sample(all_ids, sample_num)
88
+ save_dir = f"{save_root}/{sample_num}"
89
+ os.makedirs(save_dir, exist_ok=True)
90
+ for id in chosen_id:
91
+ img_path = f"/data/vision/torralba/clip_dissection/huiren/lhq/lhq_1024_jpg/lhq_1024_jpg/{id}.jpg"
92
+ shutil.copy(img_path, save_dir)
93
+
94
+ return chosen_id
95
+
96
+ if __name__ == "__main__":
97
+ # all_ids = generate_idx()
98
+ # with open("/data/vision/torralba/clip_dissection/huiren/lhq/idx/all_ids.pickle", "rb") as f:
99
+ # all_ids = pickle.load(f)
100
+ # # random_sample(all_ids, 1)
101
+ #
102
+ # # generate_idx(data_folder="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/subsample/100",
103
+ # # save_path="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/idx/subsample_100.pickle")
104
+ #
105
+ # # lhq 500
106
+ # with open("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/idx/subsample_100.pickle", "rb") as f:
107
+ # lhq_100_idx = pickle.load(f)
108
+ #
109
+ # extra_idx = set(all_ids) - set(lhq_100_idx)
110
+ # add_idx = random.sample(extra_idx, 400)
111
+ # lhq_500_idx = lhq_100_idx + add_idx
112
+ # with open("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/idx/subsample_500.pickle", "wb") as f:
113
+ # pickle.dump(lhq_500_idx, f)
114
+ # save_dir = "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/subsample/500"
115
+ # os.makedirs(save_dir, exist_ok=True)
116
+ # for id in lhq_500_idx:
117
+ # img_path = f"/data/vision/torralba/clip_dissection/huiren/lhq/lhq_1024_jpg/lhq_1024_jpg/{id}.jpg"
118
+ # # softlink
119
+ # os.symlink(img_path, os.path.join(save_dir, f"{id}.jpg"))
120
+
121
+ # lhq9
122
+ all_ids = generate_idx(data_folder="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/subsample/9",
123
+ save_path="/data/vision/torralba/clip_dissection/huiren/lhq/idx/subsample_9.pickle")
124
+ print(all_ids)
125
+
126
+
127
+
custom_datasets/mypath.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ class MyPath(object):
5
+ @staticmethod
6
+ def db_root_dir(database=''):
7
+ coco_root = "/data/vision/torralba/datasets/coco_2017"
8
+ sam_caption_root = "/vision-nfs/torralba/datasets/vision/sam/captions"
9
+
10
+ root_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
11
+ map={
12
+ "coco_train": f"{coco_root}/train2017/",
13
+ "coco_caption_train": f"{coco_root}/annotations/captions_train2017.json",
14
+ "coco_val": f"{coco_root}/val2017/",
15
+ "coco_caption_val": f"{coco_root}/annotations/captions_val2017.json",
16
+ "sam_images": "/vision-nfs/torralba/datasets/vision/sam/images",
17
+ "sam_captions": sam_caption_root,
18
+ "sam_whole_filtered_ids_train": "data/filtered_sam/all_remain_ids_train.pickle",
19
+ "sam_whole_filtered_ids_val": "data/filtered_sam/all_remain_ids_val.pickle",
20
+ "sam_id_dict": "data/filtered_sam/all_id_dict.pickle",
21
+
22
+ "lhq_ids_sub500": "data/LHQ500_caption/idx/subsample_500.pickle",
23
+ "lhq_images": "data/LHQ500_caption/subsample_500",
24
+ "lhq_captions": "data/LHQ500_caption/captions",
25
+ }
26
+ ret = map.get(database, None)
27
+ if ret is None:
28
+ raise NotImplementedError
29
+ return ret
custom_datasets/sam.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Authors: Hui Ren (rhfeiyang.github.io)
2
+ import os.path
3
+ import sys
4
+ from typing import Any, Callable, List, Optional, Tuple
5
+
6
+ import tqdm
7
+ from PIL import Image
8
+
9
+ from torch.utils.data import Dataset
10
+ import pickle
11
+ from torchvision import transforms
12
+ # import torch
13
+ # import torchvision
14
+ # import re
15
+
16
+
17
+ class SamDataset(Dataset):
18
+ def __init__(self, image_folder_path:str, caption_folder_path:str, id_file:str = "data/sam/clip_filtered_ids.pickle",id_dict_file:str =None , transforms: Optional[Callable] = None,
19
+ resolution=None,
20
+ get_img=True,
21
+ get_cap=True,):
22
+ if id_dict_file is not None:
23
+ with open(id_dict_file, 'rb') as f:
24
+ print(f"Loading id_dict from {id_dict_file}", flush=True)
25
+ self.id_dict = pickle.load(f)
26
+ print(f"Loaded id_dict from {id_dict_file}", flush=True)
27
+ else:
28
+ self.id_dict = None
29
+ if isinstance(id_file, list):
30
+ self.ids = id_file
31
+ elif isinstance(id_file, str):
32
+ with open(id_file, 'rb') as f:
33
+ print(f"Loading ids from {id_file}", flush=True)
34
+ self.ids = pickle.load(f)
35
+ print(f"Loaded ids from {id_file}", flush=True)
36
+ self.resolution = resolution
37
+ self.ori_image_folder_path = image_folder_path
38
+ if self.resolution is not None:
39
+ if os.path.exists("/var/jomat/datasets/"):
40
+ # self.image_folder_path = f"/var/jomat/datasets/SAM_{resolution}"
41
+ self.image_folder_path = f"{image_folder_path}_{resolution}"
42
+ else:
43
+ self.image_folder_path = f"{image_folder_path}_{resolution}"
44
+ os.makedirs(self.image_folder_path, exist_ok=True)
45
+ else:
46
+ self.image_folder_path = image_folder_path
47
+ self.caption_folder_path = caption_folder_path
48
+ self.transforms = transforms
49
+ self.column_names = ["image", "text"]
50
+ self.get_img = get_img
51
+ self.get_cap = get_cap
52
+
53
+ def __len__(self):
54
+ # return 100
55
+ return len(self.ids)
56
+
57
+ def __getitem__(self, index: int):
58
+ id = self.ids[index]
59
+ ret={"id":id}
60
+ try:
61
+ # if index == 1:
62
+ # raise Exception("test")
63
+ if self.get_img:
64
+ image = self._load_image(id)
65
+ ret["image"]=image
66
+ if self.get_cap:
67
+ target = self._load_caption(id)
68
+ ret["text"] = [target]
69
+ if self.transforms is not None:
70
+ ret = self.transforms(ret)
71
+ return ret
72
+ except Exception as e:
73
+ raise e
74
+ print(f"Error loading image and caption for id {id}, error: {e}, redirecting to index 0", flush=True)
75
+ ret = self[0]
76
+ return ret
77
+
78
+ def define_resolution(self, resolution: int):
79
+ self.resolution = resolution
80
+ if os.path.exists("/var/jomat/datasets/"):
81
+ self.image_folder_path = f"/var/jomat/datasets/SAM_{resolution}"
82
+ # self.image_folder_path = f"{self.ori_image_folder_path}_{resolution}"
83
+ else:
84
+ self.image_folder_path = f"{self.ori_image_folder_path}_{resolution}"
85
+ print(f"SamDataset resolution defined to {resolution}, new image folder path: {self.image_folder_path}")
86
+ def _load_image(self, id: int) -> Image.Image:
87
+ if self.id_dict is not None:
88
+ subfolder = self.id_dict[id]
89
+ image_path = f"{self.image_folder_path}/{subfolder}/sa_{id}.jpg"
90
+ else:
91
+ image_path = f"{self.image_folder_path}/sa_{id}.jpg"
92
+
93
+ try:
94
+ with open(image_path, 'rb') as f:
95
+ img = Image.open(f).convert("RGB")
96
+ # return img
97
+ except:
98
+ # load original image
99
+ if self.id_dict is not None:
100
+ subfolder = self.id_dict[id]
101
+ ori_image_path = f"{self.ori_image_folder_path}/{subfolder}/sa_{id}.jpg"
102
+ else:
103
+ ori_image_path = f"{self.ori_image_folder_path}/sa_{id}.jpg"
104
+ assert os.path.exists(ori_image_path)
105
+ with open(ori_image_path, 'rb') as f:
106
+ img = Image.open(f).convert("RGB")
107
+ # resize image keep aspect ratio
108
+ if self.resolution is not None:
109
+ img = transforms.Resize(self.resolution, interpolation=transforms.InterpolationMode.BICUBIC)(img)
110
+ # write image
111
+ os.makedirs(os.path.dirname(image_path), exist_ok=True)
112
+ img.save(image_path)
113
+
114
+ return img
115
+
116
+
117
+ def _load_caption(self, id: int):
118
+ caption_path = f"{self.caption_folder_path}/sa_{id}.txt"
119
+ if not os.path.exists(caption_path):
120
+ return None
121
+ try:
122
+ with open(caption_path, 'r', encoding="utf-8") as f:
123
+ content = f.read()
124
+ except Exception as e:
125
+ raise e
126
+ print(f"Error reading caption file {caption_path}, error: {e}")
127
+ return None
128
+ sentences = content.split('.')
129
+ # remove empty sentences and sentences with "black and white"(too many false prediction)
130
+ sentences = [sentence.strip() for sentence in sentences if sentence.strip() and "black and white" not in sentence]
131
+ # join sentence
132
+ sentences = ". ".join(sentences)
133
+ if len(sentences) > 0 and sentences[-1] != '.':
134
+ sentences += '.'
135
+
136
+ return sentences
137
+
138
+ def with_transform(self, transform):
139
+ self.transforms = transform
140
+ return self
141
+
142
+ def subsample(self, n: int = 10000):
143
+ if n is None or n == -1:
144
+ return self
145
+ ori_len = len(self)
146
+ assert n <= ori_len
147
+ # equal interval subsample
148
+ ids = self.ids[::ori_len // n][:n]
149
+ self.ids = ids
150
+ print(f"SAM dataset subsampled from {ori_len} to {len(self)}")
151
+ return self
152
+
153
+
154
+ if __name__ == "__main__":
155
+ # sam_filt(caption_filt=False, clip_filt=False, clip_logit=True)
156
+ from custom_datasets.sam_caption.mypath import MyPath
157
+ dataset = SamDataset(image_folder_path=MyPath.db_root_dir("sam_images"), caption_folder_path=MyPath.db_root_dir("sam_captions"), id_file=MyPath.db_root_dir("sam_whole_filtered_ids_train"), id_dict_file=MyPath.db_root_dir("sam_id_dict"))
158
+ dataset.get_img = False
159
+ for i in tqdm.tqdm(dataset):
160
+ a=i['text']
data/Art_adapters/albert-gleizes_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1802d12e4d9526eedb89d99f69051849f14774da3c73ebc9b1393c2b13f17022
3
+ size 2187129
data/Art_adapters/andre-derain_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c39b39f32ff88dfed978ccc651715ade9edfd901d529adbeb5eedb715b8e159
3
+ size 2187129
data/Art_adapters/andy_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd7764b19a2b4513b3c22f1607d72daa63c4ace97ea803e29e2bcf3f13bab2e8
3
+ size 2187129
data/Art_adapters/camille-corot_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:426c2e4a3bfc26f7fdcc3e82989d717fa5fc6e732cd9df9f8bb293ab72cacfa5
3
+ size 2187129
data/Art_adapters/gerhard-richter_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8be8ef590baceb2bdfac8b25976df88fa7baa1a9c718ed16aa4fa8fa247bb421
3
+ size 2187129
data/Art_adapters/henri-matisse_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:212f0f16ae84c0bae96e213a0b0d5f4309209b332d48cbaa1748b5cdcfb3238a
3
+ size 2187129
data/Art_adapters/jackson-pollock_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5cff54e3e7c544577dbc39d7015a89c4786cd012cf944d0b9db334c1a1d7e30b
3
+ size 2187129
data/Art_adapters/joan-miro_subset2/adapter_alpha1.0_rank1_all_up_1000steps.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c26bdb5bfba85b4eb00631eda149912ba557935773842f95c0596999f799a2b4
3
+ size 2187129
data/Art_adapters/kandinsky_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24b33205841d9b09c0076b4ba295be29d94677e69b7269465897bbf059a40454
3
+ size 2187129
data/Art_adapters/katsushika-hokusai_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b34b75325c3fd0353b55f390027a32a98f771df7d2fb21dbd8bce81a12ba59e9
3
+ size 2187129
data/Art_adapters/klimt_subset3/adapter_alpha1.0_rank1_all_up_1000steps.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7457f14af7c77f98675063582b35317963d46e942459575d38b5996ed190c58f
3
+ size 2187129
data/Art_adapters/m.c.-escher_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6df86764f4d4ceec0bd6124a74a51c36665c8491511a5488737b9a64300b97b
3
+ size 2187129
data/Art_adapters/monet_subset2/adapter_alpha1.0_rank1_all_up_1000steps.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6a9ba0305edca3286258a06023b97914b850fbc8b4f5a14769537f9a01ef33f1
3
+ size 2187129
data/Art_adapters/picasso_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ce7899c19b32dacd2dc46090fd3429495a2230c173bcd96149236d27b5151fd
3
+ size 2187129
data/Art_adapters/roy-lichtenstein_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ac428a5d0fb136b79eec2349fbcbd99dfac2315c0a7f54d7985299b60b6f66f
3
+ size 2187129
data/Art_adapters/van_gogh_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ca866dd868fb89a1180bb140dfaf1e48701993c8fa173d70c56c60c9af8d8fb
3
+ size 2187129
data/Art_adapters/walter-battiss_subset2/adapter_alpha1.0_rank1_all_up_1000steps.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:41cad39d7b6e1873cfef85be478851820f5dc80cd7ce11afe2bfa3584662e3ac
3
+ size 2187129
data/unsafe.png ADDED
hf_demo.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Authors: Hui Ren (rhfeiyang.github.io)
2
+ import os
3
+
4
+ import gradio as gr
5
+ from diffusers import DiffusionPipeline
6
+ import matplotlib.pyplot as plt
7
+ import torch
8
+ from PIL import Image
9
+
10
+
11
+
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ pipe = DiffusionPipeline.from_pretrained("rhfeiyang/art-free-diffusion-v1",).to(device)
14
+
15
+ from inference import get_lora_network, inference, get_validation_dataloader
16
+ lora_map = {
17
+ "None": "None",
18
+ "Andre Derain": "andre-derain_subset1",
19
+ "Vincent van Gogh": "van_gogh_subset1",
20
+ "Andy Warhol": "andy_subset1",
21
+ "Walter Battiss": "walter-battiss_subset2",
22
+ "Camille Corot": "camille-corot_subset1",
23
+ "Claude Monet": "monet_subset2",
24
+ "Pablo Picasso": "picasso_subset1",
25
+ "Jackson Pollock": "jackson-pollock_subset1",
26
+ "Gerhard Richter": "gerhard-richter_subset1",
27
+ "M.C. Escher": "m.c.-escher_subset1",
28
+ "Albert Gleizes": "albert-gleizes_subset1",
29
+ "Hokusai": "katsushika-hokusai_subset1",
30
+ "Wassily Kandinsky": "kandinsky_subset1",
31
+ "Gustav Klimt": "klimt_subset3",
32
+ "Roy Lichtenstein": "roy-lichtenstein_subset1",
33
+ "Henri Matisse": "henri-matisse_subset1",
34
+ "Joan Miro": "joan-miro_subset2",
35
+ }
36
+
37
+ def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0, steps=50, guidance_scale=7.5):
38
+ adapter_path = lora_map[adapter_choice]
39
+ if adapter_path not in [None, "None"]:
40
+ adapter_path = f"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
41
+
42
+ prompts = [prompt]*samples
43
+ infer_loader = get_validation_dataloader(prompts)
44
+ network = get_lora_network(pipe.unet, adapter_path)["network"]
45
+ pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
46
+ height=512, width=512, scales=[1.0],
47
+ save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,
48
+ start_noise=-1, show=False, style_prompt="sks art", no_load=True,
49
+ from_scratch=True)[0][1.0]
50
+ return pred_images
51
+
52
+ def demo_inference_stylization(adapter_path:str, prompts:list, image:list, start_noise=800,seed:int=0):
53
+ infer_loader = get_validation_dataloader(prompts, image)
54
+ network = get_lora_network(pipe.unet, adapter_path,"all_up")["network"]
55
+ pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
56
+ height=512, width=512, scales=[0.,1.],
57
+ save_dir=None, seed=seed,steps=20, guidance_scale=7.5,
58
+ start_noise=start_noise, show=True, style_prompt="sks art", no_load=True,
59
+ from_scratch=False)
60
+ return pred_images
61
+
62
+ # def infer(prompt, samples, steps, scale, seed):
63
+ # generator = torch.Generator(device=device).manual_seed(seed)
64
+ # images_list = pipe( # type: ignore
65
+ # [prompt] * samples,
66
+ # num_inference_steps=steps,
67
+ # guidance_scale=scale,
68
+ # generator=generator,
69
+ # )
70
+ # images = []
71
+ # safe_image = Image.open(r"data/unsafe.png")
72
+ # print(images_list)
73
+ # for i, image in enumerate(images_list["images"]): # type: ignore
74
+ # if images_list["nsfw_content_detected"][i]: # type: ignore
75
+ # images.append(safe_image)
76
+ # else:
77
+ # images.append(image)
78
+ # return images
79
+
80
+
81
+
82
+
83
+ block = gr.Blocks()
84
+ # Direct infer
85
+ with block:
86
+ with gr.Group():
87
+ with gr.Row():
88
+ text = gr.Textbox(
89
+ label="Enter your prompt",
90
+ max_lines=2,
91
+ placeholder="Enter your prompt",
92
+ container=False,
93
+ value="Park with cherry blossom trees, picnicker’s and a clear blue pond.",
94
+ )
95
+
96
+
97
+
98
+ btn = gr.Button("Run", scale=0)
99
+ gallery = gr.Gallery(
100
+ label="Generated images",
101
+ show_label=False,
102
+ elem_id="gallery",
103
+ columns=[2],
104
+ )
105
+
106
+ advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
107
+
108
+ with gr.Row(elem_id="advanced-options"):
109
+ adapter_choice = gr.Dropdown(
110
+ label="Choose adapter",
111
+ choices=["None", "Andre Derain","Vincent van Gogh","Andy Warhol", "Walter Battiss",
112
+ "Camille Corot", "Claude Monet", "Pablo Picasso",
113
+ "Jackson Pollock", "Gerhard Richter", "M.C. Escher",
114
+ "Albert Gleizes", "Hokusai", "Wassily Kandinsky", "Gustav Klimt", "Roy Lichtenstein",
115
+ "Henri Matisse", "Joan Miro"
116
+ ],
117
+ value="None"
118
+ )
119
+ # print(adapter_choice[0])
120
+ # lora_path = lora_map[adapter_choice.value]
121
+ # if lora_path is not None:
122
+ # lora_path = f"data/Art_adapters/{lora_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
123
+
124
+ samples = gr.Slider(label="Images", minimum=1, maximum=4, value=1, step=1)
125
+ steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=20, step=1)
126
+ scale = gr.Slider(
127
+ label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
128
+ )
129
+ print(scale)
130
+ seed = gr.Slider(
131
+ label="Seed",
132
+ minimum=0,
133
+ maximum=2147483647,
134
+ step=1,
135
+ randomize=True,
136
+ )
137
+
138
+ gr.on([text.submit, btn.click], demo_inference_gen, inputs=[adapter_choice, text, samples, seed, steps, scale], outputs=gallery)
139
+ advanced_button.click(
140
+ None,
141
+ [],
142
+ text,
143
+ )
144
+
145
+
146
+
147
+ block.launch()
hf_demo_test.ipynb ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "initial_id",
7
+ "metadata": {
8
+ "ExecuteTime": {
9
+ "end_time": "2024-12-09T09:44:30.641366Z",
10
+ "start_time": "2024-12-09T09:44:11.789050Z"
11
+ }
12
+ },
13
+ "outputs": [],
14
+ "source": [
15
+ "import os\n",
16
+ "\n",
17
+ "import gradio as gr\n",
18
+ "from diffusers import DiffusionPipeline\n",
19
+ "import matplotlib.pyplot as plt\n",
20
+ "import torch\n",
21
+ "from PIL import Image\n"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": 2,
27
+ "id": "ddf33e0d3abacc2c",
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "import sys\n",
32
+ "#append current path\n",
33
+ "sys.path.extend(\"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/release/hf_demo\")"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": 3,
39
+ "id": "643e49fd601daf8f",
40
+ "metadata": {
41
+ "ExecuteTime": {
42
+ "end_time": "2024-12-09T09:44:35.790962Z",
43
+ "start_time": "2024-12-09T09:44:35.779496Z"
44
+ }
45
+ },
46
+ "outputs": [],
47
+ "source": [
48
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\""
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": 4,
54
+ "id": "e03aae2a4e5676dd",
55
+ "metadata": {
56
+ "ExecuteTime": {
57
+ "end_time": "2024-12-09T09:44:44.157412Z",
58
+ "start_time": "2024-12-09T09:44:37.138452Z"
59
+ }
60
+ },
61
+ "outputs": [
62
+ {
63
+ "name": "stderr",
64
+ "output_type": "stream",
65
+ "text": [
66
+ "/data/vision/torralba/selfmanaged/torralba/scratch/jomat/sam_dataset/miniforge3/envs/diffusion/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
67
+ " warnings.warn(\n"
68
+ ]
69
+ },
70
+ {
71
+ "data": {
72
+ "application/vnd.jupyter.widget-view+json": {
73
+ "model_id": "9df8347307674ba8afb0250e23109aa1",
74
+ "version_major": 2,
75
+ "version_minor": 0
76
+ },
77
+ "text/plain": [
78
+ "Loading pipeline components...: 0%| | 0/7 [00:00<?, ?it/s]"
79
+ ]
80
+ },
81
+ "metadata": {},
82
+ "output_type": "display_data"
83
+ }
84
+ ],
85
+ "source": [
86
+ "pipe = DiffusionPipeline.from_pretrained(\"rhfeiyang/art-free-diffusion-v1\",).to(\"cuda\")\n",
87
+ "device = \"cuda\""
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": 5,
93
+ "id": "83916bc68ff5d914",
94
+ "metadata": {
95
+ "ExecuteTime": {
96
+ "end_time": "2024-12-09T09:44:52.694399Z",
97
+ "start_time": "2024-12-09T09:44:44.210695Z"
98
+ }
99
+ },
100
+ "outputs": [],
101
+ "source": [
102
+ "from inference import get_lora_network, inference, get_validation_dataloader\n",
103
+ "lora_map = {\n",
104
+ " \"None\": \"None\",\n",
105
+ " \"Andre Derain\": \"andre-derain_subset1\",\n",
106
+ " \"Vincent van Gogh\": \"van_gogh_subset1\",\n",
107
+ " \"Andy Warhol\": \"andy_subset1\",\n",
108
+ " \"Walter Battiss\": \"walter-battiss_subset2\",\n",
109
+ " \"Camille Corot\": \"camille-corot_subset1\",\n",
110
+ " \"Claude Monet\": \"monet_subset2\",\n",
111
+ " \"Pablo Picasso\": \"picasso_subset1\",\n",
112
+ " \"Jackson Pollock\": \"jackson-pollock_subset1\",\n",
113
+ " \"Gerhard Richter\": \"gerhard-richter_subset1\",\n",
114
+ " \"M.C. Escher\": \"m.c.-escher_subset1\",\n",
115
+ " \"Albert Gleizes\": \"albert-gleizes_subset1\",\n",
116
+ " \"Hokusai\": \"katsushika-hokusai_subset1\",\n",
117
+ " \"Wassily Kandinsky\": \"kandinsky_subset1\",\n",
118
+ " \"Gustav Klimt\": \"klimt_subset3\",\n",
119
+ " \"Roy Lichtenstein\": \"roy-lichtenstein_subset1\",\n",
120
+ " \"Henri Matisse\": \"henri-matisse_subset1\",\n",
121
+ " \"Joan Miro\": \"joan-miro_subset2\",\n",
122
+ "}\n",
123
+ "\n",
124
+ "def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0, steps=50, guidance_scale=7.5):\n",
125
+ " adapter_path = lora_map[adapter_choice]\n",
126
+ " if adapter_path not in [None, \"None\"]:\n",
127
+ " adapter_path = f\"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt\"\n",
128
+ "\n",
129
+ " prompts = [prompt]*samples\n",
130
+ " infer_loader = get_validation_dataloader(prompts)\n",
131
+ " network = get_lora_network(pipe.unet, adapter_path)[\"network\"]\n",
132
+ " pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n",
133
+ " height=512, width=512, scales=[1.0],\n",
134
+ " save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,\n",
135
+ " start_noise=-1, show=False, style_prompt=\"sks art\", no_load=True,\n",
136
+ " from_scratch=True)[0][1.0]\n",
137
+ " return pred_images\n",
138
+ "\n",
139
+ "def demo_inference_stylization(adapter_path:str, prompts:list, image:list, start_noise=800,seed:int=0):\n",
140
+ " infer_loader = get_validation_dataloader(prompts, image)\n",
141
+ " network = get_lora_network(pipe.unet, adapter_path,\"all_up\")[\"network\"]\n",
142
+ " pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n",
143
+ " height=512, width=512, scales=[0.,1.],\n",
144
+ " save_dir=None, seed=seed,steps=20, guidance_scale=7.5,\n",
145
+ " start_noise=start_noise, show=True, style_prompt=\"sks art\", no_load=True,\n",
146
+ " from_scratch=False)\n",
147
+ " return pred_images\n",
148
+ "\n",
149
+ "# def infer(prompt, samples, steps, scale, seed):\n",
150
+ "# generator = torch.Generator(device=device).manual_seed(seed)\n",
151
+ "# images_list = pipe( # type: ignore\n",
152
+ "# [prompt] * samples,\n",
153
+ "# num_inference_steps=steps,\n",
154
+ "# guidance_scale=scale,\n",
155
+ "# generator=generator,\n",
156
+ "# )\n",
157
+ "# images = []\n",
158
+ "# safe_image = Image.open(r\"data/unsafe.png\")\n",
159
+ "# print(images_list)\n",
160
+ "# for i, image in enumerate(images_list[\"images\"]): # type: ignore\n",
161
+ "# if images_list[\"nsfw_content_detected\"][i]: # type: ignore\n",
162
+ "# images.append(safe_image)\n",
163
+ "# else:\n",
164
+ "# images.append(image)\n",
165
+ "# return images\n"
166
+ ]
167
+ },
168
+ {
169
+ "cell_type": "code",
170
+ "execution_count": 6,
171
+ "id": "aa33e9d104023847",
172
+ "metadata": {
173
+ "ExecuteTime": {
174
+ "end_time": "2024-12-09T12:09:39.339583Z",
175
+ "start_time": "2024-12-09T12:09:38.953936Z"
176
+ }
177
+ },
178
+ "outputs": [
179
+ {
180
+ "name": "stdout",
181
+ "output_type": "stream",
182
+ "text": [
183
+ "<gradio.components.slider.Slider object at 0x7fa12d3a5280>\n",
184
+ "Running on local URL: http://127.0.0.1:7876\n",
185
+ "Running on public URL: https://be7cce8fec75395c82.gradio.live\n",
186
+ "\n",
187
+ "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
188
+ ]
189
+ },
190
+ {
191
+ "data": {
192
+ "text/html": [
193
+ "<div><iframe src=\"https://be7cce8fec75395c82.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
194
+ ],
195
+ "text/plain": [
196
+ "<IPython.core.display.HTML object>"
197
+ ]
198
+ },
199
+ "metadata": {},
200
+ "output_type": "display_data"
201
+ },
202
+ {
203
+ "data": {
204
+ "text/plain": []
205
+ },
206
+ "execution_count": 6,
207
+ "metadata": {},
208
+ "output_type": "execute_result"
209
+ },
210
+ {
211
+ "name": "stdout",
212
+ "output_type": "stream",
213
+ "text": [
214
+ "Train method: None\n",
215
+ "Rank: 1, Alpha: 1\n",
216
+ "create LoRA for U-Net: 0 modules.\n",
217
+ "save dir: None\n",
218
+ "['Park with cherry blossom trees, picnicker’s and a clear blue pond in the style of sks art'], seed=949192390\n"
219
+ ]
220
+ },
221
+ {
222
+ "name": "stderr",
223
+ "output_type": "stream",
224
+ "text": [
225
+ "/data/vision/torralba/selfmanaged/torralba/scratch/jomat/sam_dataset/miniforge3/envs/diffusion/lib/python3.9/site-packages/torch/nn/modules/conv.py:456: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at /opt/conda/conda-bld/pytorch_1712608883701/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)\n",
226
+ " return F.conv2d(input, weight, bias, self.stride,\n",
227
+ "\n",
228
+ "00%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [00:03<00:00, 6.90it/s]"
229
+ ]
230
+ },
231
+ {
232
+ "name": "stdout",
233
+ "output_type": "stream",
234
+ "text": [
235
+ "Time taken for one batch, Art Adapter scale=1.0: 3.2747044563293457\n"
236
+ ]
237
+ }
238
+ ],
239
+ "source": [
240
+ "block = gr.Blocks()\n",
241
+ "# Direct infer\n",
242
+ "with block:\n",
243
+ " with gr.Group():\n",
244
+ " with gr.Row():\n",
245
+ " text = gr.Textbox(\n",
246
+ " label=\"Enter your prompt\",\n",
247
+ " max_lines=2,\n",
248
+ " placeholder=\"Enter your prompt\",\n",
249
+ " container=False,\n",
250
+ " value=\"Park with cherry blossom trees, picnicker’s and a clear blue pond.\",\n",
251
+ " )\n",
252
+ " \n",
253
+ "\n",
254
+ " \n",
255
+ " btn = gr.Button(\"Run\", scale=0)\n",
256
+ " gallery = gr.Gallery(\n",
257
+ " label=\"Generated images\",\n",
258
+ " show_label=False,\n",
259
+ " elem_id=\"gallery\",\n",
260
+ " columns=[2],\n",
261
+ " )\n",
262
+ "\n",
263
+ " advanced_button = gr.Button(\"Advanced options\", elem_id=\"advanced-btn\")\n",
264
+ "\n",
265
+ " with gr.Row(elem_id=\"advanced-options\"):\n",
266
+ " adapter_choice = gr.Dropdown(\n",
267
+ " label=\"Choose adapter\",\n",
268
+ " choices=[\"None\", \"Andre Derain\",\"Vincent van Gogh\",\"Andy Warhol\", \"Walter Battiss\",\n",
269
+ " \"Camille Corot\", \"Claude Monet\", \"Pablo Picasso\",\n",
270
+ " \"Jackson Pollock\", \"Gerhard Richter\", \"M.C. Escher\",\n",
271
+ " \"Albert Gleizes\", \"Hokusai\", \"Wassily Kandinsky\", \"Gustav Klimt\", \"Roy Lichtenstein\",\n",
272
+ " \"Henri Matisse\", \"Joan Miro\"\n",
273
+ " ],\n",
274
+ " value=\"None\"\n",
275
+ " )\n",
276
+ " # print(adapter_choice[0])\n",
277
+ " # lora_path = lora_map[adapter_choice.value]\n",
278
+ " # if lora_path is not None:\n",
279
+ " # lora_path = f\"data/Art_adapters/{lora_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt\"\n",
280
+ "\n",
281
+ " samples = gr.Slider(label=\"Images\", minimum=1, maximum=4, value=1, step=1)\n",
282
+ " steps = gr.Slider(label=\"Steps\", minimum=1, maximum=50, value=20, step=1)\n",
283
+ " scale = gr.Slider(\n",
284
+ " label=\"Guidance Scale\", minimum=0, maximum=50, value=7.5, step=0.1\n",
285
+ " )\n",
286
+ " print(scale)\n",
287
+ " seed = gr.Slider(\n",
288
+ " label=\"Seed\",\n",
289
+ " minimum=0,\n",
290
+ " maximum=2147483647,\n",
291
+ " step=1,\n",
292
+ " randomize=True,\n",
293
+ " )\n",
294
+ "\n",
295
+ " gr.on([text.submit, btn.click], demo_inference_gen, inputs=[adapter_choice, text, samples, seed, steps, scale], outputs=gallery)\n",
296
+ " advanced_button.click(\n",
297
+ " None,\n",
298
+ " [],\n",
299
+ " text,\n",
300
+ " )\n",
301
+ "\n",
302
+ "\n",
303
+ "block.launch(share=True)"
304
+ ]
305
+ },
306
+ {
307
+ "cell_type": "code",
308
+ "execution_count": null,
309
+ "id": "3239c12167a5f2cd",
310
+ "metadata": {},
311
+ "outputs": [],
312
+ "source": []
313
+ }
314
+ ],
315
+ "metadata": {
316
+ "kernelspec": {
317
+ "display_name": "Python 3 (ipykernel)",
318
+ "language": "python",
319
+ "name": "python3"
320
+ },
321
+ "language_info": {
322
+ "codemirror_mode": {
323
+ "name": "ipython",
324
+ "version": 3
325
+ },
326
+ "file_extension": ".py",
327
+ "mimetype": "text/x-python",
328
+ "name": "python",
329
+ "nbconvert_exporter": "python",
330
+ "pygments_lexer": "ipython3",
331
+ "version": "3.9.18"
332
+ }
333
+ },
334
+ "nbformat": 4,
335
+ "nbformat_minor": 5
336
+ }
inference.py ADDED
@@ -0,0 +1,657 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Authors: Hui Ren (rhfeiyang.github.io)
2
+ import torch
3
+ from PIL import Image
4
+ import argparse
5
+ import os, json, random
6
+
7
+ import matplotlib.pyplot as plt
8
+ import glob, re
9
+
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+
13
+ import sys
14
+ import gc
15
+ from transformers import CLIPTextModel, CLIPTokenizer, BertModel, BertTokenizer
16
+
17
+ # import train_util
18
+
19
+ from utils.train_util import get_noisy_image, encode_prompts
20
+
21
+ from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler, DDIMScheduler, PNDMScheduler
22
+
23
+ from typing import Any, Dict, List, Optional, Tuple, Union
24
+ from utils.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
25
+ import argparse
26
+ # from diffusers.training_utils import EMAModel
27
+ import shutil
28
+ import yaml
29
+ from easydict import EasyDict
30
+ from utils.metrics import StyleContentMetric
31
+ from torchvision import transforms
32
+
33
+ from custom_datasets.coco import CustomCocoCaptions
34
+ from custom_datasets.imagepair import ImageSet
35
+ from custom_datasets import get_dataset
36
+ # from stable_diffusion.utils.modules import get_diffusion_modules
37
+ # from diffusers import StableDiffusionImg2ImgPipeline
38
+ from diffusers.utils.torch_utils import randn_tensor
39
+ import pickle
40
+ import time
41
+ def flush():
42
+ torch.cuda.empty_cache()
43
+ gc.collect()
44
+
45
+ def get_train_method(lora_weight):
46
+ if lora_weight is None:
47
+ return 'None'
48
+ if 'full' in lora_weight:
49
+ train_method = 'full'
50
+ elif "down_1_up_2_attn" in lora_weight:
51
+ train_method = 'up_2_attn'
52
+ print(f"Using up_2_attn for {lora_weight}")
53
+ elif "down_2_up_1_up_2_attn" in lora_weight:
54
+ train_method = 'down_2_up_2_attn'
55
+ elif "down_2_up_2_attn" in lora_weight:
56
+ train_method = 'down_2_up_2_attn'
57
+ elif "down_2_attn" in lora_weight:
58
+ train_method = 'down_2_attn'
59
+ elif 'noxattn' in lora_weight:
60
+ train_method = 'noxattn'
61
+ elif "xattn" in lora_weight:
62
+ train_method = 'xattn'
63
+ elif "attn" in lora_weight:
64
+ train_method = 'attn'
65
+ elif "all_up" in lora_weight:
66
+ train_method = 'all_up'
67
+ else:
68
+ train_method = 'None'
69
+ return train_method
70
+
71
+ def get_validation_dataloader(infer_prompts:list[str]=None, infer_images :list[str]=None,resolution=512, batch_size=10, num_workers=4, val_set="laion_pop500"):
72
+ data_transforms = transforms.Compose(
73
+ [
74
+ transforms.Resize(resolution),
75
+ transforms.CenterCrop(resolution),
76
+ ]
77
+ )
78
+ def preprocess(example):
79
+ ret={}
80
+ ret["image"] = data_transforms(example["image"]) if "image" in example else None
81
+ if "caption" in example:
82
+ if isinstance(example["caption"][0], list):
83
+ ret["caption"] = example["caption"][0][0]
84
+ else:
85
+ ret["caption"] = example["caption"][0]
86
+ if "seed" in example:
87
+ ret["seed"] = example["seed"]
88
+ if "id" in example:
89
+ ret["id"] = example["id"]
90
+ if "path" in example:
91
+ ret["path"] = example["path"]
92
+ return ret
93
+
94
+ def collate_fn(examples):
95
+ out = {}
96
+ if "image" in examples[0]:
97
+ pixel_values = [example["image"] for example in examples]
98
+ out["pixel_values"] = pixel_values
99
+ # notice: only take the first prompt for each image
100
+ if "caption" in examples[0]:
101
+ prompts = [example["caption"] for example in examples]
102
+ out["prompts"] = prompts
103
+ if "seed" in examples[0]:
104
+ seeds = [example["seed"] for example in examples]
105
+ out["seed"] = seeds
106
+ if "path" in examples[0]:
107
+ paths = [example["path"] for example in examples]
108
+ out["path"] = paths
109
+ return out
110
+ if infer_prompts is None:
111
+ if val_set == "lhq500":
112
+ dataset = get_dataset("lhq_sub500", get_val=False)["train"]
113
+ elif val_set == "custom_coco100":
114
+ dataset = get_dataset("custom_coco100", get_val=False)["train"]
115
+ elif val_set == "custom_coco500":
116
+ dataset = get_dataset("custom_coco500", get_val=False)["train"]
117
+
118
+ elif os.path.isdir(val_set):
119
+ image_folder = os.path.join(val_set, "paintings")
120
+ caption_folder = os.path.join(val_set, "captions")
121
+ dataset = ImageSet(folder=image_folder, caption=caption_folder, keep_in_mem=True)
122
+ elif "custom_caption" in val_set:
123
+ from custom_datasets.custom_caption import Caption_set
124
+ name = val_set.replace("custom_caption_", "")
125
+ dataset = Caption_set(set_name = name)
126
+ elif val_set == "laion_pop500":
127
+ dataset = get_dataset("laion_pop500", get_val=False)["train"]
128
+ elif val_set == "laion_pop500_first_sentence":
129
+ dataset = get_dataset("laion_pop500_first_sentence", get_val=False)["train"]
130
+ else:
131
+ raise ValueError("Unknown dataset")
132
+ dataset.with_transform(preprocess)
133
+ elif isinstance(infer_prompts, torch.utils.data.Dataset):
134
+ dataset = infer_prompts
135
+ try:
136
+ dataset.with_transform(preprocess)
137
+ except:
138
+ pass
139
+
140
+ else:
141
+ class Dataset(torch.utils.data.Dataset):
142
+ def __init__(self, prompts, images=None):
143
+ self.prompts = prompts
144
+ self.images = images
145
+ self.get_img = False
146
+ if images is not None:
147
+ assert len(prompts) == len(images)
148
+ self.get_img = True
149
+ if isinstance(images[0], str):
150
+ self.images = [Image.open(image).convert("RGB") for image in images]
151
+ else:
152
+ self.images = [None] * len(prompts)
153
+ def __len__(self):
154
+ return len(self.prompts)
155
+ def __getitem__(self, idx):
156
+ img = self.images[idx]
157
+ if self.get_img and img is not None:
158
+ img = data_transforms(img)
159
+ return {"caption": self.prompts[idx], "image":img}
160
+ dataset = Dataset(infer_prompts, infer_images)
161
+
162
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, drop_last=False,
163
+ num_workers=num_workers, pin_memory=True)
164
+ return dataloader
165
+
166
+ def get_lora_network(unet , lora_path, train_method="None", rank=1, alpha=1.0, device="cuda", weight_dtype=torch.float32):
167
+ if train_method in [None, "None"]:
168
+ train_method = get_train_method(lora_path)
169
+ print(f"Train method: {train_method}")
170
+
171
+ network_type = "c3lier"
172
+ if train_method == 'xattn':
173
+ network_type = 'lierla'
174
+
175
+ modules = DEFAULT_TARGET_REPLACE
176
+ if network_type == "c3lier":
177
+ modules += UNET_TARGET_REPLACE_MODULE_CONV
178
+
179
+ alpha = 1
180
+ if "rank" in lora_path:
181
+ rank = int(re.search(r'rank(\d+)', lora_path).group(1))
182
+ if 'alpha1' in lora_path:
183
+ alpha = 1.0
184
+ print(f"Rank: {rank}, Alpha: {alpha}")
185
+
186
+ network = LoRANetwork(
187
+ unet,
188
+ rank=rank,
189
+ multiplier=1.0,
190
+ alpha=alpha,
191
+ train_method=train_method,
192
+ ).to(device, dtype=weight_dtype)
193
+ if lora_path not in [None, "None"]:
194
+ lora_state_dict = torch.load(lora_path)
195
+ miss = network.load_state_dict(lora_state_dict, strict=False)
196
+ print(f"Missing: {miss}")
197
+ ret = {"network": network, "train_method": train_method}
198
+ return ret
199
+
200
+ def get_model(pretrained_ckpt_path, unet_ckpt=None,revision=None, variant=None, lora_path=None, weight_dtype=torch.float32,
201
+ device="cuda"):
202
+ modules = {}
203
+ pipe = DiffusionPipeline.from_pretrained(pretrained_ckpt_path, revision=revision, variant=variant)
204
+ if unet_ckpt is not None:
205
+ pipe.unet.from_pretrained(unet_ckpt, subfolder="unet_ema", revision=revision, variant=variant)
206
+ unet = pipe.unet
207
+ vae = pipe.vae
208
+ text_encoder = pipe.text_encoder
209
+ tokenizer = pipe.tokenizer
210
+ modules["unet"] = unet
211
+ modules["vae"] = vae
212
+ modules["text_encoder"] = text_encoder
213
+ modules["tokenizer"] = tokenizer
214
+ # tokenizer = modules["tokenizer"]
215
+
216
+ unet.enable_xformers_memory_efficient_attention()
217
+ unet.to(device, dtype=weight_dtype)
218
+ if weight_dtype != torch.bfloat16:
219
+ vae.to(device, dtype=torch.float32)
220
+ else:
221
+ vae.to(device, dtype=weight_dtype)
222
+ text_encoder.to(device, dtype=weight_dtype)
223
+
224
+ if lora_path is not None:
225
+ network = get_lora_network(unet, lora_path, device=device, weight_dtype=weight_dtype)
226
+ modules["network"] = network
227
+ return modules
228
+
229
+
230
+
231
+ @torch.no_grad()
232
+ def inference(network: LoRANetwork, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, vae: AutoencoderKL, unet: UNet2DConditionModel, noise_scheduler: LMSDiscreteScheduler,
233
+ dataloader, height:int, width:int, scales:list = np.linspace(0,2,5),save_dir:str=None, seed:int = None,
234
+ weight_dtype: torch.dtype = torch.float32, device: torch.device="cuda", batch_size:int=1, steps:int=50, guidance_scale:float=7.5, start_noise:int=800,
235
+ uncond_prompt:str=None, uncond_embed=None, style_prompt = None, show:bool = False, no_load:bool=False, from_scratch=False):
236
+ print(f"save dir: {save_dir}")
237
+ if start_noise < 0:
238
+ assert from_scratch
239
+ network = network.eval()
240
+ unet = unet.eval()
241
+ vae = vae.eval()
242
+ do_convert = not from_scratch
243
+
244
+ if not do_convert:
245
+ try:
246
+ dataloader.dataset.get_img = False
247
+ except:
248
+ pass
249
+ scales = list(scales)
250
+ else:
251
+ scales = ["Real Image"] + list(scales)
252
+
253
+ if not no_load and os.path.exists(os.path.join(save_dir, "infer_imgs.pickle")):
254
+ with open(os.path.join(save_dir, "infer_imgs.pickle"), 'rb') as f:
255
+ pred_images = pickle.load(f)
256
+ take=True
257
+ for key in scales:
258
+ if key not in pred_images:
259
+ take=False
260
+ break
261
+ if take:
262
+ print(f"Found existing inference results in {save_dir}", flush=True)
263
+ return pred_images
264
+
265
+ max_length = tokenizer.model_max_length
266
+
267
+ pred_images = {scale :[] for scale in scales}
268
+ all_seeds = {scale:[] for scale in scales}
269
+
270
+ prompts = []
271
+ ori_prompts = []
272
+ if save_dir is not None:
273
+ img_output_dir = os.path.join(save_dir, "outputs")
274
+ os.makedirs(img_output_dir, exist_ok=True)
275
+
276
+ if uncond_embed is None:
277
+ if uncond_prompt is None:
278
+ uncond_input_text = [""]
279
+ else:
280
+ uncond_input_text = [uncond_prompt]
281
+ uncond_embed = encode_prompts(tokenizer = tokenizer, text_encoder = text_encoder, prompts = uncond_input_text)
282
+
283
+
284
+ for batch in dataloader:
285
+ ori_prompt = batch["prompts"]
286
+ image = batch["pixel_values"] if do_convert else None
287
+ if do_convert:
288
+ pred_images["Real Image"] += image
289
+ if isinstance(ori_prompt, list):
290
+ if isinstance(text_encoder, CLIPTextModel):
291
+ # trunc prompts for clip encoder
292
+ ori_prompt = [p.split(".")[0]+"." for p in ori_prompt]
293
+ prompt = [f"{p.strip()[::-1].replace('.', '',1)[::-1]} in the style of {style_prompt}" for p in ori_prompt] if style_prompt is not None else ori_prompt
294
+ else:
295
+ if isinstance(text_encoder, CLIPTextModel):
296
+ ori_prompt = ori_prompt.split(".")[0]+"."
297
+ prompt = f"{prompt.strip()[::-1].replace('.', '',1)[::-1]} in the style of {style_prompt}" if style_prompt is not None else ori_prompt
298
+
299
+ bcz = len(prompt)
300
+ single_seed = seed
301
+ if dataloader.batch_size == 1 and seed is None:
302
+ if "seed" in batch:
303
+ single_seed = batch["seed"][0]
304
+
305
+ print(f"{prompt}, seed={single_seed}")
306
+
307
+ # text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt").to(device)
308
+ # original_embeddings = text_encoder(**text_input)[0]
309
+
310
+ prompts += prompt
311
+ ori_prompts += ori_prompt
312
+ # style_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt").to(device)
313
+ # # style_embeddings = text_encoder(**style_input)[0]
314
+ # style_embeddings = text_encoder(style_input.input_ids, return_dict=False)[0]
315
+
316
+ style_embeddings = encode_prompts(tokenizer = tokenizer, text_encoder = text_encoder, prompts = prompt)
317
+ original_embeddings = encode_prompts(tokenizer = tokenizer, text_encoder = text_encoder, prompts = ori_prompt)
318
+ if uncond_embed.shape[0] == 1 and bcz > 1:
319
+ uncond_embeddings = uncond_embed.repeat(bcz, 1, 1)
320
+ else:
321
+ uncond_embeddings = uncond_embed
322
+ style_text_embeddings = torch.cat([uncond_embeddings, style_embeddings])
323
+ original_embeddings = torch.cat([uncond_embeddings, original_embeddings])
324
+
325
+ generator = torch.manual_seed(single_seed) if single_seed is not None else None
326
+ noise_scheduler.set_timesteps(steps)
327
+ if do_convert:
328
+ noised_latent, _, _ = get_noisy_image(image, vae, generator, unet, noise_scheduler, total_timesteps=int((1000-start_noise)/1000 *steps))
329
+ else:
330
+ latent_shape = (bcz, 4, height//8, width//8)
331
+ noised_latent = randn_tensor(latent_shape, generator=generator, device=vae.device)
332
+ noised_latent = noised_latent.to(unet.dtype)
333
+ noised_latent = noised_latent * noise_scheduler.init_noise_sigma
334
+ for scale in scales:
335
+ start_time = time.time()
336
+ if not isinstance(scale, float) and not isinstance(scale, int):
337
+ continue
338
+
339
+ latents = noised_latent.clone().to(weight_dtype).to(device)
340
+ noise_scheduler.set_timesteps(steps)
341
+ for t in tqdm(noise_scheduler.timesteps):
342
+ if do_convert and t>start_noise:
343
+ continue
344
+ else:
345
+ if t > start_noise and start_noise >= 0:
346
+ current_scale = 0
347
+ else:
348
+ current_scale = scale
349
+ network.set_lora_slider(scale=current_scale)
350
+ text_embedding = style_text_embeddings
351
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
352
+ latent_model_input = torch.cat([latents] * 2)
353
+
354
+ latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
355
+ # predict the noise residual
356
+ with network:
357
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embedding).sample
358
+
359
+ # perform guidance
360
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
361
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
362
+
363
+ # compute the previous noisy sample x_t -> x_t-1
364
+ if isinstance(noise_scheduler, DDPMScheduler):
365
+ latents = noise_scheduler.step(noise_pred, t, latents, generator=torch.manual_seed(single_seed+t) if single_seed is not None else None).prev_sample
366
+ else:
367
+ latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
368
+
369
+ # scale and decode the image latents with vae
370
+ latents = 1 / 0.18215 * latents.to(vae.dtype)
371
+
372
+
373
+ with torch.no_grad():
374
+ image = vae.decode(latents).sample
375
+ image = (image / 2 + 0.5).clamp(0, 1)
376
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
377
+ images = (image * 255).round().astype("uint8")
378
+
379
+
380
+ pil_images = [Image.fromarray(image) for image in images]
381
+ pred_images[scale]+=pil_images
382
+ all_seeds[scale] += [single_seed] * bcz
383
+
384
+ end_time = time.time()
385
+ print(f"Time taken for one batch, Art Adapter scale={scale}: {end_time-start_time}", flush=True)
386
+
387
+ if save_dir is not None or show:
388
+ end_idx = len(list(pred_images.values())[0])
389
+ for i in range(end_idx-bcz, end_idx):
390
+ keys = list(pred_images.keys())
391
+ images_list = [pred_images[key][i] for key in keys]
392
+ prompt = prompts[i]
393
+ if len(scales)==1:
394
+ plt.imshow(images_list[0])
395
+ plt.axis('off')
396
+ plt.title(f"{prompt}_{single_seed}_start{start_noise}", fontsize=20)
397
+ else:
398
+ fig, ax = plt.subplots(1, len(images_list), figsize=(len(scales)*5,6), layout="constrained")
399
+ for id, a in enumerate(ax):
400
+ a.imshow(images_list[id])
401
+ if isinstance(scales[id], float) or isinstance(scales[id], int):
402
+ a.set_title(f"Art Adapter scale={scales[id]}", fontsize=20)
403
+ else:
404
+ a.set_title(f"{keys[id]}", fontsize=20)
405
+ a.axis('off')
406
+
407
+ # plt.suptitle(f"{os.path.basename(lora_weight).replace('.pt','')}", fontsize=20)
408
+
409
+ # plt.tight_layout()
410
+ # if do_convert:
411
+ # plt.suptitle(f"{prompt}\nseed{single_seed}_start{start_noise}_guidance{guidance_scale}", fontsize=20)
412
+ # else:
413
+ # plt.suptitle(f"{prompt}\nseed{single_seed}_from_scratch_guidance{guidance_scale}", fontsize=20)
414
+
415
+ if save_dir is not None:
416
+ plt.savefig(f"{img_output_dir}/{prompt.replace(' ', '_')[:100]}_seed{single_seed}_start{start_noise}.png")
417
+ if show:
418
+ plt.show()
419
+ plt.close()
420
+
421
+ flush()
422
+
423
+ if save_dir is not None:
424
+ with open(os.path.join(save_dir, "infer_imgs.pickle" ), 'wb') as f:
425
+ pickle.dump(pred_images, f)
426
+ with open(os.path.join(save_dir, "all_seeds.pickle"), 'wb') as f:
427
+ to_save={"all_seeds":all_seeds, "batch_size":batch_size}
428
+ pickle.dump(to_save, f)
429
+ for scale, images in pred_images.items():
430
+ subfolder = os.path.join(save_dir,"images", f"{scale}")
431
+ os.makedirs(subfolder, exist_ok=True)
432
+
433
+ used_prompt = ori_prompts
434
+ if (isinstance(scale, float) or isinstance(scale, int)): #and scale != 0:
435
+ used_prompt = prompts
436
+ for i, image in enumerate(images):
437
+ if scale == "Real Image":
438
+ suffix = ""
439
+ else:
440
+ suffix = f"_seed{all_seeds[scale][i]}"
441
+ image.save(os.path.join(subfolder, f"{used_prompt[i].replace(' ', '_')[:100]}{suffix}.jpg"))
442
+ with open(os.path.join(save_dir, "infer_prompts.txt"), 'w') as f:
443
+ for prompt in prompts:
444
+ f.write(f"{prompt}\n")
445
+ with open(os.path.join(save_dir, "ori_prompts.txt"), 'w') as f:
446
+ for prompt in ori_prompts:
447
+ f.write(f"{prompt}\n")
448
+ print(f"Saved inference results to {save_dir}", flush=True)
449
+ return pred_images, prompts
450
+
451
+ @torch.no_grad()
452
+ def infer_metric(ref_image_folder,pred_images, prompts, save_dir, start_noise=""):
453
+ prompts = [prompt.split(" in the style of ")[0] for prompt in prompts]
454
+ scores = {}
455
+ original_images = pred_images["Real Image"] if "Real Image" in pred_images else None
456
+ metric = StyleContentMetric(ref_image_folder)
457
+ for scale, images in pred_images.items():
458
+ score = metric(images, original_images, prompts)
459
+
460
+ scores[scale] = score
461
+ print(f"Style transfer score at scale {scale}: {score}")
462
+ scores["ref_path"] = ref_image_folder
463
+ save_name = f"scores_start{start_noise}.json"
464
+ os.makedirs(save_dir, exist_ok=True)
465
+ with open(os.path.join(save_dir, save_name), 'w') as f:
466
+ json.dump(scores, f, indent=2)
467
+ return scores
468
+
469
+ def parse_args():
470
+ parser = argparse.ArgumentParser(description='Inference with LoRA')
471
+ parser.add_argument('--lora_weights', type=str, default=["None"],
472
+ nargs='+', help='path to your model file')
473
+ parser.add_argument('--prompts', type=str, default=[],
474
+ nargs='+', help='prompts to try')
475
+ parser.add_argument("--prompt_file", type=str, default=None, help="path to the prompt file")
476
+ parser.add_argument("--prompt_file_key", type=str, default="prompts", help="key to the prompt file")
477
+ parser.add_argument('--resolution', type=int, default=512, help='resolution of the image')
478
+ parser.add_argument('--seed', type=int, default=None, help='seed for the random number generator')
479
+ parser.add_argument("--start_noise", type=int, default=800, help="start noise")
480
+ parser.add_argument("--from_scratch", default=False, action="store_true", help="from scratch")
481
+ parser.add_argument("--ref_image_folder", type=str, default=None, help="folder containing reference images")
482
+ parser.add_argument("--show", action="store_true", help="show the image")
483
+ parser.add_argument("--batch_size", type=int, default=1, help="batch size")
484
+ parser.add_argument("--scales", type=float, default=[0.,1.], nargs='+', help="scales to test")
485
+ parser.add_argument("--train_method", type=str, default=None, help="train method")
486
+
487
+ # parser.add_argument("--vae_path", type=str, default="CompVis/stable-diffusion-v1-4", help="Path to the VAE model.")
488
+ # parser.add_argument("--text_encoder_path", type=str, default="CompVis/stable-diffusion-v1-4", help="Path to the text encoder model.")
489
+ parser.add_argument("--pretrained_model_name_or_path", type=str, default="rhfeiyang/art-free-diffusion-v1", help="Path to the pretrained model.")
490
+ parser.add_argument("--unet_ckpt", default=None, type=str, help="Path to the unet checkpoint")
491
+ parser.add_argument("--guidance_scale", type=float, default=5.0, help="guidance scale")
492
+ parser.add_argument("--infer_mode", default="sks_art", help="inference mode") #, choices=["style", "ori", "artist", "sks_art","Peter"]
493
+ parser.add_argument("--save_dir", type=str, default="inference_output", help="save directory")
494
+ parser.add_argument("--num_workers", type=int, default=4, help="number of workers")
495
+ parser.add_argument("--no_load", action="store_true", help="no load the pre-inferred results")
496
+ parser.add_argument("--infer_prompts", type=str, default=None, nargs="+", help="prompts to infer")
497
+ parser.add_argument("--infer_images", type=str, default=None, nargs="+", help="images to infer")
498
+ parser.add_argument("--rank", type=int, default=1, help="rank of the lora")
499
+ parser.add_argument("--val_set", type=str, default="laion_pop500", help="validation set")
500
+ parser.add_argument("--folder_name", type=str, default=None, help="folder name")
501
+ parser.add_argument("--scheduler_type",type=str, choices=["ddpm", "ddim", "pndm","lms"], default="ddpm", help="scheduler type")
502
+ parser.add_argument("--infer_steps", type=int, default=50, help="inference steps")
503
+ parser.add_argument("--weight_dtype", type=str, default="fp32", help="weight dtype")
504
+ parser.add_argument("--custom_coco_cap", action="store_true", help="use custom coco caption")
505
+ args = parser.parse_args()
506
+ if args.infer_prompts is not None and len(args.infer_prompts) == 1 and os.path.isfile(args.infer_prompts[0]):
507
+ if args.infer_prompts[0].endswith(".txt") and args.custom_coco_cap:
508
+ args.infer_prompts = CustomCocoCaptions(custom_file=args.infer_prompts[0])
509
+ elif args.infer_prompts[0].endswith(".txt"):
510
+ with open(args.infer_prompts[0], 'r') as f:
511
+ args.infer_prompts = f.readlines()
512
+ args.infer_prompts = [prompt.strip() for prompt in args.infer_prompts]
513
+ elif args.infer_prompts[0].endswith(".csv"):
514
+ from custom_datasets.custom_caption import Caption_set
515
+ caption_set = Caption_set(args.infer_prompts[0])
516
+ args.infer_prompts = caption_set
517
+
518
+
519
+ if args.infer_mode == "style":
520
+ with open(os.path.join(args.ref_image_folder, "style_label.txt"), 'r') as f:
521
+ args.style_label = f.readlines()[0].strip()
522
+ elif args.infer_mode == "artist":
523
+ with open(os.path.join(args.ref_image_folder, "style_label.txt"), 'r') as f:
524
+ args.style_label = f.readlines()[0].strip()
525
+ args.style_label = args.style_label.split(",")[0].strip()
526
+ elif args.infer_mode == "ori":
527
+ args.style_label = None
528
+ else:
529
+ args.style_label = args.infer_mode.replace("_", " ")
530
+ if args.ref_image_folder is not None:
531
+ args.ref_image_folder = os.path.join(args.ref_image_folder, "paintings")
532
+
533
+ if args.start_noise < 0:
534
+ args.from_scratch = True
535
+
536
+
537
+ print(args.__dict__)
538
+ return args
539
+
540
+
541
+ def main(args):
542
+ lora_weights = args.lora_weights
543
+
544
+ if len(lora_weights) == 1 and isinstance(lora_weights[0], str) and os.path.isdir(lora_weights[0]):
545
+ lora_weights = glob.glob(os.path.join(lora_weights[0], "*.pt"))
546
+ lora_weights=sorted(lora_weights, reverse=True)
547
+
548
+ width = args.resolution
549
+ height = args.resolution
550
+ steps = args.infer_steps
551
+
552
+ revision = None
553
+ device = 'cuda'
554
+ rank = args.rank
555
+ if args.weight_dtype == "fp32":
556
+ weight_dtype = torch.float32
557
+ elif args.weight_dtype=="fp16":
558
+ weight_dtype = torch.float16
559
+ elif args.weight_dtype=="bf16":
560
+ weight_dtype = torch.bfloat16
561
+
562
+ modules = get_model(args.pretrained_model_name_or_path, unet_ckpt=args.unet_ckpt, revision=revision, variant=None, lora_path=None, weight_dtype=weight_dtype, device=device, )
563
+ if args.scheduler_type == "pndm":
564
+ noise_scheduler = PNDMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
565
+
566
+ elif args.scheduler_type == "ddpm":
567
+ noise_scheduler = DDPMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
568
+ elif args.scheduler_type == "ddim":
569
+ noise_scheduler = DDIMScheduler(
570
+ beta_start=0.00085,
571
+ beta_end=0.012,
572
+ beta_schedule="scaled_linear",
573
+ num_train_timesteps=1000,
574
+ clip_sample=False,
575
+ prediction_type="epsilon",
576
+ )
577
+ elif args.scheduler_type == "lms":
578
+ noise_scheduler = LMSDiscreteScheduler(beta_start=0.00085,
579
+ beta_end=0.012,
580
+ beta_schedule="scaled_linear",
581
+ num_train_timesteps=1000)
582
+ else:
583
+ raise ValueError("Unknown scheduler type")
584
+ cache=EasyDict()
585
+ cache.modules = modules
586
+
587
+ unet = modules["unet"]
588
+ vae = modules["vae"]
589
+ text_encoder = modules["text_encoder"]
590
+ tokenizer = modules["tokenizer"]
591
+
592
+ unet.requires_grad_(False)
593
+
594
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
595
+ vae.requires_grad_(False)
596
+ text_encoder.requires_grad_(False)
597
+
598
+ ## dataloader
599
+ dataloader = get_validation_dataloader(infer_prompts=args.infer_prompts, infer_images=args.infer_images,
600
+ resolution=args.resolution,
601
+ batch_size=args.batch_size, num_workers=args.num_workers,
602
+ val_set=args.val_set)
603
+
604
+
605
+ for lora_weight in lora_weights:
606
+ print(f"Testing {lora_weight}")
607
+ # for different seeds on same prompt
608
+ seed = args.seed
609
+
610
+ network_ret = get_lora_network(unet, lora_weight, train_method=args.train_method, rank=rank, alpha=1.0, device=device, weight_dtype=weight_dtype)
611
+ network = network_ret["network"]
612
+ train_method = network_ret["train_method"]
613
+ if args.save_dir is not None:
614
+ save_dir = args.save_dir
615
+ if args.style_label is not None:
616
+ save_dir = os.path.join(save_dir, f"{args.style_label.replace(' ', '_')}")
617
+ else:
618
+ save_dir = os.path.join(save_dir, f"ori/{args.start_noise}")
619
+ else:
620
+ if args.folder_name is not None:
621
+ folder_name = args.folder_name
622
+ else:
623
+ folder_name = "validation" if args.infer_prompts is None else "validation_prompts"
624
+ save_dir = os.path.join(os.path.dirname(lora_weight), f"{folder_name}/{train_method}", os.path.basename(lora_weight).replace('.pt','').split('_')[-1])
625
+ if args.infer_prompts is None:
626
+ save_dir = os.path.join(save_dir, f"{args.val_set}")
627
+
628
+ infer_config = f"{args.scheduler_type}{args.infer_steps}_{args.weight_dtype}_guidance{args.guidance_scale}"
629
+ save_dir = os.path.join(save_dir, infer_config)
630
+ os.makedirs(save_dir, exist_ok=True)
631
+ if args.from_scratch:
632
+ save_dir = os.path.join(save_dir, "from_scratch")
633
+ else:
634
+ save_dir = os.path.join(save_dir, "transfer")
635
+ save_dir = os.path.join(save_dir, f"start{args.start_noise}")
636
+ os.makedirs(save_dir, exist_ok=True)
637
+ with open(os.path.join(save_dir, "infer_args.yaml"), 'w') as f:
638
+ yaml.dump(vars(args), f)
639
+ # save code
640
+ code_dir = os.path.join(save_dir, "code")
641
+ os.makedirs(code_dir, exist_ok=True)
642
+ current_file = os.path.basename(__file__)
643
+ shutil.copy(__file__, os.path.join(code_dir, current_file))
644
+ with torch.no_grad():
645
+ pred_images, prompts = inference(network, tokenizer, text_encoder, vae, unet, noise_scheduler, dataloader, height, width,
646
+ args.scales, save_dir, seed, weight_dtype, device, args.batch_size, steps, guidance_scale=args.guidance_scale,
647
+ start_noise=args.start_noise, show=args.show, style_prompt=args.style_label, no_load=args.no_load,
648
+ from_scratch=args.from_scratch)
649
+
650
+ if args.ref_image_folder is not None:
651
+ flush()
652
+ print("Calculating metrics")
653
+ infer_metric(args.ref_image_folder, pred_images, save_dir, args.start_noise)
654
+
655
+ if __name__ == "__main__":
656
+ args = parse_args()
657
+ main(args)
utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Authors: Hui Ren (rhfeiyang.github.io)
utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (203 Bytes). View file
 
utils/__pycache__/lora.cpython-39.pyc ADDED
Binary file (6.29 kB). View file
 
utils/__pycache__/metrics.cpython-39.pyc ADDED
Binary file (19.3 kB). View file
 
utils/__pycache__/train_util.cpython-39.pyc ADDED
Binary file (10.9 kB). View file
 
utils/art_filter.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Authors: Hui Ren (rhfeiyang.github.io)
2
+
3
+ from transformers import CLIPProcessor, CLIPModel
4
+ import torch
5
+ import numpy as np
6
+ import os
7
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
8
+ from tqdm import tqdm
9
+
10
+ class Caption_filter:
11
+ def __init__(self, filter_prompts=["painting", "paintings", "art", "artwork", "drawings", "sketch", "sketches", "illustration", "illustrations",
12
+ "sculpture","sculptures", "installation", "printmaking", "digital art", "conceptual art", "mosaic", "tapestry",
13
+ "abstract", "realism", "surrealism", "impressionism", "expressionism", "cubism", "minimalism", "baroque", "rococo",
14
+ "pop art", "art nouveau", "art deco", "futurism", "dadaism",
15
+ "stamp", "stamps", "advertisement", "advertisements","logo", "logos"
16
+ ],):
17
+ self.filter_prompts = filter_prompts
18
+ self.total_count=0
19
+ self.filter_count=[0]*len(filter_prompts)
20
+
21
+ def reset(self):
22
+ self.total_count=0
23
+ self.filter_count=[0]*len(self.filter_prompts)
24
+ def filter(self, captions):
25
+ filter_result = []
26
+ for caption in captions:
27
+ words = caption[0]
28
+ if words == None:
29
+ filter_result.append((True, "None"))
30
+ continue
31
+ words = words.lower()
32
+ words = words.split()
33
+ filt = False
34
+ reason=None
35
+ for i, filter_keyword in enumerate(self.filter_prompts):
36
+ key_len = len(filter_keyword.split())
37
+ for j in range(len(words)-key_len+1):
38
+ if " ".join(words[j:j+key_len]) == filter_keyword:
39
+ self.filter_count[i] += 1
40
+ filt = True
41
+ reason = filter_keyword
42
+ break
43
+ if filt:
44
+ break
45
+ filter_result.append((filt, reason))
46
+ self.total_count += 1
47
+ return filter_result
48
+
49
+ class Clip_filter:
50
+ prompt_threshold = {
51
+ "painting": 17,
52
+ "art": 17.5,
53
+ "artwork": 19,
54
+ "drawing": 15.8,
55
+ "sketch": 17,
56
+ "illustration": 15,
57
+ "sculpture": 19.2,
58
+ "installation art": 20,
59
+ "printmaking art": 16.3,
60
+ "digital art": 15,
61
+ "conceptual art": 18,
62
+ "mosaic art": 19,
63
+ "tapestry": 16,
64
+ "abstract art":16.5,
65
+ "realism art": 16,
66
+ "surrealism art": 15,
67
+ "impressionism art": 17,
68
+ "expressionism art": 17,
69
+ "cubism art": 15,
70
+ "minimalism art": 16,
71
+ "baroque art": 17.5,
72
+ "rococo art": 17,
73
+ "pop art": 16,
74
+ "art nouveau": 19,
75
+ "art deco": 19,
76
+ "futurism art": 16.5,
77
+ "dadaism art": 16.5,
78
+ "stamp": 18,
79
+ "advertisement": 16.5,
80
+ "logo": 15.5,
81
+ }
82
+ @torch.no_grad()
83
+ def __init__(self, positive_prompt=["painting", "art", "artwork", "drawing", "sketch", "illustration",
84
+ "sculpture", "installation art", "printmaking art", "digital art", "conceptual art", "mosaic art", "tapestry",
85
+ "abstract art", "realism art", "surrealism art", "impressionism art", "expressionism art", "cubism art",
86
+ "minimalism art", "baroque art", "rococo art",
87
+ "pop art", "art nouveau", "art deco", "futurism art", "dadaism art",
88
+ "stamp", "advertisement",
89
+ "logo"
90
+ ],
91
+ device="cuda"):
92
+ self.device = device
93
+ self.model = (CLIPModel.from_pretrained("openai/clip-vit-large-patch14")).to(device)
94
+ self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
95
+ self.positive_prompt = positive_prompt
96
+ self.text = self.positive_prompt
97
+ self.tokenizer = self.processor.tokenizer
98
+ self.image_processor = self.processor.image_processor
99
+ self.text_encoding = self.tokenizer(self.text, return_tensors="pt", padding=True).to(device)
100
+ self.text_features = self.model.get_text_features(**self.text_encoding)
101
+ self.text_features = self.text_features / self.text_features.norm(p=2, dim=-1, keepdim=True)
102
+ @torch.no_grad()
103
+ def similarity(self, image):
104
+ # inputs = self.processor(text=self.text, images=image, return_tensors="pt", padding=True)
105
+ image_processed = self.image_processor(image, return_tensors="pt", padding=True).to(self.device, non_blocking=True)
106
+ inputs = {**self.text_encoding, **image_processed}
107
+ outputs = self.model(**inputs)
108
+ logits_per_image = outputs.logits_per_image
109
+ return logits_per_image
110
+
111
+ def get_logits(self, image):
112
+ logits_per_image = self.similarity(image)
113
+ return logits_per_image.cpu()
114
+
115
+ def get_image_features(self, image):
116
+ image_processed = self.image_processor(image, return_tensors="pt", padding=True).to(self.device, non_blocking=True)
117
+ image_features = self.model.get_image_features(**image_processed)
118
+ return image_features
119
+
120
+
121
+ class Art_filter:
122
+ def __init__(self):
123
+ self.caption_filter = Caption_filter()
124
+ self.clip_filter = Clip_filter()
125
+ def caption_filt(self, dataloader):
126
+ self.caption_filter.reset()
127
+ dataloader.dataset.get_img = False
128
+ dataloader.dataset.get_cap = True
129
+ remain_ids = []
130
+ filtered_ids = []
131
+ for i, batch in tqdm(enumerate(dataloader)):
132
+ captions = batch["text"]
133
+ filter_result = self.caption_filter.filter(captions)
134
+ for j, (filt, reason) in enumerate(filter_result):
135
+ if filt:
136
+ filtered_ids.append((batch["ids"][j], reason))
137
+ if i%10==0:
138
+ print(f"Filtered caption: {captions[j]}, reason: {reason}")
139
+ else:
140
+ remain_ids.append(batch["ids"][j])
141
+ return {"remain_ids":remain_ids, "filtered_ids":filtered_ids, "total_count":self.caption_filter.total_count, "filter_count":self.caption_filter.filter_count, "filter_prompts":self.caption_filter.filter_prompts}
142
+
143
+ def clip_filt(self, clip_logits_ckpt:dict):
144
+ logits = clip_logits_ckpt["clip_logits"]
145
+ ids = clip_logits_ckpt["ids"]
146
+ text = clip_logits_ckpt["text"]
147
+ filt_mask = torch.zeros(logits.shape[0], dtype=torch.bool)
148
+ for i, prompt in enumerate(text):
149
+ threshold = Clip_filter.prompt_threshold[prompt]
150
+ filt_mask = filt_mask | (logits[:,i] >= threshold)
151
+ filt_ids = []
152
+ remain_ids = []
153
+ for i, id in enumerate(ids):
154
+ if filt_mask[i]:
155
+ filt_ids.append(id)
156
+ else:
157
+ remain_ids.append(id)
158
+ return {"remain_ids":remain_ids, "filtered_ids":filt_ids}
159
+
160
+ def clip_feature(self, dataloader):
161
+ dataloader.dataset.get_img = True
162
+ dataloader.dataset.get_cap = False
163
+ clip_features = []
164
+ ids = []
165
+ for i, batch in enumerate(dataloader):
166
+ images = batch["images"]
167
+ features = self.clip_filter.get_image_features(images).cpu()
168
+ clip_features.append(features)
169
+ ids.extend(batch["ids"])
170
+ clip_features = torch.cat(clip_features)
171
+ return {"clip_features":clip_features, "ids":ids}
172
+
173
+
174
+ def clip_logit(self, dataloader):
175
+ dataloader.dataset.get_img = True
176
+ dataloader.dataset.get_cap = False
177
+ clip_features = []
178
+ clip_logits = []
179
+ ids = []
180
+ for i, batch in enumerate(dataloader):
181
+ images = batch["images"]
182
+ # logits = self.clip_filter.get_logits(images)
183
+ feature = self.clip_filter.get_image_features(images)
184
+ logits = self.clip_logit_by_feat(feature)["clip_logits"]
185
+
186
+ clip_features.append(feature)
187
+ clip_logits.append(logits)
188
+ ids.extend(batch["ids"])
189
+
190
+ clip_features = torch.cat(clip_features)
191
+ clip_logits = torch.cat(clip_logits)
192
+ return {"clip_features":clip_features, "clip_logits":clip_logits, "ids":ids, "text": self.clip_filter.text}
193
+
194
+ def clip_logit_by_feat(self, feature):
195
+ feature = feature.clone().to(self.clip_filter.device)
196
+ feature = feature / feature.norm(p=2, dim=-1, keepdim=True)
197
+ logit_scale = self.clip_filter.model.logit_scale.exp()
198
+ logits = ((feature @ self.clip_filter.text_features.T) * logit_scale).cpu()
199
+ return {"clip_logits":logits, "text": self.clip_filter.text}
200
+
201
+
202
+
203
+ if __name__ == "__main__":
204
+ import pickle
205
+ with open("/vision-nfs/torralba/scratch/jomat/sam_dataset/filt_result/sa_000000/clip_logits_result.pickle","rb") as f:
206
+ result=pickle.load(f)
207
+ feat = result['clip_features']
208
+ logits =Art_filter().clip_logit_by_feat(feat)
209
+ print(logits)
210
+
utils/config_util.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Optional
2
+
3
+ import yaml
4
+
5
+ from pydantic import BaseModel
6
+ import torch
7
+
8
+ from lora import TRAINING_METHODS
9
+
10
+ PRECISION_TYPES = Literal["fp32", "fp16", "bf16", "float32", "float16", "bfloat16"]
11
+ NETWORK_TYPES = Literal["lierla", "c3lier"]
12
+
13
+
14
+ class PretrainedModelConfig(BaseModel):
15
+ name_or_path: str
16
+ ckpt_path: Optional[str] = None
17
+ v2: bool = False
18
+ v_pred: bool = False
19
+
20
+ clip_skip: Optional[int] = None
21
+
22
+
23
+ class NetworkConfig(BaseModel):
24
+ type: NETWORK_TYPES = "lierla"
25
+ rank: int = 4
26
+ alpha: float = 1.0
27
+
28
+ training_method: TRAINING_METHODS = "full"
29
+
30
+
31
+ class TrainConfig(BaseModel):
32
+ precision: PRECISION_TYPES = "bfloat16"
33
+ noise_scheduler: Literal["ddim", "ddpm", "lms", "euler_a"] = "ddim"
34
+
35
+ iterations: int = 500
36
+ lr: float = 1e-4
37
+ optimizer: str = "adamw"
38
+ optimizer_args: str = ""
39
+ lr_scheduler: str = "constant"
40
+
41
+ max_denoising_steps: int = 50
42
+
43
+
44
+ class SaveConfig(BaseModel):
45
+ name: str = "untitled"
46
+ path: str = "./output"
47
+ per_steps: int = 200
48
+ precision: PRECISION_TYPES = "float32"
49
+
50
+
51
+ class LoggingConfig(BaseModel):
52
+ use_wandb: bool = False
53
+
54
+ verbose: bool = False
55
+
56
+
57
+ class OtherConfig(BaseModel):
58
+ use_xformers: bool = False
59
+
60
+
61
+ class RootConfig(BaseModel):
62
+ # prompts_file: str
63
+ pretrained_model: PretrainedModelConfig
64
+
65
+ network: NetworkConfig
66
+
67
+ train: Optional[TrainConfig]
68
+
69
+ save: Optional[SaveConfig]
70
+
71
+ logging: Optional[LoggingConfig]
72
+
73
+ other: Optional[OtherConfig]
74
+
75
+
76
+ def parse_precision(precision: str) -> torch.dtype:
77
+ if precision == "fp32" or precision == "float32":
78
+ return torch.float32
79
+ elif precision == "fp16" or precision == "float16":
80
+ return torch.float16
81
+ elif precision == "bf16" or precision == "bfloat16":
82
+ return torch.bfloat16
83
+
84
+ raise ValueError(f"Invalid precision type: {precision}")
85
+
86
+
87
+ def load_config_from_yaml(config_path: str) -> RootConfig:
88
+ with open(config_path, "r") as f:
89
+ config = yaml.load(f, Loader=yaml.FullLoader)
90
+
91
+ root = RootConfig(**config)
92
+
93
+ if root.train is None:
94
+ root.train = TrainConfig()
95
+
96
+ if root.save is None:
97
+ root.save = SaveConfig()
98
+
99
+ if root.logging is None:
100
+ root.logging = LoggingConfig()
101
+
102
+ if root.other is None:
103
+ root.other = OtherConfig()
104
+
105
+ return root
utils/debug_util.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # デバッグ用...
2
+
3
+ import torch
4
+
5
+
6
+ def check_requires_grad(model: torch.nn.Module):
7
+ for name, module in list(model.named_modules())[:5]:
8
+ if len(list(module.parameters())) > 0:
9
+ print(f"Module: {name}")
10
+ for name, param in list(module.named_parameters())[:2]:
11
+ print(f" Parameter: {name}, Requires Grad: {param.requires_grad}")
12
+
13
+
14
+ def check_training_mode(model: torch.nn.Module):
15
+ for name, module in list(model.named_modules())[:5]:
16
+ print(f"Module: {name}, Training Mode: {module.training}")
utils/lora.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ref:
2
+ # - https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
3
+ # - https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py
4
+
5
+ import os
6
+ import math
7
+ from typing import Optional, List, Type, Set, Literal
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from diffusers import UNet2DConditionModel
12
+ from safetensors.torch import save_file
13
+
14
+
15
+ UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [
16
+ # "Transformer2DModel", # どうやらこっちの方らしい? # attn1, 2
17
+ "Attention"
18
+ ]
19
+ UNET_TARGET_REPLACE_MODULE_CONV = [
20
+ "ResnetBlock2D",
21
+ "Downsample2D",
22
+ "Upsample2D",
23
+ # "DownBlock2D",
24
+ # "UpBlock2D"
25
+ ] # locon, 3clier
26
+
27
+ LORA_PREFIX_UNET = "lora_unet"
28
+
29
+ DEFAULT_TARGET_REPLACE = UNET_TARGET_REPLACE_MODULE_TRANSFORMER
30
+
31
+ TRAINING_METHODS = Literal[
32
+ "noxattn", # train all layers except x-attns and time_embed layers
33
+ "innoxattn", # train all layers except self attention layers
34
+ "selfattn", # ESD-u, train only self attention layers
35
+ "xattn", # ESD-x, train only x attention layers
36
+ "full", # train all layers
37
+ "xattn-strict", # q and k values
38
+ "noxattn-hspace",
39
+ "noxattn-hspace-last",
40
+ # "xlayer",
41
+ # "outxattn",
42
+ # "outsattn",
43
+ # "inxattn",
44
+ # "inmidsattn",
45
+ # "selflayer",
46
+ ]
47
+
48
+
49
+ class LoRAModule(nn.Module):
50
+ """
51
+ replaces forward method of the original Linear, instead of replacing the original Linear module.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ lora_name,
57
+ org_module: nn.Module,
58
+ multiplier=1.0,
59
+ lora_dim=4,
60
+ alpha=1,
61
+ ):
62
+ """if alpha == 0 or None, alpha is rank (no scaling)."""
63
+ super().__init__()
64
+ self.lora_name = lora_name
65
+ self.lora_dim = lora_dim
66
+
67
+ if "Linear" in org_module.__class__.__name__:
68
+ in_dim = org_module.in_features
69
+ out_dim = org_module.out_features
70
+ self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
71
+ self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
72
+
73
+ elif "Conv" in org_module.__class__.__name__: # 一応
74
+ in_dim = org_module.in_channels
75
+ out_dim = org_module.out_channels
76
+
77
+ self.lora_dim = min(self.lora_dim, in_dim, out_dim)
78
+ if self.lora_dim != lora_dim:
79
+ print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
80
+
81
+ kernel_size = org_module.kernel_size
82
+ stride = org_module.stride
83
+ padding = org_module.padding
84
+ self.lora_down = nn.Conv2d(
85
+ in_dim, self.lora_dim, kernel_size, stride, padding, bias=False
86
+ )
87
+ self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
88
+
89
+ if type(alpha) == torch.Tensor:
90
+ alpha = alpha.detach().numpy()
91
+ alpha = lora_dim if alpha is None or alpha == 0 else alpha
92
+ self.scale = alpha / self.lora_dim
93
+ self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
94
+
95
+ # same as microsoft's
96
+ nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
97
+ nn.init.zeros_(self.lora_up.weight)
98
+
99
+ self.multiplier = multiplier
100
+ self.org_module = org_module # remove in applying
101
+
102
+ def apply_to(self):
103
+ self.org_forward = self.org_module.forward
104
+ self.org_module.forward = self.forward
105
+ del self.org_module
106
+
107
+ def forward(self, x):
108
+ return (
109
+ self.org_forward(x)
110
+ + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
111
+ )
112
+
113
+
114
+ class LoRANetwork(nn.Module):
115
+ def __init__(
116
+ self,
117
+ unet: UNet2DConditionModel,
118
+ rank: int = 4,
119
+ multiplier: float = 1.0,
120
+ alpha: float = 1.0,
121
+ train_method: TRAINING_METHODS = "full",
122
+ ) -> None:
123
+ super().__init__()
124
+ self.lora_scale = 1
125
+ self.multiplier = multiplier
126
+ self.lora_dim = rank
127
+ self.alpha = alpha
128
+
129
+
130
+ self.module = LoRAModule
131
+
132
+
133
+ self.unet_loras = self.create_modules(
134
+ LORA_PREFIX_UNET,
135
+ unet,
136
+ DEFAULT_TARGET_REPLACE,
137
+ self.lora_dim,
138
+ self.multiplier,
139
+ train_method=train_method,
140
+ )
141
+ print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
142
+
143
+
144
+ lora_names = set()
145
+ for lora in self.unet_loras:
146
+ assert (
147
+ lora.lora_name not in lora_names
148
+ ), f"duplicated lora name: {lora.lora_name}. {lora_names}"
149
+ lora_names.add(lora.lora_name)
150
+
151
+
152
+ for lora in self.unet_loras:
153
+ lora.apply_to()
154
+ self.add_module(
155
+ lora.lora_name,
156
+ lora,
157
+ )
158
+
159
+ del unet
160
+
161
+ torch.cuda.empty_cache()
162
+
163
+ def create_modules(
164
+ self,
165
+ prefix: str,
166
+ root_module: nn.Module,
167
+ target_replace_modules: List[str],
168
+ rank: int,
169
+ multiplier: float,
170
+ train_method: TRAINING_METHODS,
171
+ ) -> list:
172
+ loras = []
173
+ names = []
174
+ for name, module in root_module.named_modules():
175
+ if train_method == "noxattn" or train_method == "noxattn-hspace" or train_method == "noxattn-hspace-last": # Cross Attention と Time Embed 以外学習
176
+ if "attn2" in name or "time_embed" in name:
177
+ continue
178
+ elif train_method == "innoxattn": # Cross Attention 以外学習
179
+ if "attn2" in name:
180
+ continue
181
+ elif train_method == "selfattn": # Self Attention のみ学習
182
+ if "attn1" not in name:
183
+ continue
184
+ elif train_method == "xattn" or train_method == "xattn-strict": # Cross Attention のみ学習
185
+ if "attn2" not in name:
186
+ continue
187
+ elif train_method == "attn":
188
+ if "attn1" not in name and "attn2" not in name:
189
+ continue
190
+ elif train_method == "full":
191
+ pass
192
+ # else:
193
+ # raise NotImplementedError(
194
+ # f"train_method: {train_method} is not implemented."
195
+ # )
196
+ ##
197
+ # union condition(b-lora)
198
+ else:
199
+ discard = True
200
+ if "all_up" in train_method:
201
+ if "up_blocks" in name:
202
+ discard = False
203
+ if "down_1" in train_method:
204
+ if not ("down_blocks.1" not in name or "attentions" not in name):
205
+ discard = False
206
+ if "down_2" in train_method:
207
+ if not ("down_blocks.2" not in name or "attentions" not in name):
208
+ discard = False
209
+ if "up_1" in train_method:
210
+ if not ("up_blocks.1" not in name or "attentions" not in name):
211
+ discard = False
212
+ if "up_2" in train_method:
213
+ if not ("up_blocks.2" not in name or "attentions" not in name):
214
+ discard = False
215
+ if discard:
216
+ continue
217
+
218
+ ##
219
+ if module.__class__.__name__ in target_replace_modules:
220
+ for child_name, child_module in module.named_modules():
221
+ if child_module.__class__.__name__ in ["Linear", "Conv2d", "LoRACompatibleLinear", "LoRACompatibleConv"]:
222
+ if train_method == 'xattn-strict':
223
+ if 'out' in child_name:
224
+ continue
225
+ if train_method == 'noxattn-hspace':
226
+ if 'mid_block' not in name:
227
+ continue
228
+ if train_method == 'noxattn-hspace-last':
229
+ if 'mid_block' not in name or '.1' not in name or 'conv2' not in child_name:
230
+ continue
231
+ lora_name = prefix + "." + name + "." + child_name
232
+ lora_name = lora_name.replace(".", "_")
233
+ # print(f"{lora_name}")
234
+ lora = self.module(
235
+ lora_name, child_module, multiplier, rank, self.alpha
236
+ )
237
+ # print(name, child_name)
238
+ # print(child_module.weight.shape)
239
+ loras.append(lora)
240
+ names.append(lora_name)
241
+ # print(f'@@@@@@@@@@@@@@@@@@@@@@@@@@@@ \n {names}')
242
+ return loras
243
+
244
+ def prepare_optimizer_params(self):
245
+ all_params = []
246
+
247
+ if self.unet_loras: # 実質これしかない
248
+ params = []
249
+ [params.extend(lora.parameters()) for lora in self.unet_loras]
250
+ param_data = {"params": params}
251
+ all_params.append(param_data)
252
+
253
+ return all_params
254
+
255
+ def save_weights(self, file, dtype=None, metadata: Optional[dict] = None):
256
+ state_dict = self.state_dict()
257
+
258
+ if dtype is not None:
259
+ for key in list(state_dict.keys()):
260
+ v = state_dict[key]
261
+ v = v.detach().clone().to("cpu").to(dtype)
262
+ state_dict[key] = v
263
+
264
+ # for key in list(state_dict.keys()):
265
+ # if not key.startswith("lora"):
266
+ # # lora以外除外
267
+ # del state_dict[key]
268
+
269
+ if os.path.splitext(file)[1] == ".safetensors":
270
+ save_file(state_dict, file, metadata)
271
+ else:
272
+ torch.save(state_dict, file)
273
+ def set_lora_slider(self, scale):
274
+ self.lora_scale = scale
275
+
276
+ def __enter__(self):
277
+ for lora in self.unet_loras:
278
+ lora.multiplier = 1.0 * self.lora_scale
279
+
280
+ def __exit__(self, exc_type, exc_value, tb):
281
+ for lora in self.unet_loras:
282
+ lora.multiplier = 0
utils/metrics.py ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Authors: Hui Ren (rhfeiyang.github.io)
2
+
3
+ import os
4
+
5
+ import numpy as np
6
+ from torchvision import transforms
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torch.nn as nn
10
+ from torch.autograd import Function
11
+ from PIL import Image
12
+ from transformers import CLIPProcessor, CLIPModel
13
+ from collections import OrderedDict
14
+ from transformers import BatchFeature
15
+ import clip
16
+ import copy
17
+ import lpips
18
+ from transformers import ViTImageProcessor, ViTModel
19
+
20
+ ## CSD_CLIP
21
+ def convert_weights_float(model: nn.Module):
22
+ """Convert applicable model parameters to fp32"""
23
+
24
+ def _convert_weights_to_fp32(l):
25
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
26
+ l.weight.data = l.weight.data.float()
27
+ if l.bias is not None:
28
+ l.bias.data = l.bias.data.float()
29
+
30
+ if isinstance(l, nn.MultiheadAttention):
31
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
32
+ tensor = getattr(l, attr)
33
+ if tensor is not None:
34
+ tensor.data = tensor.data.float()
35
+
36
+ for name in ["text_projection", "proj"]:
37
+ if hasattr(l, name):
38
+ attr = getattr(l, name)
39
+ if attr is not None:
40
+ attr.data = attr.data.float()
41
+
42
+ model.apply(_convert_weights_to_fp32)
43
+
44
+ class ReverseLayerF(Function):
45
+
46
+ @staticmethod
47
+ def forward(ctx, x, alpha):
48
+ ctx.alpha = alpha
49
+
50
+ return x.view_as(x)
51
+
52
+ @staticmethod
53
+ def backward(ctx, grad_output):
54
+ output = grad_output.neg() * ctx.alpha
55
+
56
+ return output, None
57
+
58
+
59
+ ## taken from https://github.com/moein-shariatnia/OpenAI-CLIP/blob/master/modules.py
60
+ class ProjectionHead(nn.Module):
61
+ def __init__(
62
+ self,
63
+ embedding_dim,
64
+ projection_dim,
65
+ dropout=0
66
+ ):
67
+ super().__init__()
68
+ self.projection = nn.Linear(embedding_dim, projection_dim)
69
+ self.gelu = nn.GELU()
70
+ self.fc = nn.Linear(projection_dim, projection_dim)
71
+ self.dropout = nn.Dropout(dropout)
72
+ self.layer_norm = nn.LayerNorm(projection_dim)
73
+
74
+ def forward(self, x):
75
+ projected = self.projection(x)
76
+ x = self.gelu(projected)
77
+ x = self.fc(x)
78
+ x = self.dropout(x)
79
+ x = x + projected
80
+ x = self.layer_norm(x)
81
+ return x
82
+
83
+ def convert_state_dict(state_dict):
84
+ new_state_dict = OrderedDict()
85
+ for k, v in state_dict.items():
86
+ if k.startswith("module."):
87
+ k = k.replace("module.", "")
88
+ new_state_dict[k] = v
89
+ return new_state_dict
90
+ def init_weights(m):
91
+ if isinstance(m, nn.Linear):
92
+ torch.nn.init.xavier_uniform_(m.weight)
93
+ if m.bias is not None:
94
+ nn.init.normal_(m.bias, std=1e-6)
95
+
96
+ class Metric(nn.Module):
97
+ def __init__(self):
98
+ super().__init__()
99
+ self.image_preprocess = None
100
+
101
+ def load_image(self, image_path):
102
+ with open(image_path, 'rb') as f:
103
+ image = Image.open(f).convert("RGB")
104
+ return image
105
+
106
+ def load_image_path(self, image_path):
107
+ if isinstance(image_path, str):
108
+ # should be a image folder path
109
+ images_file = os.listdir(image_path)
110
+ images = [os.path.join(image_path, image) for image in images_file if
111
+ image.endswith(".jpg") or image.endswith(".png")]
112
+ if isinstance(image_path[0], str):
113
+ images = [self.load_image(image) for image in image_path]
114
+ elif isinstance(image_path[0], np.ndarray):
115
+ images = [Image.fromarray(image) for image in image_path]
116
+ elif isinstance(image_path[0], Image.Image):
117
+ images = image_path
118
+ else:
119
+ raise Exception("Invalid input")
120
+ return images
121
+
122
+ def preprocess_image(self, image, **kwargs):
123
+ if (isinstance(image, str) and os.path.isdir(image)) or (isinstance(image, list) and (isinstance(image[0], Image.Image) or isinstance(image[0], np.ndarray) or os.path.isfile(image[0]))):
124
+ input_data = self.load_image_path(image)
125
+ input_data = [self.image_preprocess(image, **kwargs) for image in input_data]
126
+ input_data = torch.stack(input_data)
127
+ elif os.path.isfile(image):
128
+ input_data = self.load_image(image)
129
+ input_data = self.image_preprocess(input_data, **kwargs)
130
+ input_data = input_data.unsqueeze(0)
131
+ elif isinstance(image, torch.Tensor):
132
+ raise Exception("Unsupported input")
133
+ return input_data
134
+
135
+ class Clip_Basic_Metric(Metric):
136
+ def __init__(self):
137
+ super().__init__()
138
+ self.tensor_preprocess = transforms.Compose([
139
+ transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
140
+ # transforms.rescale
141
+ transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0]),
142
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
143
+ ])
144
+ self.image_preprocess = transforms.Compose([
145
+ transforms.Resize(size=224, interpolation=transforms.InterpolationMode.BICUBIC),
146
+ transforms.CenterCrop(224),
147
+ transforms.ToTensor(),
148
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
149
+ ])
150
+
151
+ class Clip_metric(Clip_Basic_Metric):
152
+
153
+ @torch.no_grad()
154
+ def __init__(self, target_style_prompt: str=None, clip_model_name="openai/clip-vit-large-patch14", device="cuda",
155
+ bath_size=8, alpha=0.5):
156
+ super().__init__()
157
+ self.device = device
158
+ self.alpha = alpha
159
+ self.model = (CLIPModel.from_pretrained(clip_model_name)).to(device)
160
+ self.processor = CLIPProcessor.from_pretrained(clip_model_name)
161
+ self.tokenizer = self.processor.tokenizer
162
+ self.image_processor = self.processor.image_processor
163
+ # self.style_class_features = self.get_text_features(self.styles).cpu()
164
+ self.style_class_features=[]
165
+ # self.noise_prompt_features = self.get_text_features("Noise")
166
+ self.model.eval()
167
+ self.batch_size = bath_size
168
+ if target_style_prompt is not None:
169
+ self.ref_style_features = self.get_text_features(target_style_prompt)
170
+ else:
171
+ self.ref_style_features = None
172
+
173
+ self.ref_image_style_prototype = None
174
+
175
+ def get_text_features(self, text):
176
+ prompt_encoding = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(self.device)
177
+ prompt_features = self.model.get_text_features(**prompt_encoding).to(self.device)
178
+ prompt_features = F.normalize(prompt_features, p=2, dim=-1)
179
+ return prompt_features
180
+
181
+ def get_image_features(self, images):
182
+ # if isinstance(image, torch.Tensor):
183
+ # self.tensor_transform(image)
184
+ # else:
185
+ # image_features = self.image_processor(image, return_tensors="pt", padding=True).to(self.device, non_blocking=True)
186
+ images = self.load_image_path(images)
187
+ if isinstance(images, torch.Tensor):
188
+ images = self.tensor_preprocess(images)
189
+ data = {"pixel_values": images}
190
+ image_features = BatchFeature(data=data, tensor_type="pt")
191
+ else:
192
+ image_features = self.image_processor(images, return_tensors="pt", padding=True).to(self.device,
193
+ non_blocking=True)
194
+
195
+ image_features = self.model.get_image_features(**image_features).to(self.device)
196
+ image_features = F.normalize(image_features, p=2, dim=-1)
197
+ return image_features
198
+
199
+ def img_text_similarity(self, image_features, text=None):
200
+ if text is not None:
201
+ prompt_feature = self.get_text_features(text)
202
+ if isinstance(text, str):
203
+ prompt_feature = prompt_feature.repeat(len(image_features), 1)
204
+ else:
205
+ prompt_feature = self.ref_style_features
206
+
207
+ similarity_each = torch.einsum("nc, nc -> n", image_features, prompt_feature)
208
+ return similarity_each
209
+
210
+ def forward(self, output_imgs, prompt=None):
211
+ image_features = self.get_image_features(output_imgs)
212
+ # print(image_features)
213
+ style_score = self.img_text_similarity(image_features.mean(dim=0, keepdim=True))
214
+ if prompt is not None:
215
+ content_score = self.img_text_similarity(image_features, prompt)
216
+
217
+ score = self.alpha * style_score + (1 - self.alpha) * content_score
218
+ return {"score": score, "style_score": style_score, "content_score": content_score}
219
+ else:
220
+ return {"style_score": style_score}
221
+
222
+ def content_score(self, output_imgs, prompt):
223
+ self.to(self.device)
224
+ image_features = self.get_image_features(output_imgs)
225
+ content_score_details = self.img_text_similarity(image_features, prompt)
226
+ self.to("cpu")
227
+ return {"CLIP_content_score": content_score_details.mean().cpu(), "CLIP_content_score_details": content_score_details.cpu()}
228
+
229
+
230
+ class CSD_CLIP(Clip_Basic_Metric):
231
+ """backbone + projection head"""
232
+ def __init__(self, name='vit_large',content_proj_head='default', ckpt_path = "data/weights/CSD-checkpoint.pth", device="cuda",
233
+ alpha=0.5, **kwargs):
234
+ super(CSD_CLIP, self).__init__()
235
+ self.alpha = alpha
236
+ self.content_proj_head = content_proj_head
237
+ self.device = device
238
+ if name == 'vit_large':
239
+ clipmodel, _ = clip.load("ViT-L/14")
240
+ self.backbone = clipmodel.visual
241
+ self.embedding_dim = 1024
242
+ elif name == 'vit_base':
243
+ clipmodel, _ = clip.load("ViT-B/16")
244
+ self.backbone = clipmodel.visual
245
+ self.embedding_dim = 768
246
+ self.feat_dim = 512
247
+ else:
248
+ raise Exception('This model is not implemented')
249
+
250
+ convert_weights_float(self.backbone)
251
+ self.last_layer_style = copy.deepcopy(self.backbone.proj)
252
+ if content_proj_head == 'custom':
253
+ self.last_layer_content = ProjectionHead(self.embedding_dim,self.feat_dim)
254
+ self.last_layer_content.apply(init_weights)
255
+
256
+ else:
257
+ self.last_layer_content = copy.deepcopy(self.backbone.proj)
258
+
259
+ self.backbone.proj = None
260
+ self.backbone.requires_grad_(False)
261
+ self.last_layer_style.requires_grad_(False)
262
+ self.last_layer_content.requires_grad_(False)
263
+ self.backbone.eval()
264
+
265
+ if ckpt_path is not None:
266
+ self.load_ckpt(ckpt_path)
267
+ self.to("cpu")
268
+
269
+ def load_ckpt(self, ckpt_path):
270
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
271
+ state_dict = convert_state_dict(checkpoint['model_state_dict'])
272
+ msg = self.load_state_dict(state_dict, strict=False)
273
+ print(f"=> loaded CSD_CLIP checkpoint with msg {msg}")
274
+
275
+ @property
276
+ def dtype(self):
277
+ return self.backbone.conv1.weight.dtype
278
+
279
+ def get_image_features(self, input_data, get_style=True,get_content=False,feature_alpha=None):
280
+ if isinstance(input_data, torch.Tensor):
281
+ input_data = self.tensor_preprocess(input_data)
282
+ elif (isinstance(input_data, str) and os.path.isdir(input_data)) or (isinstance(input_data, list) and (isinstance(input_data[0], Image.Image) or isinstance(input_data[0], np.ndarray) or os.path.isfile(input_data[0]))):
283
+ input_data = self.load_image_path(input_data)
284
+ input_data = [self.image_preprocess(image) for image in input_data]
285
+ input_data = torch.stack(input_data)
286
+ elif os.path.isfile(input_data):
287
+ input_data = self.load_image(input_data)
288
+ input_data = self.image_preprocess(input_data)
289
+ input_data = input_data.unsqueeze(0)
290
+ input_data = input_data.to(self.device)
291
+ style_output = None
292
+
293
+ feature = self.backbone(input_data)
294
+ if get_style:
295
+ style_output = feature @ self.last_layer_style
296
+ # style_output = style_output.mean(dim=0)
297
+ style_output = nn.functional.normalize(style_output, dim=-1, p=2)
298
+
299
+ content_output=None
300
+ if get_content:
301
+ if feature_alpha is not None:
302
+ reverse_feature = ReverseLayerF.apply(feature, feature_alpha)
303
+ else:
304
+ reverse_feature = feature
305
+ # if alpha is not None:
306
+ if self.content_proj_head == 'custom':
307
+ content_output = self.last_layer_content(reverse_feature)
308
+ else:
309
+ content_output = reverse_feature @ self.last_layer_content
310
+ content_output = nn.functional.normalize(content_output, dim=-1, p=2)
311
+
312
+ return feature, content_output, style_output
313
+
314
+
315
+ @torch.no_grad()
316
+ def define_ref_image_style_prototype(self, ref_image_path: str):
317
+ self.to(self.device)
318
+ _, _, self.ref_style_feature = self.get_image_features(ref_image_path)
319
+ self.to("cpu")
320
+ # self.ref_style_feature = self.ref_style_feature.mean(dim=0)
321
+ @torch.no_grad()
322
+ def forward(self, styled_data):
323
+ self.to(self.device)
324
+ # get_content_feature = original_data is not None
325
+ _, content_output, style_output = self.get_image_features(styled_data, get_content=False)
326
+ style_similarities = style_output @ self.ref_style_feature.T
327
+ mean_style_similarities = style_similarities.mean(dim=-1)
328
+ mean_style_similarity = mean_style_similarities.mean()
329
+
330
+ max_style_similarities_v, max_style_similarities_id = style_similarities.max(dim=-1)
331
+ max_style_similarity = max_style_similarities_v.mean()
332
+
333
+
334
+ self.to("cpu")
335
+ return {"CSD_similarity_mean": mean_style_similarity, "CSD_similarity_max": max_style_similarity, "CSD_similarity_mean_details": mean_style_similarities,
336
+ "CSD_similarity_max_v_details": max_style_similarities_v, "CSD_similarity_max_id_details": max_style_similarities_id}
337
+
338
+ def get_style_loss(self, styled_data):
339
+ _, _, style_output = self.get_image_features(styled_data, get_style=True, get_content=False)
340
+ style_similarity = (style_output @ self.ref_style_feature).mean()
341
+ loss = 1 - style_similarity
342
+ return loss.mean()
343
+
344
+ class LPIPS_metric(Metric):
345
+ def __init__(self, type="vgg", device="cuda"):
346
+ super(LPIPS_metric, self).__init__()
347
+ self.lpips = lpips.LPIPS(net=type)
348
+ self.device = device
349
+ self.image_preprocess = transforms.Compose([
350
+ transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
351
+ transforms.CenterCrop(256),
352
+ transforms.ToTensor(),
353
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
354
+ ])
355
+ self.to("cpu")
356
+
357
+ @torch.no_grad()
358
+ def forward(self, img1, img2):
359
+ self.to(self.device)
360
+ differences = []
361
+ for i in range(0, len(img1), 50):
362
+ img1_batch = img1[i:i+50]
363
+ img2_batch = img2[i:i+50]
364
+ img1_batch = self.preprocess_image(img1_batch).to(self.device)
365
+ img2_batch = self.preprocess_image(img2_batch).to(self.device)
366
+ differences.append(self.lpips(img1_batch, img2_batch).squeeze())
367
+ differences = torch.cat(differences)
368
+ difference = differences.mean()
369
+ # similarity = 1 - difference
370
+ self.to("cpu")
371
+ return {"LPIPS_content_difference": difference, "LPIPS_content_difference_details": differences}
372
+
373
+ class Vit_metric(Metric):
374
+ def __init__(self, device="cuda"):
375
+ super(Vit_metric, self).__init__()
376
+ self.device = device
377
+ self.model = ViTModel.from_pretrained('facebook/dino-vitb8').eval()
378
+ self.image_processor = ViTImageProcessor.from_pretrained('facebook/dino-vitb8')
379
+ self.to("cpu")
380
+ def get_image_features(self, images):
381
+ # if isinstance(image, torch.Tensor):
382
+ # self.tensor_transform(image)
383
+ # else:
384
+ # image_features = self.image_processor(image, return_tensors="pt", padding=True).to(self.device, non_blocking=True)
385
+ images = self.load_image_path(images)
386
+ batch_size = 20
387
+ all_image_features = []
388
+ for i in range(0, len(images), batch_size):
389
+ image_batch = images[i:i+batch_size]
390
+ if isinstance(image_batch, torch.Tensor):
391
+ image_batch = self.tensor_preprocess(image_batch)
392
+ data = {"pixel_values": image_batch}
393
+ image_processed = BatchFeature(data=data, tensor_type="pt")
394
+ else:
395
+ image_processed = self.image_processor(image_batch, return_tensors="pt").to(self.device)
396
+ image_features = self.model(**image_processed).last_hidden_state.flatten(start_dim=1)
397
+ image_features = F.normalize(image_features, p=2, dim=-1)
398
+ all_image_features.append(image_features)
399
+ all_image_features = torch.cat(all_image_features)
400
+ return all_image_features
401
+
402
+ @torch.no_grad()
403
+ def content_metric(self, img1, img2):
404
+ self.to(self.device)
405
+ if not(isinstance(img1, torch.Tensor) and len(img1.shape) == 2):
406
+ img1 = self.get_image_features(img1)
407
+ if not(isinstance(img2, torch.Tensor) and len(img2.shape) == 2):
408
+ img2 = self.get_image_features(img2)
409
+ similarities = torch.einsum("nc, nc -> n", img1, img2)
410
+ similarity = similarities.mean()
411
+ # self.to("cpu")
412
+ return {"Vit_content_similarity": similarity, "Vit_content_similarity_details": similarities}
413
+
414
+ # style
415
+ @torch.no_grad()
416
+ def define_ref_image_style_prototype(self, ref_image_path: str):
417
+ self.to(self.device)
418
+ self.ref_style_feature = self.get_image_features(ref_image_path)
419
+ self.to("cpu")
420
+ @torch.no_grad()
421
+ def style_metric(self, styled_data):
422
+ self.to(self.device)
423
+ if isinstance(styled_data, torch.Tensor) and len(styled_data.shape) == 2:
424
+ style_output = styled_data
425
+ else:
426
+ style_output = self.get_image_features(styled_data)
427
+ style_similarities = style_output @ self.ref_style_feature.T
428
+ mean_style_similarities = style_similarities.mean(dim=-1)
429
+ mean_style_similarity = mean_style_similarities.mean()
430
+
431
+ max_style_similarities_v, max_style_similarities_id = style_similarities.max(dim=-1)
432
+ max_style_similarity = max_style_similarities_v.mean()
433
+
434
+ # self.to("cpu")
435
+ return {"Vit_style_similarity_mean": mean_style_similarity, "Vit_style_similarity_max": max_style_similarity, "Vit_style_similarity_mean_details": mean_style_similarities,
436
+ "Vit_style_similarity_max_v_details": max_style_similarities_v, "Vit_style_similarity_max_id_details": max_style_similarities_id}
437
+ @torch.no_grad()
438
+ def forward(self, styled_data, original_data=None):
439
+ self.to(self.device)
440
+ styled_features = self.get_image_features(styled_data)
441
+ ret ={}
442
+ if original_data is not None:
443
+ content_metric = self.content_metric(styled_features, original_data)
444
+ ret["Vit_content"] = content_metric
445
+ style_metric = self.style_metric(styled_features)
446
+ ret["Vit_style"] = style_metric
447
+ self.to("cpu")
448
+ return ret
449
+
450
+
451
+
452
+ class StyleContentMetric(nn.Module):
453
+ def __init__(self, style_ref_image_folder, device="cuda"):
454
+ super(StyleContentMetric, self).__init__()
455
+ self.device = device
456
+ self.clip_style_metric = CSD_CLIP(device=device)
457
+ self.ref_image_file = os.listdir(style_ref_image_folder)
458
+ self.ref_image_file = [i for i in self.ref_image_file if i.endswith(".jpg") or i.endswith(".png")]
459
+ self.ref_image_file.sort()
460
+ self.ref_image_file = np.array(self.ref_image_file)
461
+ ref_image_file_path = [os.path.join(style_ref_image_folder, i) for i in self.ref_image_file]
462
+
463
+ self.clip_style_metric.define_ref_image_style_prototype(ref_image_file_path)
464
+ self.vit_metric = Vit_metric(device=device)
465
+ self.vit_metric.define_ref_image_style_prototype(ref_image_file_path)
466
+ self.lpips_metric = LPIPS_metric(device=device)
467
+
468
+ self.clip_content_metric = Clip_metric(alpha=0, target_style_prompt=None)
469
+
470
+ self.to("cpu")
471
+
472
+ def forward(self, styled_data, original_data=None, content_caption=None):
473
+ ret ={}
474
+ csd_score = self.clip_style_metric(styled_data)
475
+ csd_score["max_query"] = self.ref_image_file[csd_score["CSD_similarity_max_id_details"].cpu()].tolist()
476
+ torch.cuda.empty_cache()
477
+ ret["Style_CSD"] = csd_score
478
+ vit_score = self.vit_metric(styled_data, original_data)
479
+ torch.cuda.empty_cache()
480
+ vit_style = vit_score["Vit_style"]
481
+ vit_style["max_query"] = self.ref_image_file[vit_style["Vit_style_similarity_max_id_details"].cpu()].tolist()
482
+ ret["Style_VIT"] = vit_style
483
+
484
+ if original_data is not None:
485
+ vit_content = vit_score["Vit_content"]
486
+ ret["Content_VIT"] = vit_content
487
+ lpips_score = self.lpips_metric(styled_data, original_data)
488
+ torch.cuda.empty_cache()
489
+ ret["Content_LPIPS"] = lpips_score
490
+
491
+ if content_caption is not None:
492
+ clip_content = self.clip_content_metric.content_score(styled_data, content_caption)
493
+ ret["Content_CLIP"] = clip_content
494
+ torch.cuda.empty_cache()
495
+
496
+ for type_key, type_value in ret.items():
497
+ for key, value in type_value.items():
498
+ if isinstance(value, torch.Tensor):
499
+ if value.numel() == 1:
500
+ ret[type_key][key] = round(value.item(), 4)
501
+ else:
502
+ ret[type_key][key] = value.tolist()
503
+ ret[type_key][key] = [round(v, 4) for v in ret[type_key][key]]
504
+
505
+ self.to("cpu")
506
+ ret["ref_image_file"] = self.ref_image_file.tolist()
507
+ return ret
508
+
509
+
510
+ if __name__ == "__main__":
511
+ with torch.no_grad():
512
+ metric = StyleContentMetric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/Art_styles/camille-pissarro/impressionism/split_5/paintings")
513
+ score = metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/converted_photo/500",
514
+ "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/paintings")
515
+ print(score)
516
+
517
+
518
+
519
+ lpips = LPIPS_metric()
520
+ score = lpips("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/paintings",
521
+ "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/converted_photo/500")
522
+
523
+ print("lpips", score)
524
+
525
+
526
+ clip_metric = CSD_CLIP()
527
+ clip_metric.define_ref_image_style_prototype(
528
+ "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset1/paintings")
529
+
530
+ score = clip_metric(
531
+ "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/converted_photo/500")
532
+ print("subset3-subset3_sd14_converted", score)
533
+
534
+ score = clip_metric(
535
+ "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_500")
536
+ print("subset3-photo", score)
537
+
538
+
539
+
540
+ score = clip_metric(
541
+ "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset1/paintings")
542
+ print("subset3-subset1", score)
543
+
544
+ score = clip_metric(
545
+ "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/andy-warhol/pop_art/subset1/paintings")
546
+ print("subset3-andy", score)
547
+ # score = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/paintings", "A painting")
548
+
549
+ # print("subset3",score)
550
+ # score_subset2 = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset2/paintings")
551
+ # print("subset2",score_subset2)
552
+ # score_subset3 = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/paintings")
553
+ # print("subset3",score_subset3)
554
+ #
555
+ # score_subset3_converted = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/converted_photo/500")
556
+ # print("subset3-subset3_sd14_converted" , score_subset3_converted)
557
+ #
558
+ # score_subset3_coco_converted = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/coco_converted_photo/500")
559
+ # print("subset3-subset3_coco_converted" , score_subset3_coco_converted)
560
+ #
561
+ # clip_metric = Clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/sketch_500")
562
+ # score = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_500")
563
+ # print("photo500_1-sketch" ,score)
564
+ #
565
+ # clip_metric = Clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_500")
566
+ # score = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_500_new")
567
+ # print("photo500_1-photo500_2" ,score)
568
+ # from custom_datasets.imagepair import ImageSet
569
+ # import matplotlib.pyplot as plt
570
+ # dataset = ImageSet(folder = "/data/vision/torralba/scratch/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/paintings",
571
+ # caption_path="/data/vision/torralba/scratch/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/captions",
572
+ # keep_in_mem=False)
573
+ # for sample in dataset:
574
+ # score = clip_metric.content_score(sample["image"], sample["caption"][0])
575
+ # plt.imshow(sample["image"])
576
+ # plt.title(f"score: {round(score.item(),2)}\n prompt: {sample['caption'][0]}")
577
+ # plt.show()
utils/model_util.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Union, Optional
2
+
3
+ import torch
4
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
5
+ from diffusers import (
6
+ UNet2DConditionModel,
7
+ SchedulerMixin,
8
+ StableDiffusionPipeline,
9
+ StableDiffusionXLPipeline,
10
+ AutoencoderKL,
11
+ )
12
+ from diffusers.schedulers import (
13
+ DDIMScheduler,
14
+ DDPMScheduler,
15
+ LMSDiscreteScheduler,
16
+ EulerAncestralDiscreteScheduler,
17
+ )
18
+
19
+
20
+ TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4"
21
+ TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1"
22
+
23
+ AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a"]
24
+
25
+ SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection]
26
+
27
+ DIFFUSERS_CACHE_DIR = None # if you want to change the cache dir, change this
28
+ from diffusers.training_utils import EMAModel
29
+ import os
30
+ import sys
31
+
32
+ # from utils.modules import get_diffusion_modules
33
+ def load_diffusers_model(
34
+ pretrained_model_name_or_path: str,
35
+ v2: bool = False,
36
+ clip_skip: Optional[int] = None,
37
+ weight_dtype: torch.dtype = torch.float32,
38
+ ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
39
+ # VAE はいらない
40
+
41
+ if v2:
42
+ tokenizer = CLIPTokenizer.from_pretrained(
43
+ TOKENIZER_V2_MODEL_NAME,
44
+ subfolder="tokenizer",
45
+ torch_dtype=weight_dtype,
46
+ cache_dir=DIFFUSERS_CACHE_DIR,
47
+ )
48
+ text_encoder = CLIPTextModel.from_pretrained(
49
+ pretrained_model_name_or_path,
50
+ subfolder="text_encoder",
51
+ # default is clip skip 2
52
+ num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23,
53
+ torch_dtype=weight_dtype,
54
+ cache_dir=DIFFUSERS_CACHE_DIR,
55
+ )
56
+ else:
57
+ tokenizer = CLIPTokenizer.from_pretrained(
58
+ TOKENIZER_V1_MODEL_NAME,
59
+ subfolder="tokenizer",
60
+ torch_dtype=weight_dtype,
61
+ cache_dir=DIFFUSERS_CACHE_DIR,
62
+ )
63
+ text_encoder = CLIPTextModel.from_pretrained(
64
+ pretrained_model_name_or_path,
65
+ subfolder="text_encoder",
66
+ num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12,
67
+ torch_dtype=weight_dtype,
68
+ cache_dir=DIFFUSERS_CACHE_DIR,
69
+ )
70
+
71
+ unet = UNet2DConditionModel.from_pretrained(
72
+ pretrained_model_name_or_path,
73
+ subfolder="unet",
74
+ torch_dtype=weight_dtype,
75
+ cache_dir=DIFFUSERS_CACHE_DIR,
76
+ )
77
+
78
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
79
+
80
+ return tokenizer, text_encoder, unet, vae
81
+
82
+
83
+ def load_checkpoint_model(
84
+ checkpoint_path: str,
85
+ v2: bool = False,
86
+ clip_skip: Optional[int] = None,
87
+ weight_dtype: torch.dtype = torch.float32,
88
+ ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
89
+ pipe = StableDiffusionPipeline.from_ckpt(
90
+ checkpoint_path,
91
+ upcast_attention=True if v2 else False,
92
+ torch_dtype=weight_dtype,
93
+ cache_dir=DIFFUSERS_CACHE_DIR,
94
+ )
95
+
96
+ unet = pipe.unet
97
+ tokenizer = pipe.tokenizer
98
+ text_encoder = pipe.text_encoder
99
+ vae = pipe.vae
100
+ if clip_skip is not None:
101
+ if v2:
102
+ text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1)
103
+ else:
104
+ text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1)
105
+
106
+ del pipe
107
+
108
+ return tokenizer, text_encoder, unet, vae
109
+
110
+
111
+ def load_models(
112
+ pretrained_model_name_or_path: str,
113
+ ckpt_path: str,
114
+ scheduler_name: AVAILABLE_SCHEDULERS,
115
+ v2: bool = False,
116
+ v_pred: bool = False,
117
+ weight_dtype: torch.dtype = torch.float32,
118
+ ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]:
119
+ if pretrained_model_name_or_path.endswith(
120
+ ".ckpt"
121
+ ) or pretrained_model_name_or_path.endswith(".safetensors"):
122
+ tokenizer, text_encoder, unet, vae = load_checkpoint_model(
123
+ pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
124
+ )
125
+ else: # diffusers
126
+ tokenizer, text_encoder, unet, vae = load_diffusers_model(
127
+ pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
128
+ )
129
+
130
+ # VAE はいらない
131
+
132
+ scheduler = create_noise_scheduler(
133
+ scheduler_name,
134
+ prediction_type="v_prediction" if v_pred else "epsilon",
135
+ )
136
+ # trained unet_ema
137
+ if ckpt_path not in [None, "None"]:
138
+ ema_unet = EMAModel.from_pretrained(os.path.join(ckpt_path, "unet_ema"), UNet2DConditionModel)
139
+ ema_unet.copy_to(unet.parameters())
140
+ return tokenizer, text_encoder, unet, scheduler, vae
141
+
142
+
143
+ def load_diffusers_model_xl(
144
+ pretrained_model_name_or_path: str,
145
+ weight_dtype: torch.dtype = torch.float32,
146
+ ) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
147
+ # returns tokenizer, tokenizer_2, text_encoder, text_encoder_2, unet
148
+
149
+ tokenizers = [
150
+ CLIPTokenizer.from_pretrained(
151
+ pretrained_model_name_or_path,
152
+ subfolder="tokenizer",
153
+ torch_dtype=weight_dtype,
154
+ cache_dir=DIFFUSERS_CACHE_DIR,
155
+ ),
156
+ CLIPTokenizer.from_pretrained(
157
+ pretrained_model_name_or_path,
158
+ subfolder="tokenizer_2",
159
+ torch_dtype=weight_dtype,
160
+ cache_dir=DIFFUSERS_CACHE_DIR,
161
+ pad_token_id=0, # same as open clip
162
+ ),
163
+ ]
164
+
165
+ text_encoders = [
166
+ CLIPTextModel.from_pretrained(
167
+ pretrained_model_name_or_path,
168
+ subfolder="text_encoder",
169
+ torch_dtype=weight_dtype,
170
+ cache_dir=DIFFUSERS_CACHE_DIR,
171
+ ),
172
+ CLIPTextModelWithProjection.from_pretrained(
173
+ pretrained_model_name_or_path,
174
+ subfolder="text_encoder_2",
175
+ torch_dtype=weight_dtype,
176
+ cache_dir=DIFFUSERS_CACHE_DIR,
177
+ ),
178
+ ]
179
+
180
+ unet = UNet2DConditionModel.from_pretrained(
181
+ pretrained_model_name_or_path,
182
+ subfolder="unet",
183
+ torch_dtype=weight_dtype,
184
+ cache_dir=DIFFUSERS_CACHE_DIR,
185
+ )
186
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
187
+ return tokenizers, text_encoders, unet, vae
188
+
189
+
190
+ def load_checkpoint_model_xl(
191
+ checkpoint_path: str,
192
+ weight_dtype: torch.dtype = torch.float32,
193
+ ) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
194
+ pipe = StableDiffusionXLPipeline.from_single_file(
195
+ checkpoint_path,
196
+ torch_dtype=weight_dtype,
197
+ cache_dir=DIFFUSERS_CACHE_DIR,
198
+ )
199
+
200
+ unet = pipe.unet
201
+ tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
202
+ text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
203
+ if len(text_encoders) == 2:
204
+ text_encoders[1].pad_token_id = 0
205
+ vae = pipe.vae
206
+ del pipe
207
+
208
+ return tokenizers, text_encoders, unet, vae
209
+
210
+
211
+ def load_models_xl(
212
+ pretrained_model_name_or_path: str,
213
+ scheduler_name: AVAILABLE_SCHEDULERS,
214
+ weight_dtype: torch.dtype = torch.float32,
215
+ ) -> tuple[
216
+ list[CLIPTokenizer],
217
+ list[SDXL_TEXT_ENCODER_TYPE],
218
+ UNet2DConditionModel,
219
+ SchedulerMixin,
220
+ ]:
221
+ if pretrained_model_name_or_path.endswith(
222
+ ".ckpt"
223
+ ) or pretrained_model_name_or_path.endswith(".safetensors"):
224
+ (
225
+ tokenizers,
226
+ text_encoders,
227
+ unet,
228
+ vae
229
+ ) = load_checkpoint_model_xl(pretrained_model_name_or_path, weight_dtype)
230
+ else: # diffusers
231
+ (
232
+ tokenizers,
233
+ text_encoders,
234
+ unet,
235
+ vae
236
+ ) = load_diffusers_model_xl(pretrained_model_name_or_path, weight_dtype)
237
+
238
+ scheduler = create_noise_scheduler(scheduler_name)
239
+
240
+ return tokenizers, text_encoders, unet, scheduler, vae
241
+
242
+
243
+ def create_noise_scheduler(
244
+ scheduler_name: AVAILABLE_SCHEDULERS = "ddpm",
245
+ prediction_type: Literal["epsilon", "v_prediction"] = "epsilon",
246
+ ) -> SchedulerMixin:
247
+
248
+
249
+ name = scheduler_name.lower().replace(" ", "_")
250
+ if name == "ddim":
251
+ # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddim
252
+ scheduler = DDIMScheduler(
253
+ beta_start=0.00085,
254
+ beta_end=0.012,
255
+ beta_schedule="scaled_linear",
256
+ num_train_timesteps=1000,
257
+ clip_sample=False,
258
+ prediction_type=prediction_type, # これでいいの?
259
+ )
260
+ elif name == "ddpm":
261
+ # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddpm
262
+ scheduler = DDPMScheduler(
263
+ beta_start=0.00085,
264
+ beta_end=0.012,
265
+ beta_schedule="scaled_linear",
266
+ num_train_timesteps=1000,
267
+ clip_sample=False,
268
+ prediction_type=prediction_type,
269
+ )
270
+ elif name == "lms":
271
+ # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/lms_discrete
272
+ scheduler = LMSDiscreteScheduler(
273
+ beta_start=0.00085,
274
+ beta_end=0.012,
275
+ beta_schedule="scaled_linear",
276
+ num_train_timesteps=1000,
277
+ prediction_type=prediction_type,
278
+ )
279
+ elif name == "euler_a":
280
+ # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral
281
+ scheduler = EulerAncestralDiscreteScheduler(
282
+ beta_start=0.00085,
283
+ beta_end=0.012,
284
+ beta_schedule="scaled_linear",
285
+ num_train_timesteps=1000,
286
+ prediction_type=prediction_type,
287
+ )
288
+ else:
289
+ raise ValueError(f"Unknown scheduler name: {name}")
290
+
291
+ return scheduler
utils/prompt_util.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Optional, Union, List
2
+
3
+ import yaml
4
+ from pathlib import Path
5
+
6
+
7
+ from pydantic import BaseModel, root_validator
8
+ import torch
9
+ import copy
10
+
11
+ ACTION_TYPES = Literal[
12
+ "erase",
13
+ "enhance",
14
+ ]
15
+
16
+
17
+ # XL は二種類必要なので
18
+ class PromptEmbedsXL:
19
+ text_embeds: torch.FloatTensor
20
+ pooled_embeds: torch.FloatTensor
21
+
22
+ def __init__(self, *args) -> None:
23
+ self.text_embeds = args[0]
24
+ self.pooled_embeds = args[1]
25
+
26
+
27
+ # SDv1.x, SDv2.x は FloatTensor、XL は PromptEmbedsXL
28
+ PROMPT_EMBEDDING = Union[torch.FloatTensor, PromptEmbedsXL]
29
+
30
+
31
+ class PromptEmbedsCache: # 使いまわしたいので
32
+ prompts: dict[str, PROMPT_EMBEDDING] = {}
33
+
34
+ def __setitem__(self, __name: str, __value: PROMPT_EMBEDDING) -> None:
35
+ self.prompts[__name] = __value
36
+
37
+ def __getitem__(self, __name: str) -> Optional[PROMPT_EMBEDDING]:
38
+ if __name in self.prompts:
39
+ return self.prompts[__name]
40
+ else:
41
+ return None
42
+
43
+
44
+ class PromptSettings(BaseModel): # yaml のやつ
45
+ target: str
46
+ positive: str = None # if None, target will be used
47
+ unconditional: str = "" # default is ""
48
+ neutral: str = None # if None, unconditional will be used
49
+ action: ACTION_TYPES = "erase" # default is "erase"
50
+ guidance_scale: float = 1.0 # default is 1.0
51
+ resolution: int = 512 # default is 512
52
+ dynamic_resolution: bool = False # default is False
53
+ batch_size: int = 1 # default is 1
54
+ dynamic_crops: bool = False # default is False. only used when model is XL
55
+
56
+ @root_validator(pre=True)
57
+ def fill_prompts(cls, values):
58
+ keys = values.keys()
59
+ if "target" not in keys:
60
+ raise ValueError("target must be specified")
61
+ if "positive" not in keys:
62
+ values["positive"] = values["target"]
63
+ if "unconditional" not in keys:
64
+ values["unconditional"] = ""
65
+ if "neutral" not in keys:
66
+ values["neutral"] = values["unconditional"]
67
+
68
+ return values
69
+
70
+
71
+ class PromptEmbedsPair:
72
+ target: PROMPT_EMBEDDING # not want to generate the concept
73
+ positive: PROMPT_EMBEDDING # generate the concept
74
+ unconditional: PROMPT_EMBEDDING # uncondition (default should be empty)
75
+ neutral: PROMPT_EMBEDDING # base condition (default should be empty)
76
+
77
+ guidance_scale: float
78
+ resolution: int
79
+ dynamic_resolution: bool
80
+ batch_size: int
81
+ dynamic_crops: bool
82
+
83
+ loss_fn: torch.nn.Module
84
+ action: ACTION_TYPES
85
+
86
+ def __init__(
87
+ self,
88
+ loss_fn: torch.nn.Module,
89
+ target: PROMPT_EMBEDDING,
90
+ positive: PROMPT_EMBEDDING,
91
+ unconditional: PROMPT_EMBEDDING,
92
+ neutral: PROMPT_EMBEDDING,
93
+ settings: PromptSettings,
94
+ ) -> None:
95
+ self.loss_fn = loss_fn
96
+ self.target = target
97
+ self.positive = positive
98
+ self.unconditional = unconditional
99
+ self.neutral = neutral
100
+
101
+ self.guidance_scale = settings.guidance_scale
102
+ self.resolution = settings.resolution
103
+ self.dynamic_resolution = settings.dynamic_resolution
104
+ self.batch_size = settings.batch_size
105
+ self.dynamic_crops = settings.dynamic_crops
106
+ self.action = settings.action
107
+
108
+ def _erase(
109
+ self,
110
+ target_latents: torch.FloatTensor, # "van gogh"
111
+ positive_latents: torch.FloatTensor, # "van gogh"
112
+ unconditional_latents: torch.FloatTensor, # ""
113
+ neutral_latents: torch.FloatTensor, # ""
114
+ ) -> torch.FloatTensor:
115
+ """Target latents are going not to have the positive concept."""
116
+ return self.loss_fn(
117
+ target_latents,
118
+ neutral_latents
119
+ - self.guidance_scale * (positive_latents - unconditional_latents)
120
+ )
121
+
122
+
123
+ def _enhance(
124
+ self,
125
+ target_latents: torch.FloatTensor, # "van gogh"
126
+ positive_latents: torch.FloatTensor, # "van gogh"
127
+ unconditional_latents: torch.FloatTensor, # ""
128
+ neutral_latents: torch.FloatTensor, # ""
129
+ ):
130
+ """Target latents are going to have the positive concept."""
131
+ return self.loss_fn(
132
+ target_latents,
133
+ neutral_latents
134
+ + self.guidance_scale * (positive_latents - unconditional_latents)
135
+ )
136
+
137
+ def loss(
138
+ self,
139
+ **kwargs,
140
+ ):
141
+ if self.action == "erase":
142
+ return self._erase(**kwargs)
143
+
144
+ elif self.action == "enhance":
145
+ return self._enhance(**kwargs)
146
+
147
+ else:
148
+ raise ValueError("action must be erase or enhance")
149
+
150
+
151
+ def load_prompts_from_yaml(path, attributes = []):
152
+ with open(path, "r") as f:
153
+ prompts = yaml.safe_load(f)
154
+ print(prompts)
155
+ if len(prompts) == 0:
156
+ raise ValueError("prompts file is empty")
157
+ if len(attributes)!=0:
158
+ newprompts = []
159
+ for i in range(len(prompts)):
160
+ for att in attributes:
161
+ copy_ = copy.deepcopy(prompts[i])
162
+ copy_['target'] = att + ' ' + copy_['target']
163
+ copy_['positive'] = att + ' ' + copy_['positive']
164
+ copy_['neutral'] = att + ' ' + copy_['neutral']
165
+ copy_['unconditional'] = att + ' ' + copy_['unconditional']
166
+ newprompts.append(copy_)
167
+ else:
168
+ newprompts = copy.deepcopy(prompts)
169
+
170
+ print(newprompts)
171
+ print(len(prompts), len(newprompts))
172
+ prompt_settings = [PromptSettings(**prompt) for prompt in newprompts]
173
+
174
+ return prompt_settings
utils/train_util.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+
5
+ from transformers import CLIPTextModel, CLIPTokenizer, BertModel, BertTokenizer
6
+ from diffusers import UNet2DConditionModel, SchedulerMixin
7
+ from diffusers.image_processor import VaeImageProcessor
8
+ import sys
9
+ import os
10
+ # sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
11
+ # from imagesliders.model_util import SDXL_TEXT_ENCODER_TYPE
12
+ from diffusers.utils.torch_utils import randn_tensor
13
+
14
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
15
+
16
+ SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection]
17
+
18
+ from tqdm import tqdm
19
+
20
+ UNET_IN_CHANNELS = 4 # Stable Diffusion in_channels
21
+ VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
22
+
23
+ UNET_ATTENTION_TIME_EMBED_DIM = 256 # XL
24
+ TEXT_ENCODER_2_PROJECTION_DIM = 1280
25
+ UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM = 2816
26
+
27
+
28
+ def get_random_noise(
29
+ batch_size: int, height: int, width: int, generator: torch.Generator = None
30
+ ) -> torch.Tensor:
31
+ return torch.randn(
32
+ (
33
+ batch_size,
34
+ UNET_IN_CHANNELS,
35
+ height // VAE_SCALE_FACTOR,
36
+ width // VAE_SCALE_FACTOR,
37
+ ),
38
+ generator=generator,
39
+ device="cpu",
40
+ )
41
+
42
+
43
+
44
+ def apply_noise_offset(latents: torch.FloatTensor, noise_offset: float):
45
+ latents = latents + noise_offset * torch.randn(
46
+ (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
47
+ )
48
+ return latents
49
+
50
+
51
+ def get_initial_latents(
52
+ scheduler: SchedulerMixin,
53
+ n_imgs: int,
54
+ height: int,
55
+ width: int,
56
+ n_prompts: int,
57
+ generator=None,
58
+ ) -> torch.Tensor:
59
+ noise = get_random_noise(n_imgs, height, width, generator=generator).repeat(
60
+ n_prompts, 1, 1, 1
61
+ )
62
+
63
+ latents = noise * scheduler.init_noise_sigma
64
+
65
+ return latents
66
+
67
+
68
+ def text_tokenize(
69
+ tokenizer, # 普通ならひとつ、XLならふたつ!
70
+ prompts,
71
+ ):
72
+ return tokenizer(
73
+ prompts,
74
+ padding="max_length",
75
+ max_length=tokenizer.model_max_length,
76
+ truncation=True,
77
+ return_tensors="pt",
78
+ )
79
+
80
+
81
+ def text_encode(text_encoder , tokens):
82
+ tokens = tokens.to(text_encoder.device)
83
+ if isinstance(text_encoder, BertModel):
84
+ embed = text_encoder(**tokens, return_dict=False)[0]
85
+ elif isinstance(text_encoder, CLIPTextModel):
86
+ # embed = text_encoder(**tokens, return_dict=False)[0]
87
+ embed = text_encoder(tokens.input_ids, return_dict=False)[0]
88
+ else:
89
+ raise ValueError("text_encoder must be BertModel or CLIPTextModel")
90
+ return embed
91
+
92
+ def encode_prompts(
93
+ tokenizer,
94
+ text_encoder,
95
+ prompts: list[str],
96
+ ):
97
+ # print(f"prompts: {prompts}")
98
+ text_tokens = text_tokenize(tokenizer, prompts)
99
+ # print(f"text_tokens: {text_tokens}")
100
+ text_embeddings = text_encode(text_encoder, text_tokens)
101
+ # print(f"text_embeddings: {text_embeddings}")
102
+
103
+
104
+ return text_embeddings
105
+
106
+ def prompt_replace(original, key="{prompt}", prompt=""):
107
+ if key not in original:
108
+ return original
109
+
110
+ if isinstance(prompt, list):
111
+ ret =[]
112
+ for p in prompt:
113
+ p = p.replace(".", "")
114
+ r = original.replace(key, p)
115
+ r = r.capitalize()
116
+ ret.append(r)
117
+ else:
118
+ prompt = prompt.replace(".", "")
119
+ ret = original.replace(key, prompt)
120
+ ret = ret.capitalize()
121
+ return ret
122
+
123
+
124
+
125
+ def text_encode_xl(
126
+ text_encoder: SDXL_TEXT_ENCODER_TYPE,
127
+ tokens: torch.FloatTensor,
128
+ num_images_per_prompt: int = 1,
129
+ ):
130
+ prompt_embeds = text_encoder(
131
+ tokens.to(text_encoder.device), output_hidden_states=True
132
+ )
133
+ pooled_prompt_embeds = prompt_embeds[0]
134
+ prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer
135
+
136
+ bs_embed, seq_len, _ = prompt_embeds.shape
137
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
138
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
139
+
140
+ return prompt_embeds, pooled_prompt_embeds
141
+
142
+
143
+ def encode_prompts_xl(
144
+ tokenizers: list[CLIPTokenizer],
145
+ text_encoders: list[SDXL_TEXT_ENCODER_TYPE],
146
+ prompts: list[str],
147
+ num_images_per_prompt: int = 1,
148
+ ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
149
+ # text_encoder and text_encoder_2's penuultimate layer's output
150
+ text_embeds_list = []
151
+ pooled_text_embeds = None # always text_encoder_2's pool
152
+
153
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
154
+ text_tokens_input_ids = text_tokenize(tokenizer, prompts)
155
+ text_embeds, pooled_text_embeds = text_encode_xl(
156
+ text_encoder, text_tokens_input_ids, num_images_per_prompt
157
+ )
158
+
159
+ text_embeds_list.append(text_embeds)
160
+
161
+ bs_embed = pooled_text_embeds.shape[0]
162
+ pooled_text_embeds = pooled_text_embeds.repeat(1, num_images_per_prompt).view(
163
+ bs_embed * num_images_per_prompt, -1
164
+ )
165
+
166
+ return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds
167
+
168
+
169
+ def concat_embeddings(
170
+ unconditional: torch.FloatTensor,
171
+ conditional: torch.FloatTensor,
172
+ n_imgs: int,
173
+ ):
174
+ if conditional.shape[0] == n_imgs and unconditional.shape[0] == 1:
175
+ return torch.cat([unconditional.repeat(n_imgs, 1, 1), conditional], dim=0)
176
+ return torch.cat([unconditional, conditional]).repeat_interleave(n_imgs, dim=0)
177
+
178
+
179
+ def predict_noise(
180
+ unet: UNet2DConditionModel,
181
+ scheduler: SchedulerMixin,
182
+ timestep: int,
183
+ latents: torch.FloatTensor,
184
+ text_embeddings: torch.FloatTensor, # uncond な text embed と cond な text embed を結合したもの
185
+ guidance_scale=7.5,
186
+ ) -> torch.FloatTensor:
187
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
188
+ latent_model_input = torch.cat([latents] * 2)
189
+
190
+ latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
191
+ # batch_size = latents.shape[0]
192
+ # text_embeddings = text_embeddings.repeat_interleave(batch_size, dim=0)
193
+ # predict the noise residual
194
+ noise_pred = unet(
195
+ latent_model_input,
196
+ timestep,
197
+ encoder_hidden_states=text_embeddings,
198
+ ).sample
199
+
200
+ # perform guidance
201
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
202
+ guided_target = noise_pred_uncond + guidance_scale * (
203
+ noise_pred_text - noise_pred_uncond
204
+ )
205
+
206
+ return guided_target
207
+
208
+
209
+
210
+ @torch.no_grad()
211
+ def diffusion(
212
+ unet: UNet2DConditionModel,
213
+ scheduler: SchedulerMixin,
214
+ latents: torch.FloatTensor,
215
+ text_embeddings: torch.FloatTensor,
216
+ total_timesteps: int = 1000,
217
+ start_timesteps=0,
218
+ **kwargs,
219
+ ):
220
+ # latents_steps = []
221
+
222
+ for timestep in scheduler.timesteps[start_timesteps:total_timesteps]:
223
+ noise_pred = predict_noise(
224
+ unet, scheduler, timestep, latents, text_embeddings, **kwargs
225
+ )
226
+
227
+ # compute the previous noisy sample x_t -> x_t-1
228
+ latents = scheduler.step(noise_pred, timestep, latents).prev_sample
229
+
230
+ # return latents_steps
231
+ return latents
232
+
233
+ @torch.no_grad()
234
+ def get_noisy_image(
235
+ img,
236
+ vae,
237
+ generator,
238
+ unet: UNet2DConditionModel,
239
+ scheduler: SchedulerMixin,
240
+ total_timesteps: int = 1000,
241
+ start_timesteps=0,
242
+
243
+ **kwargs,
244
+ ):
245
+ # latents_steps = []
246
+ vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
247
+ image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
248
+
249
+ image = img
250
+ # im_orig = image
251
+ device = vae.device
252
+ image = image_processor.preprocess(image).to(device)
253
+
254
+ init_latents = vae.encode(image).latent_dist.sample(None)
255
+ init_latents = vae.config.scaling_factor * init_latents
256
+
257
+ init_latents = torch.cat([init_latents], dim=0)
258
+
259
+ shape = init_latents.shape
260
+
261
+ noise = randn_tensor(shape, generator=generator, device=device)
262
+
263
+ time_ = total_timesteps
264
+ timestep = scheduler.timesteps[time_:time_+1]
265
+ # get latents
266
+ noised_latents = scheduler.add_noise(init_latents, noise, timestep)
267
+
268
+ return noised_latents, noise, init_latents
269
+
270
+ def subtract_noise(
271
+ latent: torch.FloatTensor,
272
+ noise: torch.FloatTensor,
273
+ timesteps: torch.IntTensor,
274
+ scheduler: SchedulerMixin,
275
+ ) -> torch.FloatTensor:
276
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
277
+ # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
278
+ # for the subsequent add_noise calls
279
+ scheduler.alphas_cumprod = scheduler.alphas_cumprod.to(device=latent.device)
280
+ alphas_cumprod = scheduler.alphas_cumprod.to(dtype=latent.dtype)
281
+ timesteps = timesteps.to(latent.device)
282
+
283
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
284
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
285
+ while len(sqrt_alpha_prod.shape) < len(latent.shape):
286
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
287
+
288
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
289
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
290
+ while len(sqrt_one_minus_alpha_prod.shape) < len(latent.shape):
291
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
292
+
293
+ denoised_latent = (latent - sqrt_one_minus_alpha_prod * noise) / sqrt_alpha_prod
294
+ return denoised_latent
295
+ def get_denoised_image(
296
+ latents: torch.FloatTensor,
297
+ noise_pred: torch.FloatTensor,
298
+ timestep: int,
299
+ # total_timesteps: int,
300
+ scheduler: SchedulerMixin,
301
+ vae: VaeImageProcessor,
302
+ ):
303
+ denoised_latents = subtract_noise(latents, noise_pred, timestep, scheduler)
304
+ denoised_latents = denoised_latents / vae.config.scaling_factor # 0.18215
305
+ denoised_img = vae.decode(denoised_latents).sample
306
+ # denoised_img = denoised_img.clamp(-1,1)
307
+ return denoised_img
308
+
309
+
310
+ def rescale_noise_cfg(
311
+ noise_cfg: torch.FloatTensor, noise_pred_text, guidance_rescale=0.0
312
+ ):
313
+
314
+ std_text = noise_pred_text.std(
315
+ dim=list(range(1, noise_pred_text.ndim)), keepdim=True
316
+ )
317
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
318
+ # rescale the results from guidance (fixes overexposure)
319
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
320
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
321
+ noise_cfg = (
322
+ guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
323
+ )
324
+
325
+ return noise_cfg
326
+
327
+
328
+ def predict_noise_xl(
329
+ unet: UNet2DConditionModel,
330
+ scheduler: SchedulerMixin,
331
+ timestep: int,
332
+ latents: torch.FloatTensor,
333
+ text_embeddings: torch.FloatTensor, # uncond な text embed と cond な text embed を結合したもの
334
+ add_text_embeddings: torch.FloatTensor, # pooled なやつ
335
+ add_time_ids: torch.FloatTensor,
336
+ guidance_scale=7.5,
337
+ guidance_rescale=0.7,
338
+ ) -> torch.FloatTensor:
339
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
340
+ latent_model_input = torch.cat([latents] * 2)
341
+
342
+ latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
343
+
344
+ added_cond_kwargs = {
345
+ "text_embeds": add_text_embeddings,
346
+ "time_ids": add_time_ids,
347
+ }
348
+
349
+ # predict the noise residual
350
+ noise_pred = unet(
351
+ latent_model_input,
352
+ timestep,
353
+ encoder_hidden_states=text_embeddings,
354
+ added_cond_kwargs=added_cond_kwargs,
355
+ ).sample
356
+
357
+ # perform guidance
358
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
359
+ guided_target = noise_pred_uncond + guidance_scale * (
360
+ noise_pred_text - noise_pred_uncond
361
+ )
362
+
363
+ noise_pred = rescale_noise_cfg(
364
+ noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
365
+ )
366
+
367
+ return guided_target
368
+
369
+
370
+ @torch.no_grad()
371
+ def diffusion_xl(
372
+ unet: UNet2DConditionModel,
373
+ scheduler: SchedulerMixin,
374
+ latents: torch.FloatTensor,
375
+ text_embeddings: tuple[torch.FloatTensor, torch.FloatTensor],
376
+ add_text_embeddings: torch.FloatTensor,
377
+ add_time_ids: torch.FloatTensor,
378
+ guidance_scale: float = 1.0,
379
+ total_timesteps: int = 1000,
380
+ start_timesteps=0,
381
+ ):
382
+ # latents_steps = []
383
+
384
+ for timestep in tqdm(scheduler.timesteps[start_timesteps:total_timesteps]):
385
+ noise_pred = predict_noise_xl(
386
+ unet,
387
+ scheduler,
388
+ timestep,
389
+ latents,
390
+ text_embeddings,
391
+ add_text_embeddings,
392
+ add_time_ids,
393
+ guidance_scale=guidance_scale,
394
+ guidance_rescale=0.7,
395
+ )
396
+
397
+ # compute the previous noisy sample x_t -> x_t-1
398
+ latents = scheduler.step(noise_pred, timestep, latents).prev_sample
399
+
400
+ # return latents_steps
401
+ return latents
402
+
403
+
404
+ # for XL
405
+ def get_add_time_ids(
406
+ height: int,
407
+ width: int,
408
+ dynamic_crops: bool = False,
409
+ dtype: torch.dtype = torch.float32,
410
+ ):
411
+ if dynamic_crops:
412
+ # random float scale between 1 and 3
413
+ random_scale = torch.rand(1).item() * 2 + 1
414
+ original_size = (int(height * random_scale), int(width * random_scale))
415
+ # random position
416
+ crops_coords_top_left = (
417
+ torch.randint(0, original_size[0] - height, (1,)).item(),
418
+ torch.randint(0, original_size[1] - width, (1,)).item(),
419
+ )
420
+ target_size = (height, width)
421
+ else:
422
+ original_size = (height, width)
423
+ crops_coords_top_left = (0, 0)
424
+ target_size = (height, width)
425
+
426
+ # this is expected as 6
427
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
428
+
429
+ # this is expected as 2816
430
+ passed_add_embed_dim = (
431
+ UNET_ATTENTION_TIME_EMBED_DIM * len(add_time_ids) # 256 * 6
432
+ + TEXT_ENCODER_2_PROJECTION_DIM # + 1280
433
+ )
434
+ if passed_add_embed_dim != UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM:
435
+ raise ValueError(
436
+ f"Model expects an added time embedding vector of length {UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
437
+ )
438
+
439
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
440
+ return add_time_ids
441
+
442
+
443
+ def get_optimizer(name: str):
444
+ name = name.lower()
445
+
446
+ if name.startswith("dadapt"):
447
+ import dadaptation
448
+
449
+ if name == "dadaptadam":
450
+ return dadaptation.DAdaptAdam
451
+ elif name == "dadaptlion":
452
+ return dadaptation.DAdaptLion
453
+ else:
454
+ raise ValueError("DAdapt optimizer must be dadaptadam or dadaptlion")
455
+
456
+ elif name.endswith("8bit"):
457
+ import bitsandbytes as bnb
458
+
459
+ if name == "adam8bit":
460
+ return bnb.optim.Adam8bit
461
+ elif name == "lion8bit":
462
+ return bnb.optim.Lion8bit
463
+ else:
464
+ raise ValueError("8bit optimizer must be adam8bit or lion8bit")
465
+
466
+ else:
467
+ if name == "adam":
468
+ return torch.optim.Adam
469
+ elif name == "adamw":
470
+ return torch.optim.AdamW
471
+ elif name == "lion":
472
+ from lion_pytorch import Lion
473
+
474
+ return Lion
475
+ elif name == "prodigy":
476
+ import prodigyopt
477
+
478
+ return prodigyopt.Prodigy
479
+ else:
480
+ raise ValueError("Optimizer must be adam, adamw, lion or Prodigy")
481
+
482
+
483
+ def get_lr_scheduler(
484
+ name: Optional[str],
485
+ optimizer: torch.optim.Optimizer,
486
+ max_iterations: Optional[int],
487
+ lr_min: Optional[float],
488
+ **kwargs,
489
+ ):
490
+ if name == "cosine":
491
+ return torch.optim.lr_scheduler.CosineAnnealingLR(
492
+ optimizer, T_max=max_iterations, eta_min=lr_min, **kwargs
493
+ )
494
+ elif name == "cosine_with_restarts":
495
+ return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
496
+ optimizer, T_0=max_iterations // 10, T_mult=2, eta_min=lr_min, **kwargs
497
+ )
498
+ elif name == "step":
499
+ return torch.optim.lr_scheduler.StepLR(
500
+ optimizer, step_size=max_iterations // 100, gamma=0.999, **kwargs
501
+ )
502
+ elif name == "constant":
503
+ return torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1, **kwargs)
504
+ elif name == "linear":
505
+ return torch.optim.lr_scheduler.LinearLR(
506
+ optimizer, factor=0.5, total_iters=max_iterations // 100, **kwargs
507
+ )
508
+ else:
509
+ raise ValueError(
510
+ "Scheduler must be cosine, cosine_with_restarts, step, linear or constant"
511
+ )
512
+
513
+
514
+ def get_random_resolution_in_bucket(bucket_resolution: int = 512) -> tuple[int, int]:
515
+ max_resolution = bucket_resolution
516
+ min_resolution = bucket_resolution // 2
517
+
518
+ step = 64
519
+
520
+ min_step = min_resolution // step
521
+ max_step = max_resolution // step
522
+
523
+ height = torch.randint(min_step, max_step, (1,)).item() * step
524
+ width = torch.randint(min_step, max_step, (1,)).item() * step
525
+
526
+ return height, width