Spaces:
Running
on
Zero
Running
on
Zero
Upload folder using huggingface_hub
Browse files- .ipynb_checkpoints/hf_demo_test-checkpoint.ipynb +336 -0
- README.md +3 -10
- __pycache__/inference.cpython-39.pyc +0 -0
- custom_datasets/__init__.py +141 -0
- custom_datasets/__pycache__/__init__.cpython-39.pyc +0 -0
- custom_datasets/__pycache__/coco.cpython-39.pyc +0 -0
- custom_datasets/__pycache__/imagepair.cpython-39.pyc +0 -0
- custom_datasets/__pycache__/mypath.cpython-39.pyc +0 -0
- custom_datasets/coco.py +307 -0
- custom_datasets/custom_caption.py +113 -0
- custom_datasets/filt/coco/filt.py +186 -0
- custom_datasets/filt/sam_filt.py +299 -0
- custom_datasets/imagepair.py +240 -0
- custom_datasets/lhq.py +127 -0
- custom_datasets/mypath.py +29 -0
- custom_datasets/sam.py +160 -0
- data/Art_adapters/albert-gleizes_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/Art_adapters/andre-derain_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/Art_adapters/andy_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/Art_adapters/camille-corot_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/Art_adapters/gerhard-richter_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/Art_adapters/henri-matisse_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/Art_adapters/jackson-pollock_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/Art_adapters/joan-miro_subset2/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/Art_adapters/kandinsky_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/Art_adapters/katsushika-hokusai_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/Art_adapters/klimt_subset3/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/Art_adapters/m.c.-escher_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/Art_adapters/monet_subset2/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/Art_adapters/picasso_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/Art_adapters/roy-lichtenstein_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/Art_adapters/van_gogh_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/Art_adapters/walter-battiss_subset2/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/unsafe.png +0 -0
- hf_demo.py +147 -0
- hf_demo_test.ipynb +336 -0
- inference.py +657 -0
- utils/__init__.py +1 -0
- utils/__pycache__/__init__.cpython-39.pyc +0 -0
- utils/__pycache__/lora.cpython-39.pyc +0 -0
- utils/__pycache__/metrics.cpython-39.pyc +0 -0
- utils/__pycache__/train_util.cpython-39.pyc +0 -0
- utils/art_filter.py +210 -0
- utils/config_util.py +105 -0
- utils/debug_util.py +16 -0
- utils/lora.py +282 -0
- utils/metrics.py +577 -0
- utils/model_util.py +291 -0
- utils/prompt_util.py +174 -0
- utils/train_util.py +526 -0
.ipynb_checkpoints/hf_demo_test-checkpoint.ipynb
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "initial_id",
|
7 |
+
"metadata": {
|
8 |
+
"ExecuteTime": {
|
9 |
+
"end_time": "2024-12-09T09:44:30.641366Z",
|
10 |
+
"start_time": "2024-12-09T09:44:11.789050Z"
|
11 |
+
}
|
12 |
+
},
|
13 |
+
"outputs": [],
|
14 |
+
"source": [
|
15 |
+
"import os\n",
|
16 |
+
"\n",
|
17 |
+
"import gradio as gr\n",
|
18 |
+
"from diffusers import DiffusionPipeline\n",
|
19 |
+
"import matplotlib.pyplot as plt\n",
|
20 |
+
"import torch\n",
|
21 |
+
"from PIL import Image\n"
|
22 |
+
]
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"cell_type": "code",
|
26 |
+
"execution_count": 2,
|
27 |
+
"id": "ddf33e0d3abacc2c",
|
28 |
+
"metadata": {},
|
29 |
+
"outputs": [],
|
30 |
+
"source": [
|
31 |
+
"import sys\n",
|
32 |
+
"#append current path\n",
|
33 |
+
"sys.path.extend(\"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/release/hf_demo\")"
|
34 |
+
]
|
35 |
+
},
|
36 |
+
{
|
37 |
+
"cell_type": "code",
|
38 |
+
"execution_count": 3,
|
39 |
+
"id": "643e49fd601daf8f",
|
40 |
+
"metadata": {
|
41 |
+
"ExecuteTime": {
|
42 |
+
"end_time": "2024-12-09T09:44:35.790962Z",
|
43 |
+
"start_time": "2024-12-09T09:44:35.779496Z"
|
44 |
+
}
|
45 |
+
},
|
46 |
+
"outputs": [],
|
47 |
+
"source": [
|
48 |
+
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\""
|
49 |
+
]
|
50 |
+
},
|
51 |
+
{
|
52 |
+
"cell_type": "code",
|
53 |
+
"execution_count": 4,
|
54 |
+
"id": "e03aae2a4e5676dd",
|
55 |
+
"metadata": {
|
56 |
+
"ExecuteTime": {
|
57 |
+
"end_time": "2024-12-09T09:44:44.157412Z",
|
58 |
+
"start_time": "2024-12-09T09:44:37.138452Z"
|
59 |
+
}
|
60 |
+
},
|
61 |
+
"outputs": [
|
62 |
+
{
|
63 |
+
"name": "stderr",
|
64 |
+
"output_type": "stream",
|
65 |
+
"text": [
|
66 |
+
"/data/vision/torralba/selfmanaged/torralba/scratch/jomat/sam_dataset/miniforge3/envs/diffusion/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
67 |
+
" warnings.warn(\n"
|
68 |
+
]
|
69 |
+
},
|
70 |
+
{
|
71 |
+
"data": {
|
72 |
+
"application/vnd.jupyter.widget-view+json": {
|
73 |
+
"model_id": "9df8347307674ba8afb0250e23109aa1",
|
74 |
+
"version_major": 2,
|
75 |
+
"version_minor": 0
|
76 |
+
},
|
77 |
+
"text/plain": [
|
78 |
+
"Loading pipeline components...: 0%| | 0/7 [00:00<?, ?it/s]"
|
79 |
+
]
|
80 |
+
},
|
81 |
+
"metadata": {},
|
82 |
+
"output_type": "display_data"
|
83 |
+
}
|
84 |
+
],
|
85 |
+
"source": [
|
86 |
+
"pipe = DiffusionPipeline.from_pretrained(\"rhfeiyang/art-free-diffusion-v1\",).to(\"cuda\")\n",
|
87 |
+
"device = \"cuda\""
|
88 |
+
]
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"cell_type": "code",
|
92 |
+
"execution_count": 5,
|
93 |
+
"id": "83916bc68ff5d914",
|
94 |
+
"metadata": {
|
95 |
+
"ExecuteTime": {
|
96 |
+
"end_time": "2024-12-09T09:44:52.694399Z",
|
97 |
+
"start_time": "2024-12-09T09:44:44.210695Z"
|
98 |
+
}
|
99 |
+
},
|
100 |
+
"outputs": [],
|
101 |
+
"source": [
|
102 |
+
"from inference import get_lora_network, inference, get_validation_dataloader\n",
|
103 |
+
"lora_map = {\n",
|
104 |
+
" \"None\": \"None\",\n",
|
105 |
+
" \"Andre Derain\": \"andre-derain_subset1\",\n",
|
106 |
+
" \"Vincent van Gogh\": \"van_gogh_subset1\",\n",
|
107 |
+
" \"Andy Warhol\": \"andy_subset1\",\n",
|
108 |
+
" \"Walter Battiss\": \"walter-battiss_subset2\",\n",
|
109 |
+
" \"Camille Corot\": \"camille-corot_subset1\",\n",
|
110 |
+
" \"Claude Monet\": \"monet_subset2\",\n",
|
111 |
+
" \"Pablo Picasso\": \"picasso_subset1\",\n",
|
112 |
+
" \"Jackson Pollock\": \"jackson-pollock_subset1\",\n",
|
113 |
+
" \"Gerhard Richter\": \"gerhard-richter_subset1\",\n",
|
114 |
+
" \"M.C. Escher\": \"m.c.-escher_subset1\",\n",
|
115 |
+
" \"Albert Gleizes\": \"albert-gleizes_subset1\",\n",
|
116 |
+
" \"Hokusai\": \"katsushika-hokusai_subset1\",\n",
|
117 |
+
" \"Wassily Kandinsky\": \"kandinsky_subset1\",\n",
|
118 |
+
" \"Gustav Klimt\": \"klimt_subset3\",\n",
|
119 |
+
" \"Roy Lichtenstein\": \"roy-lichtenstein_subset1\",\n",
|
120 |
+
" \"Henri Matisse\": \"henri-matisse_subset1\",\n",
|
121 |
+
" \"Joan Miro\": \"joan-miro_subset2\",\n",
|
122 |
+
"}\n",
|
123 |
+
"\n",
|
124 |
+
"def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0, steps=50, guidance_scale=7.5):\n",
|
125 |
+
" adapter_path = lora_map[adapter_choice]\n",
|
126 |
+
" if adapter_path not in [None, \"None\"]:\n",
|
127 |
+
" adapter_path = f\"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt\"\n",
|
128 |
+
"\n",
|
129 |
+
" prompts = [prompt]*samples\n",
|
130 |
+
" infer_loader = get_validation_dataloader(prompts)\n",
|
131 |
+
" network = get_lora_network(pipe.unet, adapter_path)[\"network\"]\n",
|
132 |
+
" pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n",
|
133 |
+
" height=512, width=512, scales=[1.0],\n",
|
134 |
+
" save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,\n",
|
135 |
+
" start_noise=-1, show=False, style_prompt=\"sks art\", no_load=True,\n",
|
136 |
+
" from_scratch=True)[0][1.0]\n",
|
137 |
+
" return pred_images\n",
|
138 |
+
"\n",
|
139 |
+
"def demo_inference_stylization(adapter_path:str, prompts:list, image:list, start_noise=800,seed:int=0):\n",
|
140 |
+
" infer_loader = get_validation_dataloader(prompts, image)\n",
|
141 |
+
" network = get_lora_network(pipe.unet, adapter_path,\"all_up\")[\"network\"]\n",
|
142 |
+
" pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n",
|
143 |
+
" height=512, width=512, scales=[0.,1.],\n",
|
144 |
+
" save_dir=None, seed=seed,steps=20, guidance_scale=7.5,\n",
|
145 |
+
" start_noise=start_noise, show=True, style_prompt=\"sks art\", no_load=True,\n",
|
146 |
+
" from_scratch=False)\n",
|
147 |
+
" return pred_images\n",
|
148 |
+
"\n",
|
149 |
+
"# def infer(prompt, samples, steps, scale, seed):\n",
|
150 |
+
"# generator = torch.Generator(device=device).manual_seed(seed)\n",
|
151 |
+
"# images_list = pipe( # type: ignore\n",
|
152 |
+
"# [prompt] * samples,\n",
|
153 |
+
"# num_inference_steps=steps,\n",
|
154 |
+
"# guidance_scale=scale,\n",
|
155 |
+
"# generator=generator,\n",
|
156 |
+
"# )\n",
|
157 |
+
"# images = []\n",
|
158 |
+
"# safe_image = Image.open(r\"data/unsafe.png\")\n",
|
159 |
+
"# print(images_list)\n",
|
160 |
+
"# for i, image in enumerate(images_list[\"images\"]): # type: ignore\n",
|
161 |
+
"# if images_list[\"nsfw_content_detected\"][i]: # type: ignore\n",
|
162 |
+
"# images.append(safe_image)\n",
|
163 |
+
"# else:\n",
|
164 |
+
"# images.append(image)\n",
|
165 |
+
"# return images\n"
|
166 |
+
]
|
167 |
+
},
|
168 |
+
{
|
169 |
+
"cell_type": "code",
|
170 |
+
"execution_count": 6,
|
171 |
+
"id": "aa33e9d104023847",
|
172 |
+
"metadata": {
|
173 |
+
"ExecuteTime": {
|
174 |
+
"end_time": "2024-12-09T12:09:39.339583Z",
|
175 |
+
"start_time": "2024-12-09T12:09:38.953936Z"
|
176 |
+
}
|
177 |
+
},
|
178 |
+
"outputs": [
|
179 |
+
{
|
180 |
+
"name": "stdout",
|
181 |
+
"output_type": "stream",
|
182 |
+
"text": [
|
183 |
+
"<gradio.components.slider.Slider object at 0x7fa12d3a5280>\n",
|
184 |
+
"Running on local URL: http://127.0.0.1:7876\n",
|
185 |
+
"Running on public URL: https://be7cce8fec75395c82.gradio.live\n",
|
186 |
+
"\n",
|
187 |
+
"This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
|
188 |
+
]
|
189 |
+
},
|
190 |
+
{
|
191 |
+
"data": {
|
192 |
+
"text/html": [
|
193 |
+
"<div><iframe src=\"https://be7cce8fec75395c82.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
|
194 |
+
],
|
195 |
+
"text/plain": [
|
196 |
+
"<IPython.core.display.HTML object>"
|
197 |
+
]
|
198 |
+
},
|
199 |
+
"metadata": {},
|
200 |
+
"output_type": "display_data"
|
201 |
+
},
|
202 |
+
{
|
203 |
+
"data": {
|
204 |
+
"text/plain": []
|
205 |
+
},
|
206 |
+
"execution_count": 6,
|
207 |
+
"metadata": {},
|
208 |
+
"output_type": "execute_result"
|
209 |
+
},
|
210 |
+
{
|
211 |
+
"name": "stdout",
|
212 |
+
"output_type": "stream",
|
213 |
+
"text": [
|
214 |
+
"Train method: None\n",
|
215 |
+
"Rank: 1, Alpha: 1\n",
|
216 |
+
"create LoRA for U-Net: 0 modules.\n",
|
217 |
+
"save dir: None\n",
|
218 |
+
"['Park with cherry blossom trees, picnicker’s and a clear blue pond in the style of sks art'], seed=949192390\n"
|
219 |
+
]
|
220 |
+
},
|
221 |
+
{
|
222 |
+
"name": "stderr",
|
223 |
+
"output_type": "stream",
|
224 |
+
"text": [
|
225 |
+
"/data/vision/torralba/selfmanaged/torralba/scratch/jomat/sam_dataset/miniforge3/envs/diffusion/lib/python3.9/site-packages/torch/nn/modules/conv.py:456: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at /opt/conda/conda-bld/pytorch_1712608883701/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)\n",
|
226 |
+
" return F.conv2d(input, weight, bias, self.stride,\n",
|
227 |
+
"\n",
|
228 |
+
"00%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [00:03<00:00, 6.90it/s]"
|
229 |
+
]
|
230 |
+
},
|
231 |
+
{
|
232 |
+
"name": "stdout",
|
233 |
+
"output_type": "stream",
|
234 |
+
"text": [
|
235 |
+
"Time taken for one batch, Art Adapter scale=1.0: 3.2747044563293457\n"
|
236 |
+
]
|
237 |
+
}
|
238 |
+
],
|
239 |
+
"source": [
|
240 |
+
"block = gr.Blocks()\n",
|
241 |
+
"# Direct infer\n",
|
242 |
+
"with block:\n",
|
243 |
+
" with gr.Group():\n",
|
244 |
+
" with gr.Row():\n",
|
245 |
+
" text = gr.Textbox(\n",
|
246 |
+
" label=\"Enter your prompt\",\n",
|
247 |
+
" max_lines=2,\n",
|
248 |
+
" placeholder=\"Enter your prompt\",\n",
|
249 |
+
" container=False,\n",
|
250 |
+
" value=\"Park with cherry blossom trees, picnicker’s and a clear blue pond.\",\n",
|
251 |
+
" )\n",
|
252 |
+
" \n",
|
253 |
+
"\n",
|
254 |
+
" \n",
|
255 |
+
" btn = gr.Button(\"Run\", scale=0)\n",
|
256 |
+
" gallery = gr.Gallery(\n",
|
257 |
+
" label=\"Generated images\",\n",
|
258 |
+
" show_label=False,\n",
|
259 |
+
" elem_id=\"gallery\",\n",
|
260 |
+
" columns=[2],\n",
|
261 |
+
" )\n",
|
262 |
+
"\n",
|
263 |
+
" advanced_button = gr.Button(\"Advanced options\", elem_id=\"advanced-btn\")\n",
|
264 |
+
"\n",
|
265 |
+
" with gr.Row(elem_id=\"advanced-options\"):\n",
|
266 |
+
" adapter_choice = gr.Dropdown(\n",
|
267 |
+
" label=\"Choose adapter\",\n",
|
268 |
+
" choices=[\"None\", \"Andre Derain\",\"Vincent van Gogh\",\"Andy Warhol\", \"Walter Battiss\",\n",
|
269 |
+
" \"Camille Corot\", \"Claude Monet\", \"Pablo Picasso\",\n",
|
270 |
+
" \"Jackson Pollock\", \"Gerhard Richter\", \"M.C. Escher\",\n",
|
271 |
+
" \"Albert Gleizes\", \"Hokusai\", \"Wassily Kandinsky\", \"Gustav Klimt\", \"Roy Lichtenstein\",\n",
|
272 |
+
" \"Henri Matisse\", \"Joan Miro\"\n",
|
273 |
+
" ],\n",
|
274 |
+
" value=\"None\"\n",
|
275 |
+
" )\n",
|
276 |
+
" # print(adapter_choice[0])\n",
|
277 |
+
" # lora_path = lora_map[adapter_choice.value]\n",
|
278 |
+
" # if lora_path is not None:\n",
|
279 |
+
" # lora_path = f\"data/Art_adapters/{lora_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt\"\n",
|
280 |
+
"\n",
|
281 |
+
" samples = gr.Slider(label=\"Images\", minimum=1, maximum=4, value=1, step=1)\n",
|
282 |
+
" steps = gr.Slider(label=\"Steps\", minimum=1, maximum=50, value=20, step=1)\n",
|
283 |
+
" scale = gr.Slider(\n",
|
284 |
+
" label=\"Guidance Scale\", minimum=0, maximum=50, value=7.5, step=0.1\n",
|
285 |
+
" )\n",
|
286 |
+
" print(scale)\n",
|
287 |
+
" seed = gr.Slider(\n",
|
288 |
+
" label=\"Seed\",\n",
|
289 |
+
" minimum=0,\n",
|
290 |
+
" maximum=2147483647,\n",
|
291 |
+
" step=1,\n",
|
292 |
+
" randomize=True,\n",
|
293 |
+
" )\n",
|
294 |
+
"\n",
|
295 |
+
" gr.on([text.submit, btn.click], demo_inference_gen, inputs=[adapter_choice, text, samples, seed, steps, scale], outputs=gallery)\n",
|
296 |
+
" advanced_button.click(\n",
|
297 |
+
" None,\n",
|
298 |
+
" [],\n",
|
299 |
+
" text,\n",
|
300 |
+
" )\n",
|
301 |
+
"\n",
|
302 |
+
"\n",
|
303 |
+
"block.launch(share=True)"
|
304 |
+
]
|
305 |
+
},
|
306 |
+
{
|
307 |
+
"cell_type": "code",
|
308 |
+
"execution_count": null,
|
309 |
+
"id": "3239c12167a5f2cd",
|
310 |
+
"metadata": {},
|
311 |
+
"outputs": [],
|
312 |
+
"source": []
|
313 |
+
}
|
314 |
+
],
|
315 |
+
"metadata": {
|
316 |
+
"kernelspec": {
|
317 |
+
"display_name": "Python 3 (ipykernel)",
|
318 |
+
"language": "python",
|
319 |
+
"name": "python3"
|
320 |
+
},
|
321 |
+
"language_info": {
|
322 |
+
"codemirror_mode": {
|
323 |
+
"name": "ipython",
|
324 |
+
"version": 3
|
325 |
+
},
|
326 |
+
"file_extension": ".py",
|
327 |
+
"mimetype": "text/x-python",
|
328 |
+
"name": "python",
|
329 |
+
"nbconvert_exporter": "python",
|
330 |
+
"pygments_lexer": "ipython3",
|
331 |
+
"version": "3.9.18"
|
332 |
+
}
|
333 |
+
},
|
334 |
+
"nbformat": 4,
|
335 |
+
"nbformat_minor": 5
|
336 |
+
}
|
README.md
CHANGED
@@ -1,13 +1,6 @@
|
|
1 |
---
|
2 |
-
title: Art
|
3 |
-
|
4 |
-
colorFrom: purple
|
5 |
-
colorTo: red
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
short_description: Demo for Art Free Diffusion
|
11 |
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: Art-Free-Diffusion
|
3 |
+
app_file: hf_demo.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
+
sdk_version: 4.44.1
|
|
|
|
|
|
|
6 |
---
|
|
|
|
__pycache__/inference.cpython-39.pyc
ADDED
Binary file (19.8 kB). View file
|
|
custom_datasets/__init__.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .mypath import MyPath
|
2 |
+
from copy import deepcopy
|
3 |
+
from datasets import load_dataset
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
def get_dataset(dataset_name, transformation=None , train_subsample:int =None, val_subsample:int = 10000, get_val=True):
|
8 |
+
if train_subsample is not None and train_subsample<val_subsample and train_subsample!=-1:
|
9 |
+
print(f"Warning: train_subsample is smaller than val_subsample. val_subsample will be set to train_subsample: {train_subsample}")
|
10 |
+
val_subsample = train_subsample
|
11 |
+
|
12 |
+
if dataset_name == "imagenet":
|
13 |
+
from .imagenet import Imagenet1k
|
14 |
+
train_set = Imagenet1k(data_dir = MyPath.db_root_dir(dataset_name), transform = transformation, split="train", prompt_transform=Label_prompt_transform(real=True))
|
15 |
+
elif dataset_name == "coco_train":
|
16 |
+
# raise NotImplementedError("Use coco_filtered instead")
|
17 |
+
from .coco import CocoCaptions
|
18 |
+
train_set = CocoCaptions(root=MyPath.db_root_dir("coco_train"), annFile=MyPath.db_root_dir("coco_caption_train"))
|
19 |
+
elif dataset_name == "coco_val":
|
20 |
+
from .coco import CocoCaptions
|
21 |
+
train_set = CocoCaptions(root=MyPath.db_root_dir("coco_val"), annFile=MyPath.db_root_dir("coco_caption_val"))
|
22 |
+
return {"val": train_set}
|
23 |
+
|
24 |
+
elif dataset_name == "coco_clip_filtered":
|
25 |
+
from .coco import CocoCaptions_clip_filtered
|
26 |
+
train_set = CocoCaptions_clip_filtered(root=MyPath.db_root_dir("coco_train"), annFile=MyPath.db_root_dir("coco_caption_train"))
|
27 |
+
elif dataset_name == "coco_filtered_sub100":
|
28 |
+
from .coco import CocoCaptions_clip_filtered
|
29 |
+
train_set = CocoCaptions_clip_filtered(root=MyPath.db_root_dir("coco_train"), annFile=MyPath.db_root_dir("coco_caption_train"), id_file=MyPath.db_root_dir("coco_clip_filtered_ids_sub100"),)
|
30 |
+
elif dataset_name == "cifar10":
|
31 |
+
from .cifar import CIFAR10
|
32 |
+
train_set = CIFAR10(root=MyPath.db_root_dir("cifar10"), train=True, transform=transformation, prompt_transform=Label_prompt_transform(real=True))
|
33 |
+
elif dataset_name == "cifar100":
|
34 |
+
from .cifar import CIFAR100
|
35 |
+
train_set = CIFAR100(root=MyPath.db_root_dir("cifar100"), train=True, transform=transformation, prompt_transform=Label_prompt_transform(real=True))
|
36 |
+
elif "wikiart" in dataset_name and "/" not in dataset_name:
|
37 |
+
from .wikiart.wikiart import Wikiart_caption
|
38 |
+
dataset = Wikiart_caption(data_path=MyPath.db_root_dir(dataset_name))
|
39 |
+
return {"train": dataset.subsample(train_subsample).get_dataset(), "val": deepcopy(dataset).subsample(val_subsample).get_dataset() if get_val else None}
|
40 |
+
elif "imagepair" in dataset_name:
|
41 |
+
from .imagepair import ImagePair
|
42 |
+
train_set = ImagePair(folder1=MyPath.db_root_dir(dataset_name)[0], folder2=MyPath.db_root_dir(dataset_name)[1], transform=transformation).subsample(train_subsample)
|
43 |
+
# elif dataset_name == "sam_clip_filtered":
|
44 |
+
# from .sam import SamDataset
|
45 |
+
# train_set = SamDataset(image_folder_path=MyPath.db_root_dir("sam_images"), caption_folder_path=MyPath.db_root_dir("sam_captions"), id_file=MyPath.db_root_dir("sam_ids"), transforms=transformation).subsample(train_subsample)
|
46 |
+
elif dataset_name == "sam_whole_filtered":
|
47 |
+
from .sam import SamDataset
|
48 |
+
train_set = SamDataset(image_folder_path=MyPath.db_root_dir("sam_images"), caption_folder_path=MyPath.db_root_dir("sam_captions"), id_file=MyPath.db_root_dir("sam_whole_filtered_ids_train"), id_dict_file=MyPath.db_root_dir("sam_id_dict"), transforms=transformation).subsample(train_subsample)
|
49 |
+
elif dataset_name == "sam_whole_filtered_val":
|
50 |
+
from .sam import SamDataset
|
51 |
+
train_set = SamDataset(image_folder_path=MyPath.db_root_dir("sam_images"), caption_folder_path=MyPath.db_root_dir("sam_captions"), id_file=MyPath.db_root_dir("sam_whole_filtered_ids_val"), id_dict_file=MyPath.db_root_dir("sam_id_dict"), transforms=transformation).subsample(train_subsample)
|
52 |
+
return {"val": train_set}
|
53 |
+
elif dataset_name == "lhq_sub100":
|
54 |
+
from .lhq import LhqDataset
|
55 |
+
train_set = LhqDataset(image_folder_path=MyPath.db_root_dir("lhq_images"), caption_folder_path=MyPath.db_root_dir("lhq_captions"), id_file=MyPath.db_root_dir("lhq_ids_sub100"), transforms=transformation)
|
56 |
+
elif dataset_name == "lhq_sub500":
|
57 |
+
from .lhq import LhqDataset
|
58 |
+
train_set = LhqDataset(image_folder_path=MyPath.db_root_dir("lhq_images"), caption_folder_path=MyPath.db_root_dir("lhq_captions"), id_file=MyPath.db_root_dir("lhq_ids_sub500"), transforms=transformation)
|
59 |
+
elif dataset_name == "lhq_sub9":
|
60 |
+
from .lhq import LhqDataset
|
61 |
+
train_set = LhqDataset(image_folder_path=MyPath.db_root_dir("lhq_images"), caption_folder_path=MyPath.db_root_dir("lhq_captions"), id_file=MyPath.db_root_dir("lhq_ids_sub9"), transforms=transformation)
|
62 |
+
|
63 |
+
elif dataset_name == "custom_coco100":
|
64 |
+
from .coco import CustomCocoCaptions
|
65 |
+
train_set = CustomCocoCaptions(root=MyPath.db_root_dir("coco_val"), annFile=MyPath.db_root_dir("coco_caption_val"),
|
66 |
+
custom_file=MyPath.db_root_dir("custom_coco100_captions"), transforms=transformation)
|
67 |
+
elif dataset_name == "custom_coco500":
|
68 |
+
from .coco import CustomCocoCaptions
|
69 |
+
train_set = CustomCocoCaptions(root=MyPath.db_root_dir("coco_val"), annFile=MyPath.db_root_dir("coco_caption_val"),
|
70 |
+
custom_file=MyPath.db_root_dir("custom_coco500_captions"), transforms=transformation)
|
71 |
+
elif dataset_name == "laion_pop500":
|
72 |
+
from .custom_caption import Laion_pop
|
73 |
+
train_set = Laion_pop(anno_file=MyPath.db_root_dir("laion_pop500"), image_root=MyPath.db_root_dir("laion_images"), transform=transformation)
|
74 |
+
|
75 |
+
elif dataset_name == "laion_pop500_first_sentence":
|
76 |
+
from .custom_caption import Laion_pop
|
77 |
+
train_set = Laion_pop(anno_file=MyPath.db_root_dir("laion_pop500_first_sentence"), image_root=MyPath.db_root_dir("laion_images"), transform=transformation)
|
78 |
+
|
79 |
+
|
80 |
+
else:
|
81 |
+
try:
|
82 |
+
train_set = load_dataset('imagefolder', data_dir = dataset_name, split="train")
|
83 |
+
val_set = deepcopy(train_set)
|
84 |
+
if val_subsample is not None and val_subsample != -1:
|
85 |
+
val_set = val_set.shuffle(seed=0).select(range(val_subsample))
|
86 |
+
return {"train": train_set, "val": val_set if get_val else None}
|
87 |
+
except:
|
88 |
+
raise ValueError(f"dataset_name {dataset_name} not found.")
|
89 |
+
return {"train": train_set, "val": deepcopy(train_set).subsample(val_subsample) if get_val else None}
|
90 |
+
|
91 |
+
|
92 |
+
class MergeDataset(Dataset):
|
93 |
+
@staticmethod
|
94 |
+
def get_merged_dataset(dataset_names:list, transformation=None, train_subsample:int =None, val_subsample:int = 10000):
|
95 |
+
train_datasets = []
|
96 |
+
val_datasets = []
|
97 |
+
for dataset_name in dataset_names:
|
98 |
+
datasets = get_dataset(dataset_name, transformation, train_subsample, val_subsample)
|
99 |
+
train_datasets.append(datasets["train"])
|
100 |
+
val_datasets.append(datasets["val"])
|
101 |
+
train_datasets = MergeDataset(train_datasets).subsample(train_subsample)
|
102 |
+
val_datasets = MergeDataset(val_datasets).subsample(val_subsample)
|
103 |
+
return {"train": train_datasets, "val": val_datasets}
|
104 |
+
|
105 |
+
def __init__(self, datasets:list):
|
106 |
+
self.datasets = datasets
|
107 |
+
self.column_names = self.datasets[0].column_names
|
108 |
+
# self.ids = []
|
109 |
+
# start = 0
|
110 |
+
# for dataset in self.datasets:
|
111 |
+
# self.ids += [i+start for i in dataset.ids]
|
112 |
+
def define_resolution(self, resolution: int):
|
113 |
+
for dataset in self.datasets:
|
114 |
+
dataset.define_resolution(resolution)
|
115 |
+
|
116 |
+
def __len__(self):
|
117 |
+
return sum([len(dataset) for dataset in self.datasets])
|
118 |
+
def __getitem__(self, index):
|
119 |
+
for i,dataset in enumerate(self.datasets):
|
120 |
+
if index < len(dataset):
|
121 |
+
ret = dataset[index]
|
122 |
+
ret["id"] = index
|
123 |
+
ret["dataset"] = i
|
124 |
+
return ret
|
125 |
+
index -= len(dataset)
|
126 |
+
raise IndexError
|
127 |
+
|
128 |
+
def subsample(self, num:int):
|
129 |
+
if num is None:
|
130 |
+
return self
|
131 |
+
dataset_ratio = np.array([len(dataset) for dataset in self.datasets]) / len(self)
|
132 |
+
new_datasets = []
|
133 |
+
for i, dataset in enumerate(self.datasets):
|
134 |
+
new_datasets.append(dataset.subsample(int(num*dataset_ratio[i])))
|
135 |
+
return MergeDataset(new_datasets)
|
136 |
+
|
137 |
+
def with_transform(self, transform):
|
138 |
+
for dataset in self.datasets:
|
139 |
+
dataset.with_transform(transform)
|
140 |
+
return self
|
141 |
+
|
custom_datasets/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (5.8 kB). View file
|
|
custom_datasets/__pycache__/coco.cpython-39.pyc
ADDED
Binary file (10.4 kB). View file
|
|
custom_datasets/__pycache__/imagepair.cpython-39.pyc
ADDED
Binary file (8.93 kB). View file
|
|
custom_datasets/__pycache__/mypath.cpython-39.pyc
ADDED
Binary file (1.49 kB). View file
|
|
custom_datasets/coco.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
from typing import Any, Callable, List, Optional, Tuple
|
3 |
+
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
from torchvision.datasets.vision import VisionDataset
|
7 |
+
import pickle
|
8 |
+
import csv
|
9 |
+
import pandas as pd
|
10 |
+
import torch
|
11 |
+
import torchvision
|
12 |
+
import re
|
13 |
+
# from torchvision.datasets import CocoDetection
|
14 |
+
# from utils.clip_filter import Clip_filter
|
15 |
+
from tqdm import tqdm
|
16 |
+
from .mypath import MyPath
|
17 |
+
|
18 |
+
class CocoDetection(VisionDataset):
|
19 |
+
"""`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset.
|
20 |
+
|
21 |
+
It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
root (string): Root directory where images are downloaded to.
|
25 |
+
annFile (string): Path to json annotation file.
|
26 |
+
transform (callable, optional): A function/transform that takes in an PIL image
|
27 |
+
and returns a transformed version. E.g, ``transforms.PILToTensor``
|
28 |
+
target_transform (callable, optional): A function/transform that takes in the
|
29 |
+
target and transforms it.
|
30 |
+
transforms (callable, optional): A function/transform that takes input sample and its target as entry
|
31 |
+
and returns a transformed version.
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
root: str ,
|
37 |
+
annFile: str,
|
38 |
+
transform: Optional[Callable] = None,
|
39 |
+
target_transform: Optional[Callable] = None,
|
40 |
+
transforms: Optional[Callable] = None,
|
41 |
+
get_img=True,
|
42 |
+
get_cap=True
|
43 |
+
) -> None:
|
44 |
+
super().__init__(root, transforms, transform, target_transform)
|
45 |
+
from pycocotools.coco import COCO
|
46 |
+
|
47 |
+
self.coco = COCO(annFile)
|
48 |
+
self.ids = list(sorted(self.coco.imgs.keys()))
|
49 |
+
self.column_names = ["image", "text"]
|
50 |
+
self.get_img = get_img
|
51 |
+
self.get_cap = get_cap
|
52 |
+
|
53 |
+
def _load_image(self, id: int) -> Image.Image:
|
54 |
+
path = self.coco.loadImgs(id)[0]["file_name"]
|
55 |
+
with open(os.path.join(self.root, path), 'rb') as f:
|
56 |
+
img = Image.open(f).convert("RGB")
|
57 |
+
|
58 |
+
return img
|
59 |
+
|
60 |
+
def _load_target(self, id: int) -> List[Any]:
|
61 |
+
return self.coco.loadAnns(self.coco.getAnnIds(id))
|
62 |
+
|
63 |
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
64 |
+
id = self.ids[index]
|
65 |
+
ret={"id":id}
|
66 |
+
if self.get_img:
|
67 |
+
image = self._load_image(id)
|
68 |
+
ret["image"] = image
|
69 |
+
if self.get_cap:
|
70 |
+
target = self._load_target(id)
|
71 |
+
ret["caption"] = [target]
|
72 |
+
|
73 |
+
if self.transforms is not None:
|
74 |
+
ret = self.transforms(ret)
|
75 |
+
|
76 |
+
return ret
|
77 |
+
|
78 |
+
def subsample(self, n: int = 10000):
|
79 |
+
if n is None or n == -1:
|
80 |
+
return self
|
81 |
+
ori_len = len(self)
|
82 |
+
assert n <= ori_len
|
83 |
+
# equal interval subsample
|
84 |
+
ids = self.ids[::ori_len // n][:n]
|
85 |
+
self.ids = ids
|
86 |
+
print(f"COCO dataset subsampled from {ori_len} to {len(self)}")
|
87 |
+
return self
|
88 |
+
|
89 |
+
|
90 |
+
def with_transform(self, transform):
|
91 |
+
self.transforms = transform
|
92 |
+
return self
|
93 |
+
|
94 |
+
def __len__(self) -> int:
|
95 |
+
# return 100
|
96 |
+
return len(self.ids)
|
97 |
+
|
98 |
+
|
99 |
+
class CocoCaptions(CocoDetection):
|
100 |
+
"""`MS Coco Captions <https://cocodataset.org/#captions-2015>`_ Dataset.
|
101 |
+
|
102 |
+
It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
root (string): Root directory where images are downloaded to.
|
106 |
+
annFile (string): Path to json annotation file.
|
107 |
+
transform (callable, optional): A function/transform that takes in an PIL image
|
108 |
+
and returns a transformed version. E.g, ``transforms.PILToTensor``
|
109 |
+
target_transform (callable, optional): A function/transform that takes in the
|
110 |
+
target and transforms it.
|
111 |
+
transforms (callable, optional): A function/transform that takes input sample and its target as entry
|
112 |
+
and returns a transformed version.
|
113 |
+
|
114 |
+
Example:
|
115 |
+
|
116 |
+
.. code:: python
|
117 |
+
|
118 |
+
import torchvision.datasets as dset
|
119 |
+
import torchvision.transforms as transforms
|
120 |
+
cap = dset.CocoCaptions(root = 'dir where images are',
|
121 |
+
annFile = 'json annotation file',
|
122 |
+
transform=transforms.PILToTensor())
|
123 |
+
|
124 |
+
print('Number of samples: ', len(cap))
|
125 |
+
img, target = cap[3] # load 4th sample
|
126 |
+
|
127 |
+
print("Image Size: ", img.size())
|
128 |
+
print(target)
|
129 |
+
|
130 |
+
Output: ::
|
131 |
+
|
132 |
+
Number of samples: 82783
|
133 |
+
Image Size: (3L, 427L, 640L)
|
134 |
+
[u'A plane emitting smoke stream flying over a mountain.',
|
135 |
+
u'A plane darts across a bright blue sky behind a mountain covered in snow',
|
136 |
+
u'A plane leaves a contrail above the snowy mountain top.',
|
137 |
+
u'A mountain that has a plane flying overheard in the distance.',
|
138 |
+
u'A mountain view with a plume of smoke in the background']
|
139 |
+
|
140 |
+
"""
|
141 |
+
|
142 |
+
def _load_target(self, id: int) -> List[str]:
|
143 |
+
return [ann["caption"] for ann in super()._load_target(id)]
|
144 |
+
|
145 |
+
|
146 |
+
class CocoCaptions_clip_filtered(CocoCaptions):
|
147 |
+
positive_prompt=["painting", "drawing", "graffiti",]
|
148 |
+
def __init__(
|
149 |
+
self,
|
150 |
+
root: str ,
|
151 |
+
annFile: str,
|
152 |
+
transform: Optional[Callable] = None,
|
153 |
+
target_transform: Optional[Callable] = None,
|
154 |
+
transforms: Optional[Callable] = None,
|
155 |
+
regenerate: bool = False,
|
156 |
+
id_file: Optional[str] = "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/data/coco/coco_clip_filtered_ids.pickle"
|
157 |
+
) -> None:
|
158 |
+
super().__init__(root, annFile, transform, target_transform, transforms)
|
159 |
+
os.makedirs(os.path.dirname(id_file), exist_ok=True)
|
160 |
+
if os.path.exists(id_file) and not regenerate:
|
161 |
+
with open(id_file, "rb") as f:
|
162 |
+
self.ids = pickle.load(f)
|
163 |
+
else:
|
164 |
+
self.ids, naive_filtered_num = self.naive_filter()
|
165 |
+
self.ids, clip_filtered_num = self.clip_filter(0.7)
|
166 |
+
|
167 |
+
print(f"naive Filtered {naive_filtered_num} images")
|
168 |
+
print(f"Clip Filtered {clip_filtered_num} images")
|
169 |
+
|
170 |
+
with open(id_file, "wb") as f:
|
171 |
+
pickle.dump(self.ids, f)
|
172 |
+
print(f"Filtered ids saved to {id_file}")
|
173 |
+
print(f"COCO filtered dataset size: {len(self)}")
|
174 |
+
|
175 |
+
def naive_filter(self, filter_prompt="painting"):
|
176 |
+
new_ids = []
|
177 |
+
naive_filtered_num = 0
|
178 |
+
for id in self.ids:
|
179 |
+
target = self._load_target(id)
|
180 |
+
filtered = False
|
181 |
+
for prompt in target:
|
182 |
+
if filter_prompt in prompt.lower():
|
183 |
+
filtered = True
|
184 |
+
naive_filtered_num += 1
|
185 |
+
break
|
186 |
+
# if "artwork" in prompt.lower():
|
187 |
+
# pass
|
188 |
+
if not filtered:
|
189 |
+
new_ids.append(id)
|
190 |
+
return new_ids, naive_filtered_num
|
191 |
+
|
192 |
+
# def clip_filter(self, threshold=0.7):
|
193 |
+
#
|
194 |
+
# def collate_fn(examples):
|
195 |
+
# # {"image": image, "text": [target], "id":id}
|
196 |
+
# pixel_values = [example["image"] for example in examples]
|
197 |
+
# prompts = [example["text"] for example in examples]
|
198 |
+
# id = [example["id"] for example in examples]
|
199 |
+
# return {"images": pixel_values, "prompts": prompts, "ids": id}
|
200 |
+
#
|
201 |
+
#
|
202 |
+
# clip_filtered_num = 0
|
203 |
+
# clip_filter = Clip_filter(positive_prompt=self.positive_prompt)
|
204 |
+
# clip_logs={"positive_prompt":clip_filter.positive_prompt, "negative_prompt":clip_filter.negative_prompt,
|
205 |
+
# "ids":torch.Tensor([]),"logits":torch.Tensor([])}
|
206 |
+
# clip_log_file = "data/coco/clip_logs.pth"
|
207 |
+
# new_ids = []
|
208 |
+
# batch_size = 128
|
209 |
+
# dataloader = torch.utils.data.DataLoader(self, batch_size=batch_size, num_workers=10, shuffle=False,
|
210 |
+
# collate_fn=collate_fn)
|
211 |
+
# for i, batch in enumerate(tqdm(dataloader)):
|
212 |
+
# images = batch["images"]
|
213 |
+
# filter_result, logits = clip_filter.filter(images, threshold=threshold)
|
214 |
+
# ids = torch.IntTensor(batch["ids"])
|
215 |
+
# clip_logs["ids"] = torch.cat([clip_logs["ids"], ids])
|
216 |
+
# clip_logs["logits"] = torch.cat([clip_logs["logits"], logits])
|
217 |
+
#
|
218 |
+
# new_ids.extend(ids[~filter_result].tolist())
|
219 |
+
# clip_filtered_num += filter_result.sum().item()
|
220 |
+
# if i % 50 == 0:
|
221 |
+
# torch.save(clip_logs, clip_log_file)
|
222 |
+
# torch.save(clip_logs, clip_log_file)
|
223 |
+
#
|
224 |
+
# return new_ids, clip_filtered_num
|
225 |
+
|
226 |
+
|
227 |
+
class CustomCocoCaptions(CocoCaptions):
|
228 |
+
def __init__(self, root: str=MyPath.db_root_dir("coco_val"), annFile: str=MyPath.db_root_dir("coco_caption_val"), custom_file:str="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/jomat-code/filtering/ms_coco_captions_testset100.txt",transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, transforms: Optional[Callable] = None) -> None:
|
229 |
+
|
230 |
+
super().__init__(root, annFile, transform, target_transform, transforms)
|
231 |
+
self.column_names = ["image", "text"]
|
232 |
+
self.custom_file = custom_file
|
233 |
+
self.load_custom_data(custom_file)
|
234 |
+
self.transforms = transforms
|
235 |
+
|
236 |
+
def load_custom_data(self, custom_file):
|
237 |
+
self.custom_data = []
|
238 |
+
with open(custom_file, "r") as f:
|
239 |
+
data = f.readlines()
|
240 |
+
head = data[0].strip().split(",")
|
241 |
+
self.head = head
|
242 |
+
for line in data[1:]:
|
243 |
+
sub_data = line.strip().split(",")
|
244 |
+
if len(sub_data) > len(head):
|
245 |
+
sub_data_new = [sub_data[0]]
|
246 |
+
sub_data_new+=[",".join(sub_data[1:-1])]
|
247 |
+
sub_data_new.append(sub_data[-1])
|
248 |
+
sub_data = sub_data_new
|
249 |
+
assert len(sub_data) == len(head)
|
250 |
+
self.custom_data.append(sub_data)
|
251 |
+
# to pd
|
252 |
+
self.custom_data = pd.DataFrame(self.custom_data, columns=head)
|
253 |
+
|
254 |
+
def __len__(self) -> int:
|
255 |
+
return len(self.custom_data)
|
256 |
+
|
257 |
+
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
258 |
+
data = self.custom_data.iloc[index]
|
259 |
+
id = int(data["image_id"])
|
260 |
+
ret={"id":id}
|
261 |
+
if self.get_img:
|
262 |
+
image = self._load_image(id)
|
263 |
+
ret["image"] = image
|
264 |
+
if self.get_cap:
|
265 |
+
caption = data["caption"]
|
266 |
+
ret["caption"] = [caption]
|
267 |
+
ret["seed"] = int(data["random_seed"])
|
268 |
+
|
269 |
+
if self.transforms is not None:
|
270 |
+
ret = self.transforms(ret)
|
271 |
+
|
272 |
+
return ret
|
273 |
+
|
274 |
+
|
275 |
+
|
276 |
+
def get_validation_set():
|
277 |
+
coco_instance = CocoDetection(root="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/.datasets/coco_2017/train2017/", annFile="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/.datasets/coco_2017/annotations/instances_train2017.json")
|
278 |
+
discard_cat_id = coco_instance.coco.getCatIds(supNms=["person", "animal"])
|
279 |
+
discard_img_id = []
|
280 |
+
for cat_id in discard_cat_id:
|
281 |
+
discard_img_id += coco_instance.coco.catToImgs[cat_id]
|
282 |
+
|
283 |
+
coco_clip_filtered = CocoCaptions_clip_filtered(root=MyPath.db_root_dir("coco_train"), annFile=MyPath.db_root_dir("coco_caption_train"),
|
284 |
+
regenerate=False)
|
285 |
+
coco_clip_filtered_ids = coco_clip_filtered.ids
|
286 |
+
new_ids = set(coco_clip_filtered_ids) - set(discard_img_id)
|
287 |
+
new_ids = list(new_ids)
|
288 |
+
new_ids = random.sample(new_ids, 100)
|
289 |
+
with open("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/data/coco/coco_clip_filtered_subset100.pickle", "wb") as f:
|
290 |
+
pickle.dump(new_ids, f)
|
291 |
+
|
292 |
+
if __name__ == "__main__":
|
293 |
+
from mypath import MyPath
|
294 |
+
import random
|
295 |
+
# get_validation_set()
|
296 |
+
# coco_filtered_remian_id = pickle.load(open("data/coco/coco_clip_filtered_ids.pickle", "rb"))
|
297 |
+
#
|
298 |
+
# coco_filtered_subset100 = random.sample(coco_filtered_remian_id, 100)
|
299 |
+
# save_path = "data/coco/coco_clip_filtered_subset100.pickle"
|
300 |
+
# with open(save_path, "wb") as f:
|
301 |
+
# pickle.dump(coco_filtered_subset100, f)
|
302 |
+
|
303 |
+
# dataset = CocoCaptions_clip_filtered(root=MyPath.db_root_dir("coco_train"), annFile=MyPath.db_root_dir("coco_caption_train"),
|
304 |
+
# regenerate=False)
|
305 |
+
dataset = CustomCocoCaptions(root=MyPath.db_root_dir("coco_val"), annFile=MyPath.db_root_dir("coco_caption_val"),
|
306 |
+
custom_file="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/jomat-code/filtering/ms_coco_captions_testset100.txt")
|
307 |
+
dataset[0]
|
custom_datasets/custom_caption.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Authors: Hui Ren (rhfeiyang.github.io)
|
2 |
+
import torch
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
class Caption_set(torch.utils.data.Dataset):
|
9 |
+
|
10 |
+
style_set_names=[
|
11 |
+
"andre-derain_subset1",
|
12 |
+
"andy_subset1",
|
13 |
+
"camille-corot_subset1",
|
14 |
+
"gerhard-richter_subset1",
|
15 |
+
"henri-matisse_subset1",
|
16 |
+
"katsushika-hokusai_subset1",
|
17 |
+
"klimt_subset3",
|
18 |
+
"monet_subset2",
|
19 |
+
"picasso_subset1",
|
20 |
+
"van_gogh_subset1",
|
21 |
+
]
|
22 |
+
style_set_map={f"{name}":f"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/Style_captions/{name}/style_captions.csv" for name in style_set_names}
|
23 |
+
|
24 |
+
def __init__(self, prompts_path=None, set_name=None, transform=None):
|
25 |
+
assert prompts_path is not None or set_name is not None, "Either prompts_path or set_name should be provided"
|
26 |
+
if prompts_path is None:
|
27 |
+
prompts_path = self.style_set_map[set_name]
|
28 |
+
|
29 |
+
self.prompts = pd.read_csv(prompts_path, delimiter=';')
|
30 |
+
self.transform = transform
|
31 |
+
def __len__(self):
|
32 |
+
return len(self.prompts)
|
33 |
+
def __getitem__(self, idx):
|
34 |
+
ret={}
|
35 |
+
ret["id"] = idx
|
36 |
+
info = self.prompts.iloc[idx]
|
37 |
+
ret.update(info)
|
38 |
+
for k,v in ret.items():
|
39 |
+
if isinstance(v,np.int64):
|
40 |
+
ret[k] = int(v)
|
41 |
+
ret["caption"] = [ret["caption"]]
|
42 |
+
if self.transform:
|
43 |
+
ret = self.transform(ret)
|
44 |
+
return ret
|
45 |
+
|
46 |
+
def with_transform(self, transform):
|
47 |
+
self.transform = transform
|
48 |
+
return self
|
49 |
+
|
50 |
+
|
51 |
+
class HRS_caption(Caption_set):
|
52 |
+
def __init__(self, prompts_path="/vision-nfs/torralba/projects/jomat/hui/stable_diffusion/clip_dissection/Style_captions/andre-derain_subset1/style_captions.csv", transform=None, delimiter=','):
|
53 |
+
self.prompts = pd.read_csv(prompts_path, delimiter=delimiter)
|
54 |
+
self.transform = transform
|
55 |
+
self.caption_key = "original_prompts"
|
56 |
+
|
57 |
+
def __getitem__(self, idx):
|
58 |
+
ret={}
|
59 |
+
ret["id"] = idx
|
60 |
+
info = self.prompts.iloc[idx]
|
61 |
+
ret["caption"] = [info[self.caption_key]]
|
62 |
+
ret["seed"] = idx
|
63 |
+
if self.transform:
|
64 |
+
ret = self.transform(ret)
|
65 |
+
return ret
|
66 |
+
|
67 |
+
class Laion_pop(torch.utils.data.Dataset):
|
68 |
+
def __init__(self, anno_file="/vision-nfs/torralba/projects/jomat/hui/stable_diffusion/custom_datasets/laion_pop500.csv",image_root="/vision-nfs/torralba/scratch/jomat/sam_dataset/laion_pop",transform=None):
|
69 |
+
self.transform = transform
|
70 |
+
self.info = pd.read_csv(anno_file, delimiter=";")
|
71 |
+
self.caption_key = "caption"
|
72 |
+
self.image_root = image_root
|
73 |
+
self.get_img=True
|
74 |
+
self.get_caption=True
|
75 |
+
def __len__(self):
|
76 |
+
return len(self.info)
|
77 |
+
|
78 |
+
# def subsample(self, num:int):
|
79 |
+
# self.data = self.data.select(range(num))
|
80 |
+
# return self
|
81 |
+
|
82 |
+
def load_image(self, key):
|
83 |
+
image_path = os.path.join(self.image_root, f"{key:09}.jpg")
|
84 |
+
with open(image_path, "rb") as f:
|
85 |
+
image = Image.open(f).convert("RGB")
|
86 |
+
return image
|
87 |
+
|
88 |
+
def __getitem__(self, idx):
|
89 |
+
info = self.info.iloc[idx]
|
90 |
+
ret = {}
|
91 |
+
key = info["key"]
|
92 |
+
ret["id"] = key
|
93 |
+
if self.get_caption:
|
94 |
+
ret["caption"] = [info[self.caption_key]]
|
95 |
+
ret["seed"] = int(key)
|
96 |
+
if self.get_img:
|
97 |
+
ret["image"] = self.load_image(key)
|
98 |
+
|
99 |
+
if self.transform:
|
100 |
+
ret = self.transform(ret)
|
101 |
+
return ret
|
102 |
+
|
103 |
+
def with_transform(self, transform):
|
104 |
+
self.transform = transform
|
105 |
+
return self
|
106 |
+
|
107 |
+
def subset(self, ids:list):
|
108 |
+
self.info = self.info[self.info["key"].isin(ids)]
|
109 |
+
return self
|
110 |
+
|
111 |
+
if __name__ == "__main__":
|
112 |
+
dataset = Caption_set("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/Style_captions/andre-derain_subset1/style_captions.csv")
|
113 |
+
dataset[0]
|
custom_datasets/filt/coco/filt.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Authors: Hui Ren (rhfeiyang.github.io)
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import pickle
|
7 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), "../../../"))
|
8 |
+
from custom_datasets import get_dataset
|
9 |
+
from utils.art_filter import Art_filter
|
10 |
+
import torch
|
11 |
+
from matplotlib import pyplot as plt
|
12 |
+
import math
|
13 |
+
import argparse
|
14 |
+
import socket
|
15 |
+
import time
|
16 |
+
from tqdm import tqdm
|
17 |
+
import torch
|
18 |
+
def parse_args():
|
19 |
+
parser = argparse.ArgumentParser(description="Filter the coco dataset")
|
20 |
+
parser.add_argument("--check", action="store_true", help="Check the complete")
|
21 |
+
parser.add_argument("--mode", default="clip_logit", help="Filter mode: clip_logit, clip_filt, caption_filt")
|
22 |
+
parser.add_argument("--split" , default="val", help="Dataset split, val/train")
|
23 |
+
# parser.add_argument("--start_idx", default=0, type=int, help="Start index")
|
24 |
+
args = parser.parse_args()
|
25 |
+
return args
|
26 |
+
|
27 |
+
def get_feat(save_path, dataloader, filter):
|
28 |
+
clip_feat_file = save_path
|
29 |
+
# compute_new = False
|
30 |
+
clip_feat={}
|
31 |
+
if os.path.exists(clip_feat_file):
|
32 |
+
with open(clip_feat_file, 'rb') as f:
|
33 |
+
clip_feat = pickle.load(f)
|
34 |
+
else:
|
35 |
+
print(f"computing clip feat",flush=True)
|
36 |
+
clip_feature_ret = filter.clip_feature(dataloader)
|
37 |
+
clip_feat["image_features"] = clip_feature_ret["clip_features"]
|
38 |
+
clip_feat["ids"] = clip_feature_ret["ids"]
|
39 |
+
|
40 |
+
with open(clip_feat_file, 'wb') as f:
|
41 |
+
pickle.dump(clip_feat, f)
|
42 |
+
print(f"clip_feat_result saved to {clip_feat_file}",flush=True)
|
43 |
+
return clip_feat
|
44 |
+
|
45 |
+
def get_clip_logit(save_root, dataloader, filter):
|
46 |
+
feat_path = os.path.join(save_root, "clip_feat.pickle")
|
47 |
+
clip_feat = get_feat(feat_path, dataloader, filter)
|
48 |
+
clip_logits_file = os.path.join(save_root, "clip_logits.pickle")
|
49 |
+
# if clip_logit:
|
50 |
+
if os.path.exists(clip_logits_file):
|
51 |
+
with open(clip_logits_file, 'rb') as f:
|
52 |
+
clip_logits = pickle.load(f)
|
53 |
+
else:
|
54 |
+
clip_logits = filter.clip_logit_by_feat(clip_feat["image_features"])
|
55 |
+
clip_logits["ids"] = clip_feat["ids"]
|
56 |
+
with open(clip_logits_file, 'wb') as f:
|
57 |
+
pickle.dump(clip_logits, f)
|
58 |
+
print(f"clip_logits_result saved to {clip_logits_file}",flush=True)
|
59 |
+
return clip_logits
|
60 |
+
|
61 |
+
def clip_filt(save_root, dataloader, filter):
|
62 |
+
clip_filt_file = os.path.join(save_root, "clip_filt_result.pickle")
|
63 |
+
if os.path.exists(clip_filt_file):
|
64 |
+
with open(clip_filt_file, 'rb') as f:
|
65 |
+
clip_filt_result = pickle.load(f)
|
66 |
+
else:
|
67 |
+
clip_logits = get_clip_logit(save_root, dataloader, filter)
|
68 |
+
clip_filt_result = filter.clip_filt(clip_logits)
|
69 |
+
with open(clip_filt_file, 'wb') as f:
|
70 |
+
pickle.dump(clip_filt_result, f)
|
71 |
+
print(f"clip_filt_result saved to {clip_filt_file}",flush=True)
|
72 |
+
return clip_filt_result
|
73 |
+
|
74 |
+
def caption_filt(save_root, dataloader, filter):
|
75 |
+
caption_filt_file = os.path.join(save_root, "caption_filt_result.pickle")
|
76 |
+
if os.path.exists(caption_filt_file):
|
77 |
+
with open(caption_filt_file, 'rb') as f:
|
78 |
+
caption_filt_result = pickle.load(f)
|
79 |
+
else:
|
80 |
+
caption_filt_result = filter.caption_filt(dataloader)
|
81 |
+
with open(caption_filt_file, 'wb') as f:
|
82 |
+
pickle.dump(caption_filt_result, f)
|
83 |
+
print(f"caption_filt_result saved to {caption_filt_file}",flush=True)
|
84 |
+
return caption_filt_result
|
85 |
+
|
86 |
+
def gather_result(save_dir, dataloader, filter):
|
87 |
+
all_remain_ids=[]
|
88 |
+
all_remain_ids_train=[]
|
89 |
+
all_remain_ids_val=[]
|
90 |
+
all_filtered_id_num = 0
|
91 |
+
|
92 |
+
clip_filt_result = clip_filt(save_dir, dataloader, filter)
|
93 |
+
caption_filt_result = caption_filt(save_dir, dataloader, filter)
|
94 |
+
|
95 |
+
caption_filtered_ids = [i[0] for i in caption_filt_result["filtered_ids"]]
|
96 |
+
all_filtered_id_num += len(set(clip_filt_result["filtered_ids"]) | set(caption_filtered_ids) )
|
97 |
+
remain_ids = set(clip_filt_result["remain_ids"]) & set(caption_filt_result["remain_ids"])
|
98 |
+
remain_ids = list(remain_ids)
|
99 |
+
remain_ids.sort()
|
100 |
+
with open(os.path.join(save_dir, "remain_ids.pickle"), 'wb') as f:
|
101 |
+
pickle.dump(remain_ids, f)
|
102 |
+
print(f"remain_ids saved to {save_dir}/remain_ids.pickle",flush=True)
|
103 |
+
return remain_ids
|
104 |
+
|
105 |
+
@torch.no_grad()
|
106 |
+
def main(args):
|
107 |
+
filter = Art_filter()
|
108 |
+
if args.mode == "caption_filt" or args.mode == "gather_result":
|
109 |
+
filter.clip_filter = None
|
110 |
+
torch.cuda.empty_cache()
|
111 |
+
|
112 |
+
# caption_folder_path = "/vision-nfs/torralba/scratch/jomat/sam_dataset/PixArt-alpha/captions"
|
113 |
+
# image_folder_path = "/vision-nfs/torralba/scratch/jomat/sam_dataset/images"
|
114 |
+
# id_dict_dir = "/vision-nfs/torralba/scratch/jomat/sam_dataset/images/id_dict"
|
115 |
+
# filt_dir = "/vision-nfs/torralba/scratch/jomat/sam_dataset/filt_result"
|
116 |
+
|
117 |
+
def collate_fn(examples):
|
118 |
+
# {"image": image, "id":id}
|
119 |
+
ret = {}
|
120 |
+
if "image" in examples[0]:
|
121 |
+
pixel_values = [example["image"] for example in examples]
|
122 |
+
ret["images"] = pixel_values
|
123 |
+
if "caption" in examples[0]:
|
124 |
+
# prompts = [example["caption"] for example in examples]
|
125 |
+
prompts = []
|
126 |
+
for example in examples:
|
127 |
+
if isinstance(example["caption"][0], list):
|
128 |
+
prompts.append([" ".join(example["caption"][0])])
|
129 |
+
else:
|
130 |
+
prompts.append(example["caption"])
|
131 |
+
ret["text"] = prompts
|
132 |
+
id = [example["id"] for example in examples]
|
133 |
+
ret["ids"] = id
|
134 |
+
return ret
|
135 |
+
if args.split == "val":
|
136 |
+
dataset = get_dataset("coco_val")["val"]
|
137 |
+
elif args.split == "train":
|
138 |
+
dataset = get_dataset("coco_train", get_val=False)["train"]
|
139 |
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=False, num_workers=8, collate_fn=collate_fn)
|
140 |
+
|
141 |
+
error_files=[]
|
142 |
+
|
143 |
+
|
144 |
+
|
145 |
+
save_root = f"/vision-nfs/torralba/scratch/jomat/sam_dataset/coco/filt/{args.split}"
|
146 |
+
os.makedirs(save_root, exist_ok=True)
|
147 |
+
|
148 |
+
if args.mode == "clip_feat":
|
149 |
+
feat_path = os.path.join(save_root, "clip_feat.pickle")
|
150 |
+
clip_feat = get_feat(feat_path, dataloader, filter)
|
151 |
+
|
152 |
+
if args.mode == "clip_logit":
|
153 |
+
clip_logit = get_clip_logit(save_root, dataloader, filter)
|
154 |
+
|
155 |
+
if args.mode == "clip_filt":
|
156 |
+
# if os.path.exists(clip_filt_file):
|
157 |
+
# with open(clip_filt_file, 'rb') as f:
|
158 |
+
# ret = pickle.load(f)
|
159 |
+
# else:
|
160 |
+
clip_filt_result = clip_filt(save_root, dataloader, filter)
|
161 |
+
|
162 |
+
if args.mode == "caption_filt":
|
163 |
+
caption_filt_result = caption_filt(save_root, dataloader, filter)
|
164 |
+
|
165 |
+
if args.mode == "gather_result":
|
166 |
+
filtered_result = gather_result(save_root, dataloader, filter)
|
167 |
+
|
168 |
+
print("finished",flush=True)
|
169 |
+
for file in error_files:
|
170 |
+
# os.remove(file)
|
171 |
+
print(file,flush=True)
|
172 |
+
|
173 |
+
if __name__ == "__main__":
|
174 |
+
args = parse_args()
|
175 |
+
|
176 |
+
log_file = "sam_filt"
|
177 |
+
idx=0
|
178 |
+
hostname = socket.gethostname()
|
179 |
+
now_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
|
180 |
+
while os.path.exists(f"{log_file}_{hostname}_check{args.check}_{now_time}_{idx}.log"):
|
181 |
+
idx+=1
|
182 |
+
|
183 |
+
main(args)
|
184 |
+
# clip_logits_analysis()
|
185 |
+
|
186 |
+
|
custom_datasets/filt/sam_filt.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Authors: Hui Ren (rhfeiyang.github.io)
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import pickle
|
7 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), "../../"))
|
8 |
+
from custom_datasets.sam import SamDataset
|
9 |
+
from utils.art_filter import Art_filter
|
10 |
+
import torch
|
11 |
+
from matplotlib import pyplot as plt
|
12 |
+
import math
|
13 |
+
import argparse
|
14 |
+
import socket
|
15 |
+
import time
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
def parse_args():
|
19 |
+
parser = argparse.ArgumentParser(description="Filter the sam dataset")
|
20 |
+
parser.add_argument("--check", action="store_true", help="Check the complete")
|
21 |
+
parser.add_argument("--mode", default="clip_logit", choices=["clip_logit_update","clip_logit", "clip_filt", "caption_filt", "gather_result","caption_flit_append"])
|
22 |
+
parser.add_argument("--start_idx", default=0, type=int, help="Start index")
|
23 |
+
parser.add_argument("--end_idx", default=9e10, type=int, help="Start index")
|
24 |
+
args = parser.parse_args()
|
25 |
+
return args
|
26 |
+
@torch.no_grad()
|
27 |
+
def main(args):
|
28 |
+
filter = Art_filter()
|
29 |
+
if args.mode == "caption_filt" or args.mode == "gather_result":
|
30 |
+
filter.clip_filter = None
|
31 |
+
torch.cuda.empty_cache()
|
32 |
+
|
33 |
+
caption_folder_path = "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/SAM/subset/captions"
|
34 |
+
image_folder_path = "/vision-nfs/torralba/scratch/jomat/sam_dataset/nfs-data/sam/images"
|
35 |
+
id_dict_dir = "/vision-nfs/torralba/scratch/jomat/sam_dataset/sam_ids/8.16/id_dict"
|
36 |
+
filt_dir = "/vision-nfs/torralba/scratch/jomat/sam_dataset/filt_result"
|
37 |
+
def collate_fn(examples):
|
38 |
+
# {"image": image, "id":id}
|
39 |
+
ret = {}
|
40 |
+
if "image" in examples[0]:
|
41 |
+
pixel_values = [example["image"] for example in examples]
|
42 |
+
ret["images"] = pixel_values
|
43 |
+
if "text" in examples[0]:
|
44 |
+
prompts = [example["text"] for example in examples]
|
45 |
+
ret["text"] = prompts
|
46 |
+
id = [example["id"] for example in examples]
|
47 |
+
ret["ids"] = id
|
48 |
+
return ret
|
49 |
+
error_files=[]
|
50 |
+
val_set = ["sa_000000"]
|
51 |
+
result_check_set = ["sa_000020"]
|
52 |
+
all_remain_ids=[]
|
53 |
+
all_remain_ids_train=[]
|
54 |
+
all_remain_ids_val=[]
|
55 |
+
all_filtered_id_num = 0
|
56 |
+
remain_feat_num = 0
|
57 |
+
remain_caption_num = 0
|
58 |
+
filter_feat_num = 0
|
59 |
+
filter_caption_num = 0
|
60 |
+
for idx,file in tqdm(enumerate(sorted(os.listdir(id_dict_dir)))):
|
61 |
+
if idx < args.start_idx or idx >= args.end_idx:
|
62 |
+
continue
|
63 |
+
if file.endswith(".pickle") and not file.startswith("all"):
|
64 |
+
print("=====================================")
|
65 |
+
print(file,flush=True)
|
66 |
+
save_dir = os.path.join(filt_dir, file.replace("_id_dict.pickle", ""))
|
67 |
+
if not os.path.exists(save_dir):
|
68 |
+
os.makedirs(save_dir, exist_ok=True)
|
69 |
+
id_dict_file = os.path.join(id_dict_dir, file)
|
70 |
+
with open(id_dict_file, 'rb') as f:
|
71 |
+
id_dict = pickle.load(f)
|
72 |
+
ids = list(id_dict.keys())
|
73 |
+
dataset = SamDataset(image_folder_path, caption_folder_path, id_file=ids, id_dict_file=id_dict_file)
|
74 |
+
# dataset = SamDataset(image_folder_path, caption_folder_path, id_file=[10061410, 10076945, 10310013,1042012, 4487809, 4541052], id_dict_file="/vision-nfs/torralba/scratch/jomat/sam_dataset/images/id_dict/all_id_dict.pickle")
|
75 |
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=False, num_workers=8, collate_fn=collate_fn)
|
76 |
+
clip_logits = None
|
77 |
+
clip_logits_file = os.path.join(save_dir, "clip_logits_result.pickle")
|
78 |
+
clip_filt_file = os.path.join(save_dir, "clip_filt_result.pickle")
|
79 |
+
caption_filt_file = os.path.join(save_dir, "caption_filt_result.pickle")
|
80 |
+
|
81 |
+
if args.mode == "clip_feat":
|
82 |
+
compute_new = False
|
83 |
+
clip_logits = {}
|
84 |
+
if os.path.exists(clip_logits_file):
|
85 |
+
with open(clip_logits_file, 'rb') as f:
|
86 |
+
clip_logits = pickle.load(f)
|
87 |
+
if "image_features" not in clip_logits:
|
88 |
+
compute_new = True
|
89 |
+
else:
|
90 |
+
compute_new=True
|
91 |
+
if compute_new:
|
92 |
+
if clip_logits == '':
|
93 |
+
clip_logits = {}
|
94 |
+
print(f"compute clip_feat {file}",flush=True)
|
95 |
+
clip_feature_ret = filter.clip_feature(dataloader)
|
96 |
+
clip_logits["image_features"] = clip_feature_ret["clip_features"]
|
97 |
+
if "ids" in clip_logits:
|
98 |
+
assert clip_feature_ret["ids"] == clip_logits["ids"]
|
99 |
+
else:
|
100 |
+
clip_logits["ids"] = clip_feature_ret["ids"]
|
101 |
+
|
102 |
+
with open(clip_logits_file, 'wb') as f:
|
103 |
+
pickle.dump(clip_logits, f)
|
104 |
+
print(f"clip_feat_result saved to {clip_logits_file}",flush=True)
|
105 |
+
else:
|
106 |
+
print(f"skip {clip_logits_file}",flush=True)
|
107 |
+
|
108 |
+
if args.mode == "clip_logit":
|
109 |
+
# if clip_logit:
|
110 |
+
if os.path.exists(clip_logits_file):
|
111 |
+
try:
|
112 |
+
with open(clip_logits_file, 'rb') as f:
|
113 |
+
clip_logits = pickle.load(f)
|
114 |
+
except:
|
115 |
+
continue
|
116 |
+
skip = True
|
117 |
+
if args.check and clip_logits=="":
|
118 |
+
skip = False
|
119 |
+
|
120 |
+
else:
|
121 |
+
skip = False
|
122 |
+
# skip = False
|
123 |
+
if not skip:
|
124 |
+
# os.makedirs(os.path.join(save_dir, "tmp"), exist_ok=True)
|
125 |
+
with open(clip_logits_file, 'wb') as f:
|
126 |
+
pickle.dump("", f)
|
127 |
+
try:
|
128 |
+
clip_logits = filter.clip_logit(dataloader)
|
129 |
+
except:
|
130 |
+
print(f"Error in clip_logit {file}",flush=True)
|
131 |
+
continue
|
132 |
+
with open(clip_logits_file, 'wb') as f:
|
133 |
+
pickle.dump(clip_logits, f)
|
134 |
+
print(f"clip_logits_result saved to {clip_logits_file}",flush=True)
|
135 |
+
else:
|
136 |
+
print(f"skip {clip_logits_file}",flush=True)
|
137 |
+
|
138 |
+
if args.mode == "clip_logit_update":
|
139 |
+
if os.path.exists(clip_logits_file):
|
140 |
+
with open(clip_logits_file, 'rb') as f:
|
141 |
+
clip_logits = pickle.load(f)
|
142 |
+
else:
|
143 |
+
print(f"{clip_logits_file} not exist",flush=True)
|
144 |
+
continue
|
145 |
+
if clip_logits == "":
|
146 |
+
print(f"skip {clip_logits_file}",flush=True)
|
147 |
+
continue
|
148 |
+
ret = filter.clip_logit_by_feat(clip_logits["clip_features"])
|
149 |
+
# assert (clip_logits["clip_logits"] - ret["clip_logits"]).abs().max() < 0.01
|
150 |
+
clip_logits["clip_logits"] = ret["clip_logits"]
|
151 |
+
clip_logits["text"] = ret["text"]
|
152 |
+
with open(clip_logits_file, 'wb') as f:
|
153 |
+
pickle.dump(clip_logits, f)
|
154 |
+
|
155 |
+
|
156 |
+
if args.mode == "clip_filt":
|
157 |
+
# if os.path.exists(clip_filt_file):
|
158 |
+
# with open(clip_filt_file, 'rb') as f:
|
159 |
+
# ret = pickle.load(f)
|
160 |
+
# else:
|
161 |
+
|
162 |
+
if clip_logits is None:
|
163 |
+
try:
|
164 |
+
with open(clip_logits_file, 'rb') as f:
|
165 |
+
clip_logits = pickle.load(f)
|
166 |
+
except:
|
167 |
+
print(f"Error in loading {clip_logits_file}",flush=True)
|
168 |
+
error_files.append(clip_logits_file)
|
169 |
+
continue
|
170 |
+
if clip_logits == "":
|
171 |
+
print(f"skip {clip_logits_file}",flush=True)
|
172 |
+
error_files.append(clip_logits_file)
|
173 |
+
continue
|
174 |
+
clip_filt_result = filter.clip_filt(clip_logits)
|
175 |
+
with open(clip_filt_file, 'wb') as f:
|
176 |
+
pickle.dump(clip_filt_result, f)
|
177 |
+
print(f"clip_filt_result saved to {clip_filt_file}",flush=True)
|
178 |
+
|
179 |
+
if args.mode == "caption_filt":
|
180 |
+
if os.path.exists(caption_filt_file):
|
181 |
+
try:
|
182 |
+
with open(caption_filt_file, 'rb') as f:
|
183 |
+
ret = pickle.load(f)
|
184 |
+
except:
|
185 |
+
continue
|
186 |
+
skip = True
|
187 |
+
if args.check and ret=="":
|
188 |
+
skip = False
|
189 |
+
# os.remove(caption_filt_file)
|
190 |
+
print(f"empty {caption_filt_file}",flush=True)
|
191 |
+
# skip = True
|
192 |
+
else:
|
193 |
+
skip = False
|
194 |
+
if not skip:
|
195 |
+
with open(caption_filt_file, 'wb') as f:
|
196 |
+
pickle.dump("", f)
|
197 |
+
# try:
|
198 |
+
ret = filter.caption_filt(dataloader)
|
199 |
+
# except:
|
200 |
+
# print(f"Error in filtering {file}",flush=True)
|
201 |
+
# continue
|
202 |
+
with open(caption_filt_file, 'wb') as f:
|
203 |
+
pickle.dump(ret, f)
|
204 |
+
print(f"caption_filt_result saved to {caption_filt_file}",flush=True)
|
205 |
+
else:
|
206 |
+
print(f"skip {caption_filt_file}",flush=True)
|
207 |
+
|
208 |
+
if args.mode == "caption_flit_append":
|
209 |
+
if not os.path.exists(caption_filt_file):
|
210 |
+
print(f"{caption_filt_file} not exist",flush=True)
|
211 |
+
continue
|
212 |
+
with open(caption_filt_file, 'rb') as f:
|
213 |
+
old_caption_filt_result = pickle.load(f)
|
214 |
+
skip = True
|
215 |
+
for i in filter.caption_filter.filter_prompts:
|
216 |
+
if i not in old_caption_filt_result["filter_prompts"]:
|
217 |
+
skip = False
|
218 |
+
break
|
219 |
+
if skip:
|
220 |
+
print(f"skip {caption_filt_file}",flush=True)
|
221 |
+
continue
|
222 |
+
old_remain_ids = old_caption_filt_result["remain_ids"]
|
223 |
+
new_dataset = SamDataset(image_folder_path, caption_folder_path, id_file=old_remain_ids, id_dict_file=id_dict_file)
|
224 |
+
new_dataloader = torch.utils.data.DataLoader(new_dataset, batch_size=64, shuffle=False, num_workers=8, collate_fn=collate_fn)
|
225 |
+
ret = filter.caption_filt(new_dataloader)
|
226 |
+
old_caption_filt_result["remain_ids"] = ret["remain_ids"]
|
227 |
+
old_caption_filt_result["filtered_ids"].extend(ret["filtered_ids"])
|
228 |
+
new_filter_count = ret["filter_count"].copy()
|
229 |
+
for i in range(len(old_caption_filt_result["filter_count"])):
|
230 |
+
new_filter_count[i] += old_caption_filt_result["filter_count"][i]
|
231 |
+
|
232 |
+
old_caption_filt_result["filter_count"] = new_filter_count
|
233 |
+
old_caption_filt_result["filter_prompts"] = ret["filter_prompts"]
|
234 |
+
with open(caption_filt_file, 'wb') as f:
|
235 |
+
pickle.dump(old_caption_filt_result, f)
|
236 |
+
|
237 |
+
|
238 |
+
|
239 |
+
if args.mode == "gather_result":
|
240 |
+
with open(clip_filt_file, 'rb') as f:
|
241 |
+
clip_filt_result = pickle.load(f)
|
242 |
+
with open(caption_filt_file, 'rb') as f:
|
243 |
+
caption_filt_result = pickle.load(f)
|
244 |
+
caption_filtered_ids = [i[0] for i in caption_filt_result["filtered_ids"]]
|
245 |
+
all_filtered_id_num += len(set(clip_filt_result["filtered_ids"]) | set(caption_filtered_ids) )
|
246 |
+
|
247 |
+
remain_feat_num += len(clip_filt_result["remain_ids"])
|
248 |
+
remain_caption_num += len(caption_filt_result["remain_ids"])
|
249 |
+
filter_feat_num += len(clip_filt_result["filtered_ids"])
|
250 |
+
filter_caption_num += len(caption_filtered_ids)
|
251 |
+
|
252 |
+
remain_ids = set(clip_filt_result["remain_ids"]) & set(caption_filt_result["remain_ids"])
|
253 |
+
remain_ids = list(remain_ids)
|
254 |
+
remain_ids.sort()
|
255 |
+
# with open(os.path.join(save_dir, "remain_ids.pickle"), 'wb') as f:
|
256 |
+
# pickle.dump(remain_ids, f)
|
257 |
+
# print(f"remain_ids saved to {save_dir}/remain_ids.pickle",flush=True)
|
258 |
+
all_remain_ids.extend(remain_ids)
|
259 |
+
if file.replace("_id_dict.pickle","") in val_set:
|
260 |
+
all_remain_ids_val.extend(remain_ids)
|
261 |
+
else:
|
262 |
+
all_remain_ids_train.extend(remain_ids)
|
263 |
+
if args.mode == "gather_result":
|
264 |
+
print(f"filtered ids: {all_filtered_id_num}",flush=True)
|
265 |
+
print(f"remain feat num: {remain_feat_num}",flush=True)
|
266 |
+
print(f"remain caption num: {remain_caption_num}",flush=True)
|
267 |
+
print(f"filter feat num: {filter_feat_num}",flush=True)
|
268 |
+
print(f"filter caption num: {filter_caption_num}",flush=True)
|
269 |
+
all_remain_ids.sort()
|
270 |
+
with open(os.path.join(filt_dir, "all_remain_ids.pickle"), 'wb') as f:
|
271 |
+
pickle.dump(all_remain_ids, f)
|
272 |
+
with open(os.path.join(filt_dir, "all_remain_ids_train.pickle"), 'wb') as f:
|
273 |
+
pickle.dump(all_remain_ids_train, f)
|
274 |
+
with open(os.path.join(filt_dir, "all_remain_ids_val.pickle"), 'wb') as f:
|
275 |
+
pickle.dump(all_remain_ids_val, f)
|
276 |
+
|
277 |
+
print(f"all_remain_ids saved to {filt_dir}/all_remain_ids.pickle",flush=True)
|
278 |
+
print(f"all_remain_ids_train saved to {filt_dir}/all_remain_ids_train.pickle",flush=True)
|
279 |
+
print(f"all_remain_ids_val saved to {filt_dir}/all_remain_ids_val.pickle",flush=True)
|
280 |
+
|
281 |
+
print("finished",flush=True)
|
282 |
+
for file in error_files:
|
283 |
+
# os.remove(file)
|
284 |
+
print(file,flush=True)
|
285 |
+
|
286 |
+
if __name__ == "__main__":
|
287 |
+
args = parse_args()
|
288 |
+
|
289 |
+
log_file = "sam_filt"
|
290 |
+
idx=0
|
291 |
+
hostname = socket.gethostname()
|
292 |
+
now_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
|
293 |
+
while os.path.exists(f"{log_file}_{hostname}_check{args.check}_{now_time}_{idx}.log"):
|
294 |
+
idx+=1
|
295 |
+
|
296 |
+
main(args)
|
297 |
+
# clip_logits_analysis()
|
298 |
+
|
299 |
+
|
custom_datasets/imagepair.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Authors: Hui Ren (rhfeiyang.github.io)
|
2 |
+
import random
|
3 |
+
|
4 |
+
import torch.utils.data as data
|
5 |
+
from PIL import Image
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
# from tqdm import tqdm
|
9 |
+
class ImageSet(data.Dataset):
|
10 |
+
def __init__(self, folder , transform=None, keep_in_mem=True, caption=None):
|
11 |
+
self.path = folder
|
12 |
+
self.transform = transform
|
13 |
+
self.caption_path = None
|
14 |
+
self.images = []
|
15 |
+
self.captions = []
|
16 |
+
self.keep_in_mem = keep_in_mem
|
17 |
+
|
18 |
+
if not isinstance(folder, list):
|
19 |
+
self.image_files = [file for file in os.listdir(folder) if file.endswith((".png",".jpg"))]
|
20 |
+
self.image_files.sort()
|
21 |
+
else:
|
22 |
+
self.images = folder
|
23 |
+
|
24 |
+
if not isinstance(caption, list):
|
25 |
+
if caption not in [None, "", "None"]:
|
26 |
+
self.caption_path = caption
|
27 |
+
self.caption_files = [os.path.join(caption, file.replace(".png", ".txt").replace(".jpg", ".txt")) for file in self.image_files]
|
28 |
+
self.caption_files.sort()
|
29 |
+
else:
|
30 |
+
self.caption_path = True
|
31 |
+
self.captions = caption
|
32 |
+
# get all the image files png/jpg
|
33 |
+
|
34 |
+
|
35 |
+
if keep_in_mem:
|
36 |
+
if len(self.images) == 0:
|
37 |
+
for file in self.image_files:
|
38 |
+
img = self.load_image(os.path.join(self.path, file))
|
39 |
+
self.images.append(img)
|
40 |
+
if len(self.captions) == 0:
|
41 |
+
if self.caption_path is not None:
|
42 |
+
self.captions = []
|
43 |
+
for file in self.caption_files:
|
44 |
+
caption = self.load_caption(file)
|
45 |
+
self.captions.append(caption)
|
46 |
+
else:
|
47 |
+
self.images = None
|
48 |
+
|
49 |
+
def limit_num(self, n):
|
50 |
+
raise NotImplementedError
|
51 |
+
assert n <= len(self), f"n should be less than the length of the dataset {len(self)}"
|
52 |
+
self.image_files = self.image_files[:n]
|
53 |
+
self.caption_files = self.caption_files[:n]
|
54 |
+
if self.keep_in_mem:
|
55 |
+
self.images = self.images[:n]
|
56 |
+
self.captions = self.captions[:n]
|
57 |
+
print(f"Dataset limited to {n}")
|
58 |
+
|
59 |
+
def __len__(self):
|
60 |
+
if len(self.images) != 0:
|
61 |
+
return len(self.images)
|
62 |
+
else:
|
63 |
+
return len(self.image_files)
|
64 |
+
|
65 |
+
def load_image(self, path):
|
66 |
+
with open(path, 'rb') as f:
|
67 |
+
img = Image.open(f).convert('RGB')
|
68 |
+
return img
|
69 |
+
|
70 |
+
def load_caption(self, path):
|
71 |
+
with open(path, 'r') as f:
|
72 |
+
caption = f.readlines()
|
73 |
+
caption = [line.strip() for line in caption if len(line.strip()) > 0]
|
74 |
+
return caption
|
75 |
+
|
76 |
+
def __getitem__(self, index):
|
77 |
+
if len(self.images) != 0:
|
78 |
+
img = self.images[index]
|
79 |
+
else:
|
80 |
+
img = self.load_image(os.path.join(self.path, self.image_files[index]))
|
81 |
+
|
82 |
+
# if self.transform is not None:
|
83 |
+
# img = self.transform(img)
|
84 |
+
|
85 |
+
if self.caption_path is not None or len(self.captions) != 0:
|
86 |
+
if len(self.captions) != 0:
|
87 |
+
caption = self.captions[index]
|
88 |
+
else:
|
89 |
+
caption = self.load_caption(self.caption_files[index])
|
90 |
+
ret= {"image": img, "caption": caption, "id": index}
|
91 |
+
else:
|
92 |
+
ret= {"image": img, "id": index}
|
93 |
+
if self.transform is not None:
|
94 |
+
ret = self.transform(ret)
|
95 |
+
return ret
|
96 |
+
|
97 |
+
def subsample(self, n: int = 10):
|
98 |
+
if n is None or n == -1:
|
99 |
+
return self
|
100 |
+
ori_len = len(self)
|
101 |
+
assert n <= ori_len
|
102 |
+
# equal interval subsample
|
103 |
+
ids = self.image_files[::ori_len // n][:n]
|
104 |
+
self.image_files = ids
|
105 |
+
if self.keep_in_mem:
|
106 |
+
self.images = self.images[::ori_len // n][:n]
|
107 |
+
print(f"Dataset subsampled from {ori_len} to {len(self)}")
|
108 |
+
return self
|
109 |
+
|
110 |
+
def with_transform(self, transform):
|
111 |
+
self.transform = transform
|
112 |
+
return self
|
113 |
+
@staticmethod
|
114 |
+
def collate_fn(examples):
|
115 |
+
images = [example["image"] for example in examples]
|
116 |
+
ids = [example["id"] for example in examples]
|
117 |
+
if "caption" in examples[0]:
|
118 |
+
captions = [random.choice(example["caption"]) for example in examples]
|
119 |
+
return {"images": images, "captions": captions, "id": ids}
|
120 |
+
else:
|
121 |
+
return {"images": images, "id": ids}
|
122 |
+
|
123 |
+
|
124 |
+
class ImagePair(ImageSet):
|
125 |
+
def __init__(self, folder1, folder2, transform=None, keep_in_mem=True):
|
126 |
+
self.path1 = folder1
|
127 |
+
self.path2 = folder2
|
128 |
+
self.transform = transform
|
129 |
+
# get all the image files png/jpg
|
130 |
+
self.image_files = [file for file in os.listdir(folder1) if file.endswith(".png") or file.endswith(".jpg")]
|
131 |
+
self.image_files.sort()
|
132 |
+
self.keep_in_mem = keep_in_mem
|
133 |
+
if keep_in_mem:
|
134 |
+
self.images = []
|
135 |
+
for file in self.image_files:
|
136 |
+
img1 = self.load_image(os.path.join(self.path1, file))
|
137 |
+
img2 = self.load_image(os.path.join(self.path2, file))
|
138 |
+
self.images.append((img1, img2))
|
139 |
+
else:
|
140 |
+
self.images = None
|
141 |
+
|
142 |
+
def __getitem__(self, index):
|
143 |
+
if self.keep_in_mem:
|
144 |
+
img1, img2 = self.images[index]
|
145 |
+
else:
|
146 |
+
img1 = self.load_image(os.path.join(self.path1, self.image_files[index]))
|
147 |
+
img2 = self.load_image(os.path.join(self.path2, self.image_files[index]))
|
148 |
+
|
149 |
+
if self.transform is not None:
|
150 |
+
img1 = self.transform(img1)
|
151 |
+
img2 = self.transform(img2)
|
152 |
+
return {"image1": img1, "image2": img2, "id": index}
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
@staticmethod
|
157 |
+
def collate_fn(examples):
|
158 |
+
images1 = [example["image1"] for example in examples]
|
159 |
+
images2 = [example["image2"] for example in examples]
|
160 |
+
# images1 = torch.stack(images1)
|
161 |
+
# images2 = torch.stack(images2)
|
162 |
+
ids = [example["id"] for example in examples]
|
163 |
+
return {"image1": images1, "image2": images2, "id": ids}
|
164 |
+
|
165 |
+
def push_to_huggingface(self, hug_folder):
|
166 |
+
from datasets import Dataset
|
167 |
+
from datasets import Image as HugImage
|
168 |
+
photo_path = [os.path.join(self.path1, file) for file in self.image_files]
|
169 |
+
sketch_path = [os.path.join(self.path2, file) for file in self.image_files]
|
170 |
+
dataset = Dataset.from_dict({"photo": photo_path, "sketch": sketch_path, "file_name": self.image_files})
|
171 |
+
dataset = dataset.cast_column("photo", HugImage())
|
172 |
+
dataset = dataset.cast_column("sketch", HugImage())
|
173 |
+
dataset.push_to_hub(hug_folder, private=True)
|
174 |
+
|
175 |
+
class ImageClass(ImageSet):
|
176 |
+
def __init__(self, folders: list, transform=None, keep_in_mem=True):
|
177 |
+
self.paths = folders
|
178 |
+
self.transform = transform
|
179 |
+
# get all the image files png/jpg
|
180 |
+
self.image_files = []
|
181 |
+
self.keep_in_mem = keep_in_mem
|
182 |
+
for i, folder in enumerate(folders):
|
183 |
+
self.image_files+=[(os.path.join(folder, file), i) for file in os.listdir(folder) if file.endswith(".png") or file.endswith(".jpg")]
|
184 |
+
if keep_in_mem:
|
185 |
+
self.images = []
|
186 |
+
print("Loading images to memory")
|
187 |
+
for file in self.image_files:
|
188 |
+
img = self.load_image(file[0])
|
189 |
+
self.images.append((img, file[1]))
|
190 |
+
print("Loading images to memory done")
|
191 |
+
else:
|
192 |
+
self.images = None
|
193 |
+
|
194 |
+
def __getitem__(self, index):
|
195 |
+
if self.keep_in_mem:
|
196 |
+
img, label = self.images[index]
|
197 |
+
else:
|
198 |
+
img_path, label = self.image_files[index]
|
199 |
+
img = self.load_image(img_path)
|
200 |
+
|
201 |
+
if self.transform is not None:
|
202 |
+
img = self.transform(img)
|
203 |
+
return {"image": img, "label": label, "id": index}
|
204 |
+
|
205 |
+
@staticmethod
|
206 |
+
def collate_fn(examples):
|
207 |
+
images = [example["image"] for example in examples]
|
208 |
+
labels = [example["label"] for example in examples]
|
209 |
+
ids = [example["id"] for example in examples]
|
210 |
+
return {"images": images, "labels":labels, "id": ids}
|
211 |
+
|
212 |
+
|
213 |
+
if __name__ == "__main__":
|
214 |
+
# dataset = ImagePair("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_50",
|
215 |
+
# "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/sketch_50",keep_in_mem=False)
|
216 |
+
# dataset.push_to_huggingface("rhfeiyang/photo-sketch-pair-50")
|
217 |
+
|
218 |
+
|
219 |
+
|
220 |
+
dataset = ImagePair("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_500",
|
221 |
+
"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/sketch_500",
|
222 |
+
keep_in_mem=True)
|
223 |
+
# dataset.push_to_huggingface("rhfeiyang/photo-sketch-pair-500")
|
224 |
+
# ret = dataset[0]
|
225 |
+
# print(len(dataset))
|
226 |
+
import torch
|
227 |
+
from torchvision import transforms
|
228 |
+
train_transforms = transforms.Compose(
|
229 |
+
[
|
230 |
+
transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR),
|
231 |
+
transforms.CenterCrop(256),
|
232 |
+
transforms.RandomHorizontalFlip(),
|
233 |
+
transforms.ToTensor(),
|
234 |
+
transforms.Normalize([0.5], [0.5]),
|
235 |
+
]
|
236 |
+
)
|
237 |
+
dataset = dataset.with_transform(train_transforms)
|
238 |
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, collate_fn=ImagePair.collate_fn)
|
239 |
+
ret = dataloader.__iter__().__next__()
|
240 |
+
pass
|
custom_datasets/lhq.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Authors: Hui Ren (rhfeiyang.github.io)
|
2 |
+
import os
|
3 |
+
import pickle
|
4 |
+
import random
|
5 |
+
import shutil
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
from torchvision import transforms
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
class LhqDataset(Dataset):
|
11 |
+
def __init__(self, image_folder_path:str, caption_folder_path:str, id_file:str = "clip_dissection/lhq/idx/subsample_100.pickle", transforms: transforms = None,
|
12 |
+
get_img=True,
|
13 |
+
get_cap=True,):
|
14 |
+
|
15 |
+
if isinstance(id_file, list):
|
16 |
+
self.ids = id_file
|
17 |
+
elif isinstance(id_file, str):
|
18 |
+
with open(id_file, 'rb') as f:
|
19 |
+
print(f"Loading ids from {id_file}", flush=True)
|
20 |
+
self.ids = pickle.load(f)
|
21 |
+
print(f"Loaded ids from {id_file}", flush=True)
|
22 |
+
self.image_folder_path = image_folder_path
|
23 |
+
self.caption_folder_path = caption_folder_path
|
24 |
+
self.transforms = transforms
|
25 |
+
self.column_names = ["image", "text"]
|
26 |
+
self.get_img = get_img
|
27 |
+
self.get_cap = get_cap
|
28 |
+
|
29 |
+
def __len__(self):
|
30 |
+
return len(self.ids)
|
31 |
+
|
32 |
+
def __getitem__(self, index: int):
|
33 |
+
id = self.ids[index]
|
34 |
+
ret={"id":id}
|
35 |
+
if self.get_img:
|
36 |
+
image = self._load_image(id)
|
37 |
+
ret["image"]=image
|
38 |
+
if self.get_cap:
|
39 |
+
target = self._load_caption(id)
|
40 |
+
ret["caption"]=[target]
|
41 |
+
if self.transforms is not None:
|
42 |
+
ret = self.transforms(ret)
|
43 |
+
return ret
|
44 |
+
|
45 |
+
def _load_image(self, id: int):
|
46 |
+
image_path = f"{self.image_folder_path}/{id}.jpg"
|
47 |
+
with open(image_path, 'rb') as f:
|
48 |
+
img = Image.open(f).convert("RGB")
|
49 |
+
return img
|
50 |
+
|
51 |
+
def _load_caption(self, id: int):
|
52 |
+
caption_path = f"{self.caption_folder_path}/{id}.txt"
|
53 |
+
with open(caption_path, 'r') as f:
|
54 |
+
caption_file = f.read()
|
55 |
+
caption = []
|
56 |
+
for line in caption_file.split("\n"):
|
57 |
+
line = line.strip()
|
58 |
+
if len(line) > 0:
|
59 |
+
caption.append(line)
|
60 |
+
return caption
|
61 |
+
|
62 |
+
def subsample(self, n: int = 10000):
|
63 |
+
if n is None or n == -1:
|
64 |
+
return self
|
65 |
+
ori_len = len(self)
|
66 |
+
assert n <= ori_len
|
67 |
+
# equal interval subsample
|
68 |
+
ids = self.ids[::ori_len // n][:n]
|
69 |
+
self.ids = ids
|
70 |
+
print(f"LHQ dataset subsampled from {ori_len} to {len(self)}")
|
71 |
+
return self
|
72 |
+
|
73 |
+
def with_transform(self, transform):
|
74 |
+
self.transforms = transform
|
75 |
+
return self
|
76 |
+
|
77 |
+
|
78 |
+
def generate_idx(data_folder = "/data/vision/torralba/clip_dissection/huiren/lhq/lhq_1024_jpg/lhq_1024_jpg/", save_path = "/data/vision/torralba/clip_dissection/huiren/lhq/idx/all_ids.pickle"):
|
79 |
+
all_ids = os.listdir(data_folder)
|
80 |
+
all_ids = [i.split(".")[0] for i in all_ids if i.endswith(".jpg") or i.endswith(".png")]
|
81 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
82 |
+
pickle.dump(all_ids, open(f"{save_path}", "wb"))
|
83 |
+
print("all_ids generated")
|
84 |
+
return all_ids
|
85 |
+
|
86 |
+
def random_sample(all_ids, sample_num = 110, save_root = "/data/vision/torralba/clip_dissection/huiren/lhq/subsample"):
|
87 |
+
chosen_id = random.sample(all_ids, sample_num)
|
88 |
+
save_dir = f"{save_root}/{sample_num}"
|
89 |
+
os.makedirs(save_dir, exist_ok=True)
|
90 |
+
for id in chosen_id:
|
91 |
+
img_path = f"/data/vision/torralba/clip_dissection/huiren/lhq/lhq_1024_jpg/lhq_1024_jpg/{id}.jpg"
|
92 |
+
shutil.copy(img_path, save_dir)
|
93 |
+
|
94 |
+
return chosen_id
|
95 |
+
|
96 |
+
if __name__ == "__main__":
|
97 |
+
# all_ids = generate_idx()
|
98 |
+
# with open("/data/vision/torralba/clip_dissection/huiren/lhq/idx/all_ids.pickle", "rb") as f:
|
99 |
+
# all_ids = pickle.load(f)
|
100 |
+
# # random_sample(all_ids, 1)
|
101 |
+
#
|
102 |
+
# # generate_idx(data_folder="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/subsample/100",
|
103 |
+
# # save_path="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/idx/subsample_100.pickle")
|
104 |
+
#
|
105 |
+
# # lhq 500
|
106 |
+
# with open("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/idx/subsample_100.pickle", "rb") as f:
|
107 |
+
# lhq_100_idx = pickle.load(f)
|
108 |
+
#
|
109 |
+
# extra_idx = set(all_ids) - set(lhq_100_idx)
|
110 |
+
# add_idx = random.sample(extra_idx, 400)
|
111 |
+
# lhq_500_idx = lhq_100_idx + add_idx
|
112 |
+
# with open("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/idx/subsample_500.pickle", "wb") as f:
|
113 |
+
# pickle.dump(lhq_500_idx, f)
|
114 |
+
# save_dir = "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/subsample/500"
|
115 |
+
# os.makedirs(save_dir, exist_ok=True)
|
116 |
+
# for id in lhq_500_idx:
|
117 |
+
# img_path = f"/data/vision/torralba/clip_dissection/huiren/lhq/lhq_1024_jpg/lhq_1024_jpg/{id}.jpg"
|
118 |
+
# # softlink
|
119 |
+
# os.symlink(img_path, os.path.join(save_dir, f"{id}.jpg"))
|
120 |
+
|
121 |
+
# lhq9
|
122 |
+
all_ids = generate_idx(data_folder="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/subsample/9",
|
123 |
+
save_path="/data/vision/torralba/clip_dissection/huiren/lhq/idx/subsample_9.pickle")
|
124 |
+
print(all_ids)
|
125 |
+
|
126 |
+
|
127 |
+
|
custom_datasets/mypath.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
|
4 |
+
class MyPath(object):
|
5 |
+
@staticmethod
|
6 |
+
def db_root_dir(database=''):
|
7 |
+
coco_root = "/data/vision/torralba/datasets/coco_2017"
|
8 |
+
sam_caption_root = "/vision-nfs/torralba/datasets/vision/sam/captions"
|
9 |
+
|
10 |
+
root_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
11 |
+
map={
|
12 |
+
"coco_train": f"{coco_root}/train2017/",
|
13 |
+
"coco_caption_train": f"{coco_root}/annotations/captions_train2017.json",
|
14 |
+
"coco_val": f"{coco_root}/val2017/",
|
15 |
+
"coco_caption_val": f"{coco_root}/annotations/captions_val2017.json",
|
16 |
+
"sam_images": "/vision-nfs/torralba/datasets/vision/sam/images",
|
17 |
+
"sam_captions": sam_caption_root,
|
18 |
+
"sam_whole_filtered_ids_train": "data/filtered_sam/all_remain_ids_train.pickle",
|
19 |
+
"sam_whole_filtered_ids_val": "data/filtered_sam/all_remain_ids_val.pickle",
|
20 |
+
"sam_id_dict": "data/filtered_sam/all_id_dict.pickle",
|
21 |
+
|
22 |
+
"lhq_ids_sub500": "data/LHQ500_caption/idx/subsample_500.pickle",
|
23 |
+
"lhq_images": "data/LHQ500_caption/subsample_500",
|
24 |
+
"lhq_captions": "data/LHQ500_caption/captions",
|
25 |
+
}
|
26 |
+
ret = map.get(database, None)
|
27 |
+
if ret is None:
|
28 |
+
raise NotImplementedError
|
29 |
+
return ret
|
custom_datasets/sam.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Authors: Hui Ren (rhfeiyang.github.io)
|
2 |
+
import os.path
|
3 |
+
import sys
|
4 |
+
from typing import Any, Callable, List, Optional, Tuple
|
5 |
+
|
6 |
+
import tqdm
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from torch.utils.data import Dataset
|
10 |
+
import pickle
|
11 |
+
from torchvision import transforms
|
12 |
+
# import torch
|
13 |
+
# import torchvision
|
14 |
+
# import re
|
15 |
+
|
16 |
+
|
17 |
+
class SamDataset(Dataset):
|
18 |
+
def __init__(self, image_folder_path:str, caption_folder_path:str, id_file:str = "data/sam/clip_filtered_ids.pickle",id_dict_file:str =None , transforms: Optional[Callable] = None,
|
19 |
+
resolution=None,
|
20 |
+
get_img=True,
|
21 |
+
get_cap=True,):
|
22 |
+
if id_dict_file is not None:
|
23 |
+
with open(id_dict_file, 'rb') as f:
|
24 |
+
print(f"Loading id_dict from {id_dict_file}", flush=True)
|
25 |
+
self.id_dict = pickle.load(f)
|
26 |
+
print(f"Loaded id_dict from {id_dict_file}", flush=True)
|
27 |
+
else:
|
28 |
+
self.id_dict = None
|
29 |
+
if isinstance(id_file, list):
|
30 |
+
self.ids = id_file
|
31 |
+
elif isinstance(id_file, str):
|
32 |
+
with open(id_file, 'rb') as f:
|
33 |
+
print(f"Loading ids from {id_file}", flush=True)
|
34 |
+
self.ids = pickle.load(f)
|
35 |
+
print(f"Loaded ids from {id_file}", flush=True)
|
36 |
+
self.resolution = resolution
|
37 |
+
self.ori_image_folder_path = image_folder_path
|
38 |
+
if self.resolution is not None:
|
39 |
+
if os.path.exists("/var/jomat/datasets/"):
|
40 |
+
# self.image_folder_path = f"/var/jomat/datasets/SAM_{resolution}"
|
41 |
+
self.image_folder_path = f"{image_folder_path}_{resolution}"
|
42 |
+
else:
|
43 |
+
self.image_folder_path = f"{image_folder_path}_{resolution}"
|
44 |
+
os.makedirs(self.image_folder_path, exist_ok=True)
|
45 |
+
else:
|
46 |
+
self.image_folder_path = image_folder_path
|
47 |
+
self.caption_folder_path = caption_folder_path
|
48 |
+
self.transforms = transforms
|
49 |
+
self.column_names = ["image", "text"]
|
50 |
+
self.get_img = get_img
|
51 |
+
self.get_cap = get_cap
|
52 |
+
|
53 |
+
def __len__(self):
|
54 |
+
# return 100
|
55 |
+
return len(self.ids)
|
56 |
+
|
57 |
+
def __getitem__(self, index: int):
|
58 |
+
id = self.ids[index]
|
59 |
+
ret={"id":id}
|
60 |
+
try:
|
61 |
+
# if index == 1:
|
62 |
+
# raise Exception("test")
|
63 |
+
if self.get_img:
|
64 |
+
image = self._load_image(id)
|
65 |
+
ret["image"]=image
|
66 |
+
if self.get_cap:
|
67 |
+
target = self._load_caption(id)
|
68 |
+
ret["text"] = [target]
|
69 |
+
if self.transforms is not None:
|
70 |
+
ret = self.transforms(ret)
|
71 |
+
return ret
|
72 |
+
except Exception as e:
|
73 |
+
raise e
|
74 |
+
print(f"Error loading image and caption for id {id}, error: {e}, redirecting to index 0", flush=True)
|
75 |
+
ret = self[0]
|
76 |
+
return ret
|
77 |
+
|
78 |
+
def define_resolution(self, resolution: int):
|
79 |
+
self.resolution = resolution
|
80 |
+
if os.path.exists("/var/jomat/datasets/"):
|
81 |
+
self.image_folder_path = f"/var/jomat/datasets/SAM_{resolution}"
|
82 |
+
# self.image_folder_path = f"{self.ori_image_folder_path}_{resolution}"
|
83 |
+
else:
|
84 |
+
self.image_folder_path = f"{self.ori_image_folder_path}_{resolution}"
|
85 |
+
print(f"SamDataset resolution defined to {resolution}, new image folder path: {self.image_folder_path}")
|
86 |
+
def _load_image(self, id: int) -> Image.Image:
|
87 |
+
if self.id_dict is not None:
|
88 |
+
subfolder = self.id_dict[id]
|
89 |
+
image_path = f"{self.image_folder_path}/{subfolder}/sa_{id}.jpg"
|
90 |
+
else:
|
91 |
+
image_path = f"{self.image_folder_path}/sa_{id}.jpg"
|
92 |
+
|
93 |
+
try:
|
94 |
+
with open(image_path, 'rb') as f:
|
95 |
+
img = Image.open(f).convert("RGB")
|
96 |
+
# return img
|
97 |
+
except:
|
98 |
+
# load original image
|
99 |
+
if self.id_dict is not None:
|
100 |
+
subfolder = self.id_dict[id]
|
101 |
+
ori_image_path = f"{self.ori_image_folder_path}/{subfolder}/sa_{id}.jpg"
|
102 |
+
else:
|
103 |
+
ori_image_path = f"{self.ori_image_folder_path}/sa_{id}.jpg"
|
104 |
+
assert os.path.exists(ori_image_path)
|
105 |
+
with open(ori_image_path, 'rb') as f:
|
106 |
+
img = Image.open(f).convert("RGB")
|
107 |
+
# resize image keep aspect ratio
|
108 |
+
if self.resolution is not None:
|
109 |
+
img = transforms.Resize(self.resolution, interpolation=transforms.InterpolationMode.BICUBIC)(img)
|
110 |
+
# write image
|
111 |
+
os.makedirs(os.path.dirname(image_path), exist_ok=True)
|
112 |
+
img.save(image_path)
|
113 |
+
|
114 |
+
return img
|
115 |
+
|
116 |
+
|
117 |
+
def _load_caption(self, id: int):
|
118 |
+
caption_path = f"{self.caption_folder_path}/sa_{id}.txt"
|
119 |
+
if not os.path.exists(caption_path):
|
120 |
+
return None
|
121 |
+
try:
|
122 |
+
with open(caption_path, 'r', encoding="utf-8") as f:
|
123 |
+
content = f.read()
|
124 |
+
except Exception as e:
|
125 |
+
raise e
|
126 |
+
print(f"Error reading caption file {caption_path}, error: {e}")
|
127 |
+
return None
|
128 |
+
sentences = content.split('.')
|
129 |
+
# remove empty sentences and sentences with "black and white"(too many false prediction)
|
130 |
+
sentences = [sentence.strip() for sentence in sentences if sentence.strip() and "black and white" not in sentence]
|
131 |
+
# join sentence
|
132 |
+
sentences = ". ".join(sentences)
|
133 |
+
if len(sentences) > 0 and sentences[-1] != '.':
|
134 |
+
sentences += '.'
|
135 |
+
|
136 |
+
return sentences
|
137 |
+
|
138 |
+
def with_transform(self, transform):
|
139 |
+
self.transforms = transform
|
140 |
+
return self
|
141 |
+
|
142 |
+
def subsample(self, n: int = 10000):
|
143 |
+
if n is None or n == -1:
|
144 |
+
return self
|
145 |
+
ori_len = len(self)
|
146 |
+
assert n <= ori_len
|
147 |
+
# equal interval subsample
|
148 |
+
ids = self.ids[::ori_len // n][:n]
|
149 |
+
self.ids = ids
|
150 |
+
print(f"SAM dataset subsampled from {ori_len} to {len(self)}")
|
151 |
+
return self
|
152 |
+
|
153 |
+
|
154 |
+
if __name__ == "__main__":
|
155 |
+
# sam_filt(caption_filt=False, clip_filt=False, clip_logit=True)
|
156 |
+
from custom_datasets.sam_caption.mypath import MyPath
|
157 |
+
dataset = SamDataset(image_folder_path=MyPath.db_root_dir("sam_images"), caption_folder_path=MyPath.db_root_dir("sam_captions"), id_file=MyPath.db_root_dir("sam_whole_filtered_ids_train"), id_dict_file=MyPath.db_root_dir("sam_id_dict"))
|
158 |
+
dataset.get_img = False
|
159 |
+
for i in tqdm.tqdm(dataset):
|
160 |
+
a=i['text']
|
data/Art_adapters/albert-gleizes_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1802d12e4d9526eedb89d99f69051849f14774da3c73ebc9b1393c2b13f17022
|
3 |
+
size 2187129
|
data/Art_adapters/andre-derain_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4c39b39f32ff88dfed978ccc651715ade9edfd901d529adbeb5eedb715b8e159
|
3 |
+
size 2187129
|
data/Art_adapters/andy_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fd7764b19a2b4513b3c22f1607d72daa63c4ace97ea803e29e2bcf3f13bab2e8
|
3 |
+
size 2187129
|
data/Art_adapters/camille-corot_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:426c2e4a3bfc26f7fdcc3e82989d717fa5fc6e732cd9df9f8bb293ab72cacfa5
|
3 |
+
size 2187129
|
data/Art_adapters/gerhard-richter_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8be8ef590baceb2bdfac8b25976df88fa7baa1a9c718ed16aa4fa8fa247bb421
|
3 |
+
size 2187129
|
data/Art_adapters/henri-matisse_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:212f0f16ae84c0bae96e213a0b0d5f4309209b332d48cbaa1748b5cdcfb3238a
|
3 |
+
size 2187129
|
data/Art_adapters/jackson-pollock_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5cff54e3e7c544577dbc39d7015a89c4786cd012cf944d0b9db334c1a1d7e30b
|
3 |
+
size 2187129
|
data/Art_adapters/joan-miro_subset2/adapter_alpha1.0_rank1_all_up_1000steps.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c26bdb5bfba85b4eb00631eda149912ba557935773842f95c0596999f799a2b4
|
3 |
+
size 2187129
|
data/Art_adapters/kandinsky_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:24b33205841d9b09c0076b4ba295be29d94677e69b7269465897bbf059a40454
|
3 |
+
size 2187129
|
data/Art_adapters/katsushika-hokusai_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b34b75325c3fd0353b55f390027a32a98f771df7d2fb21dbd8bce81a12ba59e9
|
3 |
+
size 2187129
|
data/Art_adapters/klimt_subset3/adapter_alpha1.0_rank1_all_up_1000steps.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7457f14af7c77f98675063582b35317963d46e942459575d38b5996ed190c58f
|
3 |
+
size 2187129
|
data/Art_adapters/m.c.-escher_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c6df86764f4d4ceec0bd6124a74a51c36665c8491511a5488737b9a64300b97b
|
3 |
+
size 2187129
|
data/Art_adapters/monet_subset2/adapter_alpha1.0_rank1_all_up_1000steps.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6a9ba0305edca3286258a06023b97914b850fbc8b4f5a14769537f9a01ef33f1
|
3 |
+
size 2187129
|
data/Art_adapters/picasso_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8ce7899c19b32dacd2dc46090fd3429495a2230c173bcd96149236d27b5151fd
|
3 |
+
size 2187129
|
data/Art_adapters/roy-lichtenstein_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5ac428a5d0fb136b79eec2349fbcbd99dfac2315c0a7f54d7985299b60b6f66f
|
3 |
+
size 2187129
|
data/Art_adapters/van_gogh_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3ca866dd868fb89a1180bb140dfaf1e48701993c8fa173d70c56c60c9af8d8fb
|
3 |
+
size 2187129
|
data/Art_adapters/walter-battiss_subset2/adapter_alpha1.0_rank1_all_up_1000steps.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:41cad39d7b6e1873cfef85be478851820f5dc80cd7ce11afe2bfa3584662e3ac
|
3 |
+
size 2187129
|
data/unsafe.png
ADDED
hf_demo.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Authors: Hui Ren (rhfeiyang.github.io)
|
2 |
+
import os
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
from diffusers import DiffusionPipeline
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import torch
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
+
pipe = DiffusionPipeline.from_pretrained("rhfeiyang/art-free-diffusion-v1",).to(device)
|
14 |
+
|
15 |
+
from inference import get_lora_network, inference, get_validation_dataloader
|
16 |
+
lora_map = {
|
17 |
+
"None": "None",
|
18 |
+
"Andre Derain": "andre-derain_subset1",
|
19 |
+
"Vincent van Gogh": "van_gogh_subset1",
|
20 |
+
"Andy Warhol": "andy_subset1",
|
21 |
+
"Walter Battiss": "walter-battiss_subset2",
|
22 |
+
"Camille Corot": "camille-corot_subset1",
|
23 |
+
"Claude Monet": "monet_subset2",
|
24 |
+
"Pablo Picasso": "picasso_subset1",
|
25 |
+
"Jackson Pollock": "jackson-pollock_subset1",
|
26 |
+
"Gerhard Richter": "gerhard-richter_subset1",
|
27 |
+
"M.C. Escher": "m.c.-escher_subset1",
|
28 |
+
"Albert Gleizes": "albert-gleizes_subset1",
|
29 |
+
"Hokusai": "katsushika-hokusai_subset1",
|
30 |
+
"Wassily Kandinsky": "kandinsky_subset1",
|
31 |
+
"Gustav Klimt": "klimt_subset3",
|
32 |
+
"Roy Lichtenstein": "roy-lichtenstein_subset1",
|
33 |
+
"Henri Matisse": "henri-matisse_subset1",
|
34 |
+
"Joan Miro": "joan-miro_subset2",
|
35 |
+
}
|
36 |
+
|
37 |
+
def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0, steps=50, guidance_scale=7.5):
|
38 |
+
adapter_path = lora_map[adapter_choice]
|
39 |
+
if adapter_path not in [None, "None"]:
|
40 |
+
adapter_path = f"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
|
41 |
+
|
42 |
+
prompts = [prompt]*samples
|
43 |
+
infer_loader = get_validation_dataloader(prompts)
|
44 |
+
network = get_lora_network(pipe.unet, adapter_path)["network"]
|
45 |
+
pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
|
46 |
+
height=512, width=512, scales=[1.0],
|
47 |
+
save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,
|
48 |
+
start_noise=-1, show=False, style_prompt="sks art", no_load=True,
|
49 |
+
from_scratch=True)[0][1.0]
|
50 |
+
return pred_images
|
51 |
+
|
52 |
+
def demo_inference_stylization(adapter_path:str, prompts:list, image:list, start_noise=800,seed:int=0):
|
53 |
+
infer_loader = get_validation_dataloader(prompts, image)
|
54 |
+
network = get_lora_network(pipe.unet, adapter_path,"all_up")["network"]
|
55 |
+
pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
|
56 |
+
height=512, width=512, scales=[0.,1.],
|
57 |
+
save_dir=None, seed=seed,steps=20, guidance_scale=7.5,
|
58 |
+
start_noise=start_noise, show=True, style_prompt="sks art", no_load=True,
|
59 |
+
from_scratch=False)
|
60 |
+
return pred_images
|
61 |
+
|
62 |
+
# def infer(prompt, samples, steps, scale, seed):
|
63 |
+
# generator = torch.Generator(device=device).manual_seed(seed)
|
64 |
+
# images_list = pipe( # type: ignore
|
65 |
+
# [prompt] * samples,
|
66 |
+
# num_inference_steps=steps,
|
67 |
+
# guidance_scale=scale,
|
68 |
+
# generator=generator,
|
69 |
+
# )
|
70 |
+
# images = []
|
71 |
+
# safe_image = Image.open(r"data/unsafe.png")
|
72 |
+
# print(images_list)
|
73 |
+
# for i, image in enumerate(images_list["images"]): # type: ignore
|
74 |
+
# if images_list["nsfw_content_detected"][i]: # type: ignore
|
75 |
+
# images.append(safe_image)
|
76 |
+
# else:
|
77 |
+
# images.append(image)
|
78 |
+
# return images
|
79 |
+
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
block = gr.Blocks()
|
84 |
+
# Direct infer
|
85 |
+
with block:
|
86 |
+
with gr.Group():
|
87 |
+
with gr.Row():
|
88 |
+
text = gr.Textbox(
|
89 |
+
label="Enter your prompt",
|
90 |
+
max_lines=2,
|
91 |
+
placeholder="Enter your prompt",
|
92 |
+
container=False,
|
93 |
+
value="Park with cherry blossom trees, picnicker’s and a clear blue pond.",
|
94 |
+
)
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
btn = gr.Button("Run", scale=0)
|
99 |
+
gallery = gr.Gallery(
|
100 |
+
label="Generated images",
|
101 |
+
show_label=False,
|
102 |
+
elem_id="gallery",
|
103 |
+
columns=[2],
|
104 |
+
)
|
105 |
+
|
106 |
+
advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
|
107 |
+
|
108 |
+
with gr.Row(elem_id="advanced-options"):
|
109 |
+
adapter_choice = gr.Dropdown(
|
110 |
+
label="Choose adapter",
|
111 |
+
choices=["None", "Andre Derain","Vincent van Gogh","Andy Warhol", "Walter Battiss",
|
112 |
+
"Camille Corot", "Claude Monet", "Pablo Picasso",
|
113 |
+
"Jackson Pollock", "Gerhard Richter", "M.C. Escher",
|
114 |
+
"Albert Gleizes", "Hokusai", "Wassily Kandinsky", "Gustav Klimt", "Roy Lichtenstein",
|
115 |
+
"Henri Matisse", "Joan Miro"
|
116 |
+
],
|
117 |
+
value="None"
|
118 |
+
)
|
119 |
+
# print(adapter_choice[0])
|
120 |
+
# lora_path = lora_map[adapter_choice.value]
|
121 |
+
# if lora_path is not None:
|
122 |
+
# lora_path = f"data/Art_adapters/{lora_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
|
123 |
+
|
124 |
+
samples = gr.Slider(label="Images", minimum=1, maximum=4, value=1, step=1)
|
125 |
+
steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=20, step=1)
|
126 |
+
scale = gr.Slider(
|
127 |
+
label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
|
128 |
+
)
|
129 |
+
print(scale)
|
130 |
+
seed = gr.Slider(
|
131 |
+
label="Seed",
|
132 |
+
minimum=0,
|
133 |
+
maximum=2147483647,
|
134 |
+
step=1,
|
135 |
+
randomize=True,
|
136 |
+
)
|
137 |
+
|
138 |
+
gr.on([text.submit, btn.click], demo_inference_gen, inputs=[adapter_choice, text, samples, seed, steps, scale], outputs=gallery)
|
139 |
+
advanced_button.click(
|
140 |
+
None,
|
141 |
+
[],
|
142 |
+
text,
|
143 |
+
)
|
144 |
+
|
145 |
+
|
146 |
+
|
147 |
+
block.launch()
|
hf_demo_test.ipynb
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "initial_id",
|
7 |
+
"metadata": {
|
8 |
+
"ExecuteTime": {
|
9 |
+
"end_time": "2024-12-09T09:44:30.641366Z",
|
10 |
+
"start_time": "2024-12-09T09:44:11.789050Z"
|
11 |
+
}
|
12 |
+
},
|
13 |
+
"outputs": [],
|
14 |
+
"source": [
|
15 |
+
"import os\n",
|
16 |
+
"\n",
|
17 |
+
"import gradio as gr\n",
|
18 |
+
"from diffusers import DiffusionPipeline\n",
|
19 |
+
"import matplotlib.pyplot as plt\n",
|
20 |
+
"import torch\n",
|
21 |
+
"from PIL import Image\n"
|
22 |
+
]
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"cell_type": "code",
|
26 |
+
"execution_count": 2,
|
27 |
+
"id": "ddf33e0d3abacc2c",
|
28 |
+
"metadata": {},
|
29 |
+
"outputs": [],
|
30 |
+
"source": [
|
31 |
+
"import sys\n",
|
32 |
+
"#append current path\n",
|
33 |
+
"sys.path.extend(\"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/release/hf_demo\")"
|
34 |
+
]
|
35 |
+
},
|
36 |
+
{
|
37 |
+
"cell_type": "code",
|
38 |
+
"execution_count": 3,
|
39 |
+
"id": "643e49fd601daf8f",
|
40 |
+
"metadata": {
|
41 |
+
"ExecuteTime": {
|
42 |
+
"end_time": "2024-12-09T09:44:35.790962Z",
|
43 |
+
"start_time": "2024-12-09T09:44:35.779496Z"
|
44 |
+
}
|
45 |
+
},
|
46 |
+
"outputs": [],
|
47 |
+
"source": [
|
48 |
+
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\""
|
49 |
+
]
|
50 |
+
},
|
51 |
+
{
|
52 |
+
"cell_type": "code",
|
53 |
+
"execution_count": 4,
|
54 |
+
"id": "e03aae2a4e5676dd",
|
55 |
+
"metadata": {
|
56 |
+
"ExecuteTime": {
|
57 |
+
"end_time": "2024-12-09T09:44:44.157412Z",
|
58 |
+
"start_time": "2024-12-09T09:44:37.138452Z"
|
59 |
+
}
|
60 |
+
},
|
61 |
+
"outputs": [
|
62 |
+
{
|
63 |
+
"name": "stderr",
|
64 |
+
"output_type": "stream",
|
65 |
+
"text": [
|
66 |
+
"/data/vision/torralba/selfmanaged/torralba/scratch/jomat/sam_dataset/miniforge3/envs/diffusion/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
|
67 |
+
" warnings.warn(\n"
|
68 |
+
]
|
69 |
+
},
|
70 |
+
{
|
71 |
+
"data": {
|
72 |
+
"application/vnd.jupyter.widget-view+json": {
|
73 |
+
"model_id": "9df8347307674ba8afb0250e23109aa1",
|
74 |
+
"version_major": 2,
|
75 |
+
"version_minor": 0
|
76 |
+
},
|
77 |
+
"text/plain": [
|
78 |
+
"Loading pipeline components...: 0%| | 0/7 [00:00<?, ?it/s]"
|
79 |
+
]
|
80 |
+
},
|
81 |
+
"metadata": {},
|
82 |
+
"output_type": "display_data"
|
83 |
+
}
|
84 |
+
],
|
85 |
+
"source": [
|
86 |
+
"pipe = DiffusionPipeline.from_pretrained(\"rhfeiyang/art-free-diffusion-v1\",).to(\"cuda\")\n",
|
87 |
+
"device = \"cuda\""
|
88 |
+
]
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"cell_type": "code",
|
92 |
+
"execution_count": 5,
|
93 |
+
"id": "83916bc68ff5d914",
|
94 |
+
"metadata": {
|
95 |
+
"ExecuteTime": {
|
96 |
+
"end_time": "2024-12-09T09:44:52.694399Z",
|
97 |
+
"start_time": "2024-12-09T09:44:44.210695Z"
|
98 |
+
}
|
99 |
+
},
|
100 |
+
"outputs": [],
|
101 |
+
"source": [
|
102 |
+
"from inference import get_lora_network, inference, get_validation_dataloader\n",
|
103 |
+
"lora_map = {\n",
|
104 |
+
" \"None\": \"None\",\n",
|
105 |
+
" \"Andre Derain\": \"andre-derain_subset1\",\n",
|
106 |
+
" \"Vincent van Gogh\": \"van_gogh_subset1\",\n",
|
107 |
+
" \"Andy Warhol\": \"andy_subset1\",\n",
|
108 |
+
" \"Walter Battiss\": \"walter-battiss_subset2\",\n",
|
109 |
+
" \"Camille Corot\": \"camille-corot_subset1\",\n",
|
110 |
+
" \"Claude Monet\": \"monet_subset2\",\n",
|
111 |
+
" \"Pablo Picasso\": \"picasso_subset1\",\n",
|
112 |
+
" \"Jackson Pollock\": \"jackson-pollock_subset1\",\n",
|
113 |
+
" \"Gerhard Richter\": \"gerhard-richter_subset1\",\n",
|
114 |
+
" \"M.C. Escher\": \"m.c.-escher_subset1\",\n",
|
115 |
+
" \"Albert Gleizes\": \"albert-gleizes_subset1\",\n",
|
116 |
+
" \"Hokusai\": \"katsushika-hokusai_subset1\",\n",
|
117 |
+
" \"Wassily Kandinsky\": \"kandinsky_subset1\",\n",
|
118 |
+
" \"Gustav Klimt\": \"klimt_subset3\",\n",
|
119 |
+
" \"Roy Lichtenstein\": \"roy-lichtenstein_subset1\",\n",
|
120 |
+
" \"Henri Matisse\": \"henri-matisse_subset1\",\n",
|
121 |
+
" \"Joan Miro\": \"joan-miro_subset2\",\n",
|
122 |
+
"}\n",
|
123 |
+
"\n",
|
124 |
+
"def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0, steps=50, guidance_scale=7.5):\n",
|
125 |
+
" adapter_path = lora_map[adapter_choice]\n",
|
126 |
+
" if adapter_path not in [None, \"None\"]:\n",
|
127 |
+
" adapter_path = f\"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt\"\n",
|
128 |
+
"\n",
|
129 |
+
" prompts = [prompt]*samples\n",
|
130 |
+
" infer_loader = get_validation_dataloader(prompts)\n",
|
131 |
+
" network = get_lora_network(pipe.unet, adapter_path)[\"network\"]\n",
|
132 |
+
" pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n",
|
133 |
+
" height=512, width=512, scales=[1.0],\n",
|
134 |
+
" save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,\n",
|
135 |
+
" start_noise=-1, show=False, style_prompt=\"sks art\", no_load=True,\n",
|
136 |
+
" from_scratch=True)[0][1.0]\n",
|
137 |
+
" return pred_images\n",
|
138 |
+
"\n",
|
139 |
+
"def demo_inference_stylization(adapter_path:str, prompts:list, image:list, start_noise=800,seed:int=0):\n",
|
140 |
+
" infer_loader = get_validation_dataloader(prompts, image)\n",
|
141 |
+
" network = get_lora_network(pipe.unet, adapter_path,\"all_up\")[\"network\"]\n",
|
142 |
+
" pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n",
|
143 |
+
" height=512, width=512, scales=[0.,1.],\n",
|
144 |
+
" save_dir=None, seed=seed,steps=20, guidance_scale=7.5,\n",
|
145 |
+
" start_noise=start_noise, show=True, style_prompt=\"sks art\", no_load=True,\n",
|
146 |
+
" from_scratch=False)\n",
|
147 |
+
" return pred_images\n",
|
148 |
+
"\n",
|
149 |
+
"# def infer(prompt, samples, steps, scale, seed):\n",
|
150 |
+
"# generator = torch.Generator(device=device).manual_seed(seed)\n",
|
151 |
+
"# images_list = pipe( # type: ignore\n",
|
152 |
+
"# [prompt] * samples,\n",
|
153 |
+
"# num_inference_steps=steps,\n",
|
154 |
+
"# guidance_scale=scale,\n",
|
155 |
+
"# generator=generator,\n",
|
156 |
+
"# )\n",
|
157 |
+
"# images = []\n",
|
158 |
+
"# safe_image = Image.open(r\"data/unsafe.png\")\n",
|
159 |
+
"# print(images_list)\n",
|
160 |
+
"# for i, image in enumerate(images_list[\"images\"]): # type: ignore\n",
|
161 |
+
"# if images_list[\"nsfw_content_detected\"][i]: # type: ignore\n",
|
162 |
+
"# images.append(safe_image)\n",
|
163 |
+
"# else:\n",
|
164 |
+
"# images.append(image)\n",
|
165 |
+
"# return images\n"
|
166 |
+
]
|
167 |
+
},
|
168 |
+
{
|
169 |
+
"cell_type": "code",
|
170 |
+
"execution_count": 6,
|
171 |
+
"id": "aa33e9d104023847",
|
172 |
+
"metadata": {
|
173 |
+
"ExecuteTime": {
|
174 |
+
"end_time": "2024-12-09T12:09:39.339583Z",
|
175 |
+
"start_time": "2024-12-09T12:09:38.953936Z"
|
176 |
+
}
|
177 |
+
},
|
178 |
+
"outputs": [
|
179 |
+
{
|
180 |
+
"name": "stdout",
|
181 |
+
"output_type": "stream",
|
182 |
+
"text": [
|
183 |
+
"<gradio.components.slider.Slider object at 0x7fa12d3a5280>\n",
|
184 |
+
"Running on local URL: http://127.0.0.1:7876\n",
|
185 |
+
"Running on public URL: https://be7cce8fec75395c82.gradio.live\n",
|
186 |
+
"\n",
|
187 |
+
"This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
|
188 |
+
]
|
189 |
+
},
|
190 |
+
{
|
191 |
+
"data": {
|
192 |
+
"text/html": [
|
193 |
+
"<div><iframe src=\"https://be7cce8fec75395c82.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
|
194 |
+
],
|
195 |
+
"text/plain": [
|
196 |
+
"<IPython.core.display.HTML object>"
|
197 |
+
]
|
198 |
+
},
|
199 |
+
"metadata": {},
|
200 |
+
"output_type": "display_data"
|
201 |
+
},
|
202 |
+
{
|
203 |
+
"data": {
|
204 |
+
"text/plain": []
|
205 |
+
},
|
206 |
+
"execution_count": 6,
|
207 |
+
"metadata": {},
|
208 |
+
"output_type": "execute_result"
|
209 |
+
},
|
210 |
+
{
|
211 |
+
"name": "stdout",
|
212 |
+
"output_type": "stream",
|
213 |
+
"text": [
|
214 |
+
"Train method: None\n",
|
215 |
+
"Rank: 1, Alpha: 1\n",
|
216 |
+
"create LoRA for U-Net: 0 modules.\n",
|
217 |
+
"save dir: None\n",
|
218 |
+
"['Park with cherry blossom trees, picnicker’s and a clear blue pond in the style of sks art'], seed=949192390\n"
|
219 |
+
]
|
220 |
+
},
|
221 |
+
{
|
222 |
+
"name": "stderr",
|
223 |
+
"output_type": "stream",
|
224 |
+
"text": [
|
225 |
+
"/data/vision/torralba/selfmanaged/torralba/scratch/jomat/sam_dataset/miniforge3/envs/diffusion/lib/python3.9/site-packages/torch/nn/modules/conv.py:456: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at /opt/conda/conda-bld/pytorch_1712608883701/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)\n",
|
226 |
+
" return F.conv2d(input, weight, bias, self.stride,\n",
|
227 |
+
"\n",
|
228 |
+
"00%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [00:03<00:00, 6.90it/s]"
|
229 |
+
]
|
230 |
+
},
|
231 |
+
{
|
232 |
+
"name": "stdout",
|
233 |
+
"output_type": "stream",
|
234 |
+
"text": [
|
235 |
+
"Time taken for one batch, Art Adapter scale=1.0: 3.2747044563293457\n"
|
236 |
+
]
|
237 |
+
}
|
238 |
+
],
|
239 |
+
"source": [
|
240 |
+
"block = gr.Blocks()\n",
|
241 |
+
"# Direct infer\n",
|
242 |
+
"with block:\n",
|
243 |
+
" with gr.Group():\n",
|
244 |
+
" with gr.Row():\n",
|
245 |
+
" text = gr.Textbox(\n",
|
246 |
+
" label=\"Enter your prompt\",\n",
|
247 |
+
" max_lines=2,\n",
|
248 |
+
" placeholder=\"Enter your prompt\",\n",
|
249 |
+
" container=False,\n",
|
250 |
+
" value=\"Park with cherry blossom trees, picnicker’s and a clear blue pond.\",\n",
|
251 |
+
" )\n",
|
252 |
+
" \n",
|
253 |
+
"\n",
|
254 |
+
" \n",
|
255 |
+
" btn = gr.Button(\"Run\", scale=0)\n",
|
256 |
+
" gallery = gr.Gallery(\n",
|
257 |
+
" label=\"Generated images\",\n",
|
258 |
+
" show_label=False,\n",
|
259 |
+
" elem_id=\"gallery\",\n",
|
260 |
+
" columns=[2],\n",
|
261 |
+
" )\n",
|
262 |
+
"\n",
|
263 |
+
" advanced_button = gr.Button(\"Advanced options\", elem_id=\"advanced-btn\")\n",
|
264 |
+
"\n",
|
265 |
+
" with gr.Row(elem_id=\"advanced-options\"):\n",
|
266 |
+
" adapter_choice = gr.Dropdown(\n",
|
267 |
+
" label=\"Choose adapter\",\n",
|
268 |
+
" choices=[\"None\", \"Andre Derain\",\"Vincent van Gogh\",\"Andy Warhol\", \"Walter Battiss\",\n",
|
269 |
+
" \"Camille Corot\", \"Claude Monet\", \"Pablo Picasso\",\n",
|
270 |
+
" \"Jackson Pollock\", \"Gerhard Richter\", \"M.C. Escher\",\n",
|
271 |
+
" \"Albert Gleizes\", \"Hokusai\", \"Wassily Kandinsky\", \"Gustav Klimt\", \"Roy Lichtenstein\",\n",
|
272 |
+
" \"Henri Matisse\", \"Joan Miro\"\n",
|
273 |
+
" ],\n",
|
274 |
+
" value=\"None\"\n",
|
275 |
+
" )\n",
|
276 |
+
" # print(adapter_choice[0])\n",
|
277 |
+
" # lora_path = lora_map[adapter_choice.value]\n",
|
278 |
+
" # if lora_path is not None:\n",
|
279 |
+
" # lora_path = f\"data/Art_adapters/{lora_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt\"\n",
|
280 |
+
"\n",
|
281 |
+
" samples = gr.Slider(label=\"Images\", minimum=1, maximum=4, value=1, step=1)\n",
|
282 |
+
" steps = gr.Slider(label=\"Steps\", minimum=1, maximum=50, value=20, step=1)\n",
|
283 |
+
" scale = gr.Slider(\n",
|
284 |
+
" label=\"Guidance Scale\", minimum=0, maximum=50, value=7.5, step=0.1\n",
|
285 |
+
" )\n",
|
286 |
+
" print(scale)\n",
|
287 |
+
" seed = gr.Slider(\n",
|
288 |
+
" label=\"Seed\",\n",
|
289 |
+
" minimum=0,\n",
|
290 |
+
" maximum=2147483647,\n",
|
291 |
+
" step=1,\n",
|
292 |
+
" randomize=True,\n",
|
293 |
+
" )\n",
|
294 |
+
"\n",
|
295 |
+
" gr.on([text.submit, btn.click], demo_inference_gen, inputs=[adapter_choice, text, samples, seed, steps, scale], outputs=gallery)\n",
|
296 |
+
" advanced_button.click(\n",
|
297 |
+
" None,\n",
|
298 |
+
" [],\n",
|
299 |
+
" text,\n",
|
300 |
+
" )\n",
|
301 |
+
"\n",
|
302 |
+
"\n",
|
303 |
+
"block.launch(share=True)"
|
304 |
+
]
|
305 |
+
},
|
306 |
+
{
|
307 |
+
"cell_type": "code",
|
308 |
+
"execution_count": null,
|
309 |
+
"id": "3239c12167a5f2cd",
|
310 |
+
"metadata": {},
|
311 |
+
"outputs": [],
|
312 |
+
"source": []
|
313 |
+
}
|
314 |
+
],
|
315 |
+
"metadata": {
|
316 |
+
"kernelspec": {
|
317 |
+
"display_name": "Python 3 (ipykernel)",
|
318 |
+
"language": "python",
|
319 |
+
"name": "python3"
|
320 |
+
},
|
321 |
+
"language_info": {
|
322 |
+
"codemirror_mode": {
|
323 |
+
"name": "ipython",
|
324 |
+
"version": 3
|
325 |
+
},
|
326 |
+
"file_extension": ".py",
|
327 |
+
"mimetype": "text/x-python",
|
328 |
+
"name": "python",
|
329 |
+
"nbconvert_exporter": "python",
|
330 |
+
"pygments_lexer": "ipython3",
|
331 |
+
"version": "3.9.18"
|
332 |
+
}
|
333 |
+
},
|
334 |
+
"nbformat": 4,
|
335 |
+
"nbformat_minor": 5
|
336 |
+
}
|
inference.py
ADDED
@@ -0,0 +1,657 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Authors: Hui Ren (rhfeiyang.github.io)
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
import argparse
|
5 |
+
import os, json, random
|
6 |
+
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import glob, re
|
9 |
+
|
10 |
+
from tqdm import tqdm
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
import sys
|
14 |
+
import gc
|
15 |
+
from transformers import CLIPTextModel, CLIPTokenizer, BertModel, BertTokenizer
|
16 |
+
|
17 |
+
# import train_util
|
18 |
+
|
19 |
+
from utils.train_util import get_noisy_image, encode_prompts
|
20 |
+
|
21 |
+
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler, DDIMScheduler, PNDMScheduler
|
22 |
+
|
23 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
24 |
+
from utils.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
|
25 |
+
import argparse
|
26 |
+
# from diffusers.training_utils import EMAModel
|
27 |
+
import shutil
|
28 |
+
import yaml
|
29 |
+
from easydict import EasyDict
|
30 |
+
from utils.metrics import StyleContentMetric
|
31 |
+
from torchvision import transforms
|
32 |
+
|
33 |
+
from custom_datasets.coco import CustomCocoCaptions
|
34 |
+
from custom_datasets.imagepair import ImageSet
|
35 |
+
from custom_datasets import get_dataset
|
36 |
+
# from stable_diffusion.utils.modules import get_diffusion_modules
|
37 |
+
# from diffusers import StableDiffusionImg2ImgPipeline
|
38 |
+
from diffusers.utils.torch_utils import randn_tensor
|
39 |
+
import pickle
|
40 |
+
import time
|
41 |
+
def flush():
|
42 |
+
torch.cuda.empty_cache()
|
43 |
+
gc.collect()
|
44 |
+
|
45 |
+
def get_train_method(lora_weight):
|
46 |
+
if lora_weight is None:
|
47 |
+
return 'None'
|
48 |
+
if 'full' in lora_weight:
|
49 |
+
train_method = 'full'
|
50 |
+
elif "down_1_up_2_attn" in lora_weight:
|
51 |
+
train_method = 'up_2_attn'
|
52 |
+
print(f"Using up_2_attn for {lora_weight}")
|
53 |
+
elif "down_2_up_1_up_2_attn" in lora_weight:
|
54 |
+
train_method = 'down_2_up_2_attn'
|
55 |
+
elif "down_2_up_2_attn" in lora_weight:
|
56 |
+
train_method = 'down_2_up_2_attn'
|
57 |
+
elif "down_2_attn" in lora_weight:
|
58 |
+
train_method = 'down_2_attn'
|
59 |
+
elif 'noxattn' in lora_weight:
|
60 |
+
train_method = 'noxattn'
|
61 |
+
elif "xattn" in lora_weight:
|
62 |
+
train_method = 'xattn'
|
63 |
+
elif "attn" in lora_weight:
|
64 |
+
train_method = 'attn'
|
65 |
+
elif "all_up" in lora_weight:
|
66 |
+
train_method = 'all_up'
|
67 |
+
else:
|
68 |
+
train_method = 'None'
|
69 |
+
return train_method
|
70 |
+
|
71 |
+
def get_validation_dataloader(infer_prompts:list[str]=None, infer_images :list[str]=None,resolution=512, batch_size=10, num_workers=4, val_set="laion_pop500"):
|
72 |
+
data_transforms = transforms.Compose(
|
73 |
+
[
|
74 |
+
transforms.Resize(resolution),
|
75 |
+
transforms.CenterCrop(resolution),
|
76 |
+
]
|
77 |
+
)
|
78 |
+
def preprocess(example):
|
79 |
+
ret={}
|
80 |
+
ret["image"] = data_transforms(example["image"]) if "image" in example else None
|
81 |
+
if "caption" in example:
|
82 |
+
if isinstance(example["caption"][0], list):
|
83 |
+
ret["caption"] = example["caption"][0][0]
|
84 |
+
else:
|
85 |
+
ret["caption"] = example["caption"][0]
|
86 |
+
if "seed" in example:
|
87 |
+
ret["seed"] = example["seed"]
|
88 |
+
if "id" in example:
|
89 |
+
ret["id"] = example["id"]
|
90 |
+
if "path" in example:
|
91 |
+
ret["path"] = example["path"]
|
92 |
+
return ret
|
93 |
+
|
94 |
+
def collate_fn(examples):
|
95 |
+
out = {}
|
96 |
+
if "image" in examples[0]:
|
97 |
+
pixel_values = [example["image"] for example in examples]
|
98 |
+
out["pixel_values"] = pixel_values
|
99 |
+
# notice: only take the first prompt for each image
|
100 |
+
if "caption" in examples[0]:
|
101 |
+
prompts = [example["caption"] for example in examples]
|
102 |
+
out["prompts"] = prompts
|
103 |
+
if "seed" in examples[0]:
|
104 |
+
seeds = [example["seed"] for example in examples]
|
105 |
+
out["seed"] = seeds
|
106 |
+
if "path" in examples[0]:
|
107 |
+
paths = [example["path"] for example in examples]
|
108 |
+
out["path"] = paths
|
109 |
+
return out
|
110 |
+
if infer_prompts is None:
|
111 |
+
if val_set == "lhq500":
|
112 |
+
dataset = get_dataset("lhq_sub500", get_val=False)["train"]
|
113 |
+
elif val_set == "custom_coco100":
|
114 |
+
dataset = get_dataset("custom_coco100", get_val=False)["train"]
|
115 |
+
elif val_set == "custom_coco500":
|
116 |
+
dataset = get_dataset("custom_coco500", get_val=False)["train"]
|
117 |
+
|
118 |
+
elif os.path.isdir(val_set):
|
119 |
+
image_folder = os.path.join(val_set, "paintings")
|
120 |
+
caption_folder = os.path.join(val_set, "captions")
|
121 |
+
dataset = ImageSet(folder=image_folder, caption=caption_folder, keep_in_mem=True)
|
122 |
+
elif "custom_caption" in val_set:
|
123 |
+
from custom_datasets.custom_caption import Caption_set
|
124 |
+
name = val_set.replace("custom_caption_", "")
|
125 |
+
dataset = Caption_set(set_name = name)
|
126 |
+
elif val_set == "laion_pop500":
|
127 |
+
dataset = get_dataset("laion_pop500", get_val=False)["train"]
|
128 |
+
elif val_set == "laion_pop500_first_sentence":
|
129 |
+
dataset = get_dataset("laion_pop500_first_sentence", get_val=False)["train"]
|
130 |
+
else:
|
131 |
+
raise ValueError("Unknown dataset")
|
132 |
+
dataset.with_transform(preprocess)
|
133 |
+
elif isinstance(infer_prompts, torch.utils.data.Dataset):
|
134 |
+
dataset = infer_prompts
|
135 |
+
try:
|
136 |
+
dataset.with_transform(preprocess)
|
137 |
+
except:
|
138 |
+
pass
|
139 |
+
|
140 |
+
else:
|
141 |
+
class Dataset(torch.utils.data.Dataset):
|
142 |
+
def __init__(self, prompts, images=None):
|
143 |
+
self.prompts = prompts
|
144 |
+
self.images = images
|
145 |
+
self.get_img = False
|
146 |
+
if images is not None:
|
147 |
+
assert len(prompts) == len(images)
|
148 |
+
self.get_img = True
|
149 |
+
if isinstance(images[0], str):
|
150 |
+
self.images = [Image.open(image).convert("RGB") for image in images]
|
151 |
+
else:
|
152 |
+
self.images = [None] * len(prompts)
|
153 |
+
def __len__(self):
|
154 |
+
return len(self.prompts)
|
155 |
+
def __getitem__(self, idx):
|
156 |
+
img = self.images[idx]
|
157 |
+
if self.get_img and img is not None:
|
158 |
+
img = data_transforms(img)
|
159 |
+
return {"caption": self.prompts[idx], "image":img}
|
160 |
+
dataset = Dataset(infer_prompts, infer_images)
|
161 |
+
|
162 |
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, drop_last=False,
|
163 |
+
num_workers=num_workers, pin_memory=True)
|
164 |
+
return dataloader
|
165 |
+
|
166 |
+
def get_lora_network(unet , lora_path, train_method="None", rank=1, alpha=1.0, device="cuda", weight_dtype=torch.float32):
|
167 |
+
if train_method in [None, "None"]:
|
168 |
+
train_method = get_train_method(lora_path)
|
169 |
+
print(f"Train method: {train_method}")
|
170 |
+
|
171 |
+
network_type = "c3lier"
|
172 |
+
if train_method == 'xattn':
|
173 |
+
network_type = 'lierla'
|
174 |
+
|
175 |
+
modules = DEFAULT_TARGET_REPLACE
|
176 |
+
if network_type == "c3lier":
|
177 |
+
modules += UNET_TARGET_REPLACE_MODULE_CONV
|
178 |
+
|
179 |
+
alpha = 1
|
180 |
+
if "rank" in lora_path:
|
181 |
+
rank = int(re.search(r'rank(\d+)', lora_path).group(1))
|
182 |
+
if 'alpha1' in lora_path:
|
183 |
+
alpha = 1.0
|
184 |
+
print(f"Rank: {rank}, Alpha: {alpha}")
|
185 |
+
|
186 |
+
network = LoRANetwork(
|
187 |
+
unet,
|
188 |
+
rank=rank,
|
189 |
+
multiplier=1.0,
|
190 |
+
alpha=alpha,
|
191 |
+
train_method=train_method,
|
192 |
+
).to(device, dtype=weight_dtype)
|
193 |
+
if lora_path not in [None, "None"]:
|
194 |
+
lora_state_dict = torch.load(lora_path)
|
195 |
+
miss = network.load_state_dict(lora_state_dict, strict=False)
|
196 |
+
print(f"Missing: {miss}")
|
197 |
+
ret = {"network": network, "train_method": train_method}
|
198 |
+
return ret
|
199 |
+
|
200 |
+
def get_model(pretrained_ckpt_path, unet_ckpt=None,revision=None, variant=None, lora_path=None, weight_dtype=torch.float32,
|
201 |
+
device="cuda"):
|
202 |
+
modules = {}
|
203 |
+
pipe = DiffusionPipeline.from_pretrained(pretrained_ckpt_path, revision=revision, variant=variant)
|
204 |
+
if unet_ckpt is not None:
|
205 |
+
pipe.unet.from_pretrained(unet_ckpt, subfolder="unet_ema", revision=revision, variant=variant)
|
206 |
+
unet = pipe.unet
|
207 |
+
vae = pipe.vae
|
208 |
+
text_encoder = pipe.text_encoder
|
209 |
+
tokenizer = pipe.tokenizer
|
210 |
+
modules["unet"] = unet
|
211 |
+
modules["vae"] = vae
|
212 |
+
modules["text_encoder"] = text_encoder
|
213 |
+
modules["tokenizer"] = tokenizer
|
214 |
+
# tokenizer = modules["tokenizer"]
|
215 |
+
|
216 |
+
unet.enable_xformers_memory_efficient_attention()
|
217 |
+
unet.to(device, dtype=weight_dtype)
|
218 |
+
if weight_dtype != torch.bfloat16:
|
219 |
+
vae.to(device, dtype=torch.float32)
|
220 |
+
else:
|
221 |
+
vae.to(device, dtype=weight_dtype)
|
222 |
+
text_encoder.to(device, dtype=weight_dtype)
|
223 |
+
|
224 |
+
if lora_path is not None:
|
225 |
+
network = get_lora_network(unet, lora_path, device=device, weight_dtype=weight_dtype)
|
226 |
+
modules["network"] = network
|
227 |
+
return modules
|
228 |
+
|
229 |
+
|
230 |
+
|
231 |
+
@torch.no_grad()
|
232 |
+
def inference(network: LoRANetwork, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, vae: AutoencoderKL, unet: UNet2DConditionModel, noise_scheduler: LMSDiscreteScheduler,
|
233 |
+
dataloader, height:int, width:int, scales:list = np.linspace(0,2,5),save_dir:str=None, seed:int = None,
|
234 |
+
weight_dtype: torch.dtype = torch.float32, device: torch.device="cuda", batch_size:int=1, steps:int=50, guidance_scale:float=7.5, start_noise:int=800,
|
235 |
+
uncond_prompt:str=None, uncond_embed=None, style_prompt = None, show:bool = False, no_load:bool=False, from_scratch=False):
|
236 |
+
print(f"save dir: {save_dir}")
|
237 |
+
if start_noise < 0:
|
238 |
+
assert from_scratch
|
239 |
+
network = network.eval()
|
240 |
+
unet = unet.eval()
|
241 |
+
vae = vae.eval()
|
242 |
+
do_convert = not from_scratch
|
243 |
+
|
244 |
+
if not do_convert:
|
245 |
+
try:
|
246 |
+
dataloader.dataset.get_img = False
|
247 |
+
except:
|
248 |
+
pass
|
249 |
+
scales = list(scales)
|
250 |
+
else:
|
251 |
+
scales = ["Real Image"] + list(scales)
|
252 |
+
|
253 |
+
if not no_load and os.path.exists(os.path.join(save_dir, "infer_imgs.pickle")):
|
254 |
+
with open(os.path.join(save_dir, "infer_imgs.pickle"), 'rb') as f:
|
255 |
+
pred_images = pickle.load(f)
|
256 |
+
take=True
|
257 |
+
for key in scales:
|
258 |
+
if key not in pred_images:
|
259 |
+
take=False
|
260 |
+
break
|
261 |
+
if take:
|
262 |
+
print(f"Found existing inference results in {save_dir}", flush=True)
|
263 |
+
return pred_images
|
264 |
+
|
265 |
+
max_length = tokenizer.model_max_length
|
266 |
+
|
267 |
+
pred_images = {scale :[] for scale in scales}
|
268 |
+
all_seeds = {scale:[] for scale in scales}
|
269 |
+
|
270 |
+
prompts = []
|
271 |
+
ori_prompts = []
|
272 |
+
if save_dir is not None:
|
273 |
+
img_output_dir = os.path.join(save_dir, "outputs")
|
274 |
+
os.makedirs(img_output_dir, exist_ok=True)
|
275 |
+
|
276 |
+
if uncond_embed is None:
|
277 |
+
if uncond_prompt is None:
|
278 |
+
uncond_input_text = [""]
|
279 |
+
else:
|
280 |
+
uncond_input_text = [uncond_prompt]
|
281 |
+
uncond_embed = encode_prompts(tokenizer = tokenizer, text_encoder = text_encoder, prompts = uncond_input_text)
|
282 |
+
|
283 |
+
|
284 |
+
for batch in dataloader:
|
285 |
+
ori_prompt = batch["prompts"]
|
286 |
+
image = batch["pixel_values"] if do_convert else None
|
287 |
+
if do_convert:
|
288 |
+
pred_images["Real Image"] += image
|
289 |
+
if isinstance(ori_prompt, list):
|
290 |
+
if isinstance(text_encoder, CLIPTextModel):
|
291 |
+
# trunc prompts for clip encoder
|
292 |
+
ori_prompt = [p.split(".")[0]+"." for p in ori_prompt]
|
293 |
+
prompt = [f"{p.strip()[::-1].replace('.', '',1)[::-1]} in the style of {style_prompt}" for p in ori_prompt] if style_prompt is not None else ori_prompt
|
294 |
+
else:
|
295 |
+
if isinstance(text_encoder, CLIPTextModel):
|
296 |
+
ori_prompt = ori_prompt.split(".")[0]+"."
|
297 |
+
prompt = f"{prompt.strip()[::-1].replace('.', '',1)[::-1]} in the style of {style_prompt}" if style_prompt is not None else ori_prompt
|
298 |
+
|
299 |
+
bcz = len(prompt)
|
300 |
+
single_seed = seed
|
301 |
+
if dataloader.batch_size == 1 and seed is None:
|
302 |
+
if "seed" in batch:
|
303 |
+
single_seed = batch["seed"][0]
|
304 |
+
|
305 |
+
print(f"{prompt}, seed={single_seed}")
|
306 |
+
|
307 |
+
# text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt").to(device)
|
308 |
+
# original_embeddings = text_encoder(**text_input)[0]
|
309 |
+
|
310 |
+
prompts += prompt
|
311 |
+
ori_prompts += ori_prompt
|
312 |
+
# style_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt").to(device)
|
313 |
+
# # style_embeddings = text_encoder(**style_input)[0]
|
314 |
+
# style_embeddings = text_encoder(style_input.input_ids, return_dict=False)[0]
|
315 |
+
|
316 |
+
style_embeddings = encode_prompts(tokenizer = tokenizer, text_encoder = text_encoder, prompts = prompt)
|
317 |
+
original_embeddings = encode_prompts(tokenizer = tokenizer, text_encoder = text_encoder, prompts = ori_prompt)
|
318 |
+
if uncond_embed.shape[0] == 1 and bcz > 1:
|
319 |
+
uncond_embeddings = uncond_embed.repeat(bcz, 1, 1)
|
320 |
+
else:
|
321 |
+
uncond_embeddings = uncond_embed
|
322 |
+
style_text_embeddings = torch.cat([uncond_embeddings, style_embeddings])
|
323 |
+
original_embeddings = torch.cat([uncond_embeddings, original_embeddings])
|
324 |
+
|
325 |
+
generator = torch.manual_seed(single_seed) if single_seed is not None else None
|
326 |
+
noise_scheduler.set_timesteps(steps)
|
327 |
+
if do_convert:
|
328 |
+
noised_latent, _, _ = get_noisy_image(image, vae, generator, unet, noise_scheduler, total_timesteps=int((1000-start_noise)/1000 *steps))
|
329 |
+
else:
|
330 |
+
latent_shape = (bcz, 4, height//8, width//8)
|
331 |
+
noised_latent = randn_tensor(latent_shape, generator=generator, device=vae.device)
|
332 |
+
noised_latent = noised_latent.to(unet.dtype)
|
333 |
+
noised_latent = noised_latent * noise_scheduler.init_noise_sigma
|
334 |
+
for scale in scales:
|
335 |
+
start_time = time.time()
|
336 |
+
if not isinstance(scale, float) and not isinstance(scale, int):
|
337 |
+
continue
|
338 |
+
|
339 |
+
latents = noised_latent.clone().to(weight_dtype).to(device)
|
340 |
+
noise_scheduler.set_timesteps(steps)
|
341 |
+
for t in tqdm(noise_scheduler.timesteps):
|
342 |
+
if do_convert and t>start_noise:
|
343 |
+
continue
|
344 |
+
else:
|
345 |
+
if t > start_noise and start_noise >= 0:
|
346 |
+
current_scale = 0
|
347 |
+
else:
|
348 |
+
current_scale = scale
|
349 |
+
network.set_lora_slider(scale=current_scale)
|
350 |
+
text_embedding = style_text_embeddings
|
351 |
+
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
352 |
+
latent_model_input = torch.cat([latents] * 2)
|
353 |
+
|
354 |
+
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
|
355 |
+
# predict the noise residual
|
356 |
+
with network:
|
357 |
+
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embedding).sample
|
358 |
+
|
359 |
+
# perform guidance
|
360 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
361 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
362 |
+
|
363 |
+
# compute the previous noisy sample x_t -> x_t-1
|
364 |
+
if isinstance(noise_scheduler, DDPMScheduler):
|
365 |
+
latents = noise_scheduler.step(noise_pred, t, latents, generator=torch.manual_seed(single_seed+t) if single_seed is not None else None).prev_sample
|
366 |
+
else:
|
367 |
+
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
|
368 |
+
|
369 |
+
# scale and decode the image latents with vae
|
370 |
+
latents = 1 / 0.18215 * latents.to(vae.dtype)
|
371 |
+
|
372 |
+
|
373 |
+
with torch.no_grad():
|
374 |
+
image = vae.decode(latents).sample
|
375 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
376 |
+
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
377 |
+
images = (image * 255).round().astype("uint8")
|
378 |
+
|
379 |
+
|
380 |
+
pil_images = [Image.fromarray(image) for image in images]
|
381 |
+
pred_images[scale]+=pil_images
|
382 |
+
all_seeds[scale] += [single_seed] * bcz
|
383 |
+
|
384 |
+
end_time = time.time()
|
385 |
+
print(f"Time taken for one batch, Art Adapter scale={scale}: {end_time-start_time}", flush=True)
|
386 |
+
|
387 |
+
if save_dir is not None or show:
|
388 |
+
end_idx = len(list(pred_images.values())[0])
|
389 |
+
for i in range(end_idx-bcz, end_idx):
|
390 |
+
keys = list(pred_images.keys())
|
391 |
+
images_list = [pred_images[key][i] for key in keys]
|
392 |
+
prompt = prompts[i]
|
393 |
+
if len(scales)==1:
|
394 |
+
plt.imshow(images_list[0])
|
395 |
+
plt.axis('off')
|
396 |
+
plt.title(f"{prompt}_{single_seed}_start{start_noise}", fontsize=20)
|
397 |
+
else:
|
398 |
+
fig, ax = plt.subplots(1, len(images_list), figsize=(len(scales)*5,6), layout="constrained")
|
399 |
+
for id, a in enumerate(ax):
|
400 |
+
a.imshow(images_list[id])
|
401 |
+
if isinstance(scales[id], float) or isinstance(scales[id], int):
|
402 |
+
a.set_title(f"Art Adapter scale={scales[id]}", fontsize=20)
|
403 |
+
else:
|
404 |
+
a.set_title(f"{keys[id]}", fontsize=20)
|
405 |
+
a.axis('off')
|
406 |
+
|
407 |
+
# plt.suptitle(f"{os.path.basename(lora_weight).replace('.pt','')}", fontsize=20)
|
408 |
+
|
409 |
+
# plt.tight_layout()
|
410 |
+
# if do_convert:
|
411 |
+
# plt.suptitle(f"{prompt}\nseed{single_seed}_start{start_noise}_guidance{guidance_scale}", fontsize=20)
|
412 |
+
# else:
|
413 |
+
# plt.suptitle(f"{prompt}\nseed{single_seed}_from_scratch_guidance{guidance_scale}", fontsize=20)
|
414 |
+
|
415 |
+
if save_dir is not None:
|
416 |
+
plt.savefig(f"{img_output_dir}/{prompt.replace(' ', '_')[:100]}_seed{single_seed}_start{start_noise}.png")
|
417 |
+
if show:
|
418 |
+
plt.show()
|
419 |
+
plt.close()
|
420 |
+
|
421 |
+
flush()
|
422 |
+
|
423 |
+
if save_dir is not None:
|
424 |
+
with open(os.path.join(save_dir, "infer_imgs.pickle" ), 'wb') as f:
|
425 |
+
pickle.dump(pred_images, f)
|
426 |
+
with open(os.path.join(save_dir, "all_seeds.pickle"), 'wb') as f:
|
427 |
+
to_save={"all_seeds":all_seeds, "batch_size":batch_size}
|
428 |
+
pickle.dump(to_save, f)
|
429 |
+
for scale, images in pred_images.items():
|
430 |
+
subfolder = os.path.join(save_dir,"images", f"{scale}")
|
431 |
+
os.makedirs(subfolder, exist_ok=True)
|
432 |
+
|
433 |
+
used_prompt = ori_prompts
|
434 |
+
if (isinstance(scale, float) or isinstance(scale, int)): #and scale != 0:
|
435 |
+
used_prompt = prompts
|
436 |
+
for i, image in enumerate(images):
|
437 |
+
if scale == "Real Image":
|
438 |
+
suffix = ""
|
439 |
+
else:
|
440 |
+
suffix = f"_seed{all_seeds[scale][i]}"
|
441 |
+
image.save(os.path.join(subfolder, f"{used_prompt[i].replace(' ', '_')[:100]}{suffix}.jpg"))
|
442 |
+
with open(os.path.join(save_dir, "infer_prompts.txt"), 'w') as f:
|
443 |
+
for prompt in prompts:
|
444 |
+
f.write(f"{prompt}\n")
|
445 |
+
with open(os.path.join(save_dir, "ori_prompts.txt"), 'w') as f:
|
446 |
+
for prompt in ori_prompts:
|
447 |
+
f.write(f"{prompt}\n")
|
448 |
+
print(f"Saved inference results to {save_dir}", flush=True)
|
449 |
+
return pred_images, prompts
|
450 |
+
|
451 |
+
@torch.no_grad()
|
452 |
+
def infer_metric(ref_image_folder,pred_images, prompts, save_dir, start_noise=""):
|
453 |
+
prompts = [prompt.split(" in the style of ")[0] for prompt in prompts]
|
454 |
+
scores = {}
|
455 |
+
original_images = pred_images["Real Image"] if "Real Image" in pred_images else None
|
456 |
+
metric = StyleContentMetric(ref_image_folder)
|
457 |
+
for scale, images in pred_images.items():
|
458 |
+
score = metric(images, original_images, prompts)
|
459 |
+
|
460 |
+
scores[scale] = score
|
461 |
+
print(f"Style transfer score at scale {scale}: {score}")
|
462 |
+
scores["ref_path"] = ref_image_folder
|
463 |
+
save_name = f"scores_start{start_noise}.json"
|
464 |
+
os.makedirs(save_dir, exist_ok=True)
|
465 |
+
with open(os.path.join(save_dir, save_name), 'w') as f:
|
466 |
+
json.dump(scores, f, indent=2)
|
467 |
+
return scores
|
468 |
+
|
469 |
+
def parse_args():
|
470 |
+
parser = argparse.ArgumentParser(description='Inference with LoRA')
|
471 |
+
parser.add_argument('--lora_weights', type=str, default=["None"],
|
472 |
+
nargs='+', help='path to your model file')
|
473 |
+
parser.add_argument('--prompts', type=str, default=[],
|
474 |
+
nargs='+', help='prompts to try')
|
475 |
+
parser.add_argument("--prompt_file", type=str, default=None, help="path to the prompt file")
|
476 |
+
parser.add_argument("--prompt_file_key", type=str, default="prompts", help="key to the prompt file")
|
477 |
+
parser.add_argument('--resolution', type=int, default=512, help='resolution of the image')
|
478 |
+
parser.add_argument('--seed', type=int, default=None, help='seed for the random number generator')
|
479 |
+
parser.add_argument("--start_noise", type=int, default=800, help="start noise")
|
480 |
+
parser.add_argument("--from_scratch", default=False, action="store_true", help="from scratch")
|
481 |
+
parser.add_argument("--ref_image_folder", type=str, default=None, help="folder containing reference images")
|
482 |
+
parser.add_argument("--show", action="store_true", help="show the image")
|
483 |
+
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
|
484 |
+
parser.add_argument("--scales", type=float, default=[0.,1.], nargs='+', help="scales to test")
|
485 |
+
parser.add_argument("--train_method", type=str, default=None, help="train method")
|
486 |
+
|
487 |
+
# parser.add_argument("--vae_path", type=str, default="CompVis/stable-diffusion-v1-4", help="Path to the VAE model.")
|
488 |
+
# parser.add_argument("--text_encoder_path", type=str, default="CompVis/stable-diffusion-v1-4", help="Path to the text encoder model.")
|
489 |
+
parser.add_argument("--pretrained_model_name_or_path", type=str, default="rhfeiyang/art-free-diffusion-v1", help="Path to the pretrained model.")
|
490 |
+
parser.add_argument("--unet_ckpt", default=None, type=str, help="Path to the unet checkpoint")
|
491 |
+
parser.add_argument("--guidance_scale", type=float, default=5.0, help="guidance scale")
|
492 |
+
parser.add_argument("--infer_mode", default="sks_art", help="inference mode") #, choices=["style", "ori", "artist", "sks_art","Peter"]
|
493 |
+
parser.add_argument("--save_dir", type=str, default="inference_output", help="save directory")
|
494 |
+
parser.add_argument("--num_workers", type=int, default=4, help="number of workers")
|
495 |
+
parser.add_argument("--no_load", action="store_true", help="no load the pre-inferred results")
|
496 |
+
parser.add_argument("--infer_prompts", type=str, default=None, nargs="+", help="prompts to infer")
|
497 |
+
parser.add_argument("--infer_images", type=str, default=None, nargs="+", help="images to infer")
|
498 |
+
parser.add_argument("--rank", type=int, default=1, help="rank of the lora")
|
499 |
+
parser.add_argument("--val_set", type=str, default="laion_pop500", help="validation set")
|
500 |
+
parser.add_argument("--folder_name", type=str, default=None, help="folder name")
|
501 |
+
parser.add_argument("--scheduler_type",type=str, choices=["ddpm", "ddim", "pndm","lms"], default="ddpm", help="scheduler type")
|
502 |
+
parser.add_argument("--infer_steps", type=int, default=50, help="inference steps")
|
503 |
+
parser.add_argument("--weight_dtype", type=str, default="fp32", help="weight dtype")
|
504 |
+
parser.add_argument("--custom_coco_cap", action="store_true", help="use custom coco caption")
|
505 |
+
args = parser.parse_args()
|
506 |
+
if args.infer_prompts is not None and len(args.infer_prompts) == 1 and os.path.isfile(args.infer_prompts[0]):
|
507 |
+
if args.infer_prompts[0].endswith(".txt") and args.custom_coco_cap:
|
508 |
+
args.infer_prompts = CustomCocoCaptions(custom_file=args.infer_prompts[0])
|
509 |
+
elif args.infer_prompts[0].endswith(".txt"):
|
510 |
+
with open(args.infer_prompts[0], 'r') as f:
|
511 |
+
args.infer_prompts = f.readlines()
|
512 |
+
args.infer_prompts = [prompt.strip() for prompt in args.infer_prompts]
|
513 |
+
elif args.infer_prompts[0].endswith(".csv"):
|
514 |
+
from custom_datasets.custom_caption import Caption_set
|
515 |
+
caption_set = Caption_set(args.infer_prompts[0])
|
516 |
+
args.infer_prompts = caption_set
|
517 |
+
|
518 |
+
|
519 |
+
if args.infer_mode == "style":
|
520 |
+
with open(os.path.join(args.ref_image_folder, "style_label.txt"), 'r') as f:
|
521 |
+
args.style_label = f.readlines()[0].strip()
|
522 |
+
elif args.infer_mode == "artist":
|
523 |
+
with open(os.path.join(args.ref_image_folder, "style_label.txt"), 'r') as f:
|
524 |
+
args.style_label = f.readlines()[0].strip()
|
525 |
+
args.style_label = args.style_label.split(",")[0].strip()
|
526 |
+
elif args.infer_mode == "ori":
|
527 |
+
args.style_label = None
|
528 |
+
else:
|
529 |
+
args.style_label = args.infer_mode.replace("_", " ")
|
530 |
+
if args.ref_image_folder is not None:
|
531 |
+
args.ref_image_folder = os.path.join(args.ref_image_folder, "paintings")
|
532 |
+
|
533 |
+
if args.start_noise < 0:
|
534 |
+
args.from_scratch = True
|
535 |
+
|
536 |
+
|
537 |
+
print(args.__dict__)
|
538 |
+
return args
|
539 |
+
|
540 |
+
|
541 |
+
def main(args):
|
542 |
+
lora_weights = args.lora_weights
|
543 |
+
|
544 |
+
if len(lora_weights) == 1 and isinstance(lora_weights[0], str) and os.path.isdir(lora_weights[0]):
|
545 |
+
lora_weights = glob.glob(os.path.join(lora_weights[0], "*.pt"))
|
546 |
+
lora_weights=sorted(lora_weights, reverse=True)
|
547 |
+
|
548 |
+
width = args.resolution
|
549 |
+
height = args.resolution
|
550 |
+
steps = args.infer_steps
|
551 |
+
|
552 |
+
revision = None
|
553 |
+
device = 'cuda'
|
554 |
+
rank = args.rank
|
555 |
+
if args.weight_dtype == "fp32":
|
556 |
+
weight_dtype = torch.float32
|
557 |
+
elif args.weight_dtype=="fp16":
|
558 |
+
weight_dtype = torch.float16
|
559 |
+
elif args.weight_dtype=="bf16":
|
560 |
+
weight_dtype = torch.bfloat16
|
561 |
+
|
562 |
+
modules = get_model(args.pretrained_model_name_or_path, unet_ckpt=args.unet_ckpt, revision=revision, variant=None, lora_path=None, weight_dtype=weight_dtype, device=device, )
|
563 |
+
if args.scheduler_type == "pndm":
|
564 |
+
noise_scheduler = PNDMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
565 |
+
|
566 |
+
elif args.scheduler_type == "ddpm":
|
567 |
+
noise_scheduler = DDPMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
568 |
+
elif args.scheduler_type == "ddim":
|
569 |
+
noise_scheduler = DDIMScheduler(
|
570 |
+
beta_start=0.00085,
|
571 |
+
beta_end=0.012,
|
572 |
+
beta_schedule="scaled_linear",
|
573 |
+
num_train_timesteps=1000,
|
574 |
+
clip_sample=False,
|
575 |
+
prediction_type="epsilon",
|
576 |
+
)
|
577 |
+
elif args.scheduler_type == "lms":
|
578 |
+
noise_scheduler = LMSDiscreteScheduler(beta_start=0.00085,
|
579 |
+
beta_end=0.012,
|
580 |
+
beta_schedule="scaled_linear",
|
581 |
+
num_train_timesteps=1000)
|
582 |
+
else:
|
583 |
+
raise ValueError("Unknown scheduler type")
|
584 |
+
cache=EasyDict()
|
585 |
+
cache.modules = modules
|
586 |
+
|
587 |
+
unet = modules["unet"]
|
588 |
+
vae = modules["vae"]
|
589 |
+
text_encoder = modules["text_encoder"]
|
590 |
+
tokenizer = modules["tokenizer"]
|
591 |
+
|
592 |
+
unet.requires_grad_(False)
|
593 |
+
|
594 |
+
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
595 |
+
vae.requires_grad_(False)
|
596 |
+
text_encoder.requires_grad_(False)
|
597 |
+
|
598 |
+
## dataloader
|
599 |
+
dataloader = get_validation_dataloader(infer_prompts=args.infer_prompts, infer_images=args.infer_images,
|
600 |
+
resolution=args.resolution,
|
601 |
+
batch_size=args.batch_size, num_workers=args.num_workers,
|
602 |
+
val_set=args.val_set)
|
603 |
+
|
604 |
+
|
605 |
+
for lora_weight in lora_weights:
|
606 |
+
print(f"Testing {lora_weight}")
|
607 |
+
# for different seeds on same prompt
|
608 |
+
seed = args.seed
|
609 |
+
|
610 |
+
network_ret = get_lora_network(unet, lora_weight, train_method=args.train_method, rank=rank, alpha=1.0, device=device, weight_dtype=weight_dtype)
|
611 |
+
network = network_ret["network"]
|
612 |
+
train_method = network_ret["train_method"]
|
613 |
+
if args.save_dir is not None:
|
614 |
+
save_dir = args.save_dir
|
615 |
+
if args.style_label is not None:
|
616 |
+
save_dir = os.path.join(save_dir, f"{args.style_label.replace(' ', '_')}")
|
617 |
+
else:
|
618 |
+
save_dir = os.path.join(save_dir, f"ori/{args.start_noise}")
|
619 |
+
else:
|
620 |
+
if args.folder_name is not None:
|
621 |
+
folder_name = args.folder_name
|
622 |
+
else:
|
623 |
+
folder_name = "validation" if args.infer_prompts is None else "validation_prompts"
|
624 |
+
save_dir = os.path.join(os.path.dirname(lora_weight), f"{folder_name}/{train_method}", os.path.basename(lora_weight).replace('.pt','').split('_')[-1])
|
625 |
+
if args.infer_prompts is None:
|
626 |
+
save_dir = os.path.join(save_dir, f"{args.val_set}")
|
627 |
+
|
628 |
+
infer_config = f"{args.scheduler_type}{args.infer_steps}_{args.weight_dtype}_guidance{args.guidance_scale}"
|
629 |
+
save_dir = os.path.join(save_dir, infer_config)
|
630 |
+
os.makedirs(save_dir, exist_ok=True)
|
631 |
+
if args.from_scratch:
|
632 |
+
save_dir = os.path.join(save_dir, "from_scratch")
|
633 |
+
else:
|
634 |
+
save_dir = os.path.join(save_dir, "transfer")
|
635 |
+
save_dir = os.path.join(save_dir, f"start{args.start_noise}")
|
636 |
+
os.makedirs(save_dir, exist_ok=True)
|
637 |
+
with open(os.path.join(save_dir, "infer_args.yaml"), 'w') as f:
|
638 |
+
yaml.dump(vars(args), f)
|
639 |
+
# save code
|
640 |
+
code_dir = os.path.join(save_dir, "code")
|
641 |
+
os.makedirs(code_dir, exist_ok=True)
|
642 |
+
current_file = os.path.basename(__file__)
|
643 |
+
shutil.copy(__file__, os.path.join(code_dir, current_file))
|
644 |
+
with torch.no_grad():
|
645 |
+
pred_images, prompts = inference(network, tokenizer, text_encoder, vae, unet, noise_scheduler, dataloader, height, width,
|
646 |
+
args.scales, save_dir, seed, weight_dtype, device, args.batch_size, steps, guidance_scale=args.guidance_scale,
|
647 |
+
start_noise=args.start_noise, show=args.show, style_prompt=args.style_label, no_load=args.no_load,
|
648 |
+
from_scratch=args.from_scratch)
|
649 |
+
|
650 |
+
if args.ref_image_folder is not None:
|
651 |
+
flush()
|
652 |
+
print("Calculating metrics")
|
653 |
+
infer_metric(args.ref_image_folder, pred_images, save_dir, args.start_noise)
|
654 |
+
|
655 |
+
if __name__ == "__main__":
|
656 |
+
args = parse_args()
|
657 |
+
main(args)
|
utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# Authors: Hui Ren (rhfeiyang.github.io)
|
utils/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (203 Bytes). View file
|
|
utils/__pycache__/lora.cpython-39.pyc
ADDED
Binary file (6.29 kB). View file
|
|
utils/__pycache__/metrics.cpython-39.pyc
ADDED
Binary file (19.3 kB). View file
|
|
utils/__pycache__/train_util.cpython-39.pyc
ADDED
Binary file (10.9 kB). View file
|
|
utils/art_filter.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Authors: Hui Ren (rhfeiyang.github.io)
|
2 |
+
|
3 |
+
from transformers import CLIPProcessor, CLIPModel
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import os
|
7 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
class Caption_filter:
|
11 |
+
def __init__(self, filter_prompts=["painting", "paintings", "art", "artwork", "drawings", "sketch", "sketches", "illustration", "illustrations",
|
12 |
+
"sculpture","sculptures", "installation", "printmaking", "digital art", "conceptual art", "mosaic", "tapestry",
|
13 |
+
"abstract", "realism", "surrealism", "impressionism", "expressionism", "cubism", "minimalism", "baroque", "rococo",
|
14 |
+
"pop art", "art nouveau", "art deco", "futurism", "dadaism",
|
15 |
+
"stamp", "stamps", "advertisement", "advertisements","logo", "logos"
|
16 |
+
],):
|
17 |
+
self.filter_prompts = filter_prompts
|
18 |
+
self.total_count=0
|
19 |
+
self.filter_count=[0]*len(filter_prompts)
|
20 |
+
|
21 |
+
def reset(self):
|
22 |
+
self.total_count=0
|
23 |
+
self.filter_count=[0]*len(self.filter_prompts)
|
24 |
+
def filter(self, captions):
|
25 |
+
filter_result = []
|
26 |
+
for caption in captions:
|
27 |
+
words = caption[0]
|
28 |
+
if words == None:
|
29 |
+
filter_result.append((True, "None"))
|
30 |
+
continue
|
31 |
+
words = words.lower()
|
32 |
+
words = words.split()
|
33 |
+
filt = False
|
34 |
+
reason=None
|
35 |
+
for i, filter_keyword in enumerate(self.filter_prompts):
|
36 |
+
key_len = len(filter_keyword.split())
|
37 |
+
for j in range(len(words)-key_len+1):
|
38 |
+
if " ".join(words[j:j+key_len]) == filter_keyword:
|
39 |
+
self.filter_count[i] += 1
|
40 |
+
filt = True
|
41 |
+
reason = filter_keyword
|
42 |
+
break
|
43 |
+
if filt:
|
44 |
+
break
|
45 |
+
filter_result.append((filt, reason))
|
46 |
+
self.total_count += 1
|
47 |
+
return filter_result
|
48 |
+
|
49 |
+
class Clip_filter:
|
50 |
+
prompt_threshold = {
|
51 |
+
"painting": 17,
|
52 |
+
"art": 17.5,
|
53 |
+
"artwork": 19,
|
54 |
+
"drawing": 15.8,
|
55 |
+
"sketch": 17,
|
56 |
+
"illustration": 15,
|
57 |
+
"sculpture": 19.2,
|
58 |
+
"installation art": 20,
|
59 |
+
"printmaking art": 16.3,
|
60 |
+
"digital art": 15,
|
61 |
+
"conceptual art": 18,
|
62 |
+
"mosaic art": 19,
|
63 |
+
"tapestry": 16,
|
64 |
+
"abstract art":16.5,
|
65 |
+
"realism art": 16,
|
66 |
+
"surrealism art": 15,
|
67 |
+
"impressionism art": 17,
|
68 |
+
"expressionism art": 17,
|
69 |
+
"cubism art": 15,
|
70 |
+
"minimalism art": 16,
|
71 |
+
"baroque art": 17.5,
|
72 |
+
"rococo art": 17,
|
73 |
+
"pop art": 16,
|
74 |
+
"art nouveau": 19,
|
75 |
+
"art deco": 19,
|
76 |
+
"futurism art": 16.5,
|
77 |
+
"dadaism art": 16.5,
|
78 |
+
"stamp": 18,
|
79 |
+
"advertisement": 16.5,
|
80 |
+
"logo": 15.5,
|
81 |
+
}
|
82 |
+
@torch.no_grad()
|
83 |
+
def __init__(self, positive_prompt=["painting", "art", "artwork", "drawing", "sketch", "illustration",
|
84 |
+
"sculpture", "installation art", "printmaking art", "digital art", "conceptual art", "mosaic art", "tapestry",
|
85 |
+
"abstract art", "realism art", "surrealism art", "impressionism art", "expressionism art", "cubism art",
|
86 |
+
"minimalism art", "baroque art", "rococo art",
|
87 |
+
"pop art", "art nouveau", "art deco", "futurism art", "dadaism art",
|
88 |
+
"stamp", "advertisement",
|
89 |
+
"logo"
|
90 |
+
],
|
91 |
+
device="cuda"):
|
92 |
+
self.device = device
|
93 |
+
self.model = (CLIPModel.from_pretrained("openai/clip-vit-large-patch14")).to(device)
|
94 |
+
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
95 |
+
self.positive_prompt = positive_prompt
|
96 |
+
self.text = self.positive_prompt
|
97 |
+
self.tokenizer = self.processor.tokenizer
|
98 |
+
self.image_processor = self.processor.image_processor
|
99 |
+
self.text_encoding = self.tokenizer(self.text, return_tensors="pt", padding=True).to(device)
|
100 |
+
self.text_features = self.model.get_text_features(**self.text_encoding)
|
101 |
+
self.text_features = self.text_features / self.text_features.norm(p=2, dim=-1, keepdim=True)
|
102 |
+
@torch.no_grad()
|
103 |
+
def similarity(self, image):
|
104 |
+
# inputs = self.processor(text=self.text, images=image, return_tensors="pt", padding=True)
|
105 |
+
image_processed = self.image_processor(image, return_tensors="pt", padding=True).to(self.device, non_blocking=True)
|
106 |
+
inputs = {**self.text_encoding, **image_processed}
|
107 |
+
outputs = self.model(**inputs)
|
108 |
+
logits_per_image = outputs.logits_per_image
|
109 |
+
return logits_per_image
|
110 |
+
|
111 |
+
def get_logits(self, image):
|
112 |
+
logits_per_image = self.similarity(image)
|
113 |
+
return logits_per_image.cpu()
|
114 |
+
|
115 |
+
def get_image_features(self, image):
|
116 |
+
image_processed = self.image_processor(image, return_tensors="pt", padding=True).to(self.device, non_blocking=True)
|
117 |
+
image_features = self.model.get_image_features(**image_processed)
|
118 |
+
return image_features
|
119 |
+
|
120 |
+
|
121 |
+
class Art_filter:
|
122 |
+
def __init__(self):
|
123 |
+
self.caption_filter = Caption_filter()
|
124 |
+
self.clip_filter = Clip_filter()
|
125 |
+
def caption_filt(self, dataloader):
|
126 |
+
self.caption_filter.reset()
|
127 |
+
dataloader.dataset.get_img = False
|
128 |
+
dataloader.dataset.get_cap = True
|
129 |
+
remain_ids = []
|
130 |
+
filtered_ids = []
|
131 |
+
for i, batch in tqdm(enumerate(dataloader)):
|
132 |
+
captions = batch["text"]
|
133 |
+
filter_result = self.caption_filter.filter(captions)
|
134 |
+
for j, (filt, reason) in enumerate(filter_result):
|
135 |
+
if filt:
|
136 |
+
filtered_ids.append((batch["ids"][j], reason))
|
137 |
+
if i%10==0:
|
138 |
+
print(f"Filtered caption: {captions[j]}, reason: {reason}")
|
139 |
+
else:
|
140 |
+
remain_ids.append(batch["ids"][j])
|
141 |
+
return {"remain_ids":remain_ids, "filtered_ids":filtered_ids, "total_count":self.caption_filter.total_count, "filter_count":self.caption_filter.filter_count, "filter_prompts":self.caption_filter.filter_prompts}
|
142 |
+
|
143 |
+
def clip_filt(self, clip_logits_ckpt:dict):
|
144 |
+
logits = clip_logits_ckpt["clip_logits"]
|
145 |
+
ids = clip_logits_ckpt["ids"]
|
146 |
+
text = clip_logits_ckpt["text"]
|
147 |
+
filt_mask = torch.zeros(logits.shape[0], dtype=torch.bool)
|
148 |
+
for i, prompt in enumerate(text):
|
149 |
+
threshold = Clip_filter.prompt_threshold[prompt]
|
150 |
+
filt_mask = filt_mask | (logits[:,i] >= threshold)
|
151 |
+
filt_ids = []
|
152 |
+
remain_ids = []
|
153 |
+
for i, id in enumerate(ids):
|
154 |
+
if filt_mask[i]:
|
155 |
+
filt_ids.append(id)
|
156 |
+
else:
|
157 |
+
remain_ids.append(id)
|
158 |
+
return {"remain_ids":remain_ids, "filtered_ids":filt_ids}
|
159 |
+
|
160 |
+
def clip_feature(self, dataloader):
|
161 |
+
dataloader.dataset.get_img = True
|
162 |
+
dataloader.dataset.get_cap = False
|
163 |
+
clip_features = []
|
164 |
+
ids = []
|
165 |
+
for i, batch in enumerate(dataloader):
|
166 |
+
images = batch["images"]
|
167 |
+
features = self.clip_filter.get_image_features(images).cpu()
|
168 |
+
clip_features.append(features)
|
169 |
+
ids.extend(batch["ids"])
|
170 |
+
clip_features = torch.cat(clip_features)
|
171 |
+
return {"clip_features":clip_features, "ids":ids}
|
172 |
+
|
173 |
+
|
174 |
+
def clip_logit(self, dataloader):
|
175 |
+
dataloader.dataset.get_img = True
|
176 |
+
dataloader.dataset.get_cap = False
|
177 |
+
clip_features = []
|
178 |
+
clip_logits = []
|
179 |
+
ids = []
|
180 |
+
for i, batch in enumerate(dataloader):
|
181 |
+
images = batch["images"]
|
182 |
+
# logits = self.clip_filter.get_logits(images)
|
183 |
+
feature = self.clip_filter.get_image_features(images)
|
184 |
+
logits = self.clip_logit_by_feat(feature)["clip_logits"]
|
185 |
+
|
186 |
+
clip_features.append(feature)
|
187 |
+
clip_logits.append(logits)
|
188 |
+
ids.extend(batch["ids"])
|
189 |
+
|
190 |
+
clip_features = torch.cat(clip_features)
|
191 |
+
clip_logits = torch.cat(clip_logits)
|
192 |
+
return {"clip_features":clip_features, "clip_logits":clip_logits, "ids":ids, "text": self.clip_filter.text}
|
193 |
+
|
194 |
+
def clip_logit_by_feat(self, feature):
|
195 |
+
feature = feature.clone().to(self.clip_filter.device)
|
196 |
+
feature = feature / feature.norm(p=2, dim=-1, keepdim=True)
|
197 |
+
logit_scale = self.clip_filter.model.logit_scale.exp()
|
198 |
+
logits = ((feature @ self.clip_filter.text_features.T) * logit_scale).cpu()
|
199 |
+
return {"clip_logits":logits, "text": self.clip_filter.text}
|
200 |
+
|
201 |
+
|
202 |
+
|
203 |
+
if __name__ == "__main__":
|
204 |
+
import pickle
|
205 |
+
with open("/vision-nfs/torralba/scratch/jomat/sam_dataset/filt_result/sa_000000/clip_logits_result.pickle","rb") as f:
|
206 |
+
result=pickle.load(f)
|
207 |
+
feat = result['clip_features']
|
208 |
+
logits =Art_filter().clip_logit_by_feat(feat)
|
209 |
+
print(logits)
|
210 |
+
|
utils/config_util.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal, Optional
|
2 |
+
|
3 |
+
import yaml
|
4 |
+
|
5 |
+
from pydantic import BaseModel
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from lora import TRAINING_METHODS
|
9 |
+
|
10 |
+
PRECISION_TYPES = Literal["fp32", "fp16", "bf16", "float32", "float16", "bfloat16"]
|
11 |
+
NETWORK_TYPES = Literal["lierla", "c3lier"]
|
12 |
+
|
13 |
+
|
14 |
+
class PretrainedModelConfig(BaseModel):
|
15 |
+
name_or_path: str
|
16 |
+
ckpt_path: Optional[str] = None
|
17 |
+
v2: bool = False
|
18 |
+
v_pred: bool = False
|
19 |
+
|
20 |
+
clip_skip: Optional[int] = None
|
21 |
+
|
22 |
+
|
23 |
+
class NetworkConfig(BaseModel):
|
24 |
+
type: NETWORK_TYPES = "lierla"
|
25 |
+
rank: int = 4
|
26 |
+
alpha: float = 1.0
|
27 |
+
|
28 |
+
training_method: TRAINING_METHODS = "full"
|
29 |
+
|
30 |
+
|
31 |
+
class TrainConfig(BaseModel):
|
32 |
+
precision: PRECISION_TYPES = "bfloat16"
|
33 |
+
noise_scheduler: Literal["ddim", "ddpm", "lms", "euler_a"] = "ddim"
|
34 |
+
|
35 |
+
iterations: int = 500
|
36 |
+
lr: float = 1e-4
|
37 |
+
optimizer: str = "adamw"
|
38 |
+
optimizer_args: str = ""
|
39 |
+
lr_scheduler: str = "constant"
|
40 |
+
|
41 |
+
max_denoising_steps: int = 50
|
42 |
+
|
43 |
+
|
44 |
+
class SaveConfig(BaseModel):
|
45 |
+
name: str = "untitled"
|
46 |
+
path: str = "./output"
|
47 |
+
per_steps: int = 200
|
48 |
+
precision: PRECISION_TYPES = "float32"
|
49 |
+
|
50 |
+
|
51 |
+
class LoggingConfig(BaseModel):
|
52 |
+
use_wandb: bool = False
|
53 |
+
|
54 |
+
verbose: bool = False
|
55 |
+
|
56 |
+
|
57 |
+
class OtherConfig(BaseModel):
|
58 |
+
use_xformers: bool = False
|
59 |
+
|
60 |
+
|
61 |
+
class RootConfig(BaseModel):
|
62 |
+
# prompts_file: str
|
63 |
+
pretrained_model: PretrainedModelConfig
|
64 |
+
|
65 |
+
network: NetworkConfig
|
66 |
+
|
67 |
+
train: Optional[TrainConfig]
|
68 |
+
|
69 |
+
save: Optional[SaveConfig]
|
70 |
+
|
71 |
+
logging: Optional[LoggingConfig]
|
72 |
+
|
73 |
+
other: Optional[OtherConfig]
|
74 |
+
|
75 |
+
|
76 |
+
def parse_precision(precision: str) -> torch.dtype:
|
77 |
+
if precision == "fp32" or precision == "float32":
|
78 |
+
return torch.float32
|
79 |
+
elif precision == "fp16" or precision == "float16":
|
80 |
+
return torch.float16
|
81 |
+
elif precision == "bf16" or precision == "bfloat16":
|
82 |
+
return torch.bfloat16
|
83 |
+
|
84 |
+
raise ValueError(f"Invalid precision type: {precision}")
|
85 |
+
|
86 |
+
|
87 |
+
def load_config_from_yaml(config_path: str) -> RootConfig:
|
88 |
+
with open(config_path, "r") as f:
|
89 |
+
config = yaml.load(f, Loader=yaml.FullLoader)
|
90 |
+
|
91 |
+
root = RootConfig(**config)
|
92 |
+
|
93 |
+
if root.train is None:
|
94 |
+
root.train = TrainConfig()
|
95 |
+
|
96 |
+
if root.save is None:
|
97 |
+
root.save = SaveConfig()
|
98 |
+
|
99 |
+
if root.logging is None:
|
100 |
+
root.logging = LoggingConfig()
|
101 |
+
|
102 |
+
if root.other is None:
|
103 |
+
root.other = OtherConfig()
|
104 |
+
|
105 |
+
return root
|
utils/debug_util.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# デバッグ用...
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def check_requires_grad(model: torch.nn.Module):
|
7 |
+
for name, module in list(model.named_modules())[:5]:
|
8 |
+
if len(list(module.parameters())) > 0:
|
9 |
+
print(f"Module: {name}")
|
10 |
+
for name, param in list(module.named_parameters())[:2]:
|
11 |
+
print(f" Parameter: {name}, Requires Grad: {param.requires_grad}")
|
12 |
+
|
13 |
+
|
14 |
+
def check_training_mode(model: torch.nn.Module):
|
15 |
+
for name, module in list(model.named_modules())[:5]:
|
16 |
+
print(f"Module: {name}, Training Mode: {module.training}")
|
utils/lora.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ref:
|
2 |
+
# - https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
|
3 |
+
# - https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py
|
4 |
+
|
5 |
+
import os
|
6 |
+
import math
|
7 |
+
from typing import Optional, List, Type, Set, Literal
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from diffusers import UNet2DConditionModel
|
12 |
+
from safetensors.torch import save_file
|
13 |
+
|
14 |
+
|
15 |
+
UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [
|
16 |
+
# "Transformer2DModel", # どうやらこっちの方らしい? # attn1, 2
|
17 |
+
"Attention"
|
18 |
+
]
|
19 |
+
UNET_TARGET_REPLACE_MODULE_CONV = [
|
20 |
+
"ResnetBlock2D",
|
21 |
+
"Downsample2D",
|
22 |
+
"Upsample2D",
|
23 |
+
# "DownBlock2D",
|
24 |
+
# "UpBlock2D"
|
25 |
+
] # locon, 3clier
|
26 |
+
|
27 |
+
LORA_PREFIX_UNET = "lora_unet"
|
28 |
+
|
29 |
+
DEFAULT_TARGET_REPLACE = UNET_TARGET_REPLACE_MODULE_TRANSFORMER
|
30 |
+
|
31 |
+
TRAINING_METHODS = Literal[
|
32 |
+
"noxattn", # train all layers except x-attns and time_embed layers
|
33 |
+
"innoxattn", # train all layers except self attention layers
|
34 |
+
"selfattn", # ESD-u, train only self attention layers
|
35 |
+
"xattn", # ESD-x, train only x attention layers
|
36 |
+
"full", # train all layers
|
37 |
+
"xattn-strict", # q and k values
|
38 |
+
"noxattn-hspace",
|
39 |
+
"noxattn-hspace-last",
|
40 |
+
# "xlayer",
|
41 |
+
# "outxattn",
|
42 |
+
# "outsattn",
|
43 |
+
# "inxattn",
|
44 |
+
# "inmidsattn",
|
45 |
+
# "selflayer",
|
46 |
+
]
|
47 |
+
|
48 |
+
|
49 |
+
class LoRAModule(nn.Module):
|
50 |
+
"""
|
51 |
+
replaces forward method of the original Linear, instead of replacing the original Linear module.
|
52 |
+
"""
|
53 |
+
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
lora_name,
|
57 |
+
org_module: nn.Module,
|
58 |
+
multiplier=1.0,
|
59 |
+
lora_dim=4,
|
60 |
+
alpha=1,
|
61 |
+
):
|
62 |
+
"""if alpha == 0 or None, alpha is rank (no scaling)."""
|
63 |
+
super().__init__()
|
64 |
+
self.lora_name = lora_name
|
65 |
+
self.lora_dim = lora_dim
|
66 |
+
|
67 |
+
if "Linear" in org_module.__class__.__name__:
|
68 |
+
in_dim = org_module.in_features
|
69 |
+
out_dim = org_module.out_features
|
70 |
+
self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
|
71 |
+
self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
|
72 |
+
|
73 |
+
elif "Conv" in org_module.__class__.__name__: # 一応
|
74 |
+
in_dim = org_module.in_channels
|
75 |
+
out_dim = org_module.out_channels
|
76 |
+
|
77 |
+
self.lora_dim = min(self.lora_dim, in_dim, out_dim)
|
78 |
+
if self.lora_dim != lora_dim:
|
79 |
+
print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
|
80 |
+
|
81 |
+
kernel_size = org_module.kernel_size
|
82 |
+
stride = org_module.stride
|
83 |
+
padding = org_module.padding
|
84 |
+
self.lora_down = nn.Conv2d(
|
85 |
+
in_dim, self.lora_dim, kernel_size, stride, padding, bias=False
|
86 |
+
)
|
87 |
+
self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
|
88 |
+
|
89 |
+
if type(alpha) == torch.Tensor:
|
90 |
+
alpha = alpha.detach().numpy()
|
91 |
+
alpha = lora_dim if alpha is None or alpha == 0 else alpha
|
92 |
+
self.scale = alpha / self.lora_dim
|
93 |
+
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える
|
94 |
+
|
95 |
+
# same as microsoft's
|
96 |
+
nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
|
97 |
+
nn.init.zeros_(self.lora_up.weight)
|
98 |
+
|
99 |
+
self.multiplier = multiplier
|
100 |
+
self.org_module = org_module # remove in applying
|
101 |
+
|
102 |
+
def apply_to(self):
|
103 |
+
self.org_forward = self.org_module.forward
|
104 |
+
self.org_module.forward = self.forward
|
105 |
+
del self.org_module
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
return (
|
109 |
+
self.org_forward(x)
|
110 |
+
+ self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
|
111 |
+
)
|
112 |
+
|
113 |
+
|
114 |
+
class LoRANetwork(nn.Module):
|
115 |
+
def __init__(
|
116 |
+
self,
|
117 |
+
unet: UNet2DConditionModel,
|
118 |
+
rank: int = 4,
|
119 |
+
multiplier: float = 1.0,
|
120 |
+
alpha: float = 1.0,
|
121 |
+
train_method: TRAINING_METHODS = "full",
|
122 |
+
) -> None:
|
123 |
+
super().__init__()
|
124 |
+
self.lora_scale = 1
|
125 |
+
self.multiplier = multiplier
|
126 |
+
self.lora_dim = rank
|
127 |
+
self.alpha = alpha
|
128 |
+
|
129 |
+
|
130 |
+
self.module = LoRAModule
|
131 |
+
|
132 |
+
|
133 |
+
self.unet_loras = self.create_modules(
|
134 |
+
LORA_PREFIX_UNET,
|
135 |
+
unet,
|
136 |
+
DEFAULT_TARGET_REPLACE,
|
137 |
+
self.lora_dim,
|
138 |
+
self.multiplier,
|
139 |
+
train_method=train_method,
|
140 |
+
)
|
141 |
+
print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
|
142 |
+
|
143 |
+
|
144 |
+
lora_names = set()
|
145 |
+
for lora in self.unet_loras:
|
146 |
+
assert (
|
147 |
+
lora.lora_name not in lora_names
|
148 |
+
), f"duplicated lora name: {lora.lora_name}. {lora_names}"
|
149 |
+
lora_names.add(lora.lora_name)
|
150 |
+
|
151 |
+
|
152 |
+
for lora in self.unet_loras:
|
153 |
+
lora.apply_to()
|
154 |
+
self.add_module(
|
155 |
+
lora.lora_name,
|
156 |
+
lora,
|
157 |
+
)
|
158 |
+
|
159 |
+
del unet
|
160 |
+
|
161 |
+
torch.cuda.empty_cache()
|
162 |
+
|
163 |
+
def create_modules(
|
164 |
+
self,
|
165 |
+
prefix: str,
|
166 |
+
root_module: nn.Module,
|
167 |
+
target_replace_modules: List[str],
|
168 |
+
rank: int,
|
169 |
+
multiplier: float,
|
170 |
+
train_method: TRAINING_METHODS,
|
171 |
+
) -> list:
|
172 |
+
loras = []
|
173 |
+
names = []
|
174 |
+
for name, module in root_module.named_modules():
|
175 |
+
if train_method == "noxattn" or train_method == "noxattn-hspace" or train_method == "noxattn-hspace-last": # Cross Attention と Time Embed 以外学習
|
176 |
+
if "attn2" in name or "time_embed" in name:
|
177 |
+
continue
|
178 |
+
elif train_method == "innoxattn": # Cross Attention 以外学習
|
179 |
+
if "attn2" in name:
|
180 |
+
continue
|
181 |
+
elif train_method == "selfattn": # Self Attention のみ学習
|
182 |
+
if "attn1" not in name:
|
183 |
+
continue
|
184 |
+
elif train_method == "xattn" or train_method == "xattn-strict": # Cross Attention のみ学習
|
185 |
+
if "attn2" not in name:
|
186 |
+
continue
|
187 |
+
elif train_method == "attn":
|
188 |
+
if "attn1" not in name and "attn2" not in name:
|
189 |
+
continue
|
190 |
+
elif train_method == "full":
|
191 |
+
pass
|
192 |
+
# else:
|
193 |
+
# raise NotImplementedError(
|
194 |
+
# f"train_method: {train_method} is not implemented."
|
195 |
+
# )
|
196 |
+
##
|
197 |
+
# union condition(b-lora)
|
198 |
+
else:
|
199 |
+
discard = True
|
200 |
+
if "all_up" in train_method:
|
201 |
+
if "up_blocks" in name:
|
202 |
+
discard = False
|
203 |
+
if "down_1" in train_method:
|
204 |
+
if not ("down_blocks.1" not in name or "attentions" not in name):
|
205 |
+
discard = False
|
206 |
+
if "down_2" in train_method:
|
207 |
+
if not ("down_blocks.2" not in name or "attentions" not in name):
|
208 |
+
discard = False
|
209 |
+
if "up_1" in train_method:
|
210 |
+
if not ("up_blocks.1" not in name or "attentions" not in name):
|
211 |
+
discard = False
|
212 |
+
if "up_2" in train_method:
|
213 |
+
if not ("up_blocks.2" not in name or "attentions" not in name):
|
214 |
+
discard = False
|
215 |
+
if discard:
|
216 |
+
continue
|
217 |
+
|
218 |
+
##
|
219 |
+
if module.__class__.__name__ in target_replace_modules:
|
220 |
+
for child_name, child_module in module.named_modules():
|
221 |
+
if child_module.__class__.__name__ in ["Linear", "Conv2d", "LoRACompatibleLinear", "LoRACompatibleConv"]:
|
222 |
+
if train_method == 'xattn-strict':
|
223 |
+
if 'out' in child_name:
|
224 |
+
continue
|
225 |
+
if train_method == 'noxattn-hspace':
|
226 |
+
if 'mid_block' not in name:
|
227 |
+
continue
|
228 |
+
if train_method == 'noxattn-hspace-last':
|
229 |
+
if 'mid_block' not in name or '.1' not in name or 'conv2' not in child_name:
|
230 |
+
continue
|
231 |
+
lora_name = prefix + "." + name + "." + child_name
|
232 |
+
lora_name = lora_name.replace(".", "_")
|
233 |
+
# print(f"{lora_name}")
|
234 |
+
lora = self.module(
|
235 |
+
lora_name, child_module, multiplier, rank, self.alpha
|
236 |
+
)
|
237 |
+
# print(name, child_name)
|
238 |
+
# print(child_module.weight.shape)
|
239 |
+
loras.append(lora)
|
240 |
+
names.append(lora_name)
|
241 |
+
# print(f'@@@@@@@@@@@@@@@@@@@@@@@@@@@@ \n {names}')
|
242 |
+
return loras
|
243 |
+
|
244 |
+
def prepare_optimizer_params(self):
|
245 |
+
all_params = []
|
246 |
+
|
247 |
+
if self.unet_loras: # 実質これしかない
|
248 |
+
params = []
|
249 |
+
[params.extend(lora.parameters()) for lora in self.unet_loras]
|
250 |
+
param_data = {"params": params}
|
251 |
+
all_params.append(param_data)
|
252 |
+
|
253 |
+
return all_params
|
254 |
+
|
255 |
+
def save_weights(self, file, dtype=None, metadata: Optional[dict] = None):
|
256 |
+
state_dict = self.state_dict()
|
257 |
+
|
258 |
+
if dtype is not None:
|
259 |
+
for key in list(state_dict.keys()):
|
260 |
+
v = state_dict[key]
|
261 |
+
v = v.detach().clone().to("cpu").to(dtype)
|
262 |
+
state_dict[key] = v
|
263 |
+
|
264 |
+
# for key in list(state_dict.keys()):
|
265 |
+
# if not key.startswith("lora"):
|
266 |
+
# # lora以外除外
|
267 |
+
# del state_dict[key]
|
268 |
+
|
269 |
+
if os.path.splitext(file)[1] == ".safetensors":
|
270 |
+
save_file(state_dict, file, metadata)
|
271 |
+
else:
|
272 |
+
torch.save(state_dict, file)
|
273 |
+
def set_lora_slider(self, scale):
|
274 |
+
self.lora_scale = scale
|
275 |
+
|
276 |
+
def __enter__(self):
|
277 |
+
for lora in self.unet_loras:
|
278 |
+
lora.multiplier = 1.0 * self.lora_scale
|
279 |
+
|
280 |
+
def __exit__(self, exc_type, exc_value, tb):
|
281 |
+
for lora in self.unet_loras:
|
282 |
+
lora.multiplier = 0
|
utils/metrics.py
ADDED
@@ -0,0 +1,577 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Authors: Hui Ren (rhfeiyang.github.io)
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from torchvision import transforms
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torch.nn as nn
|
10 |
+
from torch.autograd import Function
|
11 |
+
from PIL import Image
|
12 |
+
from transformers import CLIPProcessor, CLIPModel
|
13 |
+
from collections import OrderedDict
|
14 |
+
from transformers import BatchFeature
|
15 |
+
import clip
|
16 |
+
import copy
|
17 |
+
import lpips
|
18 |
+
from transformers import ViTImageProcessor, ViTModel
|
19 |
+
|
20 |
+
## CSD_CLIP
|
21 |
+
def convert_weights_float(model: nn.Module):
|
22 |
+
"""Convert applicable model parameters to fp32"""
|
23 |
+
|
24 |
+
def _convert_weights_to_fp32(l):
|
25 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
26 |
+
l.weight.data = l.weight.data.float()
|
27 |
+
if l.bias is not None:
|
28 |
+
l.bias.data = l.bias.data.float()
|
29 |
+
|
30 |
+
if isinstance(l, nn.MultiheadAttention):
|
31 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
32 |
+
tensor = getattr(l, attr)
|
33 |
+
if tensor is not None:
|
34 |
+
tensor.data = tensor.data.float()
|
35 |
+
|
36 |
+
for name in ["text_projection", "proj"]:
|
37 |
+
if hasattr(l, name):
|
38 |
+
attr = getattr(l, name)
|
39 |
+
if attr is not None:
|
40 |
+
attr.data = attr.data.float()
|
41 |
+
|
42 |
+
model.apply(_convert_weights_to_fp32)
|
43 |
+
|
44 |
+
class ReverseLayerF(Function):
|
45 |
+
|
46 |
+
@staticmethod
|
47 |
+
def forward(ctx, x, alpha):
|
48 |
+
ctx.alpha = alpha
|
49 |
+
|
50 |
+
return x.view_as(x)
|
51 |
+
|
52 |
+
@staticmethod
|
53 |
+
def backward(ctx, grad_output):
|
54 |
+
output = grad_output.neg() * ctx.alpha
|
55 |
+
|
56 |
+
return output, None
|
57 |
+
|
58 |
+
|
59 |
+
## taken from https://github.com/moein-shariatnia/OpenAI-CLIP/blob/master/modules.py
|
60 |
+
class ProjectionHead(nn.Module):
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
embedding_dim,
|
64 |
+
projection_dim,
|
65 |
+
dropout=0
|
66 |
+
):
|
67 |
+
super().__init__()
|
68 |
+
self.projection = nn.Linear(embedding_dim, projection_dim)
|
69 |
+
self.gelu = nn.GELU()
|
70 |
+
self.fc = nn.Linear(projection_dim, projection_dim)
|
71 |
+
self.dropout = nn.Dropout(dropout)
|
72 |
+
self.layer_norm = nn.LayerNorm(projection_dim)
|
73 |
+
|
74 |
+
def forward(self, x):
|
75 |
+
projected = self.projection(x)
|
76 |
+
x = self.gelu(projected)
|
77 |
+
x = self.fc(x)
|
78 |
+
x = self.dropout(x)
|
79 |
+
x = x + projected
|
80 |
+
x = self.layer_norm(x)
|
81 |
+
return x
|
82 |
+
|
83 |
+
def convert_state_dict(state_dict):
|
84 |
+
new_state_dict = OrderedDict()
|
85 |
+
for k, v in state_dict.items():
|
86 |
+
if k.startswith("module."):
|
87 |
+
k = k.replace("module.", "")
|
88 |
+
new_state_dict[k] = v
|
89 |
+
return new_state_dict
|
90 |
+
def init_weights(m):
|
91 |
+
if isinstance(m, nn.Linear):
|
92 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
93 |
+
if m.bias is not None:
|
94 |
+
nn.init.normal_(m.bias, std=1e-6)
|
95 |
+
|
96 |
+
class Metric(nn.Module):
|
97 |
+
def __init__(self):
|
98 |
+
super().__init__()
|
99 |
+
self.image_preprocess = None
|
100 |
+
|
101 |
+
def load_image(self, image_path):
|
102 |
+
with open(image_path, 'rb') as f:
|
103 |
+
image = Image.open(f).convert("RGB")
|
104 |
+
return image
|
105 |
+
|
106 |
+
def load_image_path(self, image_path):
|
107 |
+
if isinstance(image_path, str):
|
108 |
+
# should be a image folder path
|
109 |
+
images_file = os.listdir(image_path)
|
110 |
+
images = [os.path.join(image_path, image) for image in images_file if
|
111 |
+
image.endswith(".jpg") or image.endswith(".png")]
|
112 |
+
if isinstance(image_path[0], str):
|
113 |
+
images = [self.load_image(image) for image in image_path]
|
114 |
+
elif isinstance(image_path[0], np.ndarray):
|
115 |
+
images = [Image.fromarray(image) for image in image_path]
|
116 |
+
elif isinstance(image_path[0], Image.Image):
|
117 |
+
images = image_path
|
118 |
+
else:
|
119 |
+
raise Exception("Invalid input")
|
120 |
+
return images
|
121 |
+
|
122 |
+
def preprocess_image(self, image, **kwargs):
|
123 |
+
if (isinstance(image, str) and os.path.isdir(image)) or (isinstance(image, list) and (isinstance(image[0], Image.Image) or isinstance(image[0], np.ndarray) or os.path.isfile(image[0]))):
|
124 |
+
input_data = self.load_image_path(image)
|
125 |
+
input_data = [self.image_preprocess(image, **kwargs) for image in input_data]
|
126 |
+
input_data = torch.stack(input_data)
|
127 |
+
elif os.path.isfile(image):
|
128 |
+
input_data = self.load_image(image)
|
129 |
+
input_data = self.image_preprocess(input_data, **kwargs)
|
130 |
+
input_data = input_data.unsqueeze(0)
|
131 |
+
elif isinstance(image, torch.Tensor):
|
132 |
+
raise Exception("Unsupported input")
|
133 |
+
return input_data
|
134 |
+
|
135 |
+
class Clip_Basic_Metric(Metric):
|
136 |
+
def __init__(self):
|
137 |
+
super().__init__()
|
138 |
+
self.tensor_preprocess = transforms.Compose([
|
139 |
+
transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
|
140 |
+
# transforms.rescale
|
141 |
+
transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0]),
|
142 |
+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
143 |
+
])
|
144 |
+
self.image_preprocess = transforms.Compose([
|
145 |
+
transforms.Resize(size=224, interpolation=transforms.InterpolationMode.BICUBIC),
|
146 |
+
transforms.CenterCrop(224),
|
147 |
+
transforms.ToTensor(),
|
148 |
+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
149 |
+
])
|
150 |
+
|
151 |
+
class Clip_metric(Clip_Basic_Metric):
|
152 |
+
|
153 |
+
@torch.no_grad()
|
154 |
+
def __init__(self, target_style_prompt: str=None, clip_model_name="openai/clip-vit-large-patch14", device="cuda",
|
155 |
+
bath_size=8, alpha=0.5):
|
156 |
+
super().__init__()
|
157 |
+
self.device = device
|
158 |
+
self.alpha = alpha
|
159 |
+
self.model = (CLIPModel.from_pretrained(clip_model_name)).to(device)
|
160 |
+
self.processor = CLIPProcessor.from_pretrained(clip_model_name)
|
161 |
+
self.tokenizer = self.processor.tokenizer
|
162 |
+
self.image_processor = self.processor.image_processor
|
163 |
+
# self.style_class_features = self.get_text_features(self.styles).cpu()
|
164 |
+
self.style_class_features=[]
|
165 |
+
# self.noise_prompt_features = self.get_text_features("Noise")
|
166 |
+
self.model.eval()
|
167 |
+
self.batch_size = bath_size
|
168 |
+
if target_style_prompt is not None:
|
169 |
+
self.ref_style_features = self.get_text_features(target_style_prompt)
|
170 |
+
else:
|
171 |
+
self.ref_style_features = None
|
172 |
+
|
173 |
+
self.ref_image_style_prototype = None
|
174 |
+
|
175 |
+
def get_text_features(self, text):
|
176 |
+
prompt_encoding = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(self.device)
|
177 |
+
prompt_features = self.model.get_text_features(**prompt_encoding).to(self.device)
|
178 |
+
prompt_features = F.normalize(prompt_features, p=2, dim=-1)
|
179 |
+
return prompt_features
|
180 |
+
|
181 |
+
def get_image_features(self, images):
|
182 |
+
# if isinstance(image, torch.Tensor):
|
183 |
+
# self.tensor_transform(image)
|
184 |
+
# else:
|
185 |
+
# image_features = self.image_processor(image, return_tensors="pt", padding=True).to(self.device, non_blocking=True)
|
186 |
+
images = self.load_image_path(images)
|
187 |
+
if isinstance(images, torch.Tensor):
|
188 |
+
images = self.tensor_preprocess(images)
|
189 |
+
data = {"pixel_values": images}
|
190 |
+
image_features = BatchFeature(data=data, tensor_type="pt")
|
191 |
+
else:
|
192 |
+
image_features = self.image_processor(images, return_tensors="pt", padding=True).to(self.device,
|
193 |
+
non_blocking=True)
|
194 |
+
|
195 |
+
image_features = self.model.get_image_features(**image_features).to(self.device)
|
196 |
+
image_features = F.normalize(image_features, p=2, dim=-1)
|
197 |
+
return image_features
|
198 |
+
|
199 |
+
def img_text_similarity(self, image_features, text=None):
|
200 |
+
if text is not None:
|
201 |
+
prompt_feature = self.get_text_features(text)
|
202 |
+
if isinstance(text, str):
|
203 |
+
prompt_feature = prompt_feature.repeat(len(image_features), 1)
|
204 |
+
else:
|
205 |
+
prompt_feature = self.ref_style_features
|
206 |
+
|
207 |
+
similarity_each = torch.einsum("nc, nc -> n", image_features, prompt_feature)
|
208 |
+
return similarity_each
|
209 |
+
|
210 |
+
def forward(self, output_imgs, prompt=None):
|
211 |
+
image_features = self.get_image_features(output_imgs)
|
212 |
+
# print(image_features)
|
213 |
+
style_score = self.img_text_similarity(image_features.mean(dim=0, keepdim=True))
|
214 |
+
if prompt is not None:
|
215 |
+
content_score = self.img_text_similarity(image_features, prompt)
|
216 |
+
|
217 |
+
score = self.alpha * style_score + (1 - self.alpha) * content_score
|
218 |
+
return {"score": score, "style_score": style_score, "content_score": content_score}
|
219 |
+
else:
|
220 |
+
return {"style_score": style_score}
|
221 |
+
|
222 |
+
def content_score(self, output_imgs, prompt):
|
223 |
+
self.to(self.device)
|
224 |
+
image_features = self.get_image_features(output_imgs)
|
225 |
+
content_score_details = self.img_text_similarity(image_features, prompt)
|
226 |
+
self.to("cpu")
|
227 |
+
return {"CLIP_content_score": content_score_details.mean().cpu(), "CLIP_content_score_details": content_score_details.cpu()}
|
228 |
+
|
229 |
+
|
230 |
+
class CSD_CLIP(Clip_Basic_Metric):
|
231 |
+
"""backbone + projection head"""
|
232 |
+
def __init__(self, name='vit_large',content_proj_head='default', ckpt_path = "data/weights/CSD-checkpoint.pth", device="cuda",
|
233 |
+
alpha=0.5, **kwargs):
|
234 |
+
super(CSD_CLIP, self).__init__()
|
235 |
+
self.alpha = alpha
|
236 |
+
self.content_proj_head = content_proj_head
|
237 |
+
self.device = device
|
238 |
+
if name == 'vit_large':
|
239 |
+
clipmodel, _ = clip.load("ViT-L/14")
|
240 |
+
self.backbone = clipmodel.visual
|
241 |
+
self.embedding_dim = 1024
|
242 |
+
elif name == 'vit_base':
|
243 |
+
clipmodel, _ = clip.load("ViT-B/16")
|
244 |
+
self.backbone = clipmodel.visual
|
245 |
+
self.embedding_dim = 768
|
246 |
+
self.feat_dim = 512
|
247 |
+
else:
|
248 |
+
raise Exception('This model is not implemented')
|
249 |
+
|
250 |
+
convert_weights_float(self.backbone)
|
251 |
+
self.last_layer_style = copy.deepcopy(self.backbone.proj)
|
252 |
+
if content_proj_head == 'custom':
|
253 |
+
self.last_layer_content = ProjectionHead(self.embedding_dim,self.feat_dim)
|
254 |
+
self.last_layer_content.apply(init_weights)
|
255 |
+
|
256 |
+
else:
|
257 |
+
self.last_layer_content = copy.deepcopy(self.backbone.proj)
|
258 |
+
|
259 |
+
self.backbone.proj = None
|
260 |
+
self.backbone.requires_grad_(False)
|
261 |
+
self.last_layer_style.requires_grad_(False)
|
262 |
+
self.last_layer_content.requires_grad_(False)
|
263 |
+
self.backbone.eval()
|
264 |
+
|
265 |
+
if ckpt_path is not None:
|
266 |
+
self.load_ckpt(ckpt_path)
|
267 |
+
self.to("cpu")
|
268 |
+
|
269 |
+
def load_ckpt(self, ckpt_path):
|
270 |
+
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
271 |
+
state_dict = convert_state_dict(checkpoint['model_state_dict'])
|
272 |
+
msg = self.load_state_dict(state_dict, strict=False)
|
273 |
+
print(f"=> loaded CSD_CLIP checkpoint with msg {msg}")
|
274 |
+
|
275 |
+
@property
|
276 |
+
def dtype(self):
|
277 |
+
return self.backbone.conv1.weight.dtype
|
278 |
+
|
279 |
+
def get_image_features(self, input_data, get_style=True,get_content=False,feature_alpha=None):
|
280 |
+
if isinstance(input_data, torch.Tensor):
|
281 |
+
input_data = self.tensor_preprocess(input_data)
|
282 |
+
elif (isinstance(input_data, str) and os.path.isdir(input_data)) or (isinstance(input_data, list) and (isinstance(input_data[0], Image.Image) or isinstance(input_data[0], np.ndarray) or os.path.isfile(input_data[0]))):
|
283 |
+
input_data = self.load_image_path(input_data)
|
284 |
+
input_data = [self.image_preprocess(image) for image in input_data]
|
285 |
+
input_data = torch.stack(input_data)
|
286 |
+
elif os.path.isfile(input_data):
|
287 |
+
input_data = self.load_image(input_data)
|
288 |
+
input_data = self.image_preprocess(input_data)
|
289 |
+
input_data = input_data.unsqueeze(0)
|
290 |
+
input_data = input_data.to(self.device)
|
291 |
+
style_output = None
|
292 |
+
|
293 |
+
feature = self.backbone(input_data)
|
294 |
+
if get_style:
|
295 |
+
style_output = feature @ self.last_layer_style
|
296 |
+
# style_output = style_output.mean(dim=0)
|
297 |
+
style_output = nn.functional.normalize(style_output, dim=-1, p=2)
|
298 |
+
|
299 |
+
content_output=None
|
300 |
+
if get_content:
|
301 |
+
if feature_alpha is not None:
|
302 |
+
reverse_feature = ReverseLayerF.apply(feature, feature_alpha)
|
303 |
+
else:
|
304 |
+
reverse_feature = feature
|
305 |
+
# if alpha is not None:
|
306 |
+
if self.content_proj_head == 'custom':
|
307 |
+
content_output = self.last_layer_content(reverse_feature)
|
308 |
+
else:
|
309 |
+
content_output = reverse_feature @ self.last_layer_content
|
310 |
+
content_output = nn.functional.normalize(content_output, dim=-1, p=2)
|
311 |
+
|
312 |
+
return feature, content_output, style_output
|
313 |
+
|
314 |
+
|
315 |
+
@torch.no_grad()
|
316 |
+
def define_ref_image_style_prototype(self, ref_image_path: str):
|
317 |
+
self.to(self.device)
|
318 |
+
_, _, self.ref_style_feature = self.get_image_features(ref_image_path)
|
319 |
+
self.to("cpu")
|
320 |
+
# self.ref_style_feature = self.ref_style_feature.mean(dim=0)
|
321 |
+
@torch.no_grad()
|
322 |
+
def forward(self, styled_data):
|
323 |
+
self.to(self.device)
|
324 |
+
# get_content_feature = original_data is not None
|
325 |
+
_, content_output, style_output = self.get_image_features(styled_data, get_content=False)
|
326 |
+
style_similarities = style_output @ self.ref_style_feature.T
|
327 |
+
mean_style_similarities = style_similarities.mean(dim=-1)
|
328 |
+
mean_style_similarity = mean_style_similarities.mean()
|
329 |
+
|
330 |
+
max_style_similarities_v, max_style_similarities_id = style_similarities.max(dim=-1)
|
331 |
+
max_style_similarity = max_style_similarities_v.mean()
|
332 |
+
|
333 |
+
|
334 |
+
self.to("cpu")
|
335 |
+
return {"CSD_similarity_mean": mean_style_similarity, "CSD_similarity_max": max_style_similarity, "CSD_similarity_mean_details": mean_style_similarities,
|
336 |
+
"CSD_similarity_max_v_details": max_style_similarities_v, "CSD_similarity_max_id_details": max_style_similarities_id}
|
337 |
+
|
338 |
+
def get_style_loss(self, styled_data):
|
339 |
+
_, _, style_output = self.get_image_features(styled_data, get_style=True, get_content=False)
|
340 |
+
style_similarity = (style_output @ self.ref_style_feature).mean()
|
341 |
+
loss = 1 - style_similarity
|
342 |
+
return loss.mean()
|
343 |
+
|
344 |
+
class LPIPS_metric(Metric):
|
345 |
+
def __init__(self, type="vgg", device="cuda"):
|
346 |
+
super(LPIPS_metric, self).__init__()
|
347 |
+
self.lpips = lpips.LPIPS(net=type)
|
348 |
+
self.device = device
|
349 |
+
self.image_preprocess = transforms.Compose([
|
350 |
+
transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
|
351 |
+
transforms.CenterCrop(256),
|
352 |
+
transforms.ToTensor(),
|
353 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
354 |
+
])
|
355 |
+
self.to("cpu")
|
356 |
+
|
357 |
+
@torch.no_grad()
|
358 |
+
def forward(self, img1, img2):
|
359 |
+
self.to(self.device)
|
360 |
+
differences = []
|
361 |
+
for i in range(0, len(img1), 50):
|
362 |
+
img1_batch = img1[i:i+50]
|
363 |
+
img2_batch = img2[i:i+50]
|
364 |
+
img1_batch = self.preprocess_image(img1_batch).to(self.device)
|
365 |
+
img2_batch = self.preprocess_image(img2_batch).to(self.device)
|
366 |
+
differences.append(self.lpips(img1_batch, img2_batch).squeeze())
|
367 |
+
differences = torch.cat(differences)
|
368 |
+
difference = differences.mean()
|
369 |
+
# similarity = 1 - difference
|
370 |
+
self.to("cpu")
|
371 |
+
return {"LPIPS_content_difference": difference, "LPIPS_content_difference_details": differences}
|
372 |
+
|
373 |
+
class Vit_metric(Metric):
|
374 |
+
def __init__(self, device="cuda"):
|
375 |
+
super(Vit_metric, self).__init__()
|
376 |
+
self.device = device
|
377 |
+
self.model = ViTModel.from_pretrained('facebook/dino-vitb8').eval()
|
378 |
+
self.image_processor = ViTImageProcessor.from_pretrained('facebook/dino-vitb8')
|
379 |
+
self.to("cpu")
|
380 |
+
def get_image_features(self, images):
|
381 |
+
# if isinstance(image, torch.Tensor):
|
382 |
+
# self.tensor_transform(image)
|
383 |
+
# else:
|
384 |
+
# image_features = self.image_processor(image, return_tensors="pt", padding=True).to(self.device, non_blocking=True)
|
385 |
+
images = self.load_image_path(images)
|
386 |
+
batch_size = 20
|
387 |
+
all_image_features = []
|
388 |
+
for i in range(0, len(images), batch_size):
|
389 |
+
image_batch = images[i:i+batch_size]
|
390 |
+
if isinstance(image_batch, torch.Tensor):
|
391 |
+
image_batch = self.tensor_preprocess(image_batch)
|
392 |
+
data = {"pixel_values": image_batch}
|
393 |
+
image_processed = BatchFeature(data=data, tensor_type="pt")
|
394 |
+
else:
|
395 |
+
image_processed = self.image_processor(image_batch, return_tensors="pt").to(self.device)
|
396 |
+
image_features = self.model(**image_processed).last_hidden_state.flatten(start_dim=1)
|
397 |
+
image_features = F.normalize(image_features, p=2, dim=-1)
|
398 |
+
all_image_features.append(image_features)
|
399 |
+
all_image_features = torch.cat(all_image_features)
|
400 |
+
return all_image_features
|
401 |
+
|
402 |
+
@torch.no_grad()
|
403 |
+
def content_metric(self, img1, img2):
|
404 |
+
self.to(self.device)
|
405 |
+
if not(isinstance(img1, torch.Tensor) and len(img1.shape) == 2):
|
406 |
+
img1 = self.get_image_features(img1)
|
407 |
+
if not(isinstance(img2, torch.Tensor) and len(img2.shape) == 2):
|
408 |
+
img2 = self.get_image_features(img2)
|
409 |
+
similarities = torch.einsum("nc, nc -> n", img1, img2)
|
410 |
+
similarity = similarities.mean()
|
411 |
+
# self.to("cpu")
|
412 |
+
return {"Vit_content_similarity": similarity, "Vit_content_similarity_details": similarities}
|
413 |
+
|
414 |
+
# style
|
415 |
+
@torch.no_grad()
|
416 |
+
def define_ref_image_style_prototype(self, ref_image_path: str):
|
417 |
+
self.to(self.device)
|
418 |
+
self.ref_style_feature = self.get_image_features(ref_image_path)
|
419 |
+
self.to("cpu")
|
420 |
+
@torch.no_grad()
|
421 |
+
def style_metric(self, styled_data):
|
422 |
+
self.to(self.device)
|
423 |
+
if isinstance(styled_data, torch.Tensor) and len(styled_data.shape) == 2:
|
424 |
+
style_output = styled_data
|
425 |
+
else:
|
426 |
+
style_output = self.get_image_features(styled_data)
|
427 |
+
style_similarities = style_output @ self.ref_style_feature.T
|
428 |
+
mean_style_similarities = style_similarities.mean(dim=-1)
|
429 |
+
mean_style_similarity = mean_style_similarities.mean()
|
430 |
+
|
431 |
+
max_style_similarities_v, max_style_similarities_id = style_similarities.max(dim=-1)
|
432 |
+
max_style_similarity = max_style_similarities_v.mean()
|
433 |
+
|
434 |
+
# self.to("cpu")
|
435 |
+
return {"Vit_style_similarity_mean": mean_style_similarity, "Vit_style_similarity_max": max_style_similarity, "Vit_style_similarity_mean_details": mean_style_similarities,
|
436 |
+
"Vit_style_similarity_max_v_details": max_style_similarities_v, "Vit_style_similarity_max_id_details": max_style_similarities_id}
|
437 |
+
@torch.no_grad()
|
438 |
+
def forward(self, styled_data, original_data=None):
|
439 |
+
self.to(self.device)
|
440 |
+
styled_features = self.get_image_features(styled_data)
|
441 |
+
ret ={}
|
442 |
+
if original_data is not None:
|
443 |
+
content_metric = self.content_metric(styled_features, original_data)
|
444 |
+
ret["Vit_content"] = content_metric
|
445 |
+
style_metric = self.style_metric(styled_features)
|
446 |
+
ret["Vit_style"] = style_metric
|
447 |
+
self.to("cpu")
|
448 |
+
return ret
|
449 |
+
|
450 |
+
|
451 |
+
|
452 |
+
class StyleContentMetric(nn.Module):
|
453 |
+
def __init__(self, style_ref_image_folder, device="cuda"):
|
454 |
+
super(StyleContentMetric, self).__init__()
|
455 |
+
self.device = device
|
456 |
+
self.clip_style_metric = CSD_CLIP(device=device)
|
457 |
+
self.ref_image_file = os.listdir(style_ref_image_folder)
|
458 |
+
self.ref_image_file = [i for i in self.ref_image_file if i.endswith(".jpg") or i.endswith(".png")]
|
459 |
+
self.ref_image_file.sort()
|
460 |
+
self.ref_image_file = np.array(self.ref_image_file)
|
461 |
+
ref_image_file_path = [os.path.join(style_ref_image_folder, i) for i in self.ref_image_file]
|
462 |
+
|
463 |
+
self.clip_style_metric.define_ref_image_style_prototype(ref_image_file_path)
|
464 |
+
self.vit_metric = Vit_metric(device=device)
|
465 |
+
self.vit_metric.define_ref_image_style_prototype(ref_image_file_path)
|
466 |
+
self.lpips_metric = LPIPS_metric(device=device)
|
467 |
+
|
468 |
+
self.clip_content_metric = Clip_metric(alpha=0, target_style_prompt=None)
|
469 |
+
|
470 |
+
self.to("cpu")
|
471 |
+
|
472 |
+
def forward(self, styled_data, original_data=None, content_caption=None):
|
473 |
+
ret ={}
|
474 |
+
csd_score = self.clip_style_metric(styled_data)
|
475 |
+
csd_score["max_query"] = self.ref_image_file[csd_score["CSD_similarity_max_id_details"].cpu()].tolist()
|
476 |
+
torch.cuda.empty_cache()
|
477 |
+
ret["Style_CSD"] = csd_score
|
478 |
+
vit_score = self.vit_metric(styled_data, original_data)
|
479 |
+
torch.cuda.empty_cache()
|
480 |
+
vit_style = vit_score["Vit_style"]
|
481 |
+
vit_style["max_query"] = self.ref_image_file[vit_style["Vit_style_similarity_max_id_details"].cpu()].tolist()
|
482 |
+
ret["Style_VIT"] = vit_style
|
483 |
+
|
484 |
+
if original_data is not None:
|
485 |
+
vit_content = vit_score["Vit_content"]
|
486 |
+
ret["Content_VIT"] = vit_content
|
487 |
+
lpips_score = self.lpips_metric(styled_data, original_data)
|
488 |
+
torch.cuda.empty_cache()
|
489 |
+
ret["Content_LPIPS"] = lpips_score
|
490 |
+
|
491 |
+
if content_caption is not None:
|
492 |
+
clip_content = self.clip_content_metric.content_score(styled_data, content_caption)
|
493 |
+
ret["Content_CLIP"] = clip_content
|
494 |
+
torch.cuda.empty_cache()
|
495 |
+
|
496 |
+
for type_key, type_value in ret.items():
|
497 |
+
for key, value in type_value.items():
|
498 |
+
if isinstance(value, torch.Tensor):
|
499 |
+
if value.numel() == 1:
|
500 |
+
ret[type_key][key] = round(value.item(), 4)
|
501 |
+
else:
|
502 |
+
ret[type_key][key] = value.tolist()
|
503 |
+
ret[type_key][key] = [round(v, 4) for v in ret[type_key][key]]
|
504 |
+
|
505 |
+
self.to("cpu")
|
506 |
+
ret["ref_image_file"] = self.ref_image_file.tolist()
|
507 |
+
return ret
|
508 |
+
|
509 |
+
|
510 |
+
if __name__ == "__main__":
|
511 |
+
with torch.no_grad():
|
512 |
+
metric = StyleContentMetric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/Art_styles/camille-pissarro/impressionism/split_5/paintings")
|
513 |
+
score = metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/converted_photo/500",
|
514 |
+
"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/paintings")
|
515 |
+
print(score)
|
516 |
+
|
517 |
+
|
518 |
+
|
519 |
+
lpips = LPIPS_metric()
|
520 |
+
score = lpips("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/paintings",
|
521 |
+
"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/converted_photo/500")
|
522 |
+
|
523 |
+
print("lpips", score)
|
524 |
+
|
525 |
+
|
526 |
+
clip_metric = CSD_CLIP()
|
527 |
+
clip_metric.define_ref_image_style_prototype(
|
528 |
+
"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset1/paintings")
|
529 |
+
|
530 |
+
score = clip_metric(
|
531 |
+
"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/converted_photo/500")
|
532 |
+
print("subset3-subset3_sd14_converted", score)
|
533 |
+
|
534 |
+
score = clip_metric(
|
535 |
+
"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_500")
|
536 |
+
print("subset3-photo", score)
|
537 |
+
|
538 |
+
|
539 |
+
|
540 |
+
score = clip_metric(
|
541 |
+
"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset1/paintings")
|
542 |
+
print("subset3-subset1", score)
|
543 |
+
|
544 |
+
score = clip_metric(
|
545 |
+
"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/andy-warhol/pop_art/subset1/paintings")
|
546 |
+
print("subset3-andy", score)
|
547 |
+
# score = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/paintings", "A painting")
|
548 |
+
|
549 |
+
# print("subset3",score)
|
550 |
+
# score_subset2 = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset2/paintings")
|
551 |
+
# print("subset2",score_subset2)
|
552 |
+
# score_subset3 = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/paintings")
|
553 |
+
# print("subset3",score_subset3)
|
554 |
+
#
|
555 |
+
# score_subset3_converted = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/converted_photo/500")
|
556 |
+
# print("subset3-subset3_sd14_converted" , score_subset3_converted)
|
557 |
+
#
|
558 |
+
# score_subset3_coco_converted = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/coco_converted_photo/500")
|
559 |
+
# print("subset3-subset3_coco_converted" , score_subset3_coco_converted)
|
560 |
+
#
|
561 |
+
# clip_metric = Clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/sketch_500")
|
562 |
+
# score = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_500")
|
563 |
+
# print("photo500_1-sketch" ,score)
|
564 |
+
#
|
565 |
+
# clip_metric = Clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_500")
|
566 |
+
# score = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_500_new")
|
567 |
+
# print("photo500_1-photo500_2" ,score)
|
568 |
+
# from custom_datasets.imagepair import ImageSet
|
569 |
+
# import matplotlib.pyplot as plt
|
570 |
+
# dataset = ImageSet(folder = "/data/vision/torralba/scratch/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/paintings",
|
571 |
+
# caption_path="/data/vision/torralba/scratch/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/captions",
|
572 |
+
# keep_in_mem=False)
|
573 |
+
# for sample in dataset:
|
574 |
+
# score = clip_metric.content_score(sample["image"], sample["caption"][0])
|
575 |
+
# plt.imshow(sample["image"])
|
576 |
+
# plt.title(f"score: {round(score.item(),2)}\n prompt: {sample['caption'][0]}")
|
577 |
+
# plt.show()
|
utils/model_util.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal, Union, Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
5 |
+
from diffusers import (
|
6 |
+
UNet2DConditionModel,
|
7 |
+
SchedulerMixin,
|
8 |
+
StableDiffusionPipeline,
|
9 |
+
StableDiffusionXLPipeline,
|
10 |
+
AutoencoderKL,
|
11 |
+
)
|
12 |
+
from diffusers.schedulers import (
|
13 |
+
DDIMScheduler,
|
14 |
+
DDPMScheduler,
|
15 |
+
LMSDiscreteScheduler,
|
16 |
+
EulerAncestralDiscreteScheduler,
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4"
|
21 |
+
TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1"
|
22 |
+
|
23 |
+
AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a"]
|
24 |
+
|
25 |
+
SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection]
|
26 |
+
|
27 |
+
DIFFUSERS_CACHE_DIR = None # if you want to change the cache dir, change this
|
28 |
+
from diffusers.training_utils import EMAModel
|
29 |
+
import os
|
30 |
+
import sys
|
31 |
+
|
32 |
+
# from utils.modules import get_diffusion_modules
|
33 |
+
def load_diffusers_model(
|
34 |
+
pretrained_model_name_or_path: str,
|
35 |
+
v2: bool = False,
|
36 |
+
clip_skip: Optional[int] = None,
|
37 |
+
weight_dtype: torch.dtype = torch.float32,
|
38 |
+
) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
|
39 |
+
# VAE はいらない
|
40 |
+
|
41 |
+
if v2:
|
42 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
43 |
+
TOKENIZER_V2_MODEL_NAME,
|
44 |
+
subfolder="tokenizer",
|
45 |
+
torch_dtype=weight_dtype,
|
46 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
47 |
+
)
|
48 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
49 |
+
pretrained_model_name_or_path,
|
50 |
+
subfolder="text_encoder",
|
51 |
+
# default is clip skip 2
|
52 |
+
num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23,
|
53 |
+
torch_dtype=weight_dtype,
|
54 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
55 |
+
)
|
56 |
+
else:
|
57 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
58 |
+
TOKENIZER_V1_MODEL_NAME,
|
59 |
+
subfolder="tokenizer",
|
60 |
+
torch_dtype=weight_dtype,
|
61 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
62 |
+
)
|
63 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
64 |
+
pretrained_model_name_or_path,
|
65 |
+
subfolder="text_encoder",
|
66 |
+
num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12,
|
67 |
+
torch_dtype=weight_dtype,
|
68 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
69 |
+
)
|
70 |
+
|
71 |
+
unet = UNet2DConditionModel.from_pretrained(
|
72 |
+
pretrained_model_name_or_path,
|
73 |
+
subfolder="unet",
|
74 |
+
torch_dtype=weight_dtype,
|
75 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
76 |
+
)
|
77 |
+
|
78 |
+
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
79 |
+
|
80 |
+
return tokenizer, text_encoder, unet, vae
|
81 |
+
|
82 |
+
|
83 |
+
def load_checkpoint_model(
|
84 |
+
checkpoint_path: str,
|
85 |
+
v2: bool = False,
|
86 |
+
clip_skip: Optional[int] = None,
|
87 |
+
weight_dtype: torch.dtype = torch.float32,
|
88 |
+
) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
|
89 |
+
pipe = StableDiffusionPipeline.from_ckpt(
|
90 |
+
checkpoint_path,
|
91 |
+
upcast_attention=True if v2 else False,
|
92 |
+
torch_dtype=weight_dtype,
|
93 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
94 |
+
)
|
95 |
+
|
96 |
+
unet = pipe.unet
|
97 |
+
tokenizer = pipe.tokenizer
|
98 |
+
text_encoder = pipe.text_encoder
|
99 |
+
vae = pipe.vae
|
100 |
+
if clip_skip is not None:
|
101 |
+
if v2:
|
102 |
+
text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1)
|
103 |
+
else:
|
104 |
+
text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1)
|
105 |
+
|
106 |
+
del pipe
|
107 |
+
|
108 |
+
return tokenizer, text_encoder, unet, vae
|
109 |
+
|
110 |
+
|
111 |
+
def load_models(
|
112 |
+
pretrained_model_name_or_path: str,
|
113 |
+
ckpt_path: str,
|
114 |
+
scheduler_name: AVAILABLE_SCHEDULERS,
|
115 |
+
v2: bool = False,
|
116 |
+
v_pred: bool = False,
|
117 |
+
weight_dtype: torch.dtype = torch.float32,
|
118 |
+
) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]:
|
119 |
+
if pretrained_model_name_or_path.endswith(
|
120 |
+
".ckpt"
|
121 |
+
) or pretrained_model_name_or_path.endswith(".safetensors"):
|
122 |
+
tokenizer, text_encoder, unet, vae = load_checkpoint_model(
|
123 |
+
pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
|
124 |
+
)
|
125 |
+
else: # diffusers
|
126 |
+
tokenizer, text_encoder, unet, vae = load_diffusers_model(
|
127 |
+
pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
|
128 |
+
)
|
129 |
+
|
130 |
+
# VAE はいらない
|
131 |
+
|
132 |
+
scheduler = create_noise_scheduler(
|
133 |
+
scheduler_name,
|
134 |
+
prediction_type="v_prediction" if v_pred else "epsilon",
|
135 |
+
)
|
136 |
+
# trained unet_ema
|
137 |
+
if ckpt_path not in [None, "None"]:
|
138 |
+
ema_unet = EMAModel.from_pretrained(os.path.join(ckpt_path, "unet_ema"), UNet2DConditionModel)
|
139 |
+
ema_unet.copy_to(unet.parameters())
|
140 |
+
return tokenizer, text_encoder, unet, scheduler, vae
|
141 |
+
|
142 |
+
|
143 |
+
def load_diffusers_model_xl(
|
144 |
+
pretrained_model_name_or_path: str,
|
145 |
+
weight_dtype: torch.dtype = torch.float32,
|
146 |
+
) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
|
147 |
+
# returns tokenizer, tokenizer_2, text_encoder, text_encoder_2, unet
|
148 |
+
|
149 |
+
tokenizers = [
|
150 |
+
CLIPTokenizer.from_pretrained(
|
151 |
+
pretrained_model_name_or_path,
|
152 |
+
subfolder="tokenizer",
|
153 |
+
torch_dtype=weight_dtype,
|
154 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
155 |
+
),
|
156 |
+
CLIPTokenizer.from_pretrained(
|
157 |
+
pretrained_model_name_or_path,
|
158 |
+
subfolder="tokenizer_2",
|
159 |
+
torch_dtype=weight_dtype,
|
160 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
161 |
+
pad_token_id=0, # same as open clip
|
162 |
+
),
|
163 |
+
]
|
164 |
+
|
165 |
+
text_encoders = [
|
166 |
+
CLIPTextModel.from_pretrained(
|
167 |
+
pretrained_model_name_or_path,
|
168 |
+
subfolder="text_encoder",
|
169 |
+
torch_dtype=weight_dtype,
|
170 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
171 |
+
),
|
172 |
+
CLIPTextModelWithProjection.from_pretrained(
|
173 |
+
pretrained_model_name_or_path,
|
174 |
+
subfolder="text_encoder_2",
|
175 |
+
torch_dtype=weight_dtype,
|
176 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
177 |
+
),
|
178 |
+
]
|
179 |
+
|
180 |
+
unet = UNet2DConditionModel.from_pretrained(
|
181 |
+
pretrained_model_name_or_path,
|
182 |
+
subfolder="unet",
|
183 |
+
torch_dtype=weight_dtype,
|
184 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
185 |
+
)
|
186 |
+
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
|
187 |
+
return tokenizers, text_encoders, unet, vae
|
188 |
+
|
189 |
+
|
190 |
+
def load_checkpoint_model_xl(
|
191 |
+
checkpoint_path: str,
|
192 |
+
weight_dtype: torch.dtype = torch.float32,
|
193 |
+
) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
|
194 |
+
pipe = StableDiffusionXLPipeline.from_single_file(
|
195 |
+
checkpoint_path,
|
196 |
+
torch_dtype=weight_dtype,
|
197 |
+
cache_dir=DIFFUSERS_CACHE_DIR,
|
198 |
+
)
|
199 |
+
|
200 |
+
unet = pipe.unet
|
201 |
+
tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
|
202 |
+
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
203 |
+
if len(text_encoders) == 2:
|
204 |
+
text_encoders[1].pad_token_id = 0
|
205 |
+
vae = pipe.vae
|
206 |
+
del pipe
|
207 |
+
|
208 |
+
return tokenizers, text_encoders, unet, vae
|
209 |
+
|
210 |
+
|
211 |
+
def load_models_xl(
|
212 |
+
pretrained_model_name_or_path: str,
|
213 |
+
scheduler_name: AVAILABLE_SCHEDULERS,
|
214 |
+
weight_dtype: torch.dtype = torch.float32,
|
215 |
+
) -> tuple[
|
216 |
+
list[CLIPTokenizer],
|
217 |
+
list[SDXL_TEXT_ENCODER_TYPE],
|
218 |
+
UNet2DConditionModel,
|
219 |
+
SchedulerMixin,
|
220 |
+
]:
|
221 |
+
if pretrained_model_name_or_path.endswith(
|
222 |
+
".ckpt"
|
223 |
+
) or pretrained_model_name_or_path.endswith(".safetensors"):
|
224 |
+
(
|
225 |
+
tokenizers,
|
226 |
+
text_encoders,
|
227 |
+
unet,
|
228 |
+
vae
|
229 |
+
) = load_checkpoint_model_xl(pretrained_model_name_or_path, weight_dtype)
|
230 |
+
else: # diffusers
|
231 |
+
(
|
232 |
+
tokenizers,
|
233 |
+
text_encoders,
|
234 |
+
unet,
|
235 |
+
vae
|
236 |
+
) = load_diffusers_model_xl(pretrained_model_name_or_path, weight_dtype)
|
237 |
+
|
238 |
+
scheduler = create_noise_scheduler(scheduler_name)
|
239 |
+
|
240 |
+
return tokenizers, text_encoders, unet, scheduler, vae
|
241 |
+
|
242 |
+
|
243 |
+
def create_noise_scheduler(
|
244 |
+
scheduler_name: AVAILABLE_SCHEDULERS = "ddpm",
|
245 |
+
prediction_type: Literal["epsilon", "v_prediction"] = "epsilon",
|
246 |
+
) -> SchedulerMixin:
|
247 |
+
|
248 |
+
|
249 |
+
name = scheduler_name.lower().replace(" ", "_")
|
250 |
+
if name == "ddim":
|
251 |
+
# https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddim
|
252 |
+
scheduler = DDIMScheduler(
|
253 |
+
beta_start=0.00085,
|
254 |
+
beta_end=0.012,
|
255 |
+
beta_schedule="scaled_linear",
|
256 |
+
num_train_timesteps=1000,
|
257 |
+
clip_sample=False,
|
258 |
+
prediction_type=prediction_type, # これでいいの?
|
259 |
+
)
|
260 |
+
elif name == "ddpm":
|
261 |
+
# https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddpm
|
262 |
+
scheduler = DDPMScheduler(
|
263 |
+
beta_start=0.00085,
|
264 |
+
beta_end=0.012,
|
265 |
+
beta_schedule="scaled_linear",
|
266 |
+
num_train_timesteps=1000,
|
267 |
+
clip_sample=False,
|
268 |
+
prediction_type=prediction_type,
|
269 |
+
)
|
270 |
+
elif name == "lms":
|
271 |
+
# https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/lms_discrete
|
272 |
+
scheduler = LMSDiscreteScheduler(
|
273 |
+
beta_start=0.00085,
|
274 |
+
beta_end=0.012,
|
275 |
+
beta_schedule="scaled_linear",
|
276 |
+
num_train_timesteps=1000,
|
277 |
+
prediction_type=prediction_type,
|
278 |
+
)
|
279 |
+
elif name == "euler_a":
|
280 |
+
# https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral
|
281 |
+
scheduler = EulerAncestralDiscreteScheduler(
|
282 |
+
beta_start=0.00085,
|
283 |
+
beta_end=0.012,
|
284 |
+
beta_schedule="scaled_linear",
|
285 |
+
num_train_timesteps=1000,
|
286 |
+
prediction_type=prediction_type,
|
287 |
+
)
|
288 |
+
else:
|
289 |
+
raise ValueError(f"Unknown scheduler name: {name}")
|
290 |
+
|
291 |
+
return scheduler
|
utils/prompt_util.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal, Optional, Union, List
|
2 |
+
|
3 |
+
import yaml
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
|
7 |
+
from pydantic import BaseModel, root_validator
|
8 |
+
import torch
|
9 |
+
import copy
|
10 |
+
|
11 |
+
ACTION_TYPES = Literal[
|
12 |
+
"erase",
|
13 |
+
"enhance",
|
14 |
+
]
|
15 |
+
|
16 |
+
|
17 |
+
# XL は二種類必要なので
|
18 |
+
class PromptEmbedsXL:
|
19 |
+
text_embeds: torch.FloatTensor
|
20 |
+
pooled_embeds: torch.FloatTensor
|
21 |
+
|
22 |
+
def __init__(self, *args) -> None:
|
23 |
+
self.text_embeds = args[0]
|
24 |
+
self.pooled_embeds = args[1]
|
25 |
+
|
26 |
+
|
27 |
+
# SDv1.x, SDv2.x は FloatTensor、XL は PromptEmbedsXL
|
28 |
+
PROMPT_EMBEDDING = Union[torch.FloatTensor, PromptEmbedsXL]
|
29 |
+
|
30 |
+
|
31 |
+
class PromptEmbedsCache: # 使いまわしたいので
|
32 |
+
prompts: dict[str, PROMPT_EMBEDDING] = {}
|
33 |
+
|
34 |
+
def __setitem__(self, __name: str, __value: PROMPT_EMBEDDING) -> None:
|
35 |
+
self.prompts[__name] = __value
|
36 |
+
|
37 |
+
def __getitem__(self, __name: str) -> Optional[PROMPT_EMBEDDING]:
|
38 |
+
if __name in self.prompts:
|
39 |
+
return self.prompts[__name]
|
40 |
+
else:
|
41 |
+
return None
|
42 |
+
|
43 |
+
|
44 |
+
class PromptSettings(BaseModel): # yaml のやつ
|
45 |
+
target: str
|
46 |
+
positive: str = None # if None, target will be used
|
47 |
+
unconditional: str = "" # default is ""
|
48 |
+
neutral: str = None # if None, unconditional will be used
|
49 |
+
action: ACTION_TYPES = "erase" # default is "erase"
|
50 |
+
guidance_scale: float = 1.0 # default is 1.0
|
51 |
+
resolution: int = 512 # default is 512
|
52 |
+
dynamic_resolution: bool = False # default is False
|
53 |
+
batch_size: int = 1 # default is 1
|
54 |
+
dynamic_crops: bool = False # default is False. only used when model is XL
|
55 |
+
|
56 |
+
@root_validator(pre=True)
|
57 |
+
def fill_prompts(cls, values):
|
58 |
+
keys = values.keys()
|
59 |
+
if "target" not in keys:
|
60 |
+
raise ValueError("target must be specified")
|
61 |
+
if "positive" not in keys:
|
62 |
+
values["positive"] = values["target"]
|
63 |
+
if "unconditional" not in keys:
|
64 |
+
values["unconditional"] = ""
|
65 |
+
if "neutral" not in keys:
|
66 |
+
values["neutral"] = values["unconditional"]
|
67 |
+
|
68 |
+
return values
|
69 |
+
|
70 |
+
|
71 |
+
class PromptEmbedsPair:
|
72 |
+
target: PROMPT_EMBEDDING # not want to generate the concept
|
73 |
+
positive: PROMPT_EMBEDDING # generate the concept
|
74 |
+
unconditional: PROMPT_EMBEDDING # uncondition (default should be empty)
|
75 |
+
neutral: PROMPT_EMBEDDING # base condition (default should be empty)
|
76 |
+
|
77 |
+
guidance_scale: float
|
78 |
+
resolution: int
|
79 |
+
dynamic_resolution: bool
|
80 |
+
batch_size: int
|
81 |
+
dynamic_crops: bool
|
82 |
+
|
83 |
+
loss_fn: torch.nn.Module
|
84 |
+
action: ACTION_TYPES
|
85 |
+
|
86 |
+
def __init__(
|
87 |
+
self,
|
88 |
+
loss_fn: torch.nn.Module,
|
89 |
+
target: PROMPT_EMBEDDING,
|
90 |
+
positive: PROMPT_EMBEDDING,
|
91 |
+
unconditional: PROMPT_EMBEDDING,
|
92 |
+
neutral: PROMPT_EMBEDDING,
|
93 |
+
settings: PromptSettings,
|
94 |
+
) -> None:
|
95 |
+
self.loss_fn = loss_fn
|
96 |
+
self.target = target
|
97 |
+
self.positive = positive
|
98 |
+
self.unconditional = unconditional
|
99 |
+
self.neutral = neutral
|
100 |
+
|
101 |
+
self.guidance_scale = settings.guidance_scale
|
102 |
+
self.resolution = settings.resolution
|
103 |
+
self.dynamic_resolution = settings.dynamic_resolution
|
104 |
+
self.batch_size = settings.batch_size
|
105 |
+
self.dynamic_crops = settings.dynamic_crops
|
106 |
+
self.action = settings.action
|
107 |
+
|
108 |
+
def _erase(
|
109 |
+
self,
|
110 |
+
target_latents: torch.FloatTensor, # "van gogh"
|
111 |
+
positive_latents: torch.FloatTensor, # "van gogh"
|
112 |
+
unconditional_latents: torch.FloatTensor, # ""
|
113 |
+
neutral_latents: torch.FloatTensor, # ""
|
114 |
+
) -> torch.FloatTensor:
|
115 |
+
"""Target latents are going not to have the positive concept."""
|
116 |
+
return self.loss_fn(
|
117 |
+
target_latents,
|
118 |
+
neutral_latents
|
119 |
+
- self.guidance_scale * (positive_latents - unconditional_latents)
|
120 |
+
)
|
121 |
+
|
122 |
+
|
123 |
+
def _enhance(
|
124 |
+
self,
|
125 |
+
target_latents: torch.FloatTensor, # "van gogh"
|
126 |
+
positive_latents: torch.FloatTensor, # "van gogh"
|
127 |
+
unconditional_latents: torch.FloatTensor, # ""
|
128 |
+
neutral_latents: torch.FloatTensor, # ""
|
129 |
+
):
|
130 |
+
"""Target latents are going to have the positive concept."""
|
131 |
+
return self.loss_fn(
|
132 |
+
target_latents,
|
133 |
+
neutral_latents
|
134 |
+
+ self.guidance_scale * (positive_latents - unconditional_latents)
|
135 |
+
)
|
136 |
+
|
137 |
+
def loss(
|
138 |
+
self,
|
139 |
+
**kwargs,
|
140 |
+
):
|
141 |
+
if self.action == "erase":
|
142 |
+
return self._erase(**kwargs)
|
143 |
+
|
144 |
+
elif self.action == "enhance":
|
145 |
+
return self._enhance(**kwargs)
|
146 |
+
|
147 |
+
else:
|
148 |
+
raise ValueError("action must be erase or enhance")
|
149 |
+
|
150 |
+
|
151 |
+
def load_prompts_from_yaml(path, attributes = []):
|
152 |
+
with open(path, "r") as f:
|
153 |
+
prompts = yaml.safe_load(f)
|
154 |
+
print(prompts)
|
155 |
+
if len(prompts) == 0:
|
156 |
+
raise ValueError("prompts file is empty")
|
157 |
+
if len(attributes)!=0:
|
158 |
+
newprompts = []
|
159 |
+
for i in range(len(prompts)):
|
160 |
+
for att in attributes:
|
161 |
+
copy_ = copy.deepcopy(prompts[i])
|
162 |
+
copy_['target'] = att + ' ' + copy_['target']
|
163 |
+
copy_['positive'] = att + ' ' + copy_['positive']
|
164 |
+
copy_['neutral'] = att + ' ' + copy_['neutral']
|
165 |
+
copy_['unconditional'] = att + ' ' + copy_['unconditional']
|
166 |
+
newprompts.append(copy_)
|
167 |
+
else:
|
168 |
+
newprompts = copy.deepcopy(prompts)
|
169 |
+
|
170 |
+
print(newprompts)
|
171 |
+
print(len(prompts), len(newprompts))
|
172 |
+
prompt_settings = [PromptSettings(**prompt) for prompt in newprompts]
|
173 |
+
|
174 |
+
return prompt_settings
|
utils/train_util.py
ADDED
@@ -0,0 +1,526 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from transformers import CLIPTextModel, CLIPTokenizer, BertModel, BertTokenizer
|
6 |
+
from diffusers import UNet2DConditionModel, SchedulerMixin
|
7 |
+
from diffusers.image_processor import VaeImageProcessor
|
8 |
+
import sys
|
9 |
+
import os
|
10 |
+
# sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
|
11 |
+
# from imagesliders.model_util import SDXL_TEXT_ENCODER_TYPE
|
12 |
+
from diffusers.utils.torch_utils import randn_tensor
|
13 |
+
|
14 |
+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
15 |
+
|
16 |
+
SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection]
|
17 |
+
|
18 |
+
from tqdm import tqdm
|
19 |
+
|
20 |
+
UNET_IN_CHANNELS = 4 # Stable Diffusion in_channels
|
21 |
+
VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8
|
22 |
+
|
23 |
+
UNET_ATTENTION_TIME_EMBED_DIM = 256 # XL
|
24 |
+
TEXT_ENCODER_2_PROJECTION_DIM = 1280
|
25 |
+
UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM = 2816
|
26 |
+
|
27 |
+
|
28 |
+
def get_random_noise(
|
29 |
+
batch_size: int, height: int, width: int, generator: torch.Generator = None
|
30 |
+
) -> torch.Tensor:
|
31 |
+
return torch.randn(
|
32 |
+
(
|
33 |
+
batch_size,
|
34 |
+
UNET_IN_CHANNELS,
|
35 |
+
height // VAE_SCALE_FACTOR,
|
36 |
+
width // VAE_SCALE_FACTOR,
|
37 |
+
),
|
38 |
+
generator=generator,
|
39 |
+
device="cpu",
|
40 |
+
)
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
def apply_noise_offset(latents: torch.FloatTensor, noise_offset: float):
|
45 |
+
latents = latents + noise_offset * torch.randn(
|
46 |
+
(latents.shape[0], latents.shape[1], 1, 1), device=latents.device
|
47 |
+
)
|
48 |
+
return latents
|
49 |
+
|
50 |
+
|
51 |
+
def get_initial_latents(
|
52 |
+
scheduler: SchedulerMixin,
|
53 |
+
n_imgs: int,
|
54 |
+
height: int,
|
55 |
+
width: int,
|
56 |
+
n_prompts: int,
|
57 |
+
generator=None,
|
58 |
+
) -> torch.Tensor:
|
59 |
+
noise = get_random_noise(n_imgs, height, width, generator=generator).repeat(
|
60 |
+
n_prompts, 1, 1, 1
|
61 |
+
)
|
62 |
+
|
63 |
+
latents = noise * scheduler.init_noise_sigma
|
64 |
+
|
65 |
+
return latents
|
66 |
+
|
67 |
+
|
68 |
+
def text_tokenize(
|
69 |
+
tokenizer, # 普通ならひとつ、XLならふたつ!
|
70 |
+
prompts,
|
71 |
+
):
|
72 |
+
return tokenizer(
|
73 |
+
prompts,
|
74 |
+
padding="max_length",
|
75 |
+
max_length=tokenizer.model_max_length,
|
76 |
+
truncation=True,
|
77 |
+
return_tensors="pt",
|
78 |
+
)
|
79 |
+
|
80 |
+
|
81 |
+
def text_encode(text_encoder , tokens):
|
82 |
+
tokens = tokens.to(text_encoder.device)
|
83 |
+
if isinstance(text_encoder, BertModel):
|
84 |
+
embed = text_encoder(**tokens, return_dict=False)[0]
|
85 |
+
elif isinstance(text_encoder, CLIPTextModel):
|
86 |
+
# embed = text_encoder(**tokens, return_dict=False)[0]
|
87 |
+
embed = text_encoder(tokens.input_ids, return_dict=False)[0]
|
88 |
+
else:
|
89 |
+
raise ValueError("text_encoder must be BertModel or CLIPTextModel")
|
90 |
+
return embed
|
91 |
+
|
92 |
+
def encode_prompts(
|
93 |
+
tokenizer,
|
94 |
+
text_encoder,
|
95 |
+
prompts: list[str],
|
96 |
+
):
|
97 |
+
# print(f"prompts: {prompts}")
|
98 |
+
text_tokens = text_tokenize(tokenizer, prompts)
|
99 |
+
# print(f"text_tokens: {text_tokens}")
|
100 |
+
text_embeddings = text_encode(text_encoder, text_tokens)
|
101 |
+
# print(f"text_embeddings: {text_embeddings}")
|
102 |
+
|
103 |
+
|
104 |
+
return text_embeddings
|
105 |
+
|
106 |
+
def prompt_replace(original, key="{prompt}", prompt=""):
|
107 |
+
if key not in original:
|
108 |
+
return original
|
109 |
+
|
110 |
+
if isinstance(prompt, list):
|
111 |
+
ret =[]
|
112 |
+
for p in prompt:
|
113 |
+
p = p.replace(".", "")
|
114 |
+
r = original.replace(key, p)
|
115 |
+
r = r.capitalize()
|
116 |
+
ret.append(r)
|
117 |
+
else:
|
118 |
+
prompt = prompt.replace(".", "")
|
119 |
+
ret = original.replace(key, prompt)
|
120 |
+
ret = ret.capitalize()
|
121 |
+
return ret
|
122 |
+
|
123 |
+
|
124 |
+
|
125 |
+
def text_encode_xl(
|
126 |
+
text_encoder: SDXL_TEXT_ENCODER_TYPE,
|
127 |
+
tokens: torch.FloatTensor,
|
128 |
+
num_images_per_prompt: int = 1,
|
129 |
+
):
|
130 |
+
prompt_embeds = text_encoder(
|
131 |
+
tokens.to(text_encoder.device), output_hidden_states=True
|
132 |
+
)
|
133 |
+
pooled_prompt_embeds = prompt_embeds[0]
|
134 |
+
prompt_embeds = prompt_embeds.hidden_states[-2] # always penultimate layer
|
135 |
+
|
136 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
137 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
138 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
139 |
+
|
140 |
+
return prompt_embeds, pooled_prompt_embeds
|
141 |
+
|
142 |
+
|
143 |
+
def encode_prompts_xl(
|
144 |
+
tokenizers: list[CLIPTokenizer],
|
145 |
+
text_encoders: list[SDXL_TEXT_ENCODER_TYPE],
|
146 |
+
prompts: list[str],
|
147 |
+
num_images_per_prompt: int = 1,
|
148 |
+
) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
149 |
+
# text_encoder and text_encoder_2's penuultimate layer's output
|
150 |
+
text_embeds_list = []
|
151 |
+
pooled_text_embeds = None # always text_encoder_2's pool
|
152 |
+
|
153 |
+
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
154 |
+
text_tokens_input_ids = text_tokenize(tokenizer, prompts)
|
155 |
+
text_embeds, pooled_text_embeds = text_encode_xl(
|
156 |
+
text_encoder, text_tokens_input_ids, num_images_per_prompt
|
157 |
+
)
|
158 |
+
|
159 |
+
text_embeds_list.append(text_embeds)
|
160 |
+
|
161 |
+
bs_embed = pooled_text_embeds.shape[0]
|
162 |
+
pooled_text_embeds = pooled_text_embeds.repeat(1, num_images_per_prompt).view(
|
163 |
+
bs_embed * num_images_per_prompt, -1
|
164 |
+
)
|
165 |
+
|
166 |
+
return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds
|
167 |
+
|
168 |
+
|
169 |
+
def concat_embeddings(
|
170 |
+
unconditional: torch.FloatTensor,
|
171 |
+
conditional: torch.FloatTensor,
|
172 |
+
n_imgs: int,
|
173 |
+
):
|
174 |
+
if conditional.shape[0] == n_imgs and unconditional.shape[0] == 1:
|
175 |
+
return torch.cat([unconditional.repeat(n_imgs, 1, 1), conditional], dim=0)
|
176 |
+
return torch.cat([unconditional, conditional]).repeat_interleave(n_imgs, dim=0)
|
177 |
+
|
178 |
+
|
179 |
+
def predict_noise(
|
180 |
+
unet: UNet2DConditionModel,
|
181 |
+
scheduler: SchedulerMixin,
|
182 |
+
timestep: int,
|
183 |
+
latents: torch.FloatTensor,
|
184 |
+
text_embeddings: torch.FloatTensor, # uncond な text embed と cond な text embed を結合したもの
|
185 |
+
guidance_scale=7.5,
|
186 |
+
) -> torch.FloatTensor:
|
187 |
+
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
188 |
+
latent_model_input = torch.cat([latents] * 2)
|
189 |
+
|
190 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
|
191 |
+
# batch_size = latents.shape[0]
|
192 |
+
# text_embeddings = text_embeddings.repeat_interleave(batch_size, dim=0)
|
193 |
+
# predict the noise residual
|
194 |
+
noise_pred = unet(
|
195 |
+
latent_model_input,
|
196 |
+
timestep,
|
197 |
+
encoder_hidden_states=text_embeddings,
|
198 |
+
).sample
|
199 |
+
|
200 |
+
# perform guidance
|
201 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
202 |
+
guided_target = noise_pred_uncond + guidance_scale * (
|
203 |
+
noise_pred_text - noise_pred_uncond
|
204 |
+
)
|
205 |
+
|
206 |
+
return guided_target
|
207 |
+
|
208 |
+
|
209 |
+
|
210 |
+
@torch.no_grad()
|
211 |
+
def diffusion(
|
212 |
+
unet: UNet2DConditionModel,
|
213 |
+
scheduler: SchedulerMixin,
|
214 |
+
latents: torch.FloatTensor,
|
215 |
+
text_embeddings: torch.FloatTensor,
|
216 |
+
total_timesteps: int = 1000,
|
217 |
+
start_timesteps=0,
|
218 |
+
**kwargs,
|
219 |
+
):
|
220 |
+
# latents_steps = []
|
221 |
+
|
222 |
+
for timestep in scheduler.timesteps[start_timesteps:total_timesteps]:
|
223 |
+
noise_pred = predict_noise(
|
224 |
+
unet, scheduler, timestep, latents, text_embeddings, **kwargs
|
225 |
+
)
|
226 |
+
|
227 |
+
# compute the previous noisy sample x_t -> x_t-1
|
228 |
+
latents = scheduler.step(noise_pred, timestep, latents).prev_sample
|
229 |
+
|
230 |
+
# return latents_steps
|
231 |
+
return latents
|
232 |
+
|
233 |
+
@torch.no_grad()
|
234 |
+
def get_noisy_image(
|
235 |
+
img,
|
236 |
+
vae,
|
237 |
+
generator,
|
238 |
+
unet: UNet2DConditionModel,
|
239 |
+
scheduler: SchedulerMixin,
|
240 |
+
total_timesteps: int = 1000,
|
241 |
+
start_timesteps=0,
|
242 |
+
|
243 |
+
**kwargs,
|
244 |
+
):
|
245 |
+
# latents_steps = []
|
246 |
+
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
|
247 |
+
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
|
248 |
+
|
249 |
+
image = img
|
250 |
+
# im_orig = image
|
251 |
+
device = vae.device
|
252 |
+
image = image_processor.preprocess(image).to(device)
|
253 |
+
|
254 |
+
init_latents = vae.encode(image).latent_dist.sample(None)
|
255 |
+
init_latents = vae.config.scaling_factor * init_latents
|
256 |
+
|
257 |
+
init_latents = torch.cat([init_latents], dim=0)
|
258 |
+
|
259 |
+
shape = init_latents.shape
|
260 |
+
|
261 |
+
noise = randn_tensor(shape, generator=generator, device=device)
|
262 |
+
|
263 |
+
time_ = total_timesteps
|
264 |
+
timestep = scheduler.timesteps[time_:time_+1]
|
265 |
+
# get latents
|
266 |
+
noised_latents = scheduler.add_noise(init_latents, noise, timestep)
|
267 |
+
|
268 |
+
return noised_latents, noise, init_latents
|
269 |
+
|
270 |
+
def subtract_noise(
|
271 |
+
latent: torch.FloatTensor,
|
272 |
+
noise: torch.FloatTensor,
|
273 |
+
timesteps: torch.IntTensor,
|
274 |
+
scheduler: SchedulerMixin,
|
275 |
+
) -> torch.FloatTensor:
|
276 |
+
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
277 |
+
# Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
|
278 |
+
# for the subsequent add_noise calls
|
279 |
+
scheduler.alphas_cumprod = scheduler.alphas_cumprod.to(device=latent.device)
|
280 |
+
alphas_cumprod = scheduler.alphas_cumprod.to(dtype=latent.dtype)
|
281 |
+
timesteps = timesteps.to(latent.device)
|
282 |
+
|
283 |
+
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
284 |
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
285 |
+
while len(sqrt_alpha_prod.shape) < len(latent.shape):
|
286 |
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
287 |
+
|
288 |
+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
289 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
290 |
+
while len(sqrt_one_minus_alpha_prod.shape) < len(latent.shape):
|
291 |
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
292 |
+
|
293 |
+
denoised_latent = (latent - sqrt_one_minus_alpha_prod * noise) / sqrt_alpha_prod
|
294 |
+
return denoised_latent
|
295 |
+
def get_denoised_image(
|
296 |
+
latents: torch.FloatTensor,
|
297 |
+
noise_pred: torch.FloatTensor,
|
298 |
+
timestep: int,
|
299 |
+
# total_timesteps: int,
|
300 |
+
scheduler: SchedulerMixin,
|
301 |
+
vae: VaeImageProcessor,
|
302 |
+
):
|
303 |
+
denoised_latents = subtract_noise(latents, noise_pred, timestep, scheduler)
|
304 |
+
denoised_latents = denoised_latents / vae.config.scaling_factor # 0.18215
|
305 |
+
denoised_img = vae.decode(denoised_latents).sample
|
306 |
+
# denoised_img = denoised_img.clamp(-1,1)
|
307 |
+
return denoised_img
|
308 |
+
|
309 |
+
|
310 |
+
def rescale_noise_cfg(
|
311 |
+
noise_cfg: torch.FloatTensor, noise_pred_text, guidance_rescale=0.0
|
312 |
+
):
|
313 |
+
|
314 |
+
std_text = noise_pred_text.std(
|
315 |
+
dim=list(range(1, noise_pred_text.ndim)), keepdim=True
|
316 |
+
)
|
317 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
318 |
+
# rescale the results from guidance (fixes overexposure)
|
319 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
320 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
321 |
+
noise_cfg = (
|
322 |
+
guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
323 |
+
)
|
324 |
+
|
325 |
+
return noise_cfg
|
326 |
+
|
327 |
+
|
328 |
+
def predict_noise_xl(
|
329 |
+
unet: UNet2DConditionModel,
|
330 |
+
scheduler: SchedulerMixin,
|
331 |
+
timestep: int,
|
332 |
+
latents: torch.FloatTensor,
|
333 |
+
text_embeddings: torch.FloatTensor, # uncond な text embed と cond な text embed を結合したもの
|
334 |
+
add_text_embeddings: torch.FloatTensor, # pooled なやつ
|
335 |
+
add_time_ids: torch.FloatTensor,
|
336 |
+
guidance_scale=7.5,
|
337 |
+
guidance_rescale=0.7,
|
338 |
+
) -> torch.FloatTensor:
|
339 |
+
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
340 |
+
latent_model_input = torch.cat([latents] * 2)
|
341 |
+
|
342 |
+
latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
|
343 |
+
|
344 |
+
added_cond_kwargs = {
|
345 |
+
"text_embeds": add_text_embeddings,
|
346 |
+
"time_ids": add_time_ids,
|
347 |
+
}
|
348 |
+
|
349 |
+
# predict the noise residual
|
350 |
+
noise_pred = unet(
|
351 |
+
latent_model_input,
|
352 |
+
timestep,
|
353 |
+
encoder_hidden_states=text_embeddings,
|
354 |
+
added_cond_kwargs=added_cond_kwargs,
|
355 |
+
).sample
|
356 |
+
|
357 |
+
# perform guidance
|
358 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
359 |
+
guided_target = noise_pred_uncond + guidance_scale * (
|
360 |
+
noise_pred_text - noise_pred_uncond
|
361 |
+
)
|
362 |
+
|
363 |
+
noise_pred = rescale_noise_cfg(
|
364 |
+
noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
|
365 |
+
)
|
366 |
+
|
367 |
+
return guided_target
|
368 |
+
|
369 |
+
|
370 |
+
@torch.no_grad()
|
371 |
+
def diffusion_xl(
|
372 |
+
unet: UNet2DConditionModel,
|
373 |
+
scheduler: SchedulerMixin,
|
374 |
+
latents: torch.FloatTensor,
|
375 |
+
text_embeddings: tuple[torch.FloatTensor, torch.FloatTensor],
|
376 |
+
add_text_embeddings: torch.FloatTensor,
|
377 |
+
add_time_ids: torch.FloatTensor,
|
378 |
+
guidance_scale: float = 1.0,
|
379 |
+
total_timesteps: int = 1000,
|
380 |
+
start_timesteps=0,
|
381 |
+
):
|
382 |
+
# latents_steps = []
|
383 |
+
|
384 |
+
for timestep in tqdm(scheduler.timesteps[start_timesteps:total_timesteps]):
|
385 |
+
noise_pred = predict_noise_xl(
|
386 |
+
unet,
|
387 |
+
scheduler,
|
388 |
+
timestep,
|
389 |
+
latents,
|
390 |
+
text_embeddings,
|
391 |
+
add_text_embeddings,
|
392 |
+
add_time_ids,
|
393 |
+
guidance_scale=guidance_scale,
|
394 |
+
guidance_rescale=0.7,
|
395 |
+
)
|
396 |
+
|
397 |
+
# compute the previous noisy sample x_t -> x_t-1
|
398 |
+
latents = scheduler.step(noise_pred, timestep, latents).prev_sample
|
399 |
+
|
400 |
+
# return latents_steps
|
401 |
+
return latents
|
402 |
+
|
403 |
+
|
404 |
+
# for XL
|
405 |
+
def get_add_time_ids(
|
406 |
+
height: int,
|
407 |
+
width: int,
|
408 |
+
dynamic_crops: bool = False,
|
409 |
+
dtype: torch.dtype = torch.float32,
|
410 |
+
):
|
411 |
+
if dynamic_crops:
|
412 |
+
# random float scale between 1 and 3
|
413 |
+
random_scale = torch.rand(1).item() * 2 + 1
|
414 |
+
original_size = (int(height * random_scale), int(width * random_scale))
|
415 |
+
# random position
|
416 |
+
crops_coords_top_left = (
|
417 |
+
torch.randint(0, original_size[0] - height, (1,)).item(),
|
418 |
+
torch.randint(0, original_size[1] - width, (1,)).item(),
|
419 |
+
)
|
420 |
+
target_size = (height, width)
|
421 |
+
else:
|
422 |
+
original_size = (height, width)
|
423 |
+
crops_coords_top_left = (0, 0)
|
424 |
+
target_size = (height, width)
|
425 |
+
|
426 |
+
# this is expected as 6
|
427 |
+
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
428 |
+
|
429 |
+
# this is expected as 2816
|
430 |
+
passed_add_embed_dim = (
|
431 |
+
UNET_ATTENTION_TIME_EMBED_DIM * len(add_time_ids) # 256 * 6
|
432 |
+
+ TEXT_ENCODER_2_PROJECTION_DIM # + 1280
|
433 |
+
)
|
434 |
+
if passed_add_embed_dim != UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM:
|
435 |
+
raise ValueError(
|
436 |
+
f"Model expects an added time embedding vector of length {UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
437 |
+
)
|
438 |
+
|
439 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
440 |
+
return add_time_ids
|
441 |
+
|
442 |
+
|
443 |
+
def get_optimizer(name: str):
|
444 |
+
name = name.lower()
|
445 |
+
|
446 |
+
if name.startswith("dadapt"):
|
447 |
+
import dadaptation
|
448 |
+
|
449 |
+
if name == "dadaptadam":
|
450 |
+
return dadaptation.DAdaptAdam
|
451 |
+
elif name == "dadaptlion":
|
452 |
+
return dadaptation.DAdaptLion
|
453 |
+
else:
|
454 |
+
raise ValueError("DAdapt optimizer must be dadaptadam or dadaptlion")
|
455 |
+
|
456 |
+
elif name.endswith("8bit"):
|
457 |
+
import bitsandbytes as bnb
|
458 |
+
|
459 |
+
if name == "adam8bit":
|
460 |
+
return bnb.optim.Adam8bit
|
461 |
+
elif name == "lion8bit":
|
462 |
+
return bnb.optim.Lion8bit
|
463 |
+
else:
|
464 |
+
raise ValueError("8bit optimizer must be adam8bit or lion8bit")
|
465 |
+
|
466 |
+
else:
|
467 |
+
if name == "adam":
|
468 |
+
return torch.optim.Adam
|
469 |
+
elif name == "adamw":
|
470 |
+
return torch.optim.AdamW
|
471 |
+
elif name == "lion":
|
472 |
+
from lion_pytorch import Lion
|
473 |
+
|
474 |
+
return Lion
|
475 |
+
elif name == "prodigy":
|
476 |
+
import prodigyopt
|
477 |
+
|
478 |
+
return prodigyopt.Prodigy
|
479 |
+
else:
|
480 |
+
raise ValueError("Optimizer must be adam, adamw, lion or Prodigy")
|
481 |
+
|
482 |
+
|
483 |
+
def get_lr_scheduler(
|
484 |
+
name: Optional[str],
|
485 |
+
optimizer: torch.optim.Optimizer,
|
486 |
+
max_iterations: Optional[int],
|
487 |
+
lr_min: Optional[float],
|
488 |
+
**kwargs,
|
489 |
+
):
|
490 |
+
if name == "cosine":
|
491 |
+
return torch.optim.lr_scheduler.CosineAnnealingLR(
|
492 |
+
optimizer, T_max=max_iterations, eta_min=lr_min, **kwargs
|
493 |
+
)
|
494 |
+
elif name == "cosine_with_restarts":
|
495 |
+
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
496 |
+
optimizer, T_0=max_iterations // 10, T_mult=2, eta_min=lr_min, **kwargs
|
497 |
+
)
|
498 |
+
elif name == "step":
|
499 |
+
return torch.optim.lr_scheduler.StepLR(
|
500 |
+
optimizer, step_size=max_iterations // 100, gamma=0.999, **kwargs
|
501 |
+
)
|
502 |
+
elif name == "constant":
|
503 |
+
return torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1, **kwargs)
|
504 |
+
elif name == "linear":
|
505 |
+
return torch.optim.lr_scheduler.LinearLR(
|
506 |
+
optimizer, factor=0.5, total_iters=max_iterations // 100, **kwargs
|
507 |
+
)
|
508 |
+
else:
|
509 |
+
raise ValueError(
|
510 |
+
"Scheduler must be cosine, cosine_with_restarts, step, linear or constant"
|
511 |
+
)
|
512 |
+
|
513 |
+
|
514 |
+
def get_random_resolution_in_bucket(bucket_resolution: int = 512) -> tuple[int, int]:
|
515 |
+
max_resolution = bucket_resolution
|
516 |
+
min_resolution = bucket_resolution // 2
|
517 |
+
|
518 |
+
step = 64
|
519 |
+
|
520 |
+
min_step = min_resolution // step
|
521 |
+
max_step = max_resolution // step
|
522 |
+
|
523 |
+
height = torch.randint(min_step, max_step, (1,)).item() * step
|
524 |
+
width = torch.randint(min_step, max_step, (1,)).item() * step
|
525 |
+
|
526 |
+
return height, width
|