Spaces:
Runtime error
Runtime error
rebuild+depth
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +0 -0
- .gitignore +128 -0
- LICENSE +0 -0
- README.md +0 -0
- app.py +4 -2
- configs/stable-diffusion/app.yaml +0 -0
- configs/stable-diffusion/test_keypose.yaml +0 -87
- configs/stable-diffusion/test_mask.yaml +0 -87
- configs/stable-diffusion/test_mask_sketch.yaml +0 -87
- configs/stable-diffusion/test_sketch.yaml +0 -87
- configs/stable-diffusion/test_sketch_edit.yaml +0 -87
- configs/stable-diffusion/train_keypose.yaml +0 -87
- configs/stable-diffusion/train_mask.yaml +0 -87
- configs/stable-diffusion/train_sketch.yaml +0 -87
- dataset_coco.py +0 -138
- demo/demos.py +26 -1
- demo/model.py +69 -6
- dist_util.py +0 -91
- environment.yaml +0 -0
- examples/edit_cat/edge.png +0 -0
- examples/edit_cat/edge_2.png +0 -0
- examples/edit_cat/im.png +0 -0
- examples/edit_cat/mask.png +0 -0
- examples/keypose/iron.png +0 -0
- examples/seg/dinner.png +0 -0
- examples/seg/motor.png +0 -0
- examples/seg_sketch/edge.png +0 -0
- examples/seg_sketch/mask.png +0 -0
- examples/sketch/car.png +0 -0
- examples/sketch/girl.jpeg +0 -0
- examples/sketch/human.png +0 -0
- examples/sketch/scenery.jpg +0 -0
- examples/sketch/scenery2.jpg +0 -0
- gradio_keypose.py +0 -254
- gradio_sketch.py +0 -147
- ldm/data/__init__.py +0 -0
- ldm/data/base.py +0 -0
- ldm/data/imagenet.py +0 -0
- ldm/data/lsun.py +0 -0
- ldm/lr_scheduler.py +0 -0
- ldm/models/autoencoder.py +0 -0
- ldm/models/diffusion/__init__.py +0 -0
- ldm/models/diffusion/classifier.py +0 -0
- ldm/models/diffusion/ddim.py +0 -0
- ldm/models/diffusion/ddpm.py +0 -0
- ldm/models/diffusion/dpm_solver/__init__.py +0 -0
- ldm/models/diffusion/dpm_solver/dpm_solver.py +0 -0
- ldm/models/diffusion/dpm_solver/sampler.py +0 -0
- ldm/models/diffusion/plms.py +0 -0
- ldm/modules/attention.py +0 -0
.gitattributes
CHANGED
File without changes
|
.gitignore
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ignored folders
|
2 |
+
models
|
3 |
+
|
4 |
+
# ignored folders
|
5 |
+
tmp/*
|
6 |
+
|
7 |
+
*.DS_Store
|
8 |
+
.idea
|
9 |
+
|
10 |
+
# ignored files
|
11 |
+
version.py
|
12 |
+
|
13 |
+
# ignored files with suffix
|
14 |
+
# *.html
|
15 |
+
# *.png
|
16 |
+
# *.jpeg
|
17 |
+
# *.jpg
|
18 |
+
# *.gif
|
19 |
+
# *.pth
|
20 |
+
# *.zip
|
21 |
+
|
22 |
+
# template
|
23 |
+
|
24 |
+
# Byte-compiled / optimized / DLL files
|
25 |
+
__pycache__/
|
26 |
+
*.pyc
|
27 |
+
*.py[cod]
|
28 |
+
*$py.class
|
29 |
+
|
30 |
+
# C extensions
|
31 |
+
*.so
|
32 |
+
|
33 |
+
# Distribution / packaging
|
34 |
+
.Python
|
35 |
+
build/
|
36 |
+
develop-eggs/
|
37 |
+
dist/
|
38 |
+
downloads/
|
39 |
+
eggs/
|
40 |
+
.eggs/
|
41 |
+
lib/
|
42 |
+
lib64/
|
43 |
+
parts/
|
44 |
+
sdist/
|
45 |
+
var/
|
46 |
+
wheels/
|
47 |
+
*.egg-info/
|
48 |
+
.installed.cfg
|
49 |
+
*.egg
|
50 |
+
MANIFEST
|
51 |
+
|
52 |
+
# PyInstaller
|
53 |
+
# Usually these files are written by a python script from a template
|
54 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
55 |
+
*.manifest
|
56 |
+
*.spec
|
57 |
+
|
58 |
+
# Installer logs
|
59 |
+
pip-log.txt
|
60 |
+
pip-delete-this-directory.txt
|
61 |
+
|
62 |
+
# Unit test / coverage reports
|
63 |
+
htmlcov/
|
64 |
+
.tox/
|
65 |
+
.coverage
|
66 |
+
.coverage.*
|
67 |
+
.cache
|
68 |
+
nosetests.xml
|
69 |
+
coverage.xml
|
70 |
+
*.cover
|
71 |
+
.hypothesis/
|
72 |
+
.pytest_cache/
|
73 |
+
|
74 |
+
# Translations
|
75 |
+
*.mo
|
76 |
+
*.pot
|
77 |
+
|
78 |
+
# Django stuff:
|
79 |
+
*.log
|
80 |
+
local_settings.py
|
81 |
+
db.sqlite3
|
82 |
+
|
83 |
+
# Flask stuff:
|
84 |
+
instance/
|
85 |
+
.webassets-cache
|
86 |
+
|
87 |
+
# Scrapy stuff:
|
88 |
+
.scrapy
|
89 |
+
|
90 |
+
# Sphinx documentation
|
91 |
+
docs/_build/
|
92 |
+
|
93 |
+
# PyBuilder
|
94 |
+
target/
|
95 |
+
|
96 |
+
# Jupyter Notebook
|
97 |
+
.ipynb_checkpoints
|
98 |
+
|
99 |
+
# pyenv
|
100 |
+
.python-version
|
101 |
+
|
102 |
+
# celery beat schedule file
|
103 |
+
celerybeat-schedule
|
104 |
+
|
105 |
+
# SageMath parsed files
|
106 |
+
*.sage.py
|
107 |
+
|
108 |
+
# Environments
|
109 |
+
.env
|
110 |
+
.venv
|
111 |
+
env/
|
112 |
+
venv/
|
113 |
+
ENV/
|
114 |
+
env.bak/
|
115 |
+
venv.bak/
|
116 |
+
|
117 |
+
# Spyder project settings
|
118 |
+
.spyderproject
|
119 |
+
.spyproject
|
120 |
+
|
121 |
+
# Rope project settings
|
122 |
+
.ropeproject
|
123 |
+
|
124 |
+
# mkdocs documentation
|
125 |
+
/site
|
126 |
+
|
127 |
+
# mypy
|
128 |
+
.mypy_cache/
|
LICENSE
CHANGED
File without changes
|
README.md
CHANGED
File without changes
|
app.py
CHANGED
@@ -8,14 +8,14 @@ os.system('mim install mmcv-full==1.7.0')
|
|
8 |
|
9 |
from demo.model import Model_all
|
10 |
import gradio as gr
|
11 |
-
from demo.demos import create_demo_keypose, create_demo_sketch, create_demo_draw, create_demo_seg
|
12 |
import torch
|
13 |
import subprocess
|
14 |
import shlex
|
15 |
from huggingface_hub import hf_hub_url
|
16 |
|
17 |
urls = {
|
18 |
-
'TencentARC/T2I-Adapter':['models/t2iadapter_keypose_sd14v1.pth', 'models/t2iadapter_seg_sd14v1.pth', 'models/t2iadapter_sketch_sd14v1.pth'],
|
19 |
'CompVis/stable-diffusion-v-1-4-original':['sd-v1-4.ckpt'],
|
20 |
'andite/anything-v4.0':['anything-v4.0-pruned.ckpt', 'anything-v4.0.vae.pt'],
|
21 |
}
|
@@ -72,5 +72,7 @@ with gr.Blocks(css='style.css') as demo:
|
|
72 |
create_demo_draw(model.process_draw)
|
73 |
with gr.TabItem('Segmentation'):
|
74 |
create_demo_seg(model.process_seg)
|
|
|
|
|
75 |
|
76 |
demo.queue().launch(debug=True, server_name='0.0.0.0')
|
|
|
8 |
|
9 |
from demo.model import Model_all
|
10 |
import gradio as gr
|
11 |
+
from demo.demos import create_demo_keypose, create_demo_sketch, create_demo_draw, create_demo_seg, create_demo_depth
|
12 |
import torch
|
13 |
import subprocess
|
14 |
import shlex
|
15 |
from huggingface_hub import hf_hub_url
|
16 |
|
17 |
urls = {
|
18 |
+
'TencentARC/T2I-Adapter':['models/t2iadapter_keypose_sd14v1.pth', 'models/t2iadapter_seg_sd14v1.pth', 'models/t2iadapter_sketch_sd14v1.pth', 'models/t2iadapter_depth_sd14v1.pth'],
|
19 |
'CompVis/stable-diffusion-v-1-4-original':['sd-v1-4.ckpt'],
|
20 |
'andite/anything-v4.0':['anything-v4.0-pruned.ckpt', 'anything-v4.0.vae.pt'],
|
21 |
}
|
|
|
72 |
create_demo_draw(model.process_draw)
|
73 |
with gr.TabItem('Segmentation'):
|
74 |
create_demo_seg(model.process_seg)
|
75 |
+
with gr.TabItem('Depth'):
|
76 |
+
create_demo_depth(model.process_depth)
|
77 |
|
78 |
demo.queue().launch(debug=True, server_name='0.0.0.0')
|
configs/stable-diffusion/app.yaml
CHANGED
File without changes
|
configs/stable-diffusion/test_keypose.yaml
DELETED
@@ -1,87 +0,0 @@
|
|
1 |
-
name: test_keypose
|
2 |
-
model:
|
3 |
-
base_learning_rate: 1.0e-04
|
4 |
-
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
5 |
-
params:
|
6 |
-
linear_start: 0.00085
|
7 |
-
linear_end: 0.0120
|
8 |
-
num_timesteps_cond: 1
|
9 |
-
log_every_t: 200
|
10 |
-
timesteps: 1000
|
11 |
-
first_stage_key: "jpg"
|
12 |
-
cond_stage_key: "txt"
|
13 |
-
image_size: 64
|
14 |
-
channels: 4
|
15 |
-
cond_stage_trainable: false # Note: different from the one we trained before
|
16 |
-
conditioning_key: crossattn
|
17 |
-
monitor: val/loss_simple_ema
|
18 |
-
scale_factor: 0.18215
|
19 |
-
use_ema: False
|
20 |
-
|
21 |
-
scheduler_config: # 10000 warmup steps
|
22 |
-
target: ldm.lr_scheduler.LambdaLinearScheduler
|
23 |
-
params:
|
24 |
-
warm_up_steps: [ 10000 ]
|
25 |
-
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
26 |
-
f_start: [ 1.e-6 ]
|
27 |
-
f_max: [ 1. ]
|
28 |
-
f_min: [ 1. ]
|
29 |
-
|
30 |
-
unet_config:
|
31 |
-
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
32 |
-
params:
|
33 |
-
image_size: 32 # unused
|
34 |
-
in_channels: 4
|
35 |
-
out_channels: 4
|
36 |
-
model_channels: 320
|
37 |
-
attention_resolutions: [ 4, 2, 1 ]
|
38 |
-
num_res_blocks: 2
|
39 |
-
channel_mult: [ 1, 2, 4, 4 ]
|
40 |
-
num_heads: 8
|
41 |
-
use_spatial_transformer: True
|
42 |
-
transformer_depth: 1
|
43 |
-
context_dim: 768
|
44 |
-
use_checkpoint: True
|
45 |
-
legacy: False
|
46 |
-
|
47 |
-
first_stage_config:
|
48 |
-
target: ldm.models.autoencoder.AutoencoderKL
|
49 |
-
params:
|
50 |
-
embed_dim: 4
|
51 |
-
monitor: val/rec_loss
|
52 |
-
ddconfig:
|
53 |
-
double_z: true
|
54 |
-
z_channels: 4
|
55 |
-
resolution: 256
|
56 |
-
in_channels: 3
|
57 |
-
out_ch: 3
|
58 |
-
ch: 128
|
59 |
-
ch_mult:
|
60 |
-
- 1
|
61 |
-
- 2
|
62 |
-
- 4
|
63 |
-
- 4
|
64 |
-
num_res_blocks: 2
|
65 |
-
attn_resolutions: []
|
66 |
-
dropout: 0.0
|
67 |
-
lossconfig:
|
68 |
-
target: torch.nn.Identity
|
69 |
-
|
70 |
-
cond_stage_config: #__is_unconditional__
|
71 |
-
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
72 |
-
params:
|
73 |
-
version: models/clip-vit-large-patch14
|
74 |
-
|
75 |
-
logger:
|
76 |
-
print_freq: 100
|
77 |
-
save_checkpoint_freq: !!float 1e4
|
78 |
-
use_tb_logger: true
|
79 |
-
wandb:
|
80 |
-
project: ~
|
81 |
-
resume_id: ~
|
82 |
-
dist_params:
|
83 |
-
backend: nccl
|
84 |
-
port: 29500
|
85 |
-
training:
|
86 |
-
lr: !!float 1e-5
|
87 |
-
save_freq: 1e4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/stable-diffusion/test_mask.yaml
DELETED
@@ -1,87 +0,0 @@
|
|
1 |
-
name: test_mask
|
2 |
-
model:
|
3 |
-
base_learning_rate: 1.0e-04
|
4 |
-
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
5 |
-
params:
|
6 |
-
linear_start: 0.00085
|
7 |
-
linear_end: 0.0120
|
8 |
-
num_timesteps_cond: 1
|
9 |
-
log_every_t: 200
|
10 |
-
timesteps: 1000
|
11 |
-
first_stage_key: "jpg"
|
12 |
-
cond_stage_key: "txt"
|
13 |
-
image_size: 64
|
14 |
-
channels: 4
|
15 |
-
cond_stage_trainable: false # Note: different from the one we trained before
|
16 |
-
conditioning_key: crossattn
|
17 |
-
monitor: val/loss_simple_ema
|
18 |
-
scale_factor: 0.18215
|
19 |
-
use_ema: False
|
20 |
-
|
21 |
-
scheduler_config: # 10000 warmup steps
|
22 |
-
target: ldm.lr_scheduler.LambdaLinearScheduler
|
23 |
-
params:
|
24 |
-
warm_up_steps: [ 10000 ]
|
25 |
-
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
26 |
-
f_start: [ 1.e-6 ]
|
27 |
-
f_max: [ 1. ]
|
28 |
-
f_min: [ 1. ]
|
29 |
-
|
30 |
-
unet_config:
|
31 |
-
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
32 |
-
params:
|
33 |
-
image_size: 32 # unused
|
34 |
-
in_channels: 4
|
35 |
-
out_channels: 4
|
36 |
-
model_channels: 320
|
37 |
-
attention_resolutions: [ 4, 2, 1 ]
|
38 |
-
num_res_blocks: 2
|
39 |
-
channel_mult: [ 1, 2, 4, 4 ]
|
40 |
-
num_heads: 8
|
41 |
-
use_spatial_transformer: True
|
42 |
-
transformer_depth: 1
|
43 |
-
context_dim: 768
|
44 |
-
use_checkpoint: True
|
45 |
-
legacy: False
|
46 |
-
|
47 |
-
first_stage_config:
|
48 |
-
target: ldm.models.autoencoder.AutoencoderKL
|
49 |
-
params:
|
50 |
-
embed_dim: 4
|
51 |
-
monitor: val/rec_loss
|
52 |
-
ddconfig:
|
53 |
-
double_z: true
|
54 |
-
z_channels: 4
|
55 |
-
resolution: 256
|
56 |
-
in_channels: 3
|
57 |
-
out_ch: 3
|
58 |
-
ch: 128
|
59 |
-
ch_mult:
|
60 |
-
- 1
|
61 |
-
- 2
|
62 |
-
- 4
|
63 |
-
- 4
|
64 |
-
num_res_blocks: 2
|
65 |
-
attn_resolutions: []
|
66 |
-
dropout: 0.0
|
67 |
-
lossconfig:
|
68 |
-
target: torch.nn.Identity
|
69 |
-
|
70 |
-
cond_stage_config: #__is_unconditional__
|
71 |
-
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
72 |
-
params:
|
73 |
-
version: models/clip-vit-large-patch14
|
74 |
-
|
75 |
-
logger:
|
76 |
-
print_freq: 100
|
77 |
-
save_checkpoint_freq: !!float 1e4
|
78 |
-
use_tb_logger: true
|
79 |
-
wandb:
|
80 |
-
project: ~
|
81 |
-
resume_id: ~
|
82 |
-
dist_params:
|
83 |
-
backend: nccl
|
84 |
-
port: 29500
|
85 |
-
training:
|
86 |
-
lr: !!float 1e-5
|
87 |
-
save_freq: 1e4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/stable-diffusion/test_mask_sketch.yaml
DELETED
@@ -1,87 +0,0 @@
|
|
1 |
-
name: test_mask_sketch
|
2 |
-
model:
|
3 |
-
base_learning_rate: 1.0e-04
|
4 |
-
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
5 |
-
params:
|
6 |
-
linear_start: 0.00085
|
7 |
-
linear_end: 0.0120
|
8 |
-
num_timesteps_cond: 1
|
9 |
-
log_every_t: 200
|
10 |
-
timesteps: 1000
|
11 |
-
first_stage_key: "jpg"
|
12 |
-
cond_stage_key: "txt"
|
13 |
-
image_size: 64
|
14 |
-
channels: 4
|
15 |
-
cond_stage_trainable: false # Note: different from the one we trained before
|
16 |
-
conditioning_key: crossattn
|
17 |
-
monitor: val/loss_simple_ema
|
18 |
-
scale_factor: 0.18215
|
19 |
-
use_ema: False
|
20 |
-
|
21 |
-
scheduler_config: # 10000 warmup steps
|
22 |
-
target: ldm.lr_scheduler.LambdaLinearScheduler
|
23 |
-
params:
|
24 |
-
warm_up_steps: [ 10000 ]
|
25 |
-
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
26 |
-
f_start: [ 1.e-6 ]
|
27 |
-
f_max: [ 1. ]
|
28 |
-
f_min: [ 1. ]
|
29 |
-
|
30 |
-
unet_config:
|
31 |
-
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
32 |
-
params:
|
33 |
-
image_size: 32 # unused
|
34 |
-
in_channels: 4
|
35 |
-
out_channels: 4
|
36 |
-
model_channels: 320
|
37 |
-
attention_resolutions: [ 4, 2, 1 ]
|
38 |
-
num_res_blocks: 2
|
39 |
-
channel_mult: [ 1, 2, 4, 4 ]
|
40 |
-
num_heads: 8
|
41 |
-
use_spatial_transformer: True
|
42 |
-
transformer_depth: 1
|
43 |
-
context_dim: 768
|
44 |
-
use_checkpoint: True
|
45 |
-
legacy: False
|
46 |
-
|
47 |
-
first_stage_config:
|
48 |
-
target: ldm.models.autoencoder.AutoencoderKL
|
49 |
-
params:
|
50 |
-
embed_dim: 4
|
51 |
-
monitor: val/rec_loss
|
52 |
-
ddconfig:
|
53 |
-
double_z: true
|
54 |
-
z_channels: 4
|
55 |
-
resolution: 256
|
56 |
-
in_channels: 3
|
57 |
-
out_ch: 3
|
58 |
-
ch: 128
|
59 |
-
ch_mult:
|
60 |
-
- 1
|
61 |
-
- 2
|
62 |
-
- 4
|
63 |
-
- 4
|
64 |
-
num_res_blocks: 2
|
65 |
-
attn_resolutions: []
|
66 |
-
dropout: 0.0
|
67 |
-
lossconfig:
|
68 |
-
target: torch.nn.Identity
|
69 |
-
|
70 |
-
cond_stage_config: #__is_unconditional__
|
71 |
-
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
72 |
-
params:
|
73 |
-
version: models/clip-vit-large-patch14
|
74 |
-
|
75 |
-
logger:
|
76 |
-
print_freq: 100
|
77 |
-
save_checkpoint_freq: !!float 1e4
|
78 |
-
use_tb_logger: true
|
79 |
-
wandb:
|
80 |
-
project: ~
|
81 |
-
resume_id: ~
|
82 |
-
dist_params:
|
83 |
-
backend: nccl
|
84 |
-
port: 29500
|
85 |
-
training:
|
86 |
-
lr: !!float 1e-5
|
87 |
-
save_freq: 1e4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/stable-diffusion/test_sketch.yaml
DELETED
@@ -1,87 +0,0 @@
|
|
1 |
-
name: test_sketch
|
2 |
-
model:
|
3 |
-
base_learning_rate: 1.0e-04
|
4 |
-
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
5 |
-
params:
|
6 |
-
linear_start: 0.00085
|
7 |
-
linear_end: 0.0120
|
8 |
-
num_timesteps_cond: 1
|
9 |
-
log_every_t: 200
|
10 |
-
timesteps: 1000
|
11 |
-
first_stage_key: "jpg"
|
12 |
-
cond_stage_key: "txt"
|
13 |
-
image_size: 64
|
14 |
-
channels: 4
|
15 |
-
cond_stage_trainable: false # Note: different from the one we trained before
|
16 |
-
conditioning_key: crossattn
|
17 |
-
monitor: val/loss_simple_ema
|
18 |
-
scale_factor: 0.18215
|
19 |
-
use_ema: False
|
20 |
-
|
21 |
-
scheduler_config: # 10000 warmup steps
|
22 |
-
target: ldm.lr_scheduler.LambdaLinearScheduler
|
23 |
-
params:
|
24 |
-
warm_up_steps: [ 10000 ]
|
25 |
-
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
26 |
-
f_start: [ 1.e-6 ]
|
27 |
-
f_max: [ 1. ]
|
28 |
-
f_min: [ 1. ]
|
29 |
-
|
30 |
-
unet_config:
|
31 |
-
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
32 |
-
params:
|
33 |
-
image_size: 32 # unused
|
34 |
-
in_channels: 4
|
35 |
-
out_channels: 4
|
36 |
-
model_channels: 320
|
37 |
-
attention_resolutions: [ 4, 2, 1 ]
|
38 |
-
num_res_blocks: 2
|
39 |
-
channel_mult: [ 1, 2, 4, 4 ]
|
40 |
-
num_heads: 8
|
41 |
-
use_spatial_transformer: True
|
42 |
-
transformer_depth: 1
|
43 |
-
context_dim: 768
|
44 |
-
use_checkpoint: True
|
45 |
-
legacy: False
|
46 |
-
|
47 |
-
first_stage_config:
|
48 |
-
target: ldm.models.autoencoder.AutoencoderKL
|
49 |
-
params:
|
50 |
-
embed_dim: 4
|
51 |
-
monitor: val/rec_loss
|
52 |
-
ddconfig:
|
53 |
-
double_z: true
|
54 |
-
z_channels: 4
|
55 |
-
resolution: 256
|
56 |
-
in_channels: 3
|
57 |
-
out_ch: 3
|
58 |
-
ch: 128
|
59 |
-
ch_mult:
|
60 |
-
- 1
|
61 |
-
- 2
|
62 |
-
- 4
|
63 |
-
- 4
|
64 |
-
num_res_blocks: 2
|
65 |
-
attn_resolutions: []
|
66 |
-
dropout: 0.0
|
67 |
-
lossconfig:
|
68 |
-
target: torch.nn.Identity
|
69 |
-
|
70 |
-
cond_stage_config: #__is_unconditional__
|
71 |
-
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
72 |
-
params:
|
73 |
-
version: models/clip-vit-large-patch14
|
74 |
-
|
75 |
-
logger:
|
76 |
-
print_freq: 100
|
77 |
-
save_checkpoint_freq: !!float 1e4
|
78 |
-
use_tb_logger: true
|
79 |
-
wandb:
|
80 |
-
project: ~
|
81 |
-
resume_id: ~
|
82 |
-
dist_params:
|
83 |
-
backend: nccl
|
84 |
-
port: 29500
|
85 |
-
training:
|
86 |
-
lr: !!float 1e-5
|
87 |
-
save_freq: 1e4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/stable-diffusion/test_sketch_edit.yaml
DELETED
@@ -1,87 +0,0 @@
|
|
1 |
-
name: test_sketch_edit
|
2 |
-
model:
|
3 |
-
base_learning_rate: 1.0e-04
|
4 |
-
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
5 |
-
params:
|
6 |
-
linear_start: 0.00085
|
7 |
-
linear_end: 0.0120
|
8 |
-
num_timesteps_cond: 1
|
9 |
-
log_every_t: 200
|
10 |
-
timesteps: 1000
|
11 |
-
first_stage_key: "jpg"
|
12 |
-
cond_stage_key: "txt"
|
13 |
-
image_size: 64
|
14 |
-
channels: 4
|
15 |
-
cond_stage_trainable: false # Note: different from the one we trained before
|
16 |
-
conditioning_key: crossattn
|
17 |
-
monitor: val/loss_simple_ema
|
18 |
-
scale_factor: 0.18215
|
19 |
-
use_ema: False
|
20 |
-
|
21 |
-
scheduler_config: # 10000 warmup steps
|
22 |
-
target: ldm.lr_scheduler.LambdaLinearScheduler
|
23 |
-
params:
|
24 |
-
warm_up_steps: [ 10000 ]
|
25 |
-
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
26 |
-
f_start: [ 1.e-6 ]
|
27 |
-
f_max: [ 1. ]
|
28 |
-
f_min: [ 1. ]
|
29 |
-
|
30 |
-
unet_config:
|
31 |
-
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
32 |
-
params:
|
33 |
-
image_size: 32 # unused
|
34 |
-
in_channels: 4
|
35 |
-
out_channels: 4
|
36 |
-
model_channels: 320
|
37 |
-
attention_resolutions: [ 4, 2, 1 ]
|
38 |
-
num_res_blocks: 2
|
39 |
-
channel_mult: [ 1, 2, 4, 4 ]
|
40 |
-
num_heads: 8
|
41 |
-
use_spatial_transformer: True
|
42 |
-
transformer_depth: 1
|
43 |
-
context_dim: 768
|
44 |
-
use_checkpoint: True
|
45 |
-
legacy: False
|
46 |
-
|
47 |
-
first_stage_config:
|
48 |
-
target: ldm.models.autoencoder.AutoencoderKL
|
49 |
-
params:
|
50 |
-
embed_dim: 4
|
51 |
-
monitor: val/rec_loss
|
52 |
-
ddconfig:
|
53 |
-
double_z: true
|
54 |
-
z_channels: 4
|
55 |
-
resolution: 256
|
56 |
-
in_channels: 3
|
57 |
-
out_ch: 3
|
58 |
-
ch: 128
|
59 |
-
ch_mult:
|
60 |
-
- 1
|
61 |
-
- 2
|
62 |
-
- 4
|
63 |
-
- 4
|
64 |
-
num_res_blocks: 2
|
65 |
-
attn_resolutions: []
|
66 |
-
dropout: 0.0
|
67 |
-
lossconfig:
|
68 |
-
target: torch.nn.Identity
|
69 |
-
|
70 |
-
cond_stage_config: #__is_unconditional__
|
71 |
-
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
72 |
-
params:
|
73 |
-
version: models/clip-vit-large-patch14
|
74 |
-
|
75 |
-
logger:
|
76 |
-
print_freq: 100
|
77 |
-
save_checkpoint_freq: !!float 1e4
|
78 |
-
use_tb_logger: true
|
79 |
-
wandb:
|
80 |
-
project: ~
|
81 |
-
resume_id: ~
|
82 |
-
dist_params:
|
83 |
-
backend: nccl
|
84 |
-
port: 29500
|
85 |
-
training:
|
86 |
-
lr: !!float 1e-5
|
87 |
-
save_freq: 1e4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/stable-diffusion/train_keypose.yaml
DELETED
@@ -1,87 +0,0 @@
|
|
1 |
-
name: train_keypose
|
2 |
-
model:
|
3 |
-
base_learning_rate: 1.0e-04
|
4 |
-
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
5 |
-
params:
|
6 |
-
linear_start: 0.00085
|
7 |
-
linear_end: 0.0120
|
8 |
-
num_timesteps_cond: 1
|
9 |
-
log_every_t: 200
|
10 |
-
timesteps: 1000
|
11 |
-
first_stage_key: "jpg"
|
12 |
-
cond_stage_key: "txt"
|
13 |
-
image_size: 64
|
14 |
-
channels: 4
|
15 |
-
cond_stage_trainable: false # Note: different from the one we trained before
|
16 |
-
conditioning_key: crossattn
|
17 |
-
monitor: val/loss_simple_ema
|
18 |
-
scale_factor: 0.18215
|
19 |
-
use_ema: False
|
20 |
-
|
21 |
-
scheduler_config: # 10000 warmup steps
|
22 |
-
target: ldm.lr_scheduler.LambdaLinearScheduler
|
23 |
-
params:
|
24 |
-
warm_up_steps: [ 10000 ]
|
25 |
-
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
26 |
-
f_start: [ 1.e-6 ]
|
27 |
-
f_max: [ 1. ]
|
28 |
-
f_min: [ 1. ]
|
29 |
-
|
30 |
-
unet_config:
|
31 |
-
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
32 |
-
params:
|
33 |
-
image_size: 32 # unused
|
34 |
-
in_channels: 4
|
35 |
-
out_channels: 4
|
36 |
-
model_channels: 320
|
37 |
-
attention_resolutions: [ 4, 2, 1 ]
|
38 |
-
num_res_blocks: 2
|
39 |
-
channel_mult: [ 1, 2, 4, 4 ]
|
40 |
-
num_heads: 8
|
41 |
-
use_spatial_transformer: True
|
42 |
-
transformer_depth: 1
|
43 |
-
context_dim: 768
|
44 |
-
use_checkpoint: True
|
45 |
-
legacy: False
|
46 |
-
|
47 |
-
first_stage_config:
|
48 |
-
target: ldm.models.autoencoder.AutoencoderKL
|
49 |
-
params:
|
50 |
-
embed_dim: 4
|
51 |
-
monitor: val/rec_loss
|
52 |
-
ddconfig:
|
53 |
-
double_z: true
|
54 |
-
z_channels: 4
|
55 |
-
resolution: 256
|
56 |
-
in_channels: 3
|
57 |
-
out_ch: 3
|
58 |
-
ch: 128
|
59 |
-
ch_mult:
|
60 |
-
- 1
|
61 |
-
- 2
|
62 |
-
- 4
|
63 |
-
- 4
|
64 |
-
num_res_blocks: 2
|
65 |
-
attn_resolutions: []
|
66 |
-
dropout: 0.0
|
67 |
-
lossconfig:
|
68 |
-
target: torch.nn.Identity
|
69 |
-
|
70 |
-
cond_stage_config: #__is_unconditional__
|
71 |
-
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
72 |
-
params:
|
73 |
-
version: models/clip-vit-large-patch14
|
74 |
-
|
75 |
-
logger:
|
76 |
-
print_freq: 100
|
77 |
-
save_checkpoint_freq: !!float 1e4
|
78 |
-
use_tb_logger: true
|
79 |
-
wandb:
|
80 |
-
project: ~
|
81 |
-
resume_id: ~
|
82 |
-
dist_params:
|
83 |
-
backend: nccl
|
84 |
-
port: 29500
|
85 |
-
training:
|
86 |
-
lr: !!float 1e-5
|
87 |
-
save_freq: 1e4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/stable-diffusion/train_mask.yaml
DELETED
@@ -1,87 +0,0 @@
|
|
1 |
-
name: train_mask
|
2 |
-
model:
|
3 |
-
base_learning_rate: 1.0e-04
|
4 |
-
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
5 |
-
params:
|
6 |
-
linear_start: 0.00085
|
7 |
-
linear_end: 0.0120
|
8 |
-
num_timesteps_cond: 1
|
9 |
-
log_every_t: 200
|
10 |
-
timesteps: 1000
|
11 |
-
first_stage_key: "jpg"
|
12 |
-
cond_stage_key: "txt"
|
13 |
-
image_size: 64
|
14 |
-
channels: 4
|
15 |
-
cond_stage_trainable: false # Note: different from the one we trained before
|
16 |
-
conditioning_key: crossattn
|
17 |
-
monitor: val/loss_simple_ema
|
18 |
-
scale_factor: 0.18215
|
19 |
-
use_ema: False
|
20 |
-
|
21 |
-
scheduler_config: # 10000 warmup steps
|
22 |
-
target: ldm.lr_scheduler.LambdaLinearScheduler
|
23 |
-
params:
|
24 |
-
warm_up_steps: [ 10000 ]
|
25 |
-
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
26 |
-
f_start: [ 1.e-6 ]
|
27 |
-
f_max: [ 1. ]
|
28 |
-
f_min: [ 1. ]
|
29 |
-
|
30 |
-
unet_config:
|
31 |
-
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
32 |
-
params:
|
33 |
-
image_size: 32 # unused
|
34 |
-
in_channels: 4
|
35 |
-
out_channels: 4
|
36 |
-
model_channels: 320
|
37 |
-
attention_resolutions: [ 4, 2, 1 ]
|
38 |
-
num_res_blocks: 2
|
39 |
-
channel_mult: [ 1, 2, 4, 4 ]
|
40 |
-
num_heads: 8
|
41 |
-
use_spatial_transformer: True
|
42 |
-
transformer_depth: 1
|
43 |
-
context_dim: 768
|
44 |
-
use_checkpoint: True
|
45 |
-
legacy: False
|
46 |
-
|
47 |
-
first_stage_config:
|
48 |
-
target: ldm.models.autoencoder.AutoencoderKL
|
49 |
-
params:
|
50 |
-
embed_dim: 4
|
51 |
-
monitor: val/rec_loss
|
52 |
-
ddconfig:
|
53 |
-
double_z: true
|
54 |
-
z_channels: 4
|
55 |
-
resolution: 256
|
56 |
-
in_channels: 3
|
57 |
-
out_ch: 3
|
58 |
-
ch: 128
|
59 |
-
ch_mult:
|
60 |
-
- 1
|
61 |
-
- 2
|
62 |
-
- 4
|
63 |
-
- 4
|
64 |
-
num_res_blocks: 2
|
65 |
-
attn_resolutions: []
|
66 |
-
dropout: 0.0
|
67 |
-
lossconfig:
|
68 |
-
target: torch.nn.Identity
|
69 |
-
|
70 |
-
cond_stage_config: #__is_unconditional__
|
71 |
-
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
72 |
-
params:
|
73 |
-
version: models/clip-vit-large-patch14
|
74 |
-
|
75 |
-
logger:
|
76 |
-
print_freq: 100
|
77 |
-
save_checkpoint_freq: !!float 1e4
|
78 |
-
use_tb_logger: true
|
79 |
-
wandb:
|
80 |
-
project: ~
|
81 |
-
resume_id: ~
|
82 |
-
dist_params:
|
83 |
-
backend: nccl
|
84 |
-
port: 29500
|
85 |
-
training:
|
86 |
-
lr: !!float 1e-5
|
87 |
-
save_freq: 1e4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/stable-diffusion/train_sketch.yaml
DELETED
@@ -1,87 +0,0 @@
|
|
1 |
-
name: train_sketch
|
2 |
-
model:
|
3 |
-
base_learning_rate: 1.0e-04
|
4 |
-
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
5 |
-
params:
|
6 |
-
linear_start: 0.00085
|
7 |
-
linear_end: 0.0120
|
8 |
-
num_timesteps_cond: 1
|
9 |
-
log_every_t: 200
|
10 |
-
timesteps: 1000
|
11 |
-
first_stage_key: "jpg"
|
12 |
-
cond_stage_key: "txt"
|
13 |
-
image_size: 64
|
14 |
-
channels: 4
|
15 |
-
cond_stage_trainable: false # Note: different from the one we trained before
|
16 |
-
conditioning_key: crossattn
|
17 |
-
monitor: val/loss_simple_ema
|
18 |
-
scale_factor: 0.18215
|
19 |
-
use_ema: False
|
20 |
-
|
21 |
-
scheduler_config: # 10000 warmup steps
|
22 |
-
target: ldm.lr_scheduler.LambdaLinearScheduler
|
23 |
-
params:
|
24 |
-
warm_up_steps: [ 10000 ]
|
25 |
-
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
26 |
-
f_start: [ 1.e-6 ]
|
27 |
-
f_max: [ 1. ]
|
28 |
-
f_min: [ 1. ]
|
29 |
-
|
30 |
-
unet_config:
|
31 |
-
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
32 |
-
params:
|
33 |
-
image_size: 32 # unused
|
34 |
-
in_channels: 4
|
35 |
-
out_channels: 4
|
36 |
-
model_channels: 320
|
37 |
-
attention_resolutions: [ 4, 2, 1 ]
|
38 |
-
num_res_blocks: 2
|
39 |
-
channel_mult: [ 1, 2, 4, 4 ]
|
40 |
-
num_heads: 8
|
41 |
-
use_spatial_transformer: True
|
42 |
-
transformer_depth: 1
|
43 |
-
context_dim: 768
|
44 |
-
use_checkpoint: True
|
45 |
-
legacy: False
|
46 |
-
|
47 |
-
first_stage_config:
|
48 |
-
target: ldm.models.autoencoder.AutoencoderKL
|
49 |
-
params:
|
50 |
-
embed_dim: 4
|
51 |
-
monitor: val/rec_loss
|
52 |
-
ddconfig:
|
53 |
-
double_z: true
|
54 |
-
z_channels: 4
|
55 |
-
resolution: 256
|
56 |
-
in_channels: 3
|
57 |
-
out_ch: 3
|
58 |
-
ch: 128
|
59 |
-
ch_mult:
|
60 |
-
- 1
|
61 |
-
- 2
|
62 |
-
- 4
|
63 |
-
- 4
|
64 |
-
num_res_blocks: 2
|
65 |
-
attn_resolutions: []
|
66 |
-
dropout: 0.0
|
67 |
-
lossconfig:
|
68 |
-
target: torch.nn.Identity
|
69 |
-
|
70 |
-
cond_stage_config: #__is_unconditional__
|
71 |
-
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
72 |
-
params:
|
73 |
-
version: models/clip-vit-large-patch14
|
74 |
-
|
75 |
-
logger:
|
76 |
-
print_freq: 100
|
77 |
-
save_checkpoint_freq: !!float 1e4
|
78 |
-
use_tb_logger: true
|
79 |
-
wandb:
|
80 |
-
project: ~
|
81 |
-
resume_id: ~
|
82 |
-
dist_params:
|
83 |
-
backend: nccl
|
84 |
-
port: 29500
|
85 |
-
training:
|
86 |
-
lr: !!float 1e-5
|
87 |
-
save_freq: 1e4
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset_coco.py
DELETED
@@ -1,138 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import json
|
3 |
-
import cv2
|
4 |
-
import torch
|
5 |
-
import os
|
6 |
-
from basicsr.utils import img2tensor, tensor2img
|
7 |
-
import random
|
8 |
-
|
9 |
-
class dataset_coco():
|
10 |
-
def __init__(self, path_json, root_path, image_size, mode='train'):
|
11 |
-
super(dataset_coco, self).__init__()
|
12 |
-
with open(path_json, 'r', encoding='utf-8') as fp:
|
13 |
-
data = json.load(fp)
|
14 |
-
data = data['images']
|
15 |
-
self.paths = []
|
16 |
-
self.root_path = root_path
|
17 |
-
for file in data:
|
18 |
-
input_path = file['filepath']
|
19 |
-
if mode == 'train':
|
20 |
-
if 'val' not in input_path:
|
21 |
-
self.paths.append(file)
|
22 |
-
else:
|
23 |
-
if 'val' in input_path:
|
24 |
-
self.paths.append(file)
|
25 |
-
|
26 |
-
def __getitem__(self, idx):
|
27 |
-
file = self.paths[idx]
|
28 |
-
input_path = file['filepath']
|
29 |
-
input_name = file['filename']
|
30 |
-
path = os.path.join(self.root_path, input_path, input_name)
|
31 |
-
im = cv2.imread(path)
|
32 |
-
im = cv2.resize(im, (512,512))
|
33 |
-
im = img2tensor(im, bgr2rgb=True, float32=True)/255.
|
34 |
-
sentences = file['sentences']
|
35 |
-
sentence = sentences[int(random.random()*len(sentences))]['raw'].strip('.')
|
36 |
-
return {'im':im, 'sentence':sentence}
|
37 |
-
|
38 |
-
def __len__(self):
|
39 |
-
return len(self.paths)
|
40 |
-
|
41 |
-
|
42 |
-
class dataset_coco_mask():
|
43 |
-
def __init__(self, path_json, root_path_im, root_path_mask, image_size):
|
44 |
-
super(dataset_coco_mask, self).__init__()
|
45 |
-
with open(path_json, 'r', encoding='utf-8') as fp:
|
46 |
-
data = json.load(fp)
|
47 |
-
data = data['annotations']
|
48 |
-
self.files = []
|
49 |
-
self.root_path_im = root_path_im
|
50 |
-
self.root_path_mask = root_path_mask
|
51 |
-
for file in data:
|
52 |
-
name = "%012d.png"%file['image_id']
|
53 |
-
self.files.append({'name':name, 'sentence':file['caption']})
|
54 |
-
|
55 |
-
def __getitem__(self, idx):
|
56 |
-
file = self.files[idx]
|
57 |
-
name = file['name']
|
58 |
-
# print(os.path.join(self.root_path_im, name))
|
59 |
-
im = cv2.imread(os.path.join(self.root_path_im, name.replace('.png','.jpg')))
|
60 |
-
im = cv2.resize(im, (512,512))
|
61 |
-
im = img2tensor(im, bgr2rgb=True, float32=True)/255.
|
62 |
-
|
63 |
-
mask = cv2.imread(os.path.join(self.root_path_mask, name))#[:,:,0]
|
64 |
-
mask = cv2.resize(mask, (512,512))
|
65 |
-
mask = img2tensor(mask, bgr2rgb=True, float32=True)[0].unsqueeze(0)#/255.
|
66 |
-
|
67 |
-
sentence = file['sentence']
|
68 |
-
return {'im':im, 'mask':mask, 'sentence':sentence}
|
69 |
-
|
70 |
-
def __len__(self):
|
71 |
-
return len(self.files)
|
72 |
-
|
73 |
-
|
74 |
-
class dataset_coco_mask_color():
|
75 |
-
def __init__(self, path_json, root_path_im, root_path_mask, image_size):
|
76 |
-
super(dataset_coco_mask_color, self).__init__()
|
77 |
-
with open(path_json, 'r', encoding='utf-8') as fp:
|
78 |
-
data = json.load(fp)
|
79 |
-
data = data['annotations']
|
80 |
-
self.files = []
|
81 |
-
self.root_path_im = root_path_im
|
82 |
-
self.root_path_mask = root_path_mask
|
83 |
-
for file in data:
|
84 |
-
name = "%012d.png"%file['image_id']
|
85 |
-
self.files.append({'name':name, 'sentence':file['caption']})
|
86 |
-
|
87 |
-
def __getitem__(self, idx):
|
88 |
-
file = self.files[idx]
|
89 |
-
name = file['name']
|
90 |
-
# print(os.path.join(self.root_path_im, name))
|
91 |
-
im = cv2.imread(os.path.join(self.root_path_im, name.replace('.png','.jpg')))
|
92 |
-
im = cv2.resize(im, (512,512))
|
93 |
-
im = img2tensor(im, bgr2rgb=True, float32=True)/255.
|
94 |
-
|
95 |
-
mask = cv2.imread(os.path.join(self.root_path_mask, name))#[:,:,0]
|
96 |
-
mask = cv2.resize(mask, (512,512))
|
97 |
-
mask = img2tensor(mask, bgr2rgb=True, float32=True)/255.#[0].unsqueeze(0)#/255.
|
98 |
-
|
99 |
-
sentence = file['sentence']
|
100 |
-
return {'im':im, 'mask':mask, 'sentence':sentence}
|
101 |
-
|
102 |
-
def __len__(self):
|
103 |
-
return len(self.files)
|
104 |
-
|
105 |
-
class dataset_coco_mask_color_sig():
|
106 |
-
def __init__(self, path_json, root_path_im, root_path_mask, image_size):
|
107 |
-
super(dataset_coco_mask_color_sig, self).__init__()
|
108 |
-
with open(path_json, 'r', encoding='utf-8') as fp:
|
109 |
-
data = json.load(fp)
|
110 |
-
data = data['annotations']
|
111 |
-
self.files = []
|
112 |
-
self.root_path_im = root_path_im
|
113 |
-
self.root_path_mask = root_path_mask
|
114 |
-
reg = {}
|
115 |
-
for file in data:
|
116 |
-
name = "%012d.png"%file['image_id']
|
117 |
-
if name in reg:
|
118 |
-
continue
|
119 |
-
self.files.append({'name':name, 'sentence':file['caption']})
|
120 |
-
reg[name] = name
|
121 |
-
|
122 |
-
def __getitem__(self, idx):
|
123 |
-
file = self.files[idx]
|
124 |
-
name = file['name']
|
125 |
-
# print(os.path.join(self.root_path_im, name))
|
126 |
-
im = cv2.imread(os.path.join(self.root_path_im, name.replace('.png','.jpg')))
|
127 |
-
im = cv2.resize(im, (512,512))
|
128 |
-
im = img2tensor(im, bgr2rgb=True, float32=True)/255.
|
129 |
-
|
130 |
-
mask = cv2.imread(os.path.join(self.root_path_mask, name))#[:,:,0]
|
131 |
-
mask = cv2.resize(mask, (512,512))
|
132 |
-
mask = img2tensor(mask, bgr2rgb=True, float32=True)/255.#[0].unsqueeze(0)#/255.
|
133 |
-
|
134 |
-
sentence = file['sentence']
|
135 |
-
return {'im':im, 'mask':mask, 'sentence':sentence, 'name': name}
|
136 |
-
|
137 |
-
def __len__(self):
|
138 |
-
return len(self.files)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo/demos.py
CHANGED
@@ -85,7 +85,32 @@ def create_demo_seg(process):
|
|
85 |
with gr.Row():
|
86 |
type_in = gr.inputs.Radio(['Segmentation', 'Image'], type="value", default='Image', label='You can input an image or a segmentation. If you choose to input a segmentation, it must correspond to the coco-stuff')
|
87 |
run_button = gr.Button(label="Run")
|
88 |
-
con_strength = gr.Slider(label="Controling Strength (The guidance strength of the segmentation to the result)", minimum=0, maximum=1, value=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
|
90 |
fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
|
91 |
base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
|
|
|
85 |
with gr.Row():
|
86 |
type_in = gr.inputs.Radio(['Segmentation', 'Image'], type="value", default='Image', label='You can input an image or a segmentation. If you choose to input a segmentation, it must correspond to the coco-stuff')
|
87 |
run_button = gr.Button(label="Run")
|
88 |
+
con_strength = gr.Slider(label="Controling Strength (The guidance strength of the segmentation to the result)", minimum=0, maximum=1, value=1, step=0.1)
|
89 |
+
scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
|
90 |
+
fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
|
91 |
+
base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
|
92 |
+
with gr.Column():
|
93 |
+
result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
94 |
+
ips = [input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model]
|
95 |
+
run_button.click(fn=process, inputs=ips, outputs=[result])
|
96 |
+
return demo
|
97 |
+
|
98 |
+
def create_demo_depth(process):
|
99 |
+
with gr.Blocks() as demo:
|
100 |
+
with gr.Row():
|
101 |
+
gr.Markdown('## T2I-Adapter (Depth)')
|
102 |
+
with gr.Row():
|
103 |
+
with gr.Column():
|
104 |
+
input_img = gr.Image(source='upload', type="numpy")
|
105 |
+
prompt = gr.Textbox(label="Prompt")
|
106 |
+
neg_prompt = gr.Textbox(label="Negative Prompt",
|
107 |
+
value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
|
108 |
+
pos_prompt = gr.Textbox(label="Positive Prompt",
|
109 |
+
value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed')
|
110 |
+
with gr.Row():
|
111 |
+
type_in = gr.inputs.Radio(['Depth', 'Image'], type="value", default='Image', label='You can input an image or a depth map')
|
112 |
+
run_button = gr.Button(label="Run")
|
113 |
+
con_strength = gr.Slider(label="Controling Strength (The guidance strength of the depth map to the result)", minimum=0, maximum=1, value=1, step=0.1)
|
114 |
scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
|
115 |
fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
|
116 |
base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
|
demo/model.py
CHANGED
@@ -4,7 +4,9 @@ from pytorch_lightning import seed_everything
|
|
4 |
from ldm.models.diffusion.plms import PLMSSampler
|
5 |
from ldm.modules.encoders.adapter import Adapter
|
6 |
from ldm.util import instantiate_from_config
|
7 |
-
from model_edge import pidinet
|
|
|
|
|
8 |
import gradio as gr
|
9 |
from omegaconf import OmegaConf
|
10 |
import mmcv
|
@@ -13,7 +15,6 @@ from mmpose.apis import (inference_top_down_pose_model, init_pose_model, process
|
|
13 |
import os
|
14 |
import cv2
|
15 |
import numpy as np
|
16 |
-
from seger import seger, Colorize
|
17 |
import torch.nn.functional as F
|
18 |
|
19 |
def preprocessing(image, device):
|
@@ -136,10 +137,8 @@ class Model_all:
|
|
136 |
self.model_sketch = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
|
137 |
use_conv=False).to(device)
|
138 |
self.model_sketch.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth", map_location=device))
|
139 |
-
self.model_edge = pidinet()
|
140 |
-
|
141 |
-
self.model_edge.load_state_dict({k.replace('module.', ''): v for k, v in ckp.items()})
|
142 |
-
self.model_edge.to(device)
|
143 |
|
144 |
# segmentation part
|
145 |
self.model_seger = seger().to(device)
|
@@ -147,6 +146,11 @@ class Model_all:
|
|
147 |
self.coler = Colorize(n=182)
|
148 |
self.model_seg = Adapter(cin=int(3*64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
|
149 |
self.model_seg.load_state_dict(torch.load("models/t2iadapter_seg_sd14v1.pth", map_location=device))
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
# keypose part
|
152 |
self.model_pose = Adapter(cin=int(3 * 64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
|
@@ -248,6 +252,65 @@ class Model_all:
|
|
248 |
|
249 |
return [im_edge, x_samples_ddim]
|
250 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
251 |
@torch.no_grad()
|
252 |
def process_seg(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
|
253 |
con_strength, base_model):
|
|
|
4 |
from ldm.models.diffusion.plms import PLMSSampler
|
5 |
from ldm.modules.encoders.adapter import Adapter
|
6 |
from ldm.util import instantiate_from_config
|
7 |
+
from ldm.modules.structure_condition.model_edge import pidinet
|
8 |
+
from ldm.modules.structure_condition.model_seg import seger, Colorize
|
9 |
+
from ldm.modules.structure_condition.midas.api import MiDaSInference
|
10 |
import gradio as gr
|
11 |
from omegaconf import OmegaConf
|
12 |
import mmcv
|
|
|
15 |
import os
|
16 |
import cv2
|
17 |
import numpy as np
|
|
|
18 |
import torch.nn.functional as F
|
19 |
|
20 |
def preprocessing(image, device):
|
|
|
137 |
self.model_sketch = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
|
138 |
use_conv=False).to(device)
|
139 |
self.model_sketch.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth", map_location=device))
|
140 |
+
self.model_edge = pidinet().to(device)
|
141 |
+
self.model_edge.load_state_dict({k.replace('module.', ''): v for k, v in torch.load('models/table5_pidinet.pth', map_location=device)['state_dict'].items()})
|
|
|
|
|
142 |
|
143 |
# segmentation part
|
144 |
self.model_seger = seger().to(device)
|
|
|
146 |
self.coler = Colorize(n=182)
|
147 |
self.model_seg = Adapter(cin=int(3*64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
|
148 |
self.model_seg.load_state_dict(torch.load("models/t2iadapter_seg_sd14v1.pth", map_location=device))
|
149 |
+
self.depth_model = MiDaSInference(model_type='dpt_hybrid').to(device)
|
150 |
+
|
151 |
+
# depth part
|
152 |
+
self.model_depth = Adapter(cin=3*64, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
|
153 |
+
self.model_depth.load_state_dict(torch.load("models/t2iadapter_depth_sd14v1.pth", map_location=device))
|
154 |
|
155 |
# keypose part
|
156 |
self.model_pose = Adapter(cin=int(3 * 64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True,
|
|
|
252 |
|
253 |
return [im_edge, x_samples_ddim]
|
254 |
|
255 |
+
@torch.no_grad()
|
256 |
+
def process_depth(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
|
257 |
+
con_strength, base_model):
|
258 |
+
if self.current_base != base_model:
|
259 |
+
ckpt = os.path.join("models", base_model)
|
260 |
+
pl_sd = torch.load(ckpt, map_location="cuda")
|
261 |
+
if "state_dict" in pl_sd:
|
262 |
+
sd = pl_sd["state_dict"]
|
263 |
+
else:
|
264 |
+
sd = pl_sd
|
265 |
+
self.base_model.load_state_dict(sd, strict=False)
|
266 |
+
self.current_base = base_model
|
267 |
+
if 'anything' in base_model.lower():
|
268 |
+
self.load_vae()
|
269 |
+
|
270 |
+
con_strength = int((1 - con_strength) * 50)
|
271 |
+
if fix_sample == 'True':
|
272 |
+
seed_everything(42)
|
273 |
+
im = cv2.resize(input_img, (512, 512))
|
274 |
+
|
275 |
+
if type_in == 'Depth':
|
276 |
+
im_depth = im.copy()
|
277 |
+
depth = img2tensor(im).unsqueeze(0) / 255.
|
278 |
+
elif type_in == 'Image':
|
279 |
+
im = img2tensor(im).unsqueeze(0) / 127.5 - 1.0
|
280 |
+
depth = self.depth_model(im.to(self.device)).repeat(1, 3, 1, 1)
|
281 |
+
depth -= torch.min(depth)
|
282 |
+
depth /= torch.max(depth)
|
283 |
+
im_depth = tensor2img(depth)
|
284 |
+
|
285 |
+
# extract condition features
|
286 |
+
c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt])
|
287 |
+
nc = self.base_model.get_learned_conditioning([neg_prompt])
|
288 |
+
features_adapter = self.model_depth(depth.to(self.device))
|
289 |
+
shape = [4, 64, 64]
|
290 |
+
|
291 |
+
# sampling
|
292 |
+
samples_ddim, _ = self.sampler.sample(S=50,
|
293 |
+
conditioning=c,
|
294 |
+
batch_size=1,
|
295 |
+
shape=shape,
|
296 |
+
verbose=False,
|
297 |
+
unconditional_guidance_scale=scale,
|
298 |
+
unconditional_conditioning=nc,
|
299 |
+
eta=0.0,
|
300 |
+
x_T=None,
|
301 |
+
features_adapter1=features_adapter,
|
302 |
+
mode='sketch',
|
303 |
+
con_strength=con_strength)
|
304 |
+
|
305 |
+
x_samples_ddim = self.base_model.decode_first_stage(samples_ddim)
|
306 |
+
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
307 |
+
x_samples_ddim = x_samples_ddim.to('cpu')
|
308 |
+
x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
|
309 |
+
x_samples_ddim = 255. * x_samples_ddim
|
310 |
+
x_samples_ddim = x_samples_ddim.astype(np.uint8)
|
311 |
+
|
312 |
+
return [im_depth, x_samples_ddim]
|
313 |
+
|
314 |
@torch.no_grad()
|
315 |
def process_seg(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale,
|
316 |
con_strength, base_model):
|
dist_util.py
DELETED
@@ -1,91 +0,0 @@
|
|
1 |
-
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
|
2 |
-
import functools
|
3 |
-
import os
|
4 |
-
import subprocess
|
5 |
-
import torch
|
6 |
-
import torch.distributed as dist
|
7 |
-
import torch.multiprocessing as mp
|
8 |
-
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
9 |
-
|
10 |
-
|
11 |
-
def init_dist(launcher, backend='nccl', **kwargs):
|
12 |
-
if mp.get_start_method(allow_none=True) is None:
|
13 |
-
mp.set_start_method('spawn')
|
14 |
-
if launcher == 'pytorch':
|
15 |
-
_init_dist_pytorch(backend, **kwargs)
|
16 |
-
elif launcher == 'slurm':
|
17 |
-
_init_dist_slurm(backend, **kwargs)
|
18 |
-
else:
|
19 |
-
raise ValueError(f'Invalid launcher type: {launcher}')
|
20 |
-
|
21 |
-
|
22 |
-
def _init_dist_pytorch(backend, **kwargs):
|
23 |
-
rank = int(os.environ['RANK'])
|
24 |
-
num_gpus = torch.cuda.device_count()
|
25 |
-
torch.cuda.set_device(rank % num_gpus)
|
26 |
-
dist.init_process_group(backend=backend, **kwargs)
|
27 |
-
|
28 |
-
|
29 |
-
def _init_dist_slurm(backend, port=None):
|
30 |
-
"""Initialize slurm distributed training environment.
|
31 |
-
|
32 |
-
If argument ``port`` is not specified, then the master port will be system
|
33 |
-
environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
|
34 |
-
environment variable, then a default port ``29500`` will be used.
|
35 |
-
|
36 |
-
Args:
|
37 |
-
backend (str): Backend of torch.distributed.
|
38 |
-
port (int, optional): Master port. Defaults to None.
|
39 |
-
"""
|
40 |
-
proc_id = int(os.environ['SLURM_PROCID'])
|
41 |
-
ntasks = int(os.environ['SLURM_NTASKS'])
|
42 |
-
node_list = os.environ['SLURM_NODELIST']
|
43 |
-
num_gpus = torch.cuda.device_count()
|
44 |
-
torch.cuda.set_device(proc_id % num_gpus)
|
45 |
-
addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1')
|
46 |
-
# specify master port
|
47 |
-
if port is not None:
|
48 |
-
os.environ['MASTER_PORT'] = str(port)
|
49 |
-
elif 'MASTER_PORT' in os.environ:
|
50 |
-
pass # use MASTER_PORT in the environment variable
|
51 |
-
else:
|
52 |
-
# 29500 is torch.distributed default port
|
53 |
-
os.environ['MASTER_PORT'] = '29500'
|
54 |
-
os.environ['MASTER_ADDR'] = addr
|
55 |
-
os.environ['WORLD_SIZE'] = str(ntasks)
|
56 |
-
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
|
57 |
-
os.environ['RANK'] = str(proc_id)
|
58 |
-
dist.init_process_group(backend=backend)
|
59 |
-
|
60 |
-
|
61 |
-
def get_dist_info():
|
62 |
-
if dist.is_available():
|
63 |
-
initialized = dist.is_initialized()
|
64 |
-
else:
|
65 |
-
initialized = False
|
66 |
-
if initialized:
|
67 |
-
rank = dist.get_rank()
|
68 |
-
world_size = dist.get_world_size()
|
69 |
-
else:
|
70 |
-
rank = 0
|
71 |
-
world_size = 1
|
72 |
-
return rank, world_size
|
73 |
-
|
74 |
-
|
75 |
-
def master_only(func):
|
76 |
-
|
77 |
-
@functools.wraps(func)
|
78 |
-
def wrapper(*args, **kwargs):
|
79 |
-
rank, _ = get_dist_info()
|
80 |
-
if rank == 0:
|
81 |
-
return func(*args, **kwargs)
|
82 |
-
|
83 |
-
return wrapper
|
84 |
-
|
85 |
-
def get_bare_model(net):
|
86 |
-
"""Get bare model, especially under wrapping with
|
87 |
-
DistributedDataParallel or DataParallel.
|
88 |
-
"""
|
89 |
-
if isinstance(net, (DataParallel, DistributedDataParallel)):
|
90 |
-
net = net.module
|
91 |
-
return net
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
environment.yaml
CHANGED
File without changes
|
examples/edit_cat/edge.png
DELETED
Binary file (5.98 kB)
|
|
examples/edit_cat/edge_2.png
DELETED
Binary file (13.3 kB)
|
|
examples/edit_cat/im.png
DELETED
Binary file (508 kB)
|
|
examples/edit_cat/mask.png
DELETED
Binary file (4.65 kB)
|
|
examples/keypose/iron.png
DELETED
Binary file (15.6 kB)
|
|
examples/seg/dinner.png
DELETED
Binary file (17.8 kB)
|
|
examples/seg/motor.png
DELETED
Binary file (20.9 kB)
|
|
examples/seg_sketch/edge.png
DELETED
Binary file (12.9 kB)
|
|
examples/seg_sketch/mask.png
DELETED
Binary file (22.2 kB)
|
|
examples/sketch/car.png
DELETED
Binary file (13.2 kB)
|
|
examples/sketch/girl.jpeg
DELETED
Binary file (214 kB)
|
|
examples/sketch/human.png
DELETED
Binary file (768 kB)
|
|
examples/sketch/scenery.jpg
DELETED
Binary file (99.8 kB)
|
|
examples/sketch/scenery2.jpg
DELETED
Binary file (144 kB)
|
|
gradio_keypose.py
DELETED
@@ -1,254 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import os.path as osp
|
3 |
-
|
4 |
-
import cv2
|
5 |
-
import numpy as np
|
6 |
-
import torch
|
7 |
-
from basicsr.utils import img2tensor, tensor2img
|
8 |
-
from pytorch_lightning import seed_everything
|
9 |
-
from ldm.models.diffusion.plms import PLMSSampler
|
10 |
-
from ldm.modules.encoders.adapter import Adapter
|
11 |
-
from ldm.util import instantiate_from_config
|
12 |
-
from model_edge import pidinet
|
13 |
-
import gradio as gr
|
14 |
-
from omegaconf import OmegaConf
|
15 |
-
import mmcv
|
16 |
-
from mmdet.apis import inference_detector, init_detector
|
17 |
-
from mmpose.apis import (inference_top_down_pose_model, init_pose_model, process_mmdet_results, vis_pose_result)
|
18 |
-
|
19 |
-
skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9], [8, 10],
|
20 |
-
[1, 2], [0, 1], [0, 2], [1, 3], [2, 4], [3, 5], [4, 6]]
|
21 |
-
|
22 |
-
pose_kpt_color = [[51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [0, 255, 0],
|
23 |
-
[255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0],
|
24 |
-
[0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0]]
|
25 |
-
|
26 |
-
pose_link_color = [[0, 255, 0], [0, 255, 0], [255, 128, 0], [255, 128, 0],
|
27 |
-
[51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [0, 255, 0], [255, 128, 0],
|
28 |
-
[0, 255, 0], [255, 128, 0], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255],
|
29 |
-
[51, 153, 255], [51, 153, 255], [51, 153, 255]]
|
30 |
-
|
31 |
-
def imshow_keypoints(img,
|
32 |
-
pose_result,
|
33 |
-
skeleton=None,
|
34 |
-
kpt_score_thr=0.1,
|
35 |
-
pose_kpt_color=None,
|
36 |
-
pose_link_color=None,
|
37 |
-
radius=4,
|
38 |
-
thickness=1):
|
39 |
-
"""Draw keypoints and links on an image.
|
40 |
-
|
41 |
-
Args:
|
42 |
-
img (ndarry): The image to draw poses on.
|
43 |
-
pose_result (list[kpts]): The poses to draw. Each element kpts is
|
44 |
-
a set of K keypoints as an Kx3 numpy.ndarray, where each
|
45 |
-
keypoint is represented as x, y, score.
|
46 |
-
kpt_score_thr (float, optional): Minimum score of keypoints
|
47 |
-
to be shown. Default: 0.3.
|
48 |
-
pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None,
|
49 |
-
the keypoint will not be drawn.
|
50 |
-
pose_link_color (np.array[Mx3]): Color of M links. If None, the
|
51 |
-
links will not be drawn.
|
52 |
-
thickness (int): Thickness of lines.
|
53 |
-
"""
|
54 |
-
|
55 |
-
img_h, img_w, _ = img.shape
|
56 |
-
img = np.zeros(img.shape)
|
57 |
-
|
58 |
-
for idx, kpts in enumerate(pose_result):
|
59 |
-
if idx > 1:
|
60 |
-
continue
|
61 |
-
kpts = kpts['keypoints']
|
62 |
-
# print(kpts)
|
63 |
-
kpts = np.array(kpts, copy=False)
|
64 |
-
|
65 |
-
# draw each point on image
|
66 |
-
if pose_kpt_color is not None:
|
67 |
-
assert len(pose_kpt_color) == len(kpts)
|
68 |
-
|
69 |
-
for kid, kpt in enumerate(kpts):
|
70 |
-
x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2]
|
71 |
-
|
72 |
-
if kpt_score < kpt_score_thr or pose_kpt_color[kid] is None:
|
73 |
-
# skip the point that should not be drawn
|
74 |
-
continue
|
75 |
-
|
76 |
-
color = tuple(int(c) for c in pose_kpt_color[kid])
|
77 |
-
cv2.circle(img, (int(x_coord), int(y_coord)), radius, color, -1)
|
78 |
-
|
79 |
-
# draw links
|
80 |
-
if skeleton is not None and pose_link_color is not None:
|
81 |
-
assert len(pose_link_color) == len(skeleton)
|
82 |
-
|
83 |
-
for sk_id, sk in enumerate(skeleton):
|
84 |
-
pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1]))
|
85 |
-
pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1]))
|
86 |
-
|
87 |
-
if (pos1[0] <= 0 or pos1[0] >= img_w or pos1[1] <= 0 or pos1[1] >= img_h or pos2[0] <= 0
|
88 |
-
or pos2[0] >= img_w or pos2[1] <= 0 or pos2[1] >= img_h or kpts[sk[0], 2] < kpt_score_thr
|
89 |
-
or kpts[sk[1], 2] < kpt_score_thr or pose_link_color[sk_id] is None):
|
90 |
-
# skip the link that should not be drawn
|
91 |
-
continue
|
92 |
-
color = tuple(int(c) for c in pose_link_color[sk_id])
|
93 |
-
cv2.line(img, pos1, pos2, color, thickness=thickness)
|
94 |
-
|
95 |
-
return img
|
96 |
-
|
97 |
-
def load_model_from_config(config, ckpt, verbose=False):
|
98 |
-
print(f"Loading model from {ckpt}")
|
99 |
-
pl_sd = torch.load(ckpt, map_location="cpu")
|
100 |
-
if "global_step" in pl_sd:
|
101 |
-
print(f"Global Step: {pl_sd['global_step']}")
|
102 |
-
if "state_dict" in pl_sd:
|
103 |
-
sd = pl_sd["state_dict"]
|
104 |
-
else:
|
105 |
-
sd = pl_sd
|
106 |
-
model = instantiate_from_config(config.model)
|
107 |
-
m, u = model.load_state_dict(sd, strict=False)
|
108 |
-
|
109 |
-
model.cuda()
|
110 |
-
model.eval()
|
111 |
-
return model
|
112 |
-
|
113 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
114 |
-
config = OmegaConf.load("configs/stable-diffusion/test_keypose.yaml")
|
115 |
-
config.model.params.cond_stage_config.params.device = device
|
116 |
-
model = load_model_from_config(config, "models/sd-v1-4.ckpt").to(device)
|
117 |
-
current_base = 'sd-v1-4.ckpt'
|
118 |
-
model_ad = Adapter(cin=int(3*64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
|
119 |
-
model_ad.load_state_dict(torch.load("models/t2iadapter_keypose_sd14v1.pth"))
|
120 |
-
sampler = PLMSSampler(model)
|
121 |
-
## mmpose
|
122 |
-
det_config = 'models/faster_rcnn_r50_fpn_coco.py'
|
123 |
-
det_checkpoint = 'models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
|
124 |
-
pose_config = 'models/hrnet_w48_coco_256x192.py'
|
125 |
-
pose_checkpoint = 'models/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth'
|
126 |
-
det_cat_id = 1
|
127 |
-
bbox_thr = 0.2
|
128 |
-
## detector
|
129 |
-
det_config_mmcv = mmcv.Config.fromfile(det_config)
|
130 |
-
det_model = init_detector(det_config_mmcv, det_checkpoint, device=device)
|
131 |
-
pose_config_mmcv = mmcv.Config.fromfile(pose_config)
|
132 |
-
pose_model = init_pose_model(pose_config_mmcv, pose_checkpoint, device=device)
|
133 |
-
W, H = 512, 512
|
134 |
-
|
135 |
-
|
136 |
-
def process(input_img, type_in, prompt, neg_prompt, fix_sample, scale, con_strength, base_model):
|
137 |
-
global current_base
|
138 |
-
if current_base != base_model:
|
139 |
-
ckpt = os.path.join("models", base_model)
|
140 |
-
pl_sd = torch.load(ckpt, map_location="cpu")
|
141 |
-
if "state_dict" in pl_sd:
|
142 |
-
sd = pl_sd["state_dict"]
|
143 |
-
else:
|
144 |
-
sd = pl_sd
|
145 |
-
model.load_state_dict(sd, strict=False)
|
146 |
-
current_base = base_model
|
147 |
-
con_strength = int((1-con_strength)*50)
|
148 |
-
if fix_sample == 'True':
|
149 |
-
seed_everything(42)
|
150 |
-
im = cv2.resize(input_img,(W,H))
|
151 |
-
|
152 |
-
if type_in == 'Keypose':
|
153 |
-
im_pose = im.copy()
|
154 |
-
im = img2tensor(im).unsqueeze(0)/255.
|
155 |
-
elif type_in == 'Image':
|
156 |
-
image = im.copy()
|
157 |
-
im = img2tensor(im).unsqueeze(0)/255.
|
158 |
-
mmdet_results = inference_detector(det_model, image)
|
159 |
-
# keep the person class bounding boxes.
|
160 |
-
person_results = process_mmdet_results(mmdet_results, det_cat_id)
|
161 |
-
|
162 |
-
# optional
|
163 |
-
return_heatmap = False
|
164 |
-
dataset = pose_model.cfg.data['test']['type']
|
165 |
-
|
166 |
-
# e.g. use ('backbone', ) to return backbone feature
|
167 |
-
output_layer_names = None
|
168 |
-
pose_results, returned_outputs = inference_top_down_pose_model(
|
169 |
-
pose_model,
|
170 |
-
image,
|
171 |
-
person_results,
|
172 |
-
bbox_thr=bbox_thr,
|
173 |
-
format='xyxy',
|
174 |
-
dataset=dataset,
|
175 |
-
dataset_info=None,
|
176 |
-
return_heatmap=return_heatmap,
|
177 |
-
outputs=output_layer_names)
|
178 |
-
|
179 |
-
# show the results
|
180 |
-
im_pose = imshow_keypoints(
|
181 |
-
image,
|
182 |
-
pose_results,
|
183 |
-
skeleton=skeleton,
|
184 |
-
pose_kpt_color=pose_kpt_color,
|
185 |
-
pose_link_color=pose_link_color,
|
186 |
-
radius=2,
|
187 |
-
thickness=2)
|
188 |
-
im_pose = cv2.resize(im_pose,(W,H))
|
189 |
-
|
190 |
-
with torch.no_grad():
|
191 |
-
c = model.get_learned_conditioning([prompt])
|
192 |
-
nc = model.get_learned_conditioning([neg_prompt])
|
193 |
-
# extract condition features
|
194 |
-
pose = img2tensor(im_pose, bgr2rgb=True, float32=True)/255.
|
195 |
-
pose = pose.unsqueeze(0)
|
196 |
-
features_adapter = model_ad(pose.to(device))
|
197 |
-
|
198 |
-
shape = [4, W//8, H//8]
|
199 |
-
|
200 |
-
# sampling
|
201 |
-
samples_ddim, _ = sampler.sample(S=50,
|
202 |
-
conditioning=c,
|
203 |
-
batch_size=1,
|
204 |
-
shape=shape,
|
205 |
-
verbose=False,
|
206 |
-
unconditional_guidance_scale=scale,
|
207 |
-
unconditional_conditioning=nc,
|
208 |
-
eta=0.0,
|
209 |
-
x_T=None,
|
210 |
-
features_adapter1=features_adapter,
|
211 |
-
mode = 'sketch',
|
212 |
-
con_strength = con_strength)
|
213 |
-
|
214 |
-
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
215 |
-
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
216 |
-
x_samples_ddim = x_samples_ddim.to('cpu')
|
217 |
-
x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
|
218 |
-
x_samples_ddim = 255.*x_samples_ddim
|
219 |
-
x_samples_ddim = x_samples_ddim.astype(np.uint8)
|
220 |
-
|
221 |
-
return [im_pose[:,:,::-1].astype(np.uint8), x_samples_ddim]
|
222 |
-
|
223 |
-
DESCRIPTION = '''# T2I-Adapter (Keypose)
|
224 |
-
[Paper](https://arxiv.org/abs/2302.08453) [GitHub](https://github.com/TencentARC/T2I-Adapter)
|
225 |
-
|
226 |
-
This gradio demo is for keypose-guided generation. The current functions include:
|
227 |
-
- Keypose to Image Generation
|
228 |
-
- Image to Image Generation
|
229 |
-
- Generation with **Anything** setting
|
230 |
-
'''
|
231 |
-
block = gr.Blocks().queue()
|
232 |
-
with block:
|
233 |
-
with gr.Row():
|
234 |
-
gr.Markdown(DESCRIPTION)
|
235 |
-
with gr.Row():
|
236 |
-
with gr.Column():
|
237 |
-
input_img = gr.Image(source='upload', type="numpy")
|
238 |
-
prompt = gr.Textbox(label="Prompt")
|
239 |
-
neg_prompt = gr.Textbox(label="Negative Prompt",
|
240 |
-
value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
|
241 |
-
with gr.Row():
|
242 |
-
type_in = gr.inputs.Radio(['Keypose', 'Image'], type="value", default='Image', label='Input Types\n (You can input an image or a keypose map)')
|
243 |
-
fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed to produce a fixed output)')
|
244 |
-
run_button = gr.Button(label="Run")
|
245 |
-
con_strength = gr.Slider(label="Controling Strength (The guidance strength of the keypose to the result)", minimum=0, maximum=1, value=1, step=0.1)
|
246 |
-
scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=9, step=0.1)
|
247 |
-
base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
|
248 |
-
with gr.Column():
|
249 |
-
result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
250 |
-
ips = [input_img, type_in, prompt, neg_prompt, fix_sample, scale, con_strength, base_model]
|
251 |
-
run_button.click(fn=process, inputs=ips, outputs=[result])
|
252 |
-
|
253 |
-
block.launch(server_name='0.0.0.0')
|
254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gradio_sketch.py
DELETED
@@ -1,147 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import os.path as osp
|
3 |
-
|
4 |
-
import cv2
|
5 |
-
import numpy as np
|
6 |
-
import torch
|
7 |
-
from basicsr.utils import img2tensor, tensor2img
|
8 |
-
from pytorch_lightning import seed_everything
|
9 |
-
from ldm.models.diffusion.plms import PLMSSampler
|
10 |
-
from ldm.modules.encoders.adapter import Adapter
|
11 |
-
from ldm.util import instantiate_from_config
|
12 |
-
from model_edge import pidinet
|
13 |
-
import gradio as gr
|
14 |
-
from omegaconf import OmegaConf
|
15 |
-
|
16 |
-
|
17 |
-
def load_model_from_config(config, ckpt, verbose=False):
|
18 |
-
print(f"Loading model from {ckpt}")
|
19 |
-
pl_sd = torch.load(ckpt, map_location="cpu")
|
20 |
-
if "global_step" in pl_sd:
|
21 |
-
print(f"Global Step: {pl_sd['global_step']}")
|
22 |
-
if "state_dict" in pl_sd:
|
23 |
-
sd = pl_sd["state_dict"]
|
24 |
-
else:
|
25 |
-
sd = pl_sd
|
26 |
-
model = instantiate_from_config(config.model)
|
27 |
-
m, u = model.load_state_dict(sd, strict=False)
|
28 |
-
# if len(m) > 0 and verbose:
|
29 |
-
# print("missing keys:")
|
30 |
-
# print(m)
|
31 |
-
# if len(u) > 0 and verbose:
|
32 |
-
# print("unexpected keys:")
|
33 |
-
# print(u)
|
34 |
-
|
35 |
-
model.cuda()
|
36 |
-
model.eval()
|
37 |
-
return model
|
38 |
-
|
39 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
40 |
-
config = OmegaConf.load("configs/stable-diffusion/test_sketch.yaml")
|
41 |
-
config.model.params.cond_stage_config.params.device = device
|
42 |
-
model = load_model_from_config(config, "models/sd-v1-4.ckpt").to(device)
|
43 |
-
current_base = 'sd-v1-4.ckpt'
|
44 |
-
model_ad = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device)
|
45 |
-
model_ad.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth"))
|
46 |
-
net_G = pidinet()
|
47 |
-
ckp = torch.load('models/table5_pidinet.pth', map_location='cpu')['state_dict']
|
48 |
-
net_G.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()})
|
49 |
-
net_G.to(device)
|
50 |
-
sampler = PLMSSampler(model)
|
51 |
-
save_memory=True
|
52 |
-
W, H = 512, 512
|
53 |
-
|
54 |
-
|
55 |
-
def process(input_img, type_in, color_back, prompt, neg_prompt, fix_sample, scale, con_strength, base_model):
|
56 |
-
global current_base
|
57 |
-
if current_base != base_model:
|
58 |
-
ckpt = os.path.join("models", base_model)
|
59 |
-
pl_sd = torch.load(ckpt, map_location="cpu")
|
60 |
-
if "state_dict" in pl_sd:
|
61 |
-
sd = pl_sd["state_dict"]
|
62 |
-
else:
|
63 |
-
sd = pl_sd
|
64 |
-
model.load_state_dict(sd, strict=False) #load_model_from_config(config, os.path.join("models", base_model)).to(device)
|
65 |
-
current_base = base_model
|
66 |
-
con_strength = int((1-con_strength)*50)
|
67 |
-
if fix_sample == 'True':
|
68 |
-
seed_everything(42)
|
69 |
-
im = cv2.resize(input_img,(W,H))
|
70 |
-
|
71 |
-
if type_in == 'Sketch':
|
72 |
-
if color_back == 'White':
|
73 |
-
im = 255-im
|
74 |
-
im_edge = im.copy()
|
75 |
-
im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0)/255.
|
76 |
-
im = im>0.5
|
77 |
-
im = im.float()
|
78 |
-
elif type_in == 'Image':
|
79 |
-
im = img2tensor(im).unsqueeze(0)/255.
|
80 |
-
im = net_G(im.to(device))[-1]
|
81 |
-
im = im>0.5
|
82 |
-
im = im.float()
|
83 |
-
im_edge = tensor2img(im)
|
84 |
-
|
85 |
-
with torch.no_grad():
|
86 |
-
c = model.get_learned_conditioning([prompt])
|
87 |
-
nc = model.get_learned_conditioning([neg_prompt])
|
88 |
-
# extract condition features
|
89 |
-
features_adapter = model_ad(im.to(device))
|
90 |
-
shape = [4, W//8, H//8]
|
91 |
-
|
92 |
-
# sampling
|
93 |
-
samples_ddim, _ = sampler.sample(S=50,
|
94 |
-
conditioning=c,
|
95 |
-
batch_size=1,
|
96 |
-
shape=shape,
|
97 |
-
verbose=False,
|
98 |
-
unconditional_guidance_scale=scale,
|
99 |
-
unconditional_conditioning=nc,
|
100 |
-
eta=0.0,
|
101 |
-
x_T=None,
|
102 |
-
features_adapter1=features_adapter,
|
103 |
-
mode = 'sketch',
|
104 |
-
con_strength = con_strength)
|
105 |
-
|
106 |
-
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
107 |
-
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
108 |
-
x_samples_ddim = x_samples_ddim.to('cpu')
|
109 |
-
x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0]
|
110 |
-
x_samples_ddim = 255.*x_samples_ddim
|
111 |
-
x_samples_ddim = x_samples_ddim.astype(np.uint8)
|
112 |
-
|
113 |
-
return [im_edge, x_samples_ddim]
|
114 |
-
|
115 |
-
DESCRIPTION = '''# T2I-Adapter (Sketch)
|
116 |
-
[Paper](https://arxiv.org/abs/2302.08453) [GitHub](https://github.com/TencentARC/T2I-Adapter)
|
117 |
-
|
118 |
-
This gradio demo is for sketch-guided generation. The current functions include:
|
119 |
-
- Sketch to Image Generation
|
120 |
-
- Image to Image Generation
|
121 |
-
- Generation with **Anything** setting
|
122 |
-
'''
|
123 |
-
block = gr.Blocks().queue()
|
124 |
-
with block:
|
125 |
-
with gr.Row():
|
126 |
-
gr.Markdown(DESCRIPTION)
|
127 |
-
with gr.Row():
|
128 |
-
with gr.Column():
|
129 |
-
input_img = gr.Image(source='upload', type="numpy")
|
130 |
-
prompt = gr.Textbox(label="Prompt")
|
131 |
-
neg_prompt = gr.Textbox(label="Negative Prompt",
|
132 |
-
value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face')
|
133 |
-
with gr.Row():
|
134 |
-
type_in = gr.inputs.Radio(['Sketch', 'Image'], type="value", default='Image', label='Input Types\n (You can input an image or a sketch)')
|
135 |
-
color_back = gr.inputs.Radio(['White', 'Black'], type="value", default='Black', label='Color of the sketch background\n (Only work for sketch input)')
|
136 |
-
run_button = gr.Button(label="Run")
|
137 |
-
con_strength = gr.Slider(label="Controling Strength (The guidance strength of the sketch to the result)", minimum=0, maximum=1, value=0.4, step=0.1)
|
138 |
-
scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=9, step=0.1)
|
139 |
-
fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)')
|
140 |
-
base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use')
|
141 |
-
with gr.Column():
|
142 |
-
result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto')
|
143 |
-
ips = [input_img, type_in, color_back, prompt, neg_prompt, fix_sample, scale, con_strength, base_model]
|
144 |
-
run_button.click(fn=process, inputs=ips, outputs=[result])
|
145 |
-
|
146 |
-
block.launch(server_name='0.0.0.0')
|
147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ldm/data/__init__.py
CHANGED
File without changes
|
ldm/data/base.py
CHANGED
File without changes
|
ldm/data/imagenet.py
CHANGED
File without changes
|
ldm/data/lsun.py
CHANGED
File without changes
|
ldm/lr_scheduler.py
CHANGED
File without changes
|
ldm/models/autoencoder.py
CHANGED
File without changes
|
ldm/models/diffusion/__init__.py
CHANGED
File without changes
|
ldm/models/diffusion/classifier.py
CHANGED
File without changes
|
ldm/models/diffusion/ddim.py
CHANGED
File without changes
|
ldm/models/diffusion/ddpm.py
CHANGED
File without changes
|
ldm/models/diffusion/dpm_solver/__init__.py
CHANGED
File without changes
|
ldm/models/diffusion/dpm_solver/dpm_solver.py
CHANGED
File without changes
|
ldm/models/diffusion/dpm_solver/sampler.py
CHANGED
File without changes
|
ldm/models/diffusion/plms.py
CHANGED
File without changes
|
ldm/modules/attention.py
CHANGED
File without changes
|