diff --git a/.gitattributes b/.gitattributes old mode 100644 new mode 100755 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..3c76fd4add404777ce8d8130784b3ad40911382c --- /dev/null +++ b/.gitignore @@ -0,0 +1,128 @@ +# ignored folders +models + +# ignored folders +tmp/* + +*.DS_Store +.idea + +# ignored files +version.py + +# ignored files with suffix +# *.html +# *.png +# *.jpeg +# *.jpg +# *.gif +# *.pth +# *.zip + +# template + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.pyc +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ \ No newline at end of file diff --git a/LICENSE b/LICENSE old mode 100644 new mode 100755 diff --git a/README.md b/README.md old mode 100644 new mode 100755 diff --git a/app.py b/app.py index a10c15a0f5db01580eb65bd317c80c494e2f99ae..48f05574e6d68ca012bf032de27ac884609cf793 100755 --- a/app.py +++ b/app.py @@ -8,14 +8,14 @@ os.system('mim install mmcv-full==1.7.0') from demo.model import Model_all import gradio as gr -from demo.demos import create_demo_keypose, create_demo_sketch, create_demo_draw, create_demo_seg +from demo.demos import create_demo_keypose, create_demo_sketch, create_demo_draw, create_demo_seg, create_demo_depth import torch import subprocess import shlex from huggingface_hub import hf_hub_url urls = { - 'TencentARC/T2I-Adapter':['models/t2iadapter_keypose_sd14v1.pth', 'models/t2iadapter_seg_sd14v1.pth', 'models/t2iadapter_sketch_sd14v1.pth'], + 'TencentARC/T2I-Adapter':['models/t2iadapter_keypose_sd14v1.pth', 'models/t2iadapter_seg_sd14v1.pth', 'models/t2iadapter_sketch_sd14v1.pth', 'models/t2iadapter_depth_sd14v1.pth'], 'CompVis/stable-diffusion-v-1-4-original':['sd-v1-4.ckpt'], 'andite/anything-v4.0':['anything-v4.0-pruned.ckpt', 'anything-v4.0.vae.pt'], } @@ -72,5 +72,7 @@ with gr.Blocks(css='style.css') as demo: create_demo_draw(model.process_draw) with gr.TabItem('Segmentation'): create_demo_seg(model.process_seg) + with gr.TabItem('Depth'): + create_demo_depth(model.process_depth) demo.queue().launch(debug=True, server_name='0.0.0.0') \ No newline at end of file diff --git a/configs/stable-diffusion/app.yaml b/configs/stable-diffusion/app.yaml old mode 100644 new mode 100755 diff --git a/configs/stable-diffusion/test_keypose.yaml b/configs/stable-diffusion/test_keypose.yaml deleted file mode 100644 index cdd57e17f59c133f6c1ca10e7235a6b9961f5108..0000000000000000000000000000000000000000 --- a/configs/stable-diffusion/test_keypose.yaml +++ /dev/null @@ -1,87 +0,0 @@ -name: test_keypose -model: - base_learning_rate: 1.0e-04 - target: ldm.models.diffusion.ddpm.LatentDiffusion - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false # Note: different from the one we trained before - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False - - scheduler_config: # 10000 warmup steps - target: ldm.lr_scheduler.LambdaLinearScheduler - params: - warm_up_steps: [ 10000 ] - cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases - f_start: [ 1.e-6 ] - f_max: [ 1. ] - f_min: [ 1. ] - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 - use_spatial_transformer: True - transformer_depth: 1 - context_dim: 768 - use_checkpoint: True - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: #__is_unconditional__ - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder - params: - version: models/clip-vit-large-patch14 - -logger: - print_freq: 100 - save_checkpoint_freq: !!float 1e4 - use_tb_logger: true - wandb: - project: ~ - resume_id: ~ -dist_params: - backend: nccl - port: 29500 -training: - lr: !!float 1e-5 - save_freq: 1e4 \ No newline at end of file diff --git a/configs/stable-diffusion/test_mask.yaml b/configs/stable-diffusion/test_mask.yaml deleted file mode 100644 index f2d6f40ac7bd15131f0930919434f6b828dac83c..0000000000000000000000000000000000000000 --- a/configs/stable-diffusion/test_mask.yaml +++ /dev/null @@ -1,87 +0,0 @@ -name: test_mask -model: - base_learning_rate: 1.0e-04 - target: ldm.models.diffusion.ddpm.LatentDiffusion - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false # Note: different from the one we trained before - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False - - scheduler_config: # 10000 warmup steps - target: ldm.lr_scheduler.LambdaLinearScheduler - params: - warm_up_steps: [ 10000 ] - cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases - f_start: [ 1.e-6 ] - f_max: [ 1. ] - f_min: [ 1. ] - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 - use_spatial_transformer: True - transformer_depth: 1 - context_dim: 768 - use_checkpoint: True - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: #__is_unconditional__ - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder - params: - version: models/clip-vit-large-patch14 - -logger: - print_freq: 100 - save_checkpoint_freq: !!float 1e4 - use_tb_logger: true - wandb: - project: ~ - resume_id: ~ -dist_params: - backend: nccl - port: 29500 -training: - lr: !!float 1e-5 - save_freq: 1e4 \ No newline at end of file diff --git a/configs/stable-diffusion/test_mask_sketch.yaml b/configs/stable-diffusion/test_mask_sketch.yaml deleted file mode 100644 index fc5f3a0199fd7a703f9279d4f2f981c7fab7e850..0000000000000000000000000000000000000000 --- a/configs/stable-diffusion/test_mask_sketch.yaml +++ /dev/null @@ -1,87 +0,0 @@ -name: test_mask_sketch -model: - base_learning_rate: 1.0e-04 - target: ldm.models.diffusion.ddpm.LatentDiffusion - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false # Note: different from the one we trained before - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False - - scheduler_config: # 10000 warmup steps - target: ldm.lr_scheduler.LambdaLinearScheduler - params: - warm_up_steps: [ 10000 ] - cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases - f_start: [ 1.e-6 ] - f_max: [ 1. ] - f_min: [ 1. ] - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 - use_spatial_transformer: True - transformer_depth: 1 - context_dim: 768 - use_checkpoint: True - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: #__is_unconditional__ - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder - params: - version: models/clip-vit-large-patch14 - -logger: - print_freq: 100 - save_checkpoint_freq: !!float 1e4 - use_tb_logger: true - wandb: - project: ~ - resume_id: ~ -dist_params: - backend: nccl - port: 29500 -training: - lr: !!float 1e-5 - save_freq: 1e4 \ No newline at end of file diff --git a/configs/stable-diffusion/test_sketch.yaml b/configs/stable-diffusion/test_sketch.yaml deleted file mode 100644 index 7b92d28668e4da3942355cc495647952a1b099d7..0000000000000000000000000000000000000000 --- a/configs/stable-diffusion/test_sketch.yaml +++ /dev/null @@ -1,87 +0,0 @@ -name: test_sketch -model: - base_learning_rate: 1.0e-04 - target: ldm.models.diffusion.ddpm.LatentDiffusion - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false # Note: different from the one we trained before - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False - - scheduler_config: # 10000 warmup steps - target: ldm.lr_scheduler.LambdaLinearScheduler - params: - warm_up_steps: [ 10000 ] - cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases - f_start: [ 1.e-6 ] - f_max: [ 1. ] - f_min: [ 1. ] - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 - use_spatial_transformer: True - transformer_depth: 1 - context_dim: 768 - use_checkpoint: True - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: #__is_unconditional__ - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder - params: - version: models/clip-vit-large-patch14 - -logger: - print_freq: 100 - save_checkpoint_freq: !!float 1e4 - use_tb_logger: true - wandb: - project: ~ - resume_id: ~ -dist_params: - backend: nccl - port: 29500 -training: - lr: !!float 1e-5 - save_freq: 1e4 \ No newline at end of file diff --git a/configs/stable-diffusion/test_sketch_edit.yaml b/configs/stable-diffusion/test_sketch_edit.yaml deleted file mode 100644 index c6583d00017ad23f783cf50230e3738e25f7630a..0000000000000000000000000000000000000000 --- a/configs/stable-diffusion/test_sketch_edit.yaml +++ /dev/null @@ -1,87 +0,0 @@ -name: test_sketch_edit -model: - base_learning_rate: 1.0e-04 - target: ldm.models.diffusion.ddpm.LatentDiffusion - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false # Note: different from the one we trained before - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False - - scheduler_config: # 10000 warmup steps - target: ldm.lr_scheduler.LambdaLinearScheduler - params: - warm_up_steps: [ 10000 ] - cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases - f_start: [ 1.e-6 ] - f_max: [ 1. ] - f_min: [ 1. ] - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 - use_spatial_transformer: True - transformer_depth: 1 - context_dim: 768 - use_checkpoint: True - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: #__is_unconditional__ - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder - params: - version: models/clip-vit-large-patch14 - -logger: - print_freq: 100 - save_checkpoint_freq: !!float 1e4 - use_tb_logger: true - wandb: - project: ~ - resume_id: ~ -dist_params: - backend: nccl - port: 29500 -training: - lr: !!float 1e-5 - save_freq: 1e4 \ No newline at end of file diff --git a/configs/stable-diffusion/train_keypose.yaml b/configs/stable-diffusion/train_keypose.yaml deleted file mode 100644 index c84c8fcd3ed9ad79c5e750259b56c263f8f70312..0000000000000000000000000000000000000000 --- a/configs/stable-diffusion/train_keypose.yaml +++ /dev/null @@ -1,87 +0,0 @@ -name: train_keypose -model: - base_learning_rate: 1.0e-04 - target: ldm.models.diffusion.ddpm.LatentDiffusion - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false # Note: different from the one we trained before - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False - - scheduler_config: # 10000 warmup steps - target: ldm.lr_scheduler.LambdaLinearScheduler - params: - warm_up_steps: [ 10000 ] - cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases - f_start: [ 1.e-6 ] - f_max: [ 1. ] - f_min: [ 1. ] - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 - use_spatial_transformer: True - transformer_depth: 1 - context_dim: 768 - use_checkpoint: True - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: #__is_unconditional__ - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder - params: - version: models/clip-vit-large-patch14 - -logger: - print_freq: 100 - save_checkpoint_freq: !!float 1e4 - use_tb_logger: true - wandb: - project: ~ - resume_id: ~ -dist_params: - backend: nccl - port: 29500 -training: - lr: !!float 1e-5 - save_freq: 1e4 \ No newline at end of file diff --git a/configs/stable-diffusion/train_mask.yaml b/configs/stable-diffusion/train_mask.yaml deleted file mode 100644 index 0b7a7e9385ee545eb06041d0e1c6217384e5d266..0000000000000000000000000000000000000000 --- a/configs/stable-diffusion/train_mask.yaml +++ /dev/null @@ -1,87 +0,0 @@ -name: train_mask -model: - base_learning_rate: 1.0e-04 - target: ldm.models.diffusion.ddpm.LatentDiffusion - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false # Note: different from the one we trained before - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False - - scheduler_config: # 10000 warmup steps - target: ldm.lr_scheduler.LambdaLinearScheduler - params: - warm_up_steps: [ 10000 ] - cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases - f_start: [ 1.e-6 ] - f_max: [ 1. ] - f_min: [ 1. ] - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 - use_spatial_transformer: True - transformer_depth: 1 - context_dim: 768 - use_checkpoint: True - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: #__is_unconditional__ - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder - params: - version: models/clip-vit-large-patch14 - -logger: - print_freq: 100 - save_checkpoint_freq: !!float 1e4 - use_tb_logger: true - wandb: - project: ~ - resume_id: ~ -dist_params: - backend: nccl - port: 29500 -training: - lr: !!float 1e-5 - save_freq: 1e4 \ No newline at end of file diff --git a/configs/stable-diffusion/train_sketch.yaml b/configs/stable-diffusion/train_sketch.yaml deleted file mode 100644 index 90da3b6a8f6e270d561934afa4d78f1157af8c81..0000000000000000000000000000000000000000 --- a/configs/stable-diffusion/train_sketch.yaml +++ /dev/null @@ -1,87 +0,0 @@ -name: train_sketch -model: - base_learning_rate: 1.0e-04 - target: ldm.models.diffusion.ddpm.LatentDiffusion - params: - linear_start: 0.00085 - linear_end: 0.0120 - num_timesteps_cond: 1 - log_every_t: 200 - timesteps: 1000 - first_stage_key: "jpg" - cond_stage_key: "txt" - image_size: 64 - channels: 4 - cond_stage_trainable: false # Note: different from the one we trained before - conditioning_key: crossattn - monitor: val/loss_simple_ema - scale_factor: 0.18215 - use_ema: False - - scheduler_config: # 10000 warmup steps - target: ldm.lr_scheduler.LambdaLinearScheduler - params: - warm_up_steps: [ 10000 ] - cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases - f_start: [ 1.e-6 ] - f_max: [ 1. ] - f_min: [ 1. ] - - unet_config: - target: ldm.modules.diffusionmodules.openaimodel.UNetModel - params: - image_size: 32 # unused - in_channels: 4 - out_channels: 4 - model_channels: 320 - attention_resolutions: [ 4, 2, 1 ] - num_res_blocks: 2 - channel_mult: [ 1, 2, 4, 4 ] - num_heads: 8 - use_spatial_transformer: True - transformer_depth: 1 - context_dim: 768 - use_checkpoint: True - legacy: False - - first_stage_config: - target: ldm.models.autoencoder.AutoencoderKL - params: - embed_dim: 4 - monitor: val/rec_loss - ddconfig: - double_z: true - z_channels: 4 - resolution: 256 - in_channels: 3 - out_ch: 3 - ch: 128 - ch_mult: - - 1 - - 2 - - 4 - - 4 - num_res_blocks: 2 - attn_resolutions: [] - dropout: 0.0 - lossconfig: - target: torch.nn.Identity - - cond_stage_config: #__is_unconditional__ - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder - params: - version: models/clip-vit-large-patch14 - -logger: - print_freq: 100 - save_checkpoint_freq: !!float 1e4 - use_tb_logger: true - wandb: - project: ~ - resume_id: ~ -dist_params: - backend: nccl - port: 29500 -training: - lr: !!float 1e-5 - save_freq: 1e4 \ No newline at end of file diff --git a/dataset_coco.py b/dataset_coco.py deleted file mode 100644 index 30ef5801d3a6d8526aa930add7a0c43d2af71a9f..0000000000000000000000000000000000000000 --- a/dataset_coco.py +++ /dev/null @@ -1,138 +0,0 @@ -import torch -import json -import cv2 -import torch -import os -from basicsr.utils import img2tensor, tensor2img -import random - -class dataset_coco(): - def __init__(self, path_json, root_path, image_size, mode='train'): - super(dataset_coco, self).__init__() - with open(path_json, 'r', encoding='utf-8') as fp: - data = json.load(fp) - data = data['images'] - self.paths = [] - self.root_path = root_path - for file in data: - input_path = file['filepath'] - if mode == 'train': - if 'val' not in input_path: - self.paths.append(file) - else: - if 'val' in input_path: - self.paths.append(file) - - def __getitem__(self, idx): - file = self.paths[idx] - input_path = file['filepath'] - input_name = file['filename'] - path = os.path.join(self.root_path, input_path, input_name) - im = cv2.imread(path) - im = cv2.resize(im, (512,512)) - im = img2tensor(im, bgr2rgb=True, float32=True)/255. - sentences = file['sentences'] - sentence = sentences[int(random.random()*len(sentences))]['raw'].strip('.') - return {'im':im, 'sentence':sentence} - - def __len__(self): - return len(self.paths) - - -class dataset_coco_mask(): - def __init__(self, path_json, root_path_im, root_path_mask, image_size): - super(dataset_coco_mask, self).__init__() - with open(path_json, 'r', encoding='utf-8') as fp: - data = json.load(fp) - data = data['annotations'] - self.files = [] - self.root_path_im = root_path_im - self.root_path_mask = root_path_mask - for file in data: - name = "%012d.png"%file['image_id'] - self.files.append({'name':name, 'sentence':file['caption']}) - - def __getitem__(self, idx): - file = self.files[idx] - name = file['name'] - # print(os.path.join(self.root_path_im, name)) - im = cv2.imread(os.path.join(self.root_path_im, name.replace('.png','.jpg'))) - im = cv2.resize(im, (512,512)) - im = img2tensor(im, bgr2rgb=True, float32=True)/255. - - mask = cv2.imread(os.path.join(self.root_path_mask, name))#[:,:,0] - mask = cv2.resize(mask, (512,512)) - mask = img2tensor(mask, bgr2rgb=True, float32=True)[0].unsqueeze(0)#/255. - - sentence = file['sentence'] - return {'im':im, 'mask':mask, 'sentence':sentence} - - def __len__(self): - return len(self.files) - - -class dataset_coco_mask_color(): - def __init__(self, path_json, root_path_im, root_path_mask, image_size): - super(dataset_coco_mask_color, self).__init__() - with open(path_json, 'r', encoding='utf-8') as fp: - data = json.load(fp) - data = data['annotations'] - self.files = [] - self.root_path_im = root_path_im - self.root_path_mask = root_path_mask - for file in data: - name = "%012d.png"%file['image_id'] - self.files.append({'name':name, 'sentence':file['caption']}) - - def __getitem__(self, idx): - file = self.files[idx] - name = file['name'] - # print(os.path.join(self.root_path_im, name)) - im = cv2.imread(os.path.join(self.root_path_im, name.replace('.png','.jpg'))) - im = cv2.resize(im, (512,512)) - im = img2tensor(im, bgr2rgb=True, float32=True)/255. - - mask = cv2.imread(os.path.join(self.root_path_mask, name))#[:,:,0] - mask = cv2.resize(mask, (512,512)) - mask = img2tensor(mask, bgr2rgb=True, float32=True)/255.#[0].unsqueeze(0)#/255. - - sentence = file['sentence'] - return {'im':im, 'mask':mask, 'sentence':sentence} - - def __len__(self): - return len(self.files) - -class dataset_coco_mask_color_sig(): - def __init__(self, path_json, root_path_im, root_path_mask, image_size): - super(dataset_coco_mask_color_sig, self).__init__() - with open(path_json, 'r', encoding='utf-8') as fp: - data = json.load(fp) - data = data['annotations'] - self.files = [] - self.root_path_im = root_path_im - self.root_path_mask = root_path_mask - reg = {} - for file in data: - name = "%012d.png"%file['image_id'] - if name in reg: - continue - self.files.append({'name':name, 'sentence':file['caption']}) - reg[name] = name - - def __getitem__(self, idx): - file = self.files[idx] - name = file['name'] - # print(os.path.join(self.root_path_im, name)) - im = cv2.imread(os.path.join(self.root_path_im, name.replace('.png','.jpg'))) - im = cv2.resize(im, (512,512)) - im = img2tensor(im, bgr2rgb=True, float32=True)/255. - - mask = cv2.imread(os.path.join(self.root_path_mask, name))#[:,:,0] - mask = cv2.resize(mask, (512,512)) - mask = img2tensor(mask, bgr2rgb=True, float32=True)/255.#[0].unsqueeze(0)#/255. - - sentence = file['sentence'] - return {'im':im, 'mask':mask, 'sentence':sentence, 'name': name} - - def __len__(self): - return len(self.files) \ No newline at end of file diff --git a/demo/demos.py b/demo/demos.py index 87e0a51d4c20d211143958c8dc9c732edea4d52d..140a112014651dce8f73c7bb0dc310fff0384b4a 100755 --- a/demo/demos.py +++ b/demo/demos.py @@ -85,7 +85,32 @@ def create_demo_seg(process): with gr.Row(): type_in = gr.inputs.Radio(['Segmentation', 'Image'], type="value", default='Image', label='You can input an image or a segmentation. If you choose to input a segmentation, it must correspond to the coco-stuff') run_button = gr.Button(label="Run") - con_strength = gr.Slider(label="Controling Strength (The guidance strength of the segmentation to the result)", minimum=0, maximum=1, value=0.4, step=0.1) + con_strength = gr.Slider(label="Controling Strength (The guidance strength of the segmentation to the result)", minimum=0, maximum=1, value=1, step=0.1) + scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1) + fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)') + base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use') + with gr.Column(): + result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto') + ips = [input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model] + run_button.click(fn=process, inputs=ips, outputs=[result]) + return demo + +def create_demo_depth(process): + with gr.Blocks() as demo: + with gr.Row(): + gr.Markdown('## T2I-Adapter (Depth)') + with gr.Row(): + with gr.Column(): + input_img = gr.Image(source='upload', type="numpy") + prompt = gr.Textbox(label="Prompt") + neg_prompt = gr.Textbox(label="Negative Prompt", + value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face') + pos_prompt = gr.Textbox(label="Positive Prompt", + value = 'crafted, elegant, meticulous, magnificent, maximum details, extremely hyper aesthetic, intricately detailed') + with gr.Row(): + type_in = gr.inputs.Radio(['Depth', 'Image'], type="value", default='Image', label='You can input an image or a depth map') + run_button = gr.Button(label="Run") + con_strength = gr.Slider(label="Controling Strength (The guidance strength of the depth map to the result)", minimum=0, maximum=1, value=1, step=0.1) scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=7.5, step=0.1) fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)') base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use') diff --git a/demo/model.py b/demo/model.py index 29eb90bf236cda9287f10e6be40a52e74ba40ab0..9ef3693e0d7906bc7bddaac40a0dfa0d10f15044 100755 --- a/demo/model.py +++ b/demo/model.py @@ -4,7 +4,9 @@ from pytorch_lightning import seed_everything from ldm.models.diffusion.plms import PLMSSampler from ldm.modules.encoders.adapter import Adapter from ldm.util import instantiate_from_config -from model_edge import pidinet +from ldm.modules.structure_condition.model_edge import pidinet +from ldm.modules.structure_condition.model_seg import seger, Colorize +from ldm.modules.structure_condition.midas.api import MiDaSInference import gradio as gr from omegaconf import OmegaConf import mmcv @@ -13,7 +15,6 @@ from mmpose.apis import (inference_top_down_pose_model, init_pose_model, process import os import cv2 import numpy as np -from seger import seger, Colorize import torch.nn.functional as F def preprocessing(image, device): @@ -136,10 +137,8 @@ class Model_all: self.model_sketch = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device) self.model_sketch.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth", map_location=device)) - self.model_edge = pidinet() - ckp = torch.load('models/table5_pidinet.pth', map_location='cpu')['state_dict'] - self.model_edge.load_state_dict({k.replace('module.', ''): v for k, v in ckp.items()}) - self.model_edge.to(device) + self.model_edge = pidinet().to(device) + self.model_edge.load_state_dict({k.replace('module.', ''): v for k, v in torch.load('models/table5_pidinet.pth', map_location=device)['state_dict'].items()}) # segmentation part self.model_seger = seger().to(device) @@ -147,6 +146,11 @@ class Model_all: self.coler = Colorize(n=182) self.model_seg = Adapter(cin=int(3*64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device) self.model_seg.load_state_dict(torch.load("models/t2iadapter_seg_sd14v1.pth", map_location=device)) + self.depth_model = MiDaSInference(model_type='dpt_hybrid').to(device) + + # depth part + self.model_depth = Adapter(cin=3*64, channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device) + self.model_depth.load_state_dict(torch.load("models/t2iadapter_depth_sd14v1.pth", map_location=device)) # keypose part self.model_pose = Adapter(cin=int(3 * 64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, @@ -248,6 +252,65 @@ class Model_all: return [im_edge, x_samples_ddim] + @torch.no_grad() + def process_depth(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, + con_strength, base_model): + if self.current_base != base_model: + ckpt = os.path.join("models", base_model) + pl_sd = torch.load(ckpt, map_location="cuda") + if "state_dict" in pl_sd: + sd = pl_sd["state_dict"] + else: + sd = pl_sd + self.base_model.load_state_dict(sd, strict=False) + self.current_base = base_model + if 'anything' in base_model.lower(): + self.load_vae() + + con_strength = int((1 - con_strength) * 50) + if fix_sample == 'True': + seed_everything(42) + im = cv2.resize(input_img, (512, 512)) + + if type_in == 'Depth': + im_depth = im.copy() + depth = img2tensor(im).unsqueeze(0) / 255. + elif type_in == 'Image': + im = img2tensor(im).unsqueeze(0) / 127.5 - 1.0 + depth = self.depth_model(im.to(self.device)).repeat(1, 3, 1, 1) + depth -= torch.min(depth) + depth /= torch.max(depth) + im_depth = tensor2img(depth) + + # extract condition features + c = self.base_model.get_learned_conditioning([prompt + ', ' + pos_prompt]) + nc = self.base_model.get_learned_conditioning([neg_prompt]) + features_adapter = self.model_depth(depth.to(self.device)) + shape = [4, 64, 64] + + # sampling + samples_ddim, _ = self.sampler.sample(S=50, + conditioning=c, + batch_size=1, + shape=shape, + verbose=False, + unconditional_guidance_scale=scale, + unconditional_conditioning=nc, + eta=0.0, + x_T=None, + features_adapter1=features_adapter, + mode='sketch', + con_strength=con_strength) + + x_samples_ddim = self.base_model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + x_samples_ddim = x_samples_ddim.to('cpu') + x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0] + x_samples_ddim = 255. * x_samples_ddim + x_samples_ddim = x_samples_ddim.astype(np.uint8) + + return [im_depth, x_samples_ddim] + @torch.no_grad() def process_seg(self, input_img, type_in, prompt, neg_prompt, pos_prompt, fix_sample, scale, con_strength, base_model): diff --git a/dist_util.py b/dist_util.py deleted file mode 100644 index 47441a48932a86d5556b1167ef327aa3b1ec8173..0000000000000000000000000000000000000000 --- a/dist_util.py +++ /dev/null @@ -1,91 +0,0 @@ -# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 -import functools -import os -import subprocess -import torch -import torch.distributed as dist -import torch.multiprocessing as mp -from torch.nn.parallel import DataParallel, DistributedDataParallel - - -def init_dist(launcher, backend='nccl', **kwargs): - if mp.get_start_method(allow_none=True) is None: - mp.set_start_method('spawn') - if launcher == 'pytorch': - _init_dist_pytorch(backend, **kwargs) - elif launcher == 'slurm': - _init_dist_slurm(backend, **kwargs) - else: - raise ValueError(f'Invalid launcher type: {launcher}') - - -def _init_dist_pytorch(backend, **kwargs): - rank = int(os.environ['RANK']) - num_gpus = torch.cuda.device_count() - torch.cuda.set_device(rank % num_gpus) - dist.init_process_group(backend=backend, **kwargs) - - -def _init_dist_slurm(backend, port=None): - """Initialize slurm distributed training environment. - - If argument ``port`` is not specified, then the master port will be system - environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system - environment variable, then a default port ``29500`` will be used. - - Args: - backend (str): Backend of torch.distributed. - port (int, optional): Master port. Defaults to None. - """ - proc_id = int(os.environ['SLURM_PROCID']) - ntasks = int(os.environ['SLURM_NTASKS']) - node_list = os.environ['SLURM_NODELIST'] - num_gpus = torch.cuda.device_count() - torch.cuda.set_device(proc_id % num_gpus) - addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1') - # specify master port - if port is not None: - os.environ['MASTER_PORT'] = str(port) - elif 'MASTER_PORT' in os.environ: - pass # use MASTER_PORT in the environment variable - else: - # 29500 is torch.distributed default port - os.environ['MASTER_PORT'] = '29500' - os.environ['MASTER_ADDR'] = addr - os.environ['WORLD_SIZE'] = str(ntasks) - os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) - os.environ['RANK'] = str(proc_id) - dist.init_process_group(backend=backend) - - -def get_dist_info(): - if dist.is_available(): - initialized = dist.is_initialized() - else: - initialized = False - if initialized: - rank = dist.get_rank() - world_size = dist.get_world_size() - else: - rank = 0 - world_size = 1 - return rank, world_size - - -def master_only(func): - - @functools.wraps(func) - def wrapper(*args, **kwargs): - rank, _ = get_dist_info() - if rank == 0: - return func(*args, **kwargs) - - return wrapper - -def get_bare_model(net): - """Get bare model, especially under wrapping with - DistributedDataParallel or DataParallel. - """ - if isinstance(net, (DataParallel, DistributedDataParallel)): - net = net.module - return net diff --git a/environment.yaml b/environment.yaml old mode 100644 new mode 100755 diff --git a/examples/edit_cat/edge.png b/examples/edit_cat/edge.png deleted file mode 100644 index fb897952c2fc901cae210e7171052fee4e5ac392..0000000000000000000000000000000000000000 Binary files a/examples/edit_cat/edge.png and /dev/null differ diff --git a/examples/edit_cat/edge_2.png b/examples/edit_cat/edge_2.png deleted file mode 100644 index 14dc10d2f86bcebbe9ddf0970b9000b9eebc5bc8..0000000000000000000000000000000000000000 Binary files a/examples/edit_cat/edge_2.png and /dev/null differ diff --git a/examples/edit_cat/im.png b/examples/edit_cat/im.png deleted file mode 100644 index 1d5e2f9a097e736e5f14313f1844163320aba331..0000000000000000000000000000000000000000 Binary files a/examples/edit_cat/im.png and /dev/null differ diff --git a/examples/edit_cat/mask.png b/examples/edit_cat/mask.png deleted file mode 100644 index 546302cb3269683aea50c1ada3721f0bf80f27a6..0000000000000000000000000000000000000000 Binary files a/examples/edit_cat/mask.png and /dev/null differ diff --git a/examples/keypose/iron.png b/examples/keypose/iron.png deleted file mode 100644 index f435ff43ff5dda2fd438938d5779f3d5a0ef1cec..0000000000000000000000000000000000000000 Binary files a/examples/keypose/iron.png and /dev/null differ diff --git a/examples/seg/dinner.png b/examples/seg/dinner.png deleted file mode 100644 index 3cc0607baa38eb2d5d79e0bdd4b6456fadffdfd3..0000000000000000000000000000000000000000 Binary files a/examples/seg/dinner.png and /dev/null differ diff --git a/examples/seg/motor.png b/examples/seg/motor.png deleted file mode 100644 index 88826007b270f41f8f3e662c46f7fdbe79545dea..0000000000000000000000000000000000000000 Binary files a/examples/seg/motor.png and /dev/null differ diff --git a/examples/seg_sketch/edge.png b/examples/seg_sketch/edge.png deleted file mode 100644 index d54c9d30ca211f4dbfdb59adeb904d71023171df..0000000000000000000000000000000000000000 Binary files a/examples/seg_sketch/edge.png and /dev/null differ diff --git a/examples/seg_sketch/mask.png b/examples/seg_sketch/mask.png deleted file mode 100644 index b555b71d56087c5469f1035eabe872ada12f72bb..0000000000000000000000000000000000000000 Binary files a/examples/seg_sketch/mask.png and /dev/null differ diff --git a/examples/sketch/car.png b/examples/sketch/car.png deleted file mode 100644 index 98923d8ebde6800b0c030b7b8eeb520cdcc99c08..0000000000000000000000000000000000000000 Binary files a/examples/sketch/car.png and /dev/null differ diff --git a/examples/sketch/girl.jpeg b/examples/sketch/girl.jpeg deleted file mode 100644 index ab3f98ca995691dc60b2d24ce6f30a0e5088b2e0..0000000000000000000000000000000000000000 Binary files a/examples/sketch/girl.jpeg and /dev/null differ diff --git a/examples/sketch/human.png b/examples/sketch/human.png deleted file mode 100644 index 646628c758479f5401618c4e49ded080b3db00b1..0000000000000000000000000000000000000000 Binary files a/examples/sketch/human.png and /dev/null differ diff --git a/examples/sketch/scenery.jpg b/examples/sketch/scenery.jpg deleted file mode 100644 index 80c293c5f80e7207ccdd5c53302c57ebdb5a0ff7..0000000000000000000000000000000000000000 Binary files a/examples/sketch/scenery.jpg and /dev/null differ diff --git a/examples/sketch/scenery2.jpg b/examples/sketch/scenery2.jpg deleted file mode 100644 index 1f230502ded225e24538fb4b32eb0094cc0437ae..0000000000000000000000000000000000000000 Binary files a/examples/sketch/scenery2.jpg and /dev/null differ diff --git a/gradio_keypose.py b/gradio_keypose.py deleted file mode 100644 index bd01a1273b088379d2564a08c0d1801a1cb65008..0000000000000000000000000000000000000000 --- a/gradio_keypose.py +++ /dev/null @@ -1,254 +0,0 @@ -import os -import os.path as osp - -import cv2 -import numpy as np -import torch -from basicsr.utils import img2tensor, tensor2img -from pytorch_lightning import seed_everything -from ldm.models.diffusion.plms import PLMSSampler -from ldm.modules.encoders.adapter import Adapter -from ldm.util import instantiate_from_config -from model_edge import pidinet -import gradio as gr -from omegaconf import OmegaConf -import mmcv -from mmdet.apis import inference_detector, init_detector -from mmpose.apis import (inference_top_down_pose_model, init_pose_model, process_mmdet_results, vis_pose_result) - -skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9], [8, 10], - [1, 2], [0, 1], [0, 2], [1, 3], [2, 4], [3, 5], [4, 6]] - -pose_kpt_color = [[51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [0, 255, 0], - [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0], - [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0]] - -pose_link_color = [[0, 255, 0], [0, 255, 0], [255, 128, 0], [255, 128, 0], - [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [0, 255, 0], [255, 128, 0], - [0, 255, 0], [255, 128, 0], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], - [51, 153, 255], [51, 153, 255], [51, 153, 255]] - -def imshow_keypoints(img, - pose_result, - skeleton=None, - kpt_score_thr=0.1, - pose_kpt_color=None, - pose_link_color=None, - radius=4, - thickness=1): - """Draw keypoints and links on an image. - - Args: - img (ndarry): The image to draw poses on. - pose_result (list[kpts]): The poses to draw. Each element kpts is - a set of K keypoints as an Kx3 numpy.ndarray, where each - keypoint is represented as x, y, score. - kpt_score_thr (float, optional): Minimum score of keypoints - to be shown. Default: 0.3. - pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None, - the keypoint will not be drawn. - pose_link_color (np.array[Mx3]): Color of M links. If None, the - links will not be drawn. - thickness (int): Thickness of lines. - """ - - img_h, img_w, _ = img.shape - img = np.zeros(img.shape) - - for idx, kpts in enumerate(pose_result): - if idx > 1: - continue - kpts = kpts['keypoints'] - # print(kpts) - kpts = np.array(kpts, copy=False) - - # draw each point on image - if pose_kpt_color is not None: - assert len(pose_kpt_color) == len(kpts) - - for kid, kpt in enumerate(kpts): - x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2] - - if kpt_score < kpt_score_thr or pose_kpt_color[kid] is None: - # skip the point that should not be drawn - continue - - color = tuple(int(c) for c in pose_kpt_color[kid]) - cv2.circle(img, (int(x_coord), int(y_coord)), radius, color, -1) - - # draw links - if skeleton is not None and pose_link_color is not None: - assert len(pose_link_color) == len(skeleton) - - for sk_id, sk in enumerate(skeleton): - pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1])) - pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1])) - - if (pos1[0] <= 0 or pos1[0] >= img_w or pos1[1] <= 0 or pos1[1] >= img_h or pos2[0] <= 0 - or pos2[0] >= img_w or pos2[1] <= 0 or pos2[1] >= img_h or kpts[sk[0], 2] < kpt_score_thr - or kpts[sk[1], 2] < kpt_score_thr or pose_link_color[sk_id] is None): - # skip the link that should not be drawn - continue - color = tuple(int(c) for c in pose_link_color[sk_id]) - cv2.line(img, pos1, pos2, color, thickness=thickness) - - return img - -def load_model_from_config(config, ckpt, verbose=False): - print(f"Loading model from {ckpt}") - pl_sd = torch.load(ckpt, map_location="cpu") - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") - if "state_dict" in pl_sd: - sd = pl_sd["state_dict"] - else: - sd = pl_sd - model = instantiate_from_config(config.model) - m, u = model.load_state_dict(sd, strict=False) - - model.cuda() - model.eval() - return model - -device = 'cuda' if torch.cuda.is_available() else 'cpu' -config = OmegaConf.load("configs/stable-diffusion/test_keypose.yaml") -config.model.params.cond_stage_config.params.device = device -model = load_model_from_config(config, "models/sd-v1-4.ckpt").to(device) -current_base = 'sd-v1-4.ckpt' -model_ad = Adapter(cin=int(3*64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device) -model_ad.load_state_dict(torch.load("models/t2iadapter_keypose_sd14v1.pth")) -sampler = PLMSSampler(model) -## mmpose -det_config = 'models/faster_rcnn_r50_fpn_coco.py' -det_checkpoint = 'models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth' -pose_config = 'models/hrnet_w48_coco_256x192.py' -pose_checkpoint = 'models/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth' -det_cat_id = 1 -bbox_thr = 0.2 -## detector -det_config_mmcv = mmcv.Config.fromfile(det_config) -det_model = init_detector(det_config_mmcv, det_checkpoint, device=device) -pose_config_mmcv = mmcv.Config.fromfile(pose_config) -pose_model = init_pose_model(pose_config_mmcv, pose_checkpoint, device=device) -W, H = 512, 512 - - -def process(input_img, type_in, prompt, neg_prompt, fix_sample, scale, con_strength, base_model): - global current_base - if current_base != base_model: - ckpt = os.path.join("models", base_model) - pl_sd = torch.load(ckpt, map_location="cpu") - if "state_dict" in pl_sd: - sd = pl_sd["state_dict"] - else: - sd = pl_sd - model.load_state_dict(sd, strict=False) - current_base = base_model - con_strength = int((1-con_strength)*50) - if fix_sample == 'True': - seed_everything(42) - im = cv2.resize(input_img,(W,H)) - - if type_in == 'Keypose': - im_pose = im.copy() - im = img2tensor(im).unsqueeze(0)/255. - elif type_in == 'Image': - image = im.copy() - im = img2tensor(im).unsqueeze(0)/255. - mmdet_results = inference_detector(det_model, image) - # keep the person class bounding boxes. - person_results = process_mmdet_results(mmdet_results, det_cat_id) - - # optional - return_heatmap = False - dataset = pose_model.cfg.data['test']['type'] - - # e.g. use ('backbone', ) to return backbone feature - output_layer_names = None - pose_results, returned_outputs = inference_top_down_pose_model( - pose_model, - image, - person_results, - bbox_thr=bbox_thr, - format='xyxy', - dataset=dataset, - dataset_info=None, - return_heatmap=return_heatmap, - outputs=output_layer_names) - - # show the results - im_pose = imshow_keypoints( - image, - pose_results, - skeleton=skeleton, - pose_kpt_color=pose_kpt_color, - pose_link_color=pose_link_color, - radius=2, - thickness=2) - im_pose = cv2.resize(im_pose,(W,H)) - - with torch.no_grad(): - c = model.get_learned_conditioning([prompt]) - nc = model.get_learned_conditioning([neg_prompt]) - # extract condition features - pose = img2tensor(im_pose, bgr2rgb=True, float32=True)/255. - pose = pose.unsqueeze(0) - features_adapter = model_ad(pose.to(device)) - - shape = [4, W//8, H//8] - - # sampling - samples_ddim, _ = sampler.sample(S=50, - conditioning=c, - batch_size=1, - shape=shape, - verbose=False, - unconditional_guidance_scale=scale, - unconditional_conditioning=nc, - eta=0.0, - x_T=None, - features_adapter1=features_adapter, - mode = 'sketch', - con_strength = con_strength) - - x_samples_ddim = model.decode_first_stage(samples_ddim) - x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - x_samples_ddim = x_samples_ddim.to('cpu') - x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0] - x_samples_ddim = 255.*x_samples_ddim - x_samples_ddim = x_samples_ddim.astype(np.uint8) - - return [im_pose[:,:,::-1].astype(np.uint8), x_samples_ddim] - -DESCRIPTION = '''# T2I-Adapter (Keypose) -[Paper](https://arxiv.org/abs/2302.08453) [GitHub](https://github.com/TencentARC/T2I-Adapter) - -This gradio demo is for keypose-guided generation. The current functions include: -- Keypose to Image Generation -- Image to Image Generation -- Generation with **Anything** setting -''' -block = gr.Blocks().queue() -with block: - with gr.Row(): - gr.Markdown(DESCRIPTION) - with gr.Row(): - with gr.Column(): - input_img = gr.Image(source='upload', type="numpy") - prompt = gr.Textbox(label="Prompt") - neg_prompt = gr.Textbox(label="Negative Prompt", - value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face') - with gr.Row(): - type_in = gr.inputs.Radio(['Keypose', 'Image'], type="value", default='Image', label='Input Types\n (You can input an image or a keypose map)') - fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed to produce a fixed output)') - run_button = gr.Button(label="Run") - con_strength = gr.Slider(label="Controling Strength (The guidance strength of the keypose to the result)", minimum=0, maximum=1, value=1, step=0.1) - scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=9, step=0.1) - base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use') - with gr.Column(): - result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto') - ips = [input_img, type_in, prompt, neg_prompt, fix_sample, scale, con_strength, base_model] - run_button.click(fn=process, inputs=ips, outputs=[result]) - -block.launch(server_name='0.0.0.0') - diff --git a/gradio_sketch.py b/gradio_sketch.py deleted file mode 100644 index ef55de98df297483a81e56e0ba52e4cedd88d236..0000000000000000000000000000000000000000 --- a/gradio_sketch.py +++ /dev/null @@ -1,147 +0,0 @@ -import os -import os.path as osp - -import cv2 -import numpy as np -import torch -from basicsr.utils import img2tensor, tensor2img -from pytorch_lightning import seed_everything -from ldm.models.diffusion.plms import PLMSSampler -from ldm.modules.encoders.adapter import Adapter -from ldm.util import instantiate_from_config -from model_edge import pidinet -import gradio as gr -from omegaconf import OmegaConf - - -def load_model_from_config(config, ckpt, verbose=False): - print(f"Loading model from {ckpt}") - pl_sd = torch.load(ckpt, map_location="cpu") - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") - if "state_dict" in pl_sd: - sd = pl_sd["state_dict"] - else: - sd = pl_sd - model = instantiate_from_config(config.model) - m, u = model.load_state_dict(sd, strict=False) - # if len(m) > 0 and verbose: - # print("missing keys:") - # print(m) - # if len(u) > 0 and verbose: - # print("unexpected keys:") - # print(u) - - model.cuda() - model.eval() - return model - -device = 'cuda' if torch.cuda.is_available() else 'cpu' -config = OmegaConf.load("configs/stable-diffusion/test_sketch.yaml") -config.model.params.cond_stage_config.params.device = device -model = load_model_from_config(config, "models/sd-v1-4.ckpt").to(device) -current_base = 'sd-v1-4.ckpt' -model_ad = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device) -model_ad.load_state_dict(torch.load("models/t2iadapter_sketch_sd14v1.pth")) -net_G = pidinet() -ckp = torch.load('models/table5_pidinet.pth', map_location='cpu')['state_dict'] -net_G.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()}) -net_G.to(device) -sampler = PLMSSampler(model) -save_memory=True -W, H = 512, 512 - - -def process(input_img, type_in, color_back, prompt, neg_prompt, fix_sample, scale, con_strength, base_model): - global current_base - if current_base != base_model: - ckpt = os.path.join("models", base_model) - pl_sd = torch.load(ckpt, map_location="cpu") - if "state_dict" in pl_sd: - sd = pl_sd["state_dict"] - else: - sd = pl_sd - model.load_state_dict(sd, strict=False) #load_model_from_config(config, os.path.join("models", base_model)).to(device) - current_base = base_model - con_strength = int((1-con_strength)*50) - if fix_sample == 'True': - seed_everything(42) - im = cv2.resize(input_img,(W,H)) - - if type_in == 'Sketch': - if color_back == 'White': - im = 255-im - im_edge = im.copy() - im = img2tensor(im)[0].unsqueeze(0).unsqueeze(0)/255. - im = im>0.5 - im = im.float() - elif type_in == 'Image': - im = img2tensor(im).unsqueeze(0)/255. - im = net_G(im.to(device))[-1] - im = im>0.5 - im = im.float() - im_edge = tensor2img(im) - - with torch.no_grad(): - c = model.get_learned_conditioning([prompt]) - nc = model.get_learned_conditioning([neg_prompt]) - # extract condition features - features_adapter = model_ad(im.to(device)) - shape = [4, W//8, H//8] - - # sampling - samples_ddim, _ = sampler.sample(S=50, - conditioning=c, - batch_size=1, - shape=shape, - verbose=False, - unconditional_guidance_scale=scale, - unconditional_conditioning=nc, - eta=0.0, - x_T=None, - features_adapter1=features_adapter, - mode = 'sketch', - con_strength = con_strength) - - x_samples_ddim = model.decode_first_stage(samples_ddim) - x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - x_samples_ddim = x_samples_ddim.to('cpu') - x_samples_ddim = x_samples_ddim.permute(0, 2, 3, 1).numpy()[0] - x_samples_ddim = 255.*x_samples_ddim - x_samples_ddim = x_samples_ddim.astype(np.uint8) - - return [im_edge, x_samples_ddim] - -DESCRIPTION = '''# T2I-Adapter (Sketch) -[Paper](https://arxiv.org/abs/2302.08453) [GitHub](https://github.com/TencentARC/T2I-Adapter) - -This gradio demo is for sketch-guided generation. The current functions include: -- Sketch to Image Generation -- Image to Image Generation -- Generation with **Anything** setting -''' -block = gr.Blocks().queue() -with block: - with gr.Row(): - gr.Markdown(DESCRIPTION) - with gr.Row(): - with gr.Column(): - input_img = gr.Image(source='upload', type="numpy") - prompt = gr.Textbox(label="Prompt") - neg_prompt = gr.Textbox(label="Negative Prompt", - value='ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face') - with gr.Row(): - type_in = gr.inputs.Radio(['Sketch', 'Image'], type="value", default='Image', label='Input Types\n (You can input an image or a sketch)') - color_back = gr.inputs.Radio(['White', 'Black'], type="value", default='Black', label='Color of the sketch background\n (Only work for sketch input)') - run_button = gr.Button(label="Run") - con_strength = gr.Slider(label="Controling Strength (The guidance strength of the sketch to the result)", minimum=0, maximum=1, value=0.4, step=0.1) - scale = gr.Slider(label="Guidance Scale (Classifier free guidance)", minimum=0.1, maximum=30.0, value=9, step=0.1) - fix_sample = gr.inputs.Radio(['True', 'False'], type="value", default='False', label='Fix Sampling\n (Fix the random seed)') - base_model = gr.inputs.Radio(['sd-v1-4.ckpt', 'anything-v4.0-pruned.ckpt'], type="value", default='sd-v1-4.ckpt', label='The base model you want to use') - with gr.Column(): - result = gr.Gallery(label='Output', show_label=False, elem_id="gallery").style(grid=2, height='auto') - ips = [input_img, type_in, color_back, prompt, neg_prompt, fix_sample, scale, con_strength, base_model] - run_button.click(fn=process, inputs=ips, outputs=[result]) - -block.launch(server_name='0.0.0.0') - diff --git a/ldm/data/__init__.py b/ldm/data/__init__.py old mode 100644 new mode 100755 diff --git a/ldm/data/base.py b/ldm/data/base.py old mode 100644 new mode 100755 diff --git a/ldm/data/imagenet.py b/ldm/data/imagenet.py old mode 100644 new mode 100755 diff --git a/ldm/data/lsun.py b/ldm/data/lsun.py old mode 100644 new mode 100755 diff --git a/ldm/lr_scheduler.py b/ldm/lr_scheduler.py old mode 100644 new mode 100755 diff --git a/ldm/models/autoencoder.py b/ldm/models/autoencoder.py old mode 100644 new mode 100755 diff --git a/ldm/models/diffusion/__init__.py b/ldm/models/diffusion/__init__.py old mode 100644 new mode 100755 diff --git a/ldm/models/diffusion/classifier.py b/ldm/models/diffusion/classifier.py old mode 100644 new mode 100755 diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py old mode 100644 new mode 100755 diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py old mode 100644 new mode 100755 diff --git a/ldm/models/diffusion/dpm_solver/__init__.py b/ldm/models/diffusion/dpm_solver/__init__.py old mode 100644 new mode 100755 diff --git a/ldm/models/diffusion/dpm_solver/dpm_solver.py b/ldm/models/diffusion/dpm_solver/dpm_solver.py old mode 100644 new mode 100755 diff --git a/ldm/models/diffusion/dpm_solver/sampler.py b/ldm/models/diffusion/dpm_solver/sampler.py old mode 100644 new mode 100755 diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py old mode 100644 new mode 100755 diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py old mode 100644 new mode 100755 diff --git a/ldm/modules/diffusionmodules/__init__.py b/ldm/modules/diffusionmodules/__init__.py old mode 100644 new mode 100755 diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py old mode 100644 new mode 100755 diff --git a/ldm/modules/diffusionmodules/openaimodel.py b/ldm/modules/diffusionmodules/openaimodel.py old mode 100644 new mode 100755 index a6004723896201987fac5d5f7e0b9ff4fc04c39b..d6e089a6786da3e977398cc7a1f83ec7fd4a4ff6 --- a/ldm/modules/diffusionmodules/openaimodel.py +++ b/ldm/modules/diffusionmodules/openaimodel.py @@ -7,7 +7,6 @@ import numpy as np import torch as th import torch.nn as nn import torch.nn.functional as F -from dist_util import init_dist, master_only, get_bare_model, get_dist_info from ldm.modules.diffusionmodules.util import ( checkpoint, diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py old mode 100644 new mode 100755 diff --git a/ldm/modules/distributions/__init__.py b/ldm/modules/distributions/__init__.py old mode 100644 new mode 100755 diff --git a/ldm/modules/distributions/distributions.py b/ldm/modules/distributions/distributions.py old mode 100644 new mode 100755 diff --git a/ldm/modules/ema.py b/ldm/modules/ema.py old mode 100644 new mode 100755 diff --git a/ldm/modules/encoders/__init__.py b/ldm/modules/encoders/__init__.py old mode 100644 new mode 100755 diff --git a/ldm/modules/encoders/adapter.py b/ldm/modules/encoders/adapter.py old mode 100644 new mode 100755 diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py old mode 100644 new mode 100755 diff --git a/ldm/modules/image_degradation/__init__.py b/ldm/modules/image_degradation/__init__.py old mode 100644 new mode 100755 diff --git a/ldm/modules/image_degradation/bsrgan.py b/ldm/modules/image_degradation/bsrgan.py old mode 100644 new mode 100755 diff --git a/ldm/modules/image_degradation/bsrgan_light.py b/ldm/modules/image_degradation/bsrgan_light.py old mode 100644 new mode 100755 diff --git a/ldm/modules/image_degradation/utils/test.png b/ldm/modules/image_degradation/utils/test.png old mode 100644 new mode 100755 diff --git a/ldm/modules/image_degradation/utils_image.py b/ldm/modules/image_degradation/utils_image.py old mode 100644 new mode 100755 diff --git a/ldm/modules/losses/__init__.py b/ldm/modules/losses/__init__.py old mode 100644 new mode 100755 diff --git a/ldm/modules/losses/contperceptual.py b/ldm/modules/losses/contperceptual.py old mode 100644 new mode 100755 diff --git a/ldm/modules/losses/vqperceptual.py b/ldm/modules/losses/vqperceptual.py old mode 100644 new mode 100755 diff --git a/ldm/modules/structure_condition/__init__.py b/ldm/modules/structure_condition/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..40a96afc6ff09d58a702b76e3f7dd412fe975e26 --- /dev/null +++ b/ldm/modules/structure_condition/__init__.py @@ -0,0 +1 @@ +# -*- coding: utf-8 -*- diff --git a/experiments/README.md b/ldm/modules/structure_condition/midas/__init__.py old mode 100644 new mode 100755 similarity index 100% rename from experiments/README.md rename to ldm/modules/structure_condition/midas/__init__.py diff --git a/ldm/modules/structure_condition/midas/api.py b/ldm/modules/structure_condition/midas/api.py new file mode 100755 index 0000000000000000000000000000000000000000..a601c72480732339b8737813e7154a52ffca6fa7 --- /dev/null +++ b/ldm/modules/structure_condition/midas/api.py @@ -0,0 +1,175 @@ +# based on https://github.com/isl-org/MiDaS +import os + +import cv2 +import torch +import torch.nn as nn +from torchvision.transforms import Compose + +from ldm.modules.structure_condition.midas.midas.dpt_depth import DPTDepthModel +from ldm.modules.structure_condition.midas.midas.midas_net import MidasNet +from ldm.modules.structure_condition.midas.midas.midas_net_custom import MidasNet_small +from ldm.modules.structure_condition.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet + + +ISL_PATHS = { + "dpt_large": "models/dpt_large-midas-2f21e586.pt", + "dpt_hybrid": "models/dpt_hybrid-midas-501f0c75.pt", + "midas_v21": "", + "midas_v21_small": "", +} + +remote_model_path = "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt" + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def load_midas_transform(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load transform only + if model_type == "dpt_large": # DPT-Large + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + elif model_type == "midas_v21_small": + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + else: + assert False, f"model_type '{model_type}' not implemented, use: --model_type large" + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return transform + + +def load_model(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load network + model_path = ISL_PATHS[model_type] + if model_type == "dpt_large": # DPT-Large + model = DPTDepthModel( + path=model_path, + backbone="vitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + if not os.path.exists(model_path): + from basicsr.utils.download_util import load_file_from_url + load_file_from_url(remote_model_path, model_dir='models') + + model = DPTDepthModel( + path=model_path, + backbone="vitb_rn50_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + model = MidasNet(model_path, non_negative=True) + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + elif model_type == "midas_v21_small": + model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, + non_negative=True, blocks={'expand': True}) + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + assert False + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return model.eval(), transform + + +class MiDaSInference(nn.Module): + MODEL_TYPES_TORCH_HUB = [ + "DPT_Large", + "DPT_Hybrid", + "MiDaS_small" + ] + MODEL_TYPES_ISL = [ + "dpt_large", + "dpt_hybrid", + "midas_v21", + "midas_v21_small", + ] + + def __init__(self, model_type): + super().__init__() + assert (model_type in self.MODEL_TYPES_ISL) + model, _ = load_model(model_type) + self.model = model + self.model.train = disabled_train + + def forward(self, x): + # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array + # NOTE: we expect that the correct transform has been called during dataloading. + with torch.no_grad(): + prediction = self.model(x) + prediction = torch.nn.functional.interpolate( + prediction.unsqueeze(1), + size=x.shape[2:], + mode="bicubic", + align_corners=False, + ) + assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3]) + return prediction diff --git a/ldm/modules/structure_condition/midas/midas/__init__.py b/ldm/modules/structure_condition/midas/midas/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ldm/modules/structure_condition/midas/midas/base_model.py b/ldm/modules/structure_condition/midas/midas/base_model.py new file mode 100755 index 0000000000000000000000000000000000000000..5cf430239b47ec5ec07531263f26f5c24a2311cd --- /dev/null +++ b/ldm/modules/structure_condition/midas/midas/base_model.py @@ -0,0 +1,16 @@ +import torch + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) diff --git a/ldm/modules/structure_condition/midas/midas/blocks.py b/ldm/modules/structure_condition/midas/midas/blocks.py new file mode 100755 index 0000000000000000000000000000000000000000..2145d18fa98060a618536d9a64fe6589e9be4f78 --- /dev/null +++ b/ldm/modules/structure_condition/midas/midas/blocks.py @@ -0,0 +1,342 @@ +import torch +import torch.nn as nn + +from .vit import ( + _make_pretrained_vitb_rn50_384, + _make_pretrained_vitl16_384, + _make_pretrained_vitb16_384, + forward_vit, +) + +def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): + if backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + elif backbone == "efficientnet_lite3": + pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand==True: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): + efficientnet = torch.hub.load( + "rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable + ) + return _make_efficientnet_backbone(efficientnet) + + +def _make_efficientnet_backbone(effnet): + pretrained = nn.Module() + + pretrained.layer1 = nn.Sequential( + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + ) + pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) + pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) + pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) + + return pretrained + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + + +class Interpolate(nn.Module): + """Interpolation module. + """ + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + diff --git a/ldm/modules/structure_condition/midas/midas/dpt_depth.py b/ldm/modules/structure_condition/midas/midas/dpt_depth.py new file mode 100755 index 0000000000000000000000000000000000000000..4e9aab5d2767dffea39da5b3f30e2798688216f1 --- /dev/null +++ b/ldm/modules/structure_condition/midas/midas/dpt_depth.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base_model import BaseModel +from .blocks import ( + FeatureFusionBlock, + FeatureFusionBlock_custom, + Interpolate, + _make_encoder, + forward_vit, +) + + +def _make_fusion_block(features, use_bn): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + hooks = { + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + } + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + False, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks[backbone], + use_readout=readout, + ) + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + + def forward(self, x): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + + layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return out + + +class DPTDepthModel(DPT): + def __init__(self, path=None, non_negative=True, **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + + head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + return super().forward(x).squeeze(dim=1) + diff --git a/ldm/modules/structure_condition/midas/midas/midas_net.py b/ldm/modules/structure_condition/midas/midas/midas_net.py new file mode 100755 index 0000000000000000000000000000000000000000..8a954977800b0a0f48807e80fa63041910e33c1f --- /dev/null +++ b/ldm/modules/structure_condition/midas/midas/midas_net.py @@ -0,0 +1,76 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, Interpolate, _make_encoder + + +class MidasNet(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=256, non_negative=True): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet, self).__init__() + + use_pretrained = False if path is None else True + + self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + ) + + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) diff --git a/ldm/modules/structure_condition/midas/midas/midas_net_custom.py b/ldm/modules/structure_condition/midas/midas/midas_net_custom.py new file mode 100755 index 0000000000000000000000000000000000000000..50e4acb5e53d5fabefe3dde16ab49c33c2b7797c --- /dev/null +++ b/ldm/modules/structure_condition/midas/midas/midas_net_custom.py @@ -0,0 +1,128 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder + + +class MidasNet_small(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, + blocks={'expand': True}): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet_small, self).__init__() + + use_pretrained = False if path else True + + self.channels_last = channels_last + self.blocks = blocks + self.backbone = backbone + + self.groups = 1 + + features1=features + features2=features + features3=features + features4=features + self.expand = False + if "expand" in self.blocks and self.blocks['expand'] == True: + self.expand = True + features1=features + features2=features*2 + features3=features*4 + features4=features*8 + + self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) + + self.scratch.activation = nn.ReLU(False) + + self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + self.scratch.activation, + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + if path: + self.load(path) + + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + if self.channels_last==True: + print("self.channels_last = ", self.channels_last) + x.contiguous(memory_format=torch.channels_last) + + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) + + + +def fuse_model(m): + prev_previous_type = nn.Identity() + prev_previous_name = '' + previous_type = nn.Identity() + previous_name = '' + for name, module in m.named_modules(): + if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: + # print("FUSED ", prev_previous_name, previous_name, name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) + elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: + # print("FUSED ", prev_previous_name, previous_name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) + # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: + # print("FUSED ", previous_name, name) + # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) + + prev_previous_type = previous_type + prev_previous_name = previous_name + previous_type = type(module) + previous_name = name \ No newline at end of file diff --git a/ldm/modules/structure_condition/midas/midas/transforms.py b/ldm/modules/structure_condition/midas/midas/transforms.py new file mode 100755 index 0000000000000000000000000000000000000000..350cbc11662633ad7f8968eb10be2e7de6e384e9 --- /dev/null +++ b/ldm/modules/structure_condition/midas/midas/transforms.py @@ -0,0 +1,234 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/ldm/modules/structure_condition/midas/midas/vit.py b/ldm/modules/structure_condition/midas/midas/vit.py new file mode 100755 index 0000000000000000000000000000000000000000..ea46b1be88b261b0dec04f3da0256f5f66f88a74 --- /dev/null +++ b/ldm/modules/structure_condition/midas/midas/vit.py @@ -0,0 +1,491 @@ +import torch +import torch.nn as nn +import timm +import types +import math +import torch.nn.functional as F + + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index :] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index :] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) + features = torch.cat((x[:, self.start_index :], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +def forward_vit(pretrained, x): + b, c, h, w = x.shape + + glob = pretrained.model.forward_flex(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size( + [ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ] + ), + ) + ) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index :], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model( + "vit_deit_base_distilled_patch16_384", pretrained=pretrained + ) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + start_index=2, + ) + + +def _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=[0, 1, 8, 11], + vit_features=768, + use_vit_only=False, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + + if use_vit_only == True: + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + else: + pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( + get_activation("1") + ) + pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( + get_activation("2") + ) + + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + if use_vit_only == True: + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + else: + pretrained.act_postprocess1 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + pretrained.act_postprocess2 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, use_readout="ignore", hooks=None, use_vit_only=False +): + model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) diff --git a/ldm/modules/structure_condition/midas/utils.py b/ldm/modules/structure_condition/midas/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..9a9d3b5b66370fa98da9e067ba53ead848ea9a59 --- /dev/null +++ b/ldm/modules/structure_condition/midas/utils.py @@ -0,0 +1,189 @@ +"""Utils for monoDepth.""" +import sys +import re +import numpy as np +import cv2 +import torch + + +def read_pfm(path): + """Read pfm file. + + Args: + path (str): path to file + + Returns: + tuple: (data, scale) + """ + with open(path, "rb") as file: + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == "PF": + color = True + elif header.decode("ascii") == "Pf": + color = False + else: + raise Exception("Not a PFM file: " + path) + + dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception("Malformed PFM header.") + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: + # little-endian + endian = "<" + scale = -scale + else: + # big-endian + endian = ">" + + data = np.fromfile(file, endian + "f") + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + + return data, scale + + +def write_pfm(path, image, scale=1): + """Write pfm file. + + Args: + path (str): pathto file + image (array): data + scale (int, optional): Scale. Defaults to 1. + """ + + with open(path, "wb") as file: + color = None + + if image.dtype.name != "float32": + raise Exception("Image dtype must be float32.") + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif ( + len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 + ): # greyscale + color = False + else: + raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") + + file.write("PF\n" if color else "Pf\n".encode()) + file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == "<" or endian == "=" and sys.byteorder == "little": + scale = -scale + + file.write("%f\n".encode() % scale) + + image.tofile(file) + + +def read_image(path): + """Read image and output RGB image (0-1). + + Args: + path (str): path to file + + Returns: + array: RGB image (0-1) + """ + img = cv2.imread(path) + + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 + + return img + + +def resize_image(img): + """Resize image and make it fit for network. + + Args: + img (array): image + + Returns: + tensor: data ready for network + """ + height_orig = img.shape[0] + width_orig = img.shape[1] + + if width_orig > height_orig: + scale = width_orig / 384 + else: + scale = height_orig / 384 + + height = (np.ceil(height_orig / scale / 32) * 32).astype(int) + width = (np.ceil(width_orig / scale / 32) * 32).astype(int) + + img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) + + img_resized = ( + torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() + ) + img_resized = img_resized.unsqueeze(0) + + return img_resized + + +def resize_depth(depth, width, height): + """Resize depth map and bring to CPU (numpy). + + Args: + depth (tensor): depth + width (int): image width + height (int): image height + + Returns: + array: processed depth + """ + depth = torch.squeeze(depth[0, :, :, :]).to("cpu") + + depth_resized = cv2.resize( + depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC + ) + + return depth_resized + +def write_depth(path, depth, bits=1): + """Write depth map to pfm and png file. + + Args: + path (str): filepath without extension + depth (array): depth + """ + write_pfm(path + ".pfm", depth.astype(np.float32)) + + depth_min = depth.min() + depth_max = depth.max() + + max_val = (2**(8*bits))-1 + + if depth_max - depth_min > np.finfo("float").eps: + out = max_val * (depth - depth_min) / (depth_max - depth_min) + else: + out = np.zeros(depth.shape, dtype=depth.type) + + if bits == 1: + cv2.imwrite(path + ".png", out.astype("uint8")) + elif bits == 2: + cv2.imwrite(path + ".png", out.astype("uint16")) + + return diff --git a/model_edge.py b/ldm/modules/structure_condition/model_edge.py old mode 100644 new mode 100755 similarity index 100% rename from model_edge.py rename to ldm/modules/structure_condition/model_edge.py diff --git a/seger.py b/ldm/modules/structure_condition/model_seg.py similarity index 100% rename from seger.py rename to ldm/modules/structure_condition/model_seg.py diff --git a/ldm/modules/structure_condition/utils.py b/ldm/modules/structure_condition/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..af6bcb9e1116a431a39579f4bbdde3a9e868e0b4 --- /dev/null +++ b/ldm/modules/structure_condition/utils.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +import cv2 +import numpy as np + +skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9], [8, 10], + [1, 2], [0, 1], [0, 2], [1, 3], [2, 4], [3, 5], [4, 6]] + +pose_kpt_color = [[51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [0, 255, 0], + [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0], + [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0]] + +pose_link_color = [[0, 255, 0], [0, 255, 0], [255, 128, 0], [255, 128, 0], + [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [0, 255, 0], [255, 128, 0], + [0, 255, 0], [255, 128, 0], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], + [51, 153, 255], [51, 153, 255], [51, 153, 255]] + + +def imshow_keypoints(img, + pose_result, + kpt_score_thr=0.1, + radius=2, + thickness=2): + """Draw keypoints and links on an image. + + Args: + img (ndarry): The image to draw poses on. + pose_result (list[kpts]): The poses to draw. Each element kpts is + a set of K keypoints as an Kx3 numpy.ndarray, where each + keypoint is represented as x, y, score. + kpt_score_thr (float, optional): Minimum score of keypoints + to be shown. Default: 0.3. + thickness (int): Thickness of lines. + """ + + img_h, img_w, _ = img.shape + img = np.zeros(img.shape) + + for idx, kpts in enumerate(pose_result): + if idx > 1: + continue + kpts = kpts['keypoints'] + # print(kpts) + kpts = np.array(kpts, copy=False) + + # draw each point on image + assert len(pose_kpt_color) == len(kpts) + + for kid, kpt in enumerate(kpts): + x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2] + + if kpt_score < kpt_score_thr or pose_kpt_color[kid] is None: + # skip the point that should not be drawn + continue + + color = tuple(int(c) for c in pose_kpt_color[kid]) + cv2.circle(img, (int(x_coord), int(y_coord)), radius, color, -1) + + # draw links + + for sk_id, sk in enumerate(skeleton): + pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1])) + pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1])) + + if (pos1[0] <= 0 or pos1[0] >= img_w or pos1[1] <= 0 or pos1[1] >= img_h or pos2[0] <= 0 + or pos2[0] >= img_w or pos2[1] <= 0 or pos2[1] >= img_h or kpts[sk[0], 2] < kpt_score_thr + or kpts[sk[1], 2] < kpt_score_thr or pose_link_color[sk_id] is None): + # skip the link that should not be drawn + continue + color = tuple(int(c) for c in pose_link_color[sk_id]) + cv2.line(img, pos1, pos2, color, thickness=thickness) + + return img diff --git a/ldm/modules/x_transformer.py b/ldm/modules/x_transformer.py old mode 100644 new mode 100755 diff --git a/ldm/util.py b/ldm/util.py old mode 100644 new mode 100755 diff --git a/load_json.py b/load_json.py deleted file mode 100644 index 839e79220748fdb9cb6f6d2365d32e30e360a455..0000000000000000000000000000000000000000 --- a/load_json.py +++ /dev/null @@ -1,6 +0,0 @@ -import json - -def load_json(path): - with open(path,'r',encoding = 'utf-8') as fp: - data = json.load(fp) - return data['images'] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index f6d1f1d4ca0d841fc7ae0dddb4d759cd256f7f97..614404fca39c73a052e7318697032e762d5105af 100755 --- a/requirements.txt +++ b/requirements.txt @@ -16,4 +16,5 @@ openmim mmpose mmdet psutil -blobfile \ No newline at end of file +blobfile +timm \ No newline at end of file diff --git a/setup.py b/setup.py deleted file mode 100644 index a24d541676407eee1bea271179ffd1d80c6a8e79..0000000000000000000000000000000000000000 --- a/setup.py +++ /dev/null @@ -1,13 +0,0 @@ -from setuptools import setup, find_packages - -setup( - name='latent-diffusion', - version='0.0.1', - description='', - packages=find_packages(), - install_requires=[ - 'torch', - 'numpy', - 'tqdm', - ], -) \ No newline at end of file diff --git a/test_keypose.py b/test_keypose.py deleted file mode 100644 index 5f546d3e526ee7d9c45b63da1a4dbbf397b898c4..0000000000000000000000000000000000000000 --- a/test_keypose.py +++ /dev/null @@ -1,466 +0,0 @@ -import argparse -import logging -import os -import os.path as osp -import time - -import cv2 -import matplotlib.pyplot as plt -import numpy as np -import torch -import torch.nn as nn -from basicsr.utils import (get_env_info, get_root_logger, get_time_str, - img2tensor, scandir, tensor2img) -from basicsr.utils.options import copy_opt_file, dict2str -from omegaconf import OmegaConf -from PIL import Image -from pytorch_lightning import seed_everything - -from dataset_coco import dataset_coco, dataset_coco_mask_color_sig -from dist_util import get_bare_model, init_dist, master_only -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.dpm_solver import DPMSolverSampler -from ldm.models.diffusion.plms import PLMSSampler -from ldm.modules.encoders.adapter import Adapter -from ldm.util import instantiate_from_config -import mmcv -from mmdet.apis import inference_detector, init_detector -from mmpose.apis import (inference_top_down_pose_model, init_pose_model, process_mmdet_results, vis_pose_result) - -skeleton = [[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9], [8, 10], - [1, 2], [0, 1], [0, 2], [1, 3], [2, 4], [3, 5], [4, 6]] - -pose_kpt_color = [[51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [0, 255, 0], - [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0], - [0, 255, 0], [255, 128, 0], [0, 255, 0], [255, 128, 0]] - -pose_link_color = [[0, 255, 0], [0, 255, 0], [255, 128, 0], [255, 128, 0], - [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], [0, 255, 0], [255, 128, 0], - [0, 255, 0], [255, 128, 0], [51, 153, 255], [51, 153, 255], [51, 153, 255], [51, 153, 255], - [51, 153, 255], [51, 153, 255], [51, 153, 255]] - -def load_model_from_config(config, ckpt, verbose=False): - print(f"Loading model from {ckpt}") - pl_sd = torch.load(ckpt, map_location="cpu") - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") - sd = pl_sd["state_dict"] - model = instantiate_from_config(config.model) - m, u = model.load_state_dict(sd, strict=False) - if len(m) > 0 and verbose: - print("missing keys:") - print(m) - if len(u) > 0 and verbose: - print("unexpected keys:") - print(u) - - model.cuda() - model.eval() - return model - -@master_only -def mkdir_and_rename(path): - """mkdirs. If path exists, rename it with timestamp and create a new one. - Args: - path (str): Folder path. - """ - if osp.exists(path): - new_name = path + '_archived_' + get_time_str() - print(f'Path already exists. Rename it to {new_name}', flush=True) - os.rename(path, new_name) - os.makedirs(path, exist_ok=True) - os.makedirs(osp.join(experiments_root, 'models')) - os.makedirs(osp.join(experiments_root, 'training_states')) - os.makedirs(osp.join(experiments_root, 'visualization')) - -def load_resume_state(opt): - resume_state_path = None - if opt.auto_resume: - state_path = osp.join('experiments', opt.name, 'training_states') - if osp.isdir(state_path): - states = list(scandir(state_path, suffix='state', recursive=False, full_path=False)) - if len(states) != 0: - states = [float(v.split('.state')[0]) for v in states] - resume_state_path = osp.join(state_path, f'{max(states):.0f}.state') - opt.resume_state_path = resume_state_path - - if resume_state_path is None: - resume_state = None - else: - device_id = torch.cuda.current_device() - resume_state = torch.load(resume_state_path, map_location=lambda storage, loc: storage.cuda(device_id)) - - return resume_state - -def imshow_keypoints(img, - pose_result, - skeleton=None, - kpt_score_thr=0.1, - pose_kpt_color=None, - pose_link_color=None, - radius=4, - thickness=1): - """Draw keypoints and links on an image. - - Args: - img (ndarry): The image to draw poses on. - pose_result (list[kpts]): The poses to draw. Each element kpts is - a set of K keypoints as an Kx3 numpy.ndarray, where each - keypoint is represented as x, y, score. - kpt_score_thr (float, optional): Minimum score of keypoints - to be shown. Default: 0.3. - pose_kpt_color (np.array[Nx3]`): Color of N keypoints. If None, - the keypoint will not be drawn. - pose_link_color (np.array[Mx3]): Color of M links. If None, the - links will not be drawn. - thickness (int): Thickness of lines. - """ - - img_h, img_w, _ = img.shape - img = np.zeros(img.shape) - - for idx, kpts in enumerate(pose_result): - if idx > 1: - continue - kpts = kpts['keypoints'] - # print(kpts) - kpts = np.array(kpts, copy=False) - - # draw each point on image - if pose_kpt_color is not None: - assert len(pose_kpt_color) == len(kpts) - - for kid, kpt in enumerate(kpts): - x_coord, y_coord, kpt_score = int(kpt[0]), int(kpt[1]), kpt[2] - - if kpt_score < kpt_score_thr or pose_kpt_color[kid] is None: - # skip the point that should not be drawn - continue - - color = tuple(int(c) for c in pose_kpt_color[kid]) - cv2.circle(img, (int(x_coord), int(y_coord)), radius, color, -1) - - # draw links - if skeleton is not None and pose_link_color is not None: - assert len(pose_link_color) == len(skeleton) - - for sk_id, sk in enumerate(skeleton): - pos1 = (int(kpts[sk[0], 0]), int(kpts[sk[0], 1])) - pos2 = (int(kpts[sk[1], 0]), int(kpts[sk[1], 1])) - - if (pos1[0] <= 0 or pos1[0] >= img_w or pos1[1] <= 0 or pos1[1] >= img_h or pos2[0] <= 0 - or pos2[0] >= img_w or pos2[1] <= 0 or pos2[1] >= img_h or kpts[sk[0], 2] < kpt_score_thr - or kpts[sk[1], 2] < kpt_score_thr or pose_link_color[sk_id] is None): - # skip the link that should not be drawn - continue - color = tuple(int(c) for c in pose_link_color[sk_id]) - cv2.line(img, pos1, pos2, color, thickness=thickness) - - return img - -parser = argparse.ArgumentParser() -parser.add_argument( - "--prompt", - type=str, - nargs="?", - default="An Iron man" -) -parser.add_argument( - "--neg_prompt", - type=str, - default="ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face" -) -parser.add_argument( - "--path_cond", - type=str, - default="examples/keypose/iron.png" -) -parser.add_argument( - "--type_in", - type=str, - default="sketch" -) -parser.add_argument( - "--bsize", - type=int, - default=8, - help="the prompt to render" -) -parser.add_argument( - "--epochs", - type=int, - default=10000, - help="the prompt to render" -) -parser.add_argument( - "--device", - type=str, - default="cuda" -) -parser.add_argument( - "--num_workers", - type=int, - default=8, - help="the prompt to render" -) -parser.add_argument( - "--use_shuffle", - type=bool, - default=True, - help="the prompt to render" -) -parser.add_argument( - "--dpm_solver", - action='store_true', - help="use dpm_solver sampling", -) -parser.add_argument( - "--plms", - action='store_true', - help="use plms sampling", -) -parser.add_argument( - "--auto_resume", - action='store_true', - help="use plms sampling", -) -parser.add_argument( - "--ckpt", - type=str, - default="models/sd-v1-4.ckpt", - help="path to checkpoint of model", -) -parser.add_argument( - "--ckpt_ad", - type=str, - default='models/t2iadapter_keypose_sd14v1.pth' -) -parser.add_argument( - "--config", - type=str, - default="configs/stable-diffusion/test_keypose.yaml", - help="path to config which constructs model", -) -parser.add_argument( - "--print_fq", - type=int, - default=100, - help="path to config which constructs model", -) -parser.add_argument( - "--H", - type=int, - default=512, - help="image height, in pixel space", -) -parser.add_argument( - "--W", - type=int, - default=512, - help="image width, in pixel space", -) -parser.add_argument( - "--C", - type=int, - default=4, - help="latent channels", -) -parser.add_argument( - "--f", - type=int, - default=8, - help="downsampling factor", -) -parser.add_argument( - "--ddim_steps", - type=int, - default=50, - help="number of ddim sampling steps", -) -parser.add_argument( - "--n_samples", - type=int, - default=10, - help="how many samples to produce for each given prompt. A.k.a. batch size", -) -parser.add_argument( - "--ddim_eta", - type=float, - default=0.0, - help="ddim eta (eta=0.0 corresponds to deterministic sampling", -) -parser.add_argument( - "--scale", - type=float, - default=7.5, - help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", -) -parser.add_argument( - "--gpus", - default=[0,1,2,3], - help="gpu idx", -) -parser.add_argument( - '--local_rank', - default=-1, - type=int, - help='node rank for distributed training' -) -parser.add_argument( - '--launcher', - default='pytorch', - type=str, - help='node rank for distributed training' -) - -## mmpose part ## -parser.add_argument( - '--det_config', - help='Config file for detection', - default='models/faster_rcnn_r50_fpn_coco.py' -) -parser.add_argument( - '--det_checkpoint', - help='Checkpoint file for detection', - default='models/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth' -) -parser.add_argument( - '--pose_config', - help='Config file for pose', - default='models/hrnet_w48_coco_256x192.py' -) -parser.add_argument( - '--pose_checkpoint', - help='Checkpoint file for pose', - default='models/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth' -) -parser.add_argument( - '--det-cat-id', - type=int, - default=1, - help='Category id for bounding box detection model' -) -parser.add_argument( - '--bbox-thr', - type=float, - default=0.2, - help='Bounding box score threshold' -) - -opt = parser.parse_args() - -if __name__ == '__main__': - # seed_everything(42) - config = OmegaConf.load(f"{opt.config}") - opt.name = config['name'] - device=opt.device - - # stable diffusion - model = load_model_from_config(config, f"{opt.ckpt}").to(device) - - # Adaptor - model_ad = Adapter(cin=int(3*64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device) - model_ad.load_state_dict(torch.load(opt.ckpt_ad)) - - experiments_root = osp.join('experiments', opt.name) - - # resume state - resume_state = load_resume_state(opt) - if resume_state is None: - mkdir_and_rename(experiments_root) - - # copy the yml file to the experiment root - copy_opt_file(opt.config, experiments_root) - - # WARNING: should not use get_root_logger in the above codes, including the called functions - # Otherwise the logger will not be properly initialized - log_file = osp.join(experiments_root, f"train_{opt.name}_{get_time_str()}.log") - logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) - logger.info(get_env_info()) - logger.info(dict2str(config)) - - for v_idx in range(opt.n_samples): - with torch.no_grad(): - if opt.dpm_solver: - sampler = DPMSolverSampler(model) - elif opt.plms: - sampler = PLMSSampler(model) - else: - sampler = DDIMSampler(model) - c = model.get_learned_conditioning([opt.prompt]) - - # costumer input - if opt.type_in == 'pose': - pose = cv2.imread(opt.path_cond) - elif opt.type_in == 'image': - # im = cv2.imread(opt.path_cond) - image = cv2.imread(opt.path_cond) - det_config_mmcv = mmcv.Config.fromfile(opt.det_config) - det_model = init_detector(det_config_mmcv, opt.det_checkpoint, device=device) - pose_config_mmcv = mmcv.Config.fromfile(opt.pose_config) - pose_model = init_pose_model(pose_config_mmcv, opt.pose_checkpoint, device=device) - - mmdet_results = inference_detector(det_model, opt.path_cond) - # keep the person class bounding boxes. - person_results = process_mmdet_results(mmdet_results, opt.det_cat_id) - - # optional - return_heatmap = False - dataset = pose_model.cfg.data['test']['type'] - - # e.g. use ('backbone', ) to return backbone feature - output_layer_names = None - pose_results, returned_outputs = inference_top_down_pose_model( - pose_model, - opt.path_cond, - person_results, - bbox_thr=opt.bbox_thr, - format='xyxy', - dataset=dataset, - dataset_info=None, - return_heatmap=return_heatmap, - outputs=output_layer_names) - - # show the results - pose = imshow_keypoints( - image, - pose_results, - skeleton=skeleton, - pose_kpt_color=pose_kpt_color, - pose_link_color=pose_link_color, - radius=2, - thickness=2) - - else: - raise TypeError('Wrong input condition.') - - pose = cv2.resize(pose,(512,512)) - cv2.imwrite(os.path.join(experiments_root, 'visualization', 'pose_idx%04d.png'%(v_idx)), pose) - - pose = img2tensor(pose, bgr2rgb=True, float32=True)/255. - pose = pose.unsqueeze(0) - - features_adapter = model_ad(pose.to(device)) - - shape = [opt.C, opt.H // opt.f, opt.W // opt.f] - - samples_ddim, intermediates = sampler.sample(S=opt.ddim_steps, - conditioning=c, - batch_size=1, - shape=shape, - verbose=False, - unconditional_guidance_scale=opt.scale, - unconditional_conditioning=model.get_learned_conditioning([opt.neg_prompt]), - eta=opt.ddim_eta, - x_T=None, - features_adapter1=features_adapter, - mode = 'pose' - ) - - x_samples_ddim = model.decode_first_stage(samples_ddim) - x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() - for id_sample, x_sample in enumerate(x_samples_ddim): - x_sample = 255.*x_sample - img = x_sample.astype(np.uint8) - cv2.imwrite(os.path.join(experiments_root, 'visualization', 'sample_idx%04d_s%04d.png'%(v_idx, id_sample)), img[:,:,::-1]) \ No newline at end of file diff --git a/test_seg.py b/test_seg.py deleted file mode 100644 index fa6a8c1e4d67ca5d6e4e606f4c287bb8554a07eb..0000000000000000000000000000000000000000 --- a/test_seg.py +++ /dev/null @@ -1,304 +0,0 @@ -import argparse -import logging -import os -import os.path as osp -import time - -import cv2 -import matplotlib.pyplot as plt -import numpy as np -import torch -import torch.nn as nn -from basicsr.utils import (get_env_info, get_root_logger, get_time_str, - img2tensor, scandir, tensor2img) -from basicsr.utils.options import copy_opt_file, dict2str -from omegaconf import OmegaConf -from PIL import Image -from pytorch_lightning import seed_everything - -from dataset_coco import dataset_coco, dataset_coco_mask_color_sig -from dist_util import get_bare_model, init_dist, master_only -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.dpm_solver import DPMSolverSampler -from ldm.models.diffusion.plms import PLMSSampler -from ldm.modules.encoders.adapter import Adapter -from ldm.util import instantiate_from_config - - -def load_model_from_config(config, ckpt, verbose=False): - print(f"Loading model from {ckpt}") - pl_sd = torch.load(ckpt, map_location="cpu") - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") - sd = pl_sd["state_dict"] - model = instantiate_from_config(config.model) - m, u = model.load_state_dict(sd, strict=False) - if len(m) > 0 and verbose: - print("missing keys:") - print(m) - if len(u) > 0 and verbose: - print("unexpected keys:") - print(u) - - model.cuda() - model.eval() - return model - -@master_only -def mkdir_and_rename(path): - """mkdirs. If path exists, rename it with timestamp and create a new one. - Args: - path (str): Folder path. - """ - if osp.exists(path): - new_name = path + '_archived_' + get_time_str() - print(f'Path already exists. Rename it to {new_name}', flush=True) - os.rename(path, new_name) - os.makedirs(path, exist_ok=True) - os.makedirs(osp.join(experiments_root, 'models')) - os.makedirs(osp.join(experiments_root, 'training_states')) - os.makedirs(osp.join(experiments_root, 'visualization')) - -def load_resume_state(opt): - resume_state_path = None - if opt.auto_resume: - state_path = osp.join('experiments', opt.name, 'training_states') - if osp.isdir(state_path): - states = list(scandir(state_path, suffix='state', recursive=False, full_path=False)) - if len(states) != 0: - states = [float(v.split('.state')[0]) for v in states] - resume_state_path = osp.join(state_path, f'{max(states):.0f}.state') - opt.resume_state_path = resume_state_path - - if resume_state_path is None: - resume_state = None - else: - device_id = torch.cuda.current_device() - resume_state = torch.load(resume_state_path, map_location=lambda storage, loc: storage.cuda(device_id)) - - return resume_state - -parser = argparse.ArgumentParser() -parser.add_argument( - "--prompt", - type=str, - nargs="?", - default="A black Honda motorcycle parked in front of a garage" -) -parser.add_argument( - "--neg_prompt", - type=str, - default="ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face" -) -parser.add_argument( - "--path_cond", - type=str, - default="examples/seg/motor.png" -) -parser.add_argument( - "--bsize", - type=int, - default=8, - help="the prompt to render" -) -parser.add_argument( - "--epochs", - type=int, - default=10000, - help="the prompt to render" -) -parser.add_argument( - "--device", - type=str, - default="cuda" -) -parser.add_argument( - "--num_workers", - type=int, - default=8, - help="the prompt to render" -) -parser.add_argument( - "--use_shuffle", - type=bool, - default=True, - help="the prompt to render" -) -parser.add_argument( - "--dpm_solver", - action='store_true', - help="use dpm_solver sampling", -) -parser.add_argument( - "--plms", - action='store_true', - help="use plms sampling", -) -parser.add_argument( - "--auto_resume", - action='store_true', - help="use plms sampling", -) -parser.add_argument( - "--ckpt", - type=str, - default="models/sd-v1-4.ckpt", - help="path to checkpoint of model", -) -parser.add_argument( - "--ckpt_ad", - type=str, - default="models/t2iadapter_seg_sd14v1.pth" -) -parser.add_argument( - "--config", - type=str, - default="configs/stable-diffusion/test_mask.yaml", - help="path to config which constructs model", -) -parser.add_argument( - "--print_fq", - type=int, - default=100, - help="path to config which constructs model", -) -parser.add_argument( - "--H", - type=int, - default=512, - help="image height, in pixel space", -) -parser.add_argument( - "--W", - type=int, - default=512, - help="image width, in pixel space", -) -parser.add_argument( - "--C", - type=int, - default=4, - help="latent channels", -) -parser.add_argument( - "--f", - type=int, - default=8, - help="downsampling factor", -) -parser.add_argument( - "--ddim_steps", - type=int, - default=50, - help="number of ddim sampling steps", -) -parser.add_argument( - "--n_samples", - type=int, - default=10, - help="how many samples to produce for each given prompt. A.k.a. batch size", -) -parser.add_argument( - "--ddim_eta", - type=float, - default=0.0, - help="ddim eta (eta=0.0 corresponds to deterministic sampling", -) -parser.add_argument( - "--scale", - type=float, - default=7.5, - help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", -) -parser.add_argument( - "--gpus", - default=[0,1,2,3], - help="gpu idx", -) -parser.add_argument( - '--local_rank', - default=-1, - type=int, - help='node rank for distributed training' -) -parser.add_argument( - '--launcher', - default='pytorch', - type=str, - help='node rank for distributed training' -) -opt = parser.parse_args() - -if __name__ == '__main__': - # seed_everything(42) - config = OmegaConf.load(f"{opt.config}") - opt.name = config['name'] - device=opt.device - - # stable diffusion - model = load_model_from_config(config, f"{opt.ckpt}").to(device) - - # Adaptor - model_ad = Adapter(cin=int(3*64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device) - model_ad.load_state_dict(torch.load(opt.ckpt_ad)) - - experiments_root = osp.join('experiments', opt.name) - - # resume state - resume_state = load_resume_state(opt) - if resume_state is None: - mkdir_and_rename(experiments_root) - - # copy the yml file to the experiment root - copy_opt_file(opt.config, experiments_root) - - # WARNING: should not use get_root_logger in the above codes, including the called functions - # Otherwise the logger will not be properly initialized - log_file = osp.join(experiments_root, f"train_{opt.name}_{get_time_str()}.log") - logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) - logger.info(get_env_info()) - logger.info(dict2str(config)) - - for v_idx in range(opt.n_samples): - with torch.no_grad(): - if opt.dpm_solver: - sampler = DPMSolverSampler(model) - elif opt.plms: - sampler = PLMSSampler(model) - else: - sampler = DDIMSampler(model) - c = model.get_learned_conditioning([opt.prompt]) - - # costumer input - mask = cv2.imread(opt.path_cond) - mask = cv2.resize(mask,(512,512)) - mask = img2tensor(mask, bgr2rgb=True, float32=True)/255. - mask = mask.unsqueeze(0) - - im_mask = tensor2img(mask) - cv2.imwrite(os.path.join(experiments_root, 'visualization', 'mask_idx%04d.png'%(v_idx)), im_mask) - - features_adapter = model_ad(mask.to(device)) - - shape = [opt.C, opt.H // opt.f, opt.W // opt.f] - - samples_ddim, intermediates = sampler.sample(S=opt.ddim_steps, - conditioning=c, - batch_size=1, - shape=shape, - verbose=False, - unconditional_guidance_scale=opt.scale, - unconditional_conditioning=model.get_learned_conditioning([opt.neg_prompt]), - eta=opt.ddim_eta, - x_T=None, - features_adapter1=features_adapter, - mode = 'mask' - ) - - x_samples_ddim = model.decode_first_stage(samples_ddim) - x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() - for id_sample, x_sample in enumerate(x_samples_ddim): - x_sample = 255.*x_sample - img = x_sample.astype(np.uint8) - cv2.imwrite(os.path.join(experiments_root, 'visualization', 'sample_idx%04d_s%04d.png'%(v_idx, id_sample)), img[:,:,::-1]) \ No newline at end of file diff --git a/test_seg_sketch.py b/test_seg_sketch.py deleted file mode 100644 index a7a00f77601ecec7683df475dc8ae5cf3f5c890a..0000000000000000000000000000000000000000 --- a/test_seg_sketch.py +++ /dev/null @@ -1,327 +0,0 @@ -import argparse -import logging -import os -import os.path as osp -import time - -import cv2 -import matplotlib.pyplot as plt -import numpy as np -import torch -import torch.nn as nn -from basicsr.utils import (get_env_info, get_root_logger, get_time_str, - img2tensor, scandir, tensor2img) -from basicsr.utils.options import copy_opt_file, dict2str -from omegaconf import OmegaConf -from PIL import Image -from pytorch_lightning import seed_everything - -from dataset_coco import dataset_coco, dataset_coco_mask_color_sig -from dist_util import get_bare_model, init_dist, master_only -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.dpm_solver import DPMSolverSampler -from ldm.models.diffusion.plms import PLMSSampler -from ldm.modules.encoders.adapter import Adapter -from ldm.util import instantiate_from_config - - -def load_model_from_config(config, ckpt, verbose=False): - print(f"Loading model from {ckpt}") - pl_sd = torch.load(ckpt, map_location="cpu") - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") - sd = pl_sd["state_dict"] - model = instantiate_from_config(config.model) - m, u = model.load_state_dict(sd, strict=False) - if len(m) > 0 and verbose: - print("missing keys:") - print(m) - if len(u) > 0 and verbose: - print("unexpected keys:") - print(u) - - model.cuda() - model.eval() - return model - -@master_only -def mkdir_and_rename(path): - """mkdirs. If path exists, rename it with timestamp and create a new one. - Args: - path (str): Folder path. - """ - if osp.exists(path): - new_name = path + '_archived_' + get_time_str() - print(f'Path already exists. Rename it to {new_name}', flush=True) - os.rename(path, new_name) - os.makedirs(path, exist_ok=True) - os.makedirs(osp.join(experiments_root, 'models')) - os.makedirs(osp.join(experiments_root, 'training_states')) - os.makedirs(osp.join(experiments_root, 'visualization')) - -def load_resume_state(opt): - resume_state_path = None - if opt.auto_resume: - state_path = osp.join('experiments', opt.name, 'training_states') - if osp.isdir(state_path): - states = list(scandir(state_path, suffix='state', recursive=False, full_path=False)) - if len(states) != 0: - states = [float(v.split('.state')[0]) for v in states] - resume_state_path = osp.join(state_path, f'{max(states):.0f}.state') - opt.resume_state_path = resume_state_path - - if resume_state_path is None: - resume_state = None - else: - device_id = torch.cuda.current_device() - resume_state = torch.load(resume_state_path, map_location=lambda storage, loc: storage.cuda(device_id)) - - return resume_state - -parser = argparse.ArgumentParser() -parser.add_argument( - "--prompt", - type=str, - nargs="?", - default="An all white kitchen with an electric stovetop" -) -parser.add_argument( - "--neg_prompt", - type=str, - default="ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face" -) -parser.add_argument( - "--path_cond", - type=str, - default="examples/seg_sketch/mask.png" -) -parser.add_argument( - "--path_cond2", - type=str, - default="examples/seg_sketch/edge.png" -) -parser.add_argument( - "--bsize", - type=int, - default=8, - help="the prompt to render" -) -parser.add_argument( - "--epochs", - type=int, - default=10000, - help="the prompt to render" -) -parser.add_argument( - "--device", - type=str, - default="cuda" -) -parser.add_argument( - "--num_workers", - type=int, - default=8, - help="the prompt to render" -) -parser.add_argument( - "--use_shuffle", - type=bool, - default=True, - help="the prompt to render" -) -parser.add_argument( - "--dpm_solver", - action='store_true', - help="use dpm_solver sampling", -) -parser.add_argument( - "--plms", - action='store_true', - help="use plms sampling", -) -parser.add_argument( - "--auto_resume", - action='store_true', - help="use plms sampling", -) -parser.add_argument( - "--ckpt", - type=str, - default="models/sd-v1-4.ckpt", - help="path to checkpoint of model", -) -parser.add_argument( - "--ckpt_ad1", - type=str, - default='models/t2iadapter_sketch_sd14v1.pth') -parser.add_argument( - "--ckpt_ad2", - type=str, - default='models/t2iadapter_seg_sd14v1.pth' -) -parser.add_argument( - "--config", - type=str, - default="configs/stable-diffusion/test_mask_sketch.yaml", - help="path to config which constructs model", -) -parser.add_argument( - "--print_fq", - type=int, - default=100, - help="path to config which constructs model", -) -parser.add_argument( - "--H", - type=int, - default=512, - help="image height, in pixel space", -) -parser.add_argument( - "--W", - type=int, - default=512, - help="image width, in pixel space", -) -parser.add_argument( - "--C", - type=int, - default=4, - help="latent channels", -) -parser.add_argument( - "--f", - type=int, - default=8, - help="downsampling factor", -) -parser.add_argument( - "--ddim_steps", - type=int, - default=50, - help="number of ddim sampling steps", -) -parser.add_argument( - "--n_samples", - type=int, - default=10, - help="how many samples to produce for each given prompt. A.k.a. batch size", -) -parser.add_argument( - "--ddim_eta", - type=float, - default=0.0, - help="ddim eta (eta=0.0 corresponds to deterministic sampling", -) -parser.add_argument( - "--scale", - type=float, - default=7.5, - help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", -) -parser.add_argument( - "--gpus", - default=[0,1,2,3], - help="gpu idx", -) -parser.add_argument( - '--local_rank', - default=-1, - type=int, - help='node rank for distributed training' -) -parser.add_argument( - '--launcher', - default='pytorch', - type=str, - help='node rank for distributed training' -) -opt = parser.parse_args() - -if __name__ == '__main__': - # seed_everything(42) - config = OmegaConf.load(f"{opt.config}") - opt.name = config['name'] - device = opt.device - - # stable diffusion - model = load_model_from_config(config, f"{opt.ckpt}").to(device) - - # Adaptor - model_ad = Adapter(cin=int(3*64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device) - model_ad2 = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device) - model_ad.load_state_dict(torch.load(opt.ckpt_ad2)) - model_ad2.load_state_dict(torch.load(opt.ckpt_ad1)) - - experiments_root = osp.join('experiments', opt.name) - - # resume state - resume_state = load_resume_state(opt) - if resume_state is None: - mkdir_and_rename(experiments_root) - - # copy the yml file to the experiment root - copy_opt_file(opt.config, experiments_root) - - # WARNING: should not use get_root_logger in the above codes, including the called functions - # Otherwise the logger will not be properly initialized - log_file = osp.join(experiments_root, f"train_{opt.name}_{get_time_str()}.log") - logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) - logger.info(get_env_info()) - logger.info(dict2str(config)) - - for v_idx in range(opt.n_samples): - with torch.no_grad(): - if opt.dpm_solver: - sampler = DPMSolverSampler(model) - elif opt.plms: - sampler = PLMSSampler(model) - else: - sampler = DDIMSampler(model) - c = model.get_learned_conditioning([opt.prompt]) - - # costumer input - mask = cv2.imread(opt.path_cond) - mask = cv2.resize(mask,(512,512)) - mask = img2tensor(mask, bgr2rgb=True, float32=True)/255. - mask = mask.unsqueeze(0) - - edge = cv2.imread(opt.path_cond2) - edge = cv2.resize(edge,(512,512)) - edge = img2tensor(edge)[0].unsqueeze(0).unsqueeze(0)/255. - - # edge = 1-edge # for white background - edge = edge>0.5 - edge = edge.float() - - im_mask = tensor2img(mask) - cv2.imwrite(os.path.join(experiments_root, 'visualization', 'mask_idx%04d.png'%(v_idx)), im_mask) - im_edge = tensor2img(edge) - cv2.imwrite(os.path.join(experiments_root, 'visualization', 'edge_idx%04d.png'%(v_idx)), im_edge) - - features_adapter1 = model_ad2(edge.to(device)) - features_adapter2 = model_ad(mask.to(device)) - - shape = [opt.C, opt.H // opt.f, opt.W // opt.f] - - samples_ddim, intermediates = sampler.sample(S=opt.ddim_steps, - conditioning=c, - batch_size=1, - shape=shape, - verbose=False, - unconditional_guidance_scale=opt.scale, - unconditional_conditioning=model.get_learned_conditioning([opt.neg_prompt]), - eta=opt.ddim_eta, - x_T=None, - features_adapter1=features_adapter1, - features_adapter2=features_adapter2, - mode = 'mul' - ) - - x_samples_ddim = model.decode_first_stage(samples_ddim) - x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() - for id_sample, x_sample in enumerate(x_samples_ddim): - x_sample = 255.*x_sample - img = x_sample.astype(np.uint8) - cv2.imwrite(os.path.join(experiments_root, 'visualization', 'sample_idx%04d_s%04d.png'%(v_idx, id_sample)), img[:,:,::-1]) \ No newline at end of file diff --git a/test_sketch.py b/test_sketch.py deleted file mode 100644 index e90ae50a1c34f1384ac0f73d30e5ac7fd7511fe2..0000000000000000000000000000000000000000 --- a/test_sketch.py +++ /dev/null @@ -1,334 +0,0 @@ -import argparse -import logging -import os -import os.path as osp -import time - -import cv2 -import matplotlib.pyplot as plt -import numpy as np -import torch -import torch.nn as nn -from basicsr.utils import (get_env_info, get_root_logger, get_time_str, - img2tensor, scandir, tensor2img) -from basicsr.utils.options import copy_opt_file, dict2str -from omegaconf import OmegaConf -from PIL import Image -from pytorch_lightning import seed_everything - -from dataset_coco import dataset_coco, dataset_coco_mask_color_sig -from dist_util import get_bare_model, init_dist, master_only -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.dpm_solver import DPMSolverSampler -from ldm.models.diffusion.plms import PLMSSampler -from ldm.modules.encoders.adapter import Adapter -from ldm.util import instantiate_from_config -from model_edge import pidinet - - -def load_model_from_config(config, ckpt, verbose=False): - print(f"Loading model from {ckpt}") - pl_sd = torch.load(ckpt, map_location="cpu") - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") - if "state_dict" in pl_sd: - sd = pl_sd["state_dict"] - else: - sd = pl_sd - model = instantiate_from_config(config.model) - m, u = model.load_state_dict(sd, strict=False) - if len(m) > 0 and verbose: - print("missing keys:") - print(m) - if len(u) > 0 and verbose: - print("unexpected keys:") - print(u) - - model.cuda() - model.eval() - return model - -@master_only -def mkdir_and_rename(path): - """mkdirs. If path exists, rename it with timestamp and create a new one. - Args: - path (str): Folder path. - """ - if osp.exists(path): - new_name = path + '_archived_' + get_time_str() - print(f'Path already exists. Rename it to {new_name}', flush=True) - os.rename(path, new_name) - os.makedirs(path, exist_ok=True) - os.makedirs(osp.join(experiments_root, 'models')) - os.makedirs(osp.join(experiments_root, 'training_states')) - os.makedirs(osp.join(experiments_root, 'visualization')) - -def load_resume_state(opt): - resume_state_path = None - if opt.auto_resume: - state_path = osp.join('experiments', opt.name, 'training_states') - if osp.isdir(state_path): - states = list(scandir(state_path, suffix='state', recursive=False, full_path=False)) - if len(states) != 0: - states = [float(v.split('.state')[0]) for v in states] - resume_state_path = osp.join(state_path, f'{max(states):.0f}.state') - opt.resume_state_path = resume_state_path - - if resume_state_path is None: - resume_state = None - else: - device_id = torch.cuda.current_device() - resume_state = torch.load(resume_state_path, map_location=lambda storage, loc: storage.cuda(device_id)) - - return resume_state - -parser = argparse.ArgumentParser() -parser.add_argument( - "--prompt", - type=str, - nargs="?", - default="A car with flying wings" -) -parser.add_argument( - "--neg_prompt", - type=str, - default="ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face" -) -parser.add_argument( - "--path_cond", - type=str, - default="examples/sketch/car.png" -) -parser.add_argument( - "--type_in", - type=str, - default="sketch" -) -parser.add_argument( - "--bsize", - type=int, - default=8, - help="the prompt to render" -) -parser.add_argument( - "--epochs", - type=int, - default=10000, - help="the prompt to render" -) -parser.add_argument( - "--device", - type=str, - default="cuda" -) -parser.add_argument( - "--num_workers", - type=int, - default=8, - help="the prompt to render" -) -parser.add_argument( - "--use_shuffle", - type=bool, - default=True, - help="the prompt to render" -) -parser.add_argument( - "--dpm_solver", - action='store_true', - help="use dpm_solver sampling", -) -parser.add_argument( - "--plms", - action='store_true', - help="use plms sampling", -) -parser.add_argument( - "--auto_resume", - action='store_true', - help="use plms sampling", -) -parser.add_argument( - "--ckpt", - type=str, - default="models/sd-v1-4.ckpt", - help="path to checkpoint of model", -) -parser.add_argument( - "--ckpt_ad", - type=str, - default="models/t2iadapter_sketch_sd14v1.pth" -) -parser.add_argument( - "--config", - type=str, - default="configs/stable-diffusion/test_sketch.yaml", - help="path to config which constructs model", -) -parser.add_argument( - "--print_fq", - type=int, - default=100, - help="path to config which constructs model", -) -parser.add_argument( - "--H", - type=int, - default=512, - help="image height, in pixel space", -) -parser.add_argument( - "--W", - type=int, - default=512, - help="image width, in pixel space", -) -parser.add_argument( - "--C", - type=int, - default=4, - help="latent channels", -) -parser.add_argument( - "--f", - type=int, - default=8, - help="downsampling factor", -) -parser.add_argument( - "--ddim_steps", - type=int, - default=50, - help="number of ddim sampling steps", -) -parser.add_argument( - "--n_samples", - type=int, - default=10, - help="how many samples to produce for each given prompt. A.k.a. batch size", -) -parser.add_argument( - "--ddim_eta", - type=float, - default=0.0, - help="ddim eta (eta=0.0 corresponds to deterministic sampling", -) -parser.add_argument( - "--scale", - type=float, - default=7.5, - help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", -) -parser.add_argument( - "--gpus", - default=[0,1,2,3], - help="gpu idx", -) -parser.add_argument( - '--local_rank', - default=-1, - type=int, - help='node rank for distributed training' -) -parser.add_argument( - '--launcher', - default='pytorch', - type=str, - help='node rank for distributed training' -) -opt = parser.parse_args() - -if __name__ == '__main__': - # seed_everything(42) - config = OmegaConf.load(f"{opt.config}") - opt.name = config['name'] - device=opt.device - - # stable diffusion - model = load_model_from_config(config, f"{opt.ckpt}").to(device) - - # Adaptor - model_ad = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device) - model_ad.load_state_dict(torch.load(opt.ckpt_ad)) - - # edge_generator - net_G = pidinet() - ckp = torch.load('models/table5_pidinet.pth', map_location='cpu')['state_dict'] - net_G.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()}) - net_G.to(device) - - experiments_root = osp.join('experiments', opt.name) - - # resume state - resume_state = load_resume_state(opt) - if resume_state is None: - mkdir_and_rename(experiments_root) - - # copy the yml file to the experiment root - copy_opt_file(opt.config, experiments_root) - - # WARNING: should not use get_root_logger in the above codes, including the called functions - # Otherwise the logger will not be properly initialized - log_file = osp.join(experiments_root, f"train_{opt.name}_{get_time_str()}.log") - logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) - logger.info(get_env_info()) - logger.info(dict2str(config)) - - - for v_idx in range(opt.n_samples): - with torch.no_grad(): - if opt.dpm_solver: - sampler = DPMSolverSampler(model) - elif opt.plms: - sampler = PLMSSampler(model) - else: - sampler = DDIMSampler(model) - c = model.get_learned_conditioning([opt.prompt]) - - if opt.type_in == 'sketch': - # costumer input - edge = cv2.imread(opt.path_cond) - edge = cv2.resize(edge,(512,512)) - edge = img2tensor(edge)[0].unsqueeze(0).unsqueeze(0)/255. - - # edge = 1-edge # for white background - edge = edge>0.5 - edge = edge.float() - elif opt.type_in == 'image': - im = cv2.imread(opt.path_cond) - im = cv2.resize(im,(512,512)) - im = img2tensor(im).unsqueeze(0)/255. - edge = net_G(im.cuda(non_blocking=True))[-1] - - edge = edge>0.5 - edge = edge.float() - else: - raise TypeError('Wrong input condition.') - - im_edge = tensor2img(edge) - cv2.imwrite(os.path.join(experiments_root, 'visualization', 'edge_idx%04d.png'%(v_idx)), im_edge) - - features_adapter = model_ad(edge.to(device)) - - shape = [opt.C, opt.H // opt.f, opt.W // opt.f] - - samples_ddim, intermediates = sampler.sample(S=opt.ddim_steps, - conditioning=c, - batch_size=1, - shape=shape, - verbose=False, - unconditional_guidance_scale=opt.scale, - unconditional_conditioning=model.get_learned_conditioning([opt.neg_prompt]), - eta=opt.ddim_eta, - x_T=None, - features_adapter1=features_adapter, - mode = 'sketch' - ) - - x_samples_ddim = model.decode_first_stage(samples_ddim) - x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() - for id_sample, x_sample in enumerate(x_samples_ddim): - x_sample = 255.*x_sample - img = x_sample.astype(np.uint8) - cv2.imwrite(os.path.join(experiments_root, 'visualization', 'sample_idx%04d_s%04d.png'%(v_idx, id_sample)), img[:,:,::-1]) \ No newline at end of file diff --git a/test_sketch_edit.py b/test_sketch_edit.py deleted file mode 100644 index 9d31c05169826d3f7f2a3187a6ab74f6635daea4..0000000000000000000000000000000000000000 --- a/test_sketch_edit.py +++ /dev/null @@ -1,331 +0,0 @@ -import argparse -import logging -import os -import os.path as osp -import time - -import cv2 -import matplotlib.pyplot as plt -import numpy as np -import torch -import torch.nn as nn -from basicsr.utils import (get_env_info, get_root_logger, get_time_str, - img2tensor, scandir, tensor2img) -from basicsr.utils.options import copy_opt_file, dict2str -from omegaconf import OmegaConf -from PIL import Image -from pytorch_lightning import seed_everything - -from dataset_coco import dataset_coco, dataset_coco_mask_color_sig -from dist_util import get_bare_model, init_dist, master_only -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.dpm_solver import DPMSolverSampler -from ldm.models.diffusion.plms import PLMSSampler -from ldm.modules.encoders.adapter import Adapter -from ldm.util import instantiate_from_config - - -def load_model_from_config(config, ckpt, verbose=False): - print(f"Loading model from {ckpt}") - pl_sd = torch.load(ckpt, map_location="cpu") - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") - sd = pl_sd["state_dict"] - model = instantiate_from_config(config.model) - m, u = model.load_state_dict(sd, strict=False) - if len(m) > 0 and verbose: - print("missing keys:") - print(m) - if len(u) > 0 and verbose: - print("unexpected keys:") - print(u) - - model.cuda() - model.eval() - return model - -@master_only -def mkdir_and_rename(path): - """mkdirs. If path exists, rename it with timestamp and create a new one. - Args: - path (str): Folder path. - """ - if osp.exists(path): - new_name = path + '_archived_' + get_time_str() - print(f'Path already exists. Rename it to {new_name}', flush=True) - os.rename(path, new_name) - os.makedirs(path, exist_ok=True) - os.makedirs(osp.join(experiments_root, 'models')) - os.makedirs(osp.join(experiments_root, 'training_states')) - os.makedirs(osp.join(experiments_root, 'visualization')) - -def load_resume_state(opt): - resume_state_path = None - if opt.auto_resume: - state_path = osp.join('experiments', opt.name, 'training_states') - if osp.isdir(state_path): - states = list(scandir(state_path, suffix='state', recursive=False, full_path=False)) - if len(states) != 0: - states = [float(v.split('.state')[0]) for v in states] - resume_state_path = osp.join(state_path, f'{max(states):.0f}.state') - opt.resume_state_path = resume_state_path - - if resume_state_path is None: - resume_state = None - else: - device_id = torch.cuda.current_device() - resume_state = torch.load(resume_state_path, map_location=lambda storage, loc: storage.cuda(device_id)) - - return resume_state - -parser = argparse.ArgumentParser() -parser.add_argument( - "--prompt", - type=str, - nargs="?", - default="A white cat" -) -parser.add_argument( - "--neg_prompt", - type=str, - default="ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, extra limbs, disfigured, deformed, body out of frame, bad anatomy, watermark, signature, cut off, low contrast, underexposed, overexposed, bad art, beginner, amateur, distorted face" -) -parser.add_argument( - "--path_cond", - type=str, - default="examples/edit_cat/edge_2.png" -) -parser.add_argument( - "--path_x0", - type=str, - default="examples/edit_cat/im.png" -) -parser.add_argument( - "--path_mask", - type=str, - default="examples/edit_cat/mask.png" -) -parser.add_argument( - "--bsize", - type=int, - default=8, - help="the prompt to render" -) -parser.add_argument( - "--epochs", - type=int, - default=10000, - help="the prompt to render" -) -parser.add_argument( - "--device", - type=str, - default="cuda" -) -parser.add_argument( - "--num_workers", - type=int, - default=8, - help="the prompt to render" -) -parser.add_argument( - "--use_shuffle", - type=bool, - default=True, - help="the prompt to render" -) -parser.add_argument( - "--dpm_solver", - action='store_true', - help="use dpm_solver sampling", -) -parser.add_argument( - "--plms", - action='store_true', - help="use plms sampling", -) -parser.add_argument( - "--auto_resume", - action='store_true', - help="use plms sampling", -) -parser.add_argument( - "--ckpt", - type=str, - default="models/sd-v1-4.ckpt", - help="path to checkpoint of model", -) -parser.add_argument( - "--ckpt_ad", - type=str, - default="models/t2iadapter_sketch_sd14v1.pth" -) -parser.add_argument( - "--config", - type=str, - default="configs/stable-diffusion/test_sketch_edit.yaml", - help="path to config which constructs model", -) -parser.add_argument( - "--print_fq", - type=int, - default=100, - help="path to config which constructs model", -) -parser.add_argument( - "--H", - type=int, - default=512, - help="image height, in pixel space", -) -parser.add_argument( - "--W", - type=int, - default=512, - help="image width, in pixel space", -) -parser.add_argument( - "--C", - type=int, - default=4, - help="latent channels", -) -parser.add_argument( - "--f", - type=int, - default=8, - help="downsampling factor", -) -parser.add_argument( - "--ddim_steps", - type=int, - default=50, - help="number of ddim sampling steps", -) -parser.add_argument( - "--n_samples", - type=int, - default=10, - help="how many samples to produce for each given prompt. A.k.a. batch size", -) -parser.add_argument( - "--ddim_eta", - type=float, - default=0.0, - help="ddim eta (eta=0.0 corresponds to deterministic sampling", -) -parser.add_argument( - "--scale", - type=float, - default=7.5, - help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", -) -parser.add_argument( - "--gpus", - default=[0,1,2,3], - help="gpu idx", -) -parser.add_argument( - '--local_rank', - default=-1, - type=int, - help='node rank for distributed training' -) -parser.add_argument( - '--launcher', - default='pytorch', - type=str, - help='node rank for distributed training' -) -opt = parser.parse_args() - -if __name__ == '__main__': - # seed_everything(42) - config = OmegaConf.load(f"{opt.config}") - opt.name = config['name'] - device = opt.device - - # stable diffusion - model = load_model_from_config(config, f"{opt.ckpt}").to(device) - - # Adaptor - model_ad = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device) - model_ad.load_state_dict(torch.load(opt.ckpt_ad)) - - experiments_root = osp.join('experiments', opt.name) - - # resume state - resume_state = load_resume_state(opt) - if resume_state is None: - mkdir_and_rename(experiments_root) - - # copy the yml file to the experiment root - copy_opt_file(opt.config, experiments_root) - - # WARNING: should not use get_root_logger in the above codes, including the called functions - # Otherwise the logger will not be properly initialized - log_file = osp.join(experiments_root, f"train_{opt.name}_{get_time_str()}.log") - logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) - logger.info(get_env_info()) - logger.info(dict2str(config)) - - - for v_idx in range(opt.n_samples): - with torch.no_grad(): - if opt.dpm_solver: - sampler = DPMSolverSampler(model) - elif opt.plms: - sampler = PLMSSampler(model) - else: - sampler = DDIMSampler(model) - c = model.get_learned_conditioning([opt.prompt]) - - # costumer input - edge = cv2.imread(opt.path_cond) - edge = cv2.resize(edge,(512,512)) - edge = img2tensor(edge)[0].unsqueeze(0).unsqueeze(0)/255. - - # edge = 1-edge # for white background - edge = edge>0.5 - edge = edge.float() - - im_edge = tensor2img(edge) - cv2.imwrite(os.path.join(experiments_root, 'visualization', 'edge_idx%04d.png'%(v_idx)), im_edge) - - features_adapter = model_ad(edge.to(device)) - - # latent of original image - x0 = cv2.imread(opt.path_x0) - x0 = img2tensor(x0).unsqueeze(0)/255. - x0 = model.encode_first_stage((x0.to(device)*2-1.).cuda(non_blocking=True)) - x0 = model.get_first_stage_encoding(x0) - - # inpainting mask - mask = cv2.imread(opt.path_mask) - mask = cv2.resize(mask, (64, 64)) - mask = 1 - img2tensor(mask).unsqueeze(0)/255. - mask = mask>0.5 - mask = mask.float()[:,0:1,:,:].to(device) - - shape = [opt.C, opt.H // opt.f, opt.W // opt.f] - - samples_ddim, _ = sampler.sample(S=opt.ddim_steps, - conditioning=c, - mask=mask, x0=x0.to(device), - batch_size=1, - shape=shape, - verbose=False, - unconditional_guidance_scale=opt.scale, - unconditional_conditioning=model.get_learned_conditioning([opt.neg_prompt]), - eta=opt.ddim_eta, - x_T=None, - features_adapter1=features_adapter - ) - - x_samples_ddim = model.decode_first_stage(samples_ddim) - x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() - for id_sample, x_sample in enumerate(x_samples_ddim): - x_sample = 255.*x_sample - img = x_sample.astype(np.uint8) - cv2.imwrite(os.path.join(experiments_root, 'visualization', 'sample_idx%04d_s%04d.png'%(v_idx, id_sample)), img[:,:,::-1]) \ No newline at end of file diff --git a/train_seg.py b/train_seg.py deleted file mode 100644 index 2ab4ae49dae9799afafff6b6a7a78544745a80d8..0000000000000000000000000000000000000000 --- a/train_seg.py +++ /dev/null @@ -1,373 +0,0 @@ -from load_json import load_json -import cv2 -import torch -import os -from basicsr.utils import img2tensor, tensor2img, scandir, get_time_str, get_root_logger, get_env_info -from dataset_coco import dataset_coco_mask_color -import argparse -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.plms import PLMSSampler -from ldm.models.diffusion.dpm_solver import DPMSolverSampler -from omegaconf import OmegaConf -from ldm.util import instantiate_from_config -from ldm.modules.encoders.adapter import Adapter -from PIL import Image -import numpy as np -import torch.nn as nn -import matplotlib.pyplot as plt -import time -import os.path as osp -from basicsr.utils.options import copy_opt_file, dict2str -import logging -from dist_util import init_dist, master_only, get_bare_model, get_dist_info - -def load_model_from_config(config, ckpt, verbose=False): - print(f"Loading model from {ckpt}") - pl_sd = torch.load(ckpt, map_location="cpu") - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") - sd = pl_sd["state_dict"] - model = instantiate_from_config(config.model) - m, u = model.load_state_dict(sd, strict=False) - if len(m) > 0 and verbose: - print("missing keys:") - print(m) - if len(u) > 0 and verbose: - print("unexpected keys:") - print(u) - - model.cuda() - model.eval() - return model - -@master_only -def mkdir_and_rename(path): - """mkdirs. If path exists, rename it with timestamp and create a new one. - - Args: - path (str): Folder path. - """ - if osp.exists(path): - new_name = path + '_archived_' + get_time_str() - print(f'Path already exists. Rename it to {new_name}', flush=True) - os.rename(path, new_name) - os.makedirs(path, exist_ok=True) - os.makedirs(osp.join(experiments_root, 'models')) - os.makedirs(osp.join(experiments_root, 'training_states')) - os.makedirs(osp.join(experiments_root, 'visualization')) - -def load_resume_state(opt): - resume_state_path = None - if opt.auto_resume: - state_path = osp.join('experiments', opt.name, 'training_states') - if osp.isdir(state_path): - states = list(scandir(state_path, suffix='state', recursive=False, full_path=False)) - if len(states) != 0: - states = [float(v.split('.state')[0]) for v in states] - resume_state_path = osp.join(state_path, f'{max(states):.0f}.state') - opt.resume_state_path = resume_state_path - # else: - # if opt['path'].get('resume_state'): - # resume_state_path = opt['path']['resume_state'] - - if resume_state_path is None: - resume_state = None - else: - device_id = torch.cuda.current_device() - resume_state = torch.load(resume_state_path, map_location=lambda storage, loc: storage.cuda(device_id)) - # check_resume(opt, resume_state['iter']) - return resume_state - -parser = argparse.ArgumentParser() -parser.add_argument( - "--bsize", - type=int, - default=8, - help="the prompt to render" -) -parser.add_argument( - "--epochs", - type=int, - default=10000, - help="the prompt to render" -) -parser.add_argument( - "--num_workers", - type=int, - default=8, - help="the prompt to render" -) -parser.add_argument( - "--use_shuffle", - type=bool, - default=True, - help="the prompt to render" -) -parser.add_argument( - "--dpm_solver", - action='store_true', - help="use dpm_solver sampling", -) -parser.add_argument( - "--plms", - action='store_true', - help="use plms sampling", -) -parser.add_argument( - "--auto_resume", - action='store_true', - help="use plms sampling", -) -parser.add_argument( - "--ckpt", - type=str, - default="ckp/sd-v1-4.ckpt", - help="path to checkpoint of model", -) -parser.add_argument( - "--config", - type=str, - default="configs/stable-diffusion/train_mask.yaml", - help="path to config which constructs model", -) -parser.add_argument( - "--print_fq", - type=int, - default=100, - help="path to config which constructs model", -) -parser.add_argument( - "--H", - type=int, - default=512, - help="image height, in pixel space", -) -parser.add_argument( - "--W", - type=int, - default=512, - help="image width, in pixel space", -) -parser.add_argument( - "--C", - type=int, - default=4, - help="latent channels", -) -parser.add_argument( - "--f", - type=int, - default=8, - help="downsampling factor", -) -parser.add_argument( - "--ddim_steps", - type=int, - default=50, - help="number of ddim sampling steps", -) -parser.add_argument( - "--n_samples", - type=int, - default=1, - help="how many samples to produce for each given prompt. A.k.a. batch size", -) -parser.add_argument( - "--ddim_eta", - type=float, - default=0.0, - help="ddim eta (eta=0.0 corresponds to deterministic sampling", -) -parser.add_argument( - "--scale", - type=float, - default=7.5, - help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", -) -parser.add_argument( - "--gpus", - default=[0,1,2,3], - help="gpu idx", -) -parser.add_argument( - '--local_rank', - default=0, - type=int, - help='node rank for distributed training' -) -parser.add_argument( - '--launcher', - default='pytorch', - type=str, - help='node rank for distributed training' -) -opt = parser.parse_args() - -if __name__ == '__main__': - config = OmegaConf.load(f"{opt.config}") - opt.name = config['name'] - - # distributed setting - init_dist(opt.launcher) - torch.backends.cudnn.benchmark = True - device='cuda' - torch.cuda.set_device(opt.local_rank) - - # dataset - path_json_train = 'coco_stuff/mask/annotations/captions_train2017.json' - path_json_val = 'coco_stuff/mask/annotations/captions_val2017.json' - train_dataset = dataset_coco_mask_color(path_json_train, - root_path_im='coco/train2017', - root_path_mask='coco_stuff/mask/train2017_color', - image_size=512 - ) - train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) - val_dataset = dataset_coco_mask_color(path_json_val, - root_path_im='coco/val2017', - root_path_mask='coco_stuff/mask/val2017_color', - image_size=512 - ) - train_dataloader = torch.utils.data.DataLoader( - train_dataset, - batch_size=opt.bsize, - shuffle=(train_sampler is None), - num_workers=opt.num_workers, - pin_memory=True, - sampler=train_sampler) - val_dataloader = torch.utils.data.DataLoader( - val_dataset, - batch_size=1, - shuffle=False, - num_workers=1, - pin_memory=False) - - # stable diffusion - model = load_model_from_config(config, f"{opt.ckpt}").to(device) - - # sketch encoder - model_ad = Adapter(cin=int(3*64), channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device) - - - # to gpus - model_ad = torch.nn.parallel.DistributedDataParallel( - model_ad, - device_ids=[opt.local_rank], - output_device=opt.local_rank) - model = torch.nn.parallel.DistributedDataParallel( - model, - device_ids=[opt.local_rank], - output_device=opt.local_rank) - # device_ids=[torch.cuda.current_device()]) - - # optimizer - params = list(model_ad.parameters()) - optimizer = torch.optim.AdamW(params, lr=config['training']['lr']) - - experiments_root = osp.join('experiments', opt.name) - - # resume state - resume_state = load_resume_state(opt) - if resume_state is None: - mkdir_and_rename(experiments_root) - start_epoch = 0 - current_iter = 0 - # WARNING: should not use get_root_logger in the above codes, including the called functions - # Otherwise the logger will not be properly initialized - log_file = osp.join(experiments_root, f"train_{opt.name}_{get_time_str()}.log") - logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) - logger.info(get_env_info()) - logger.info(dict2str(config)) - else: - # WARNING: should not use get_root_logger in the above codes, including the called functions - # Otherwise the logger will not be properly initialized - log_file = osp.join(experiments_root, f"train_{opt.name}_{get_time_str()}.log") - logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) - logger.info(get_env_info()) - logger.info(dict2str(config)) - resume_optimizers = resume_state['optimizers'] - optimizer.load_state_dict(resume_optimizers) - logger.info(f"Resuming training from epoch: {resume_state['epoch']}, " f"iter: {resume_state['iter']}.") - start_epoch = resume_state['epoch'] - current_iter = resume_state['iter'] - - # copy the yml file to the experiment root - copy_opt_file(opt.config, experiments_root) - - # training - logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter}') - for epoch in range(start_epoch, opt.epochs): - train_dataloader.sampler.set_epoch(epoch) - # train - for _, data in enumerate(train_dataloader): - current_iter += 1 - with torch.no_grad(): - c = model.module.get_learned_conditioning(data['sentence']) - z = model.module.encode_first_stage((data['im']*2-1.).cuda(non_blocking=True)) - z = model.module.get_first_stage_encoding(z) - - mask = data['mask'] - optimizer.zero_grad() - model.zero_grad() - features_adapter = model_ad(mask) - l_pixel, loss_dict = model(z, c=c, features_adapter = features_adapter) - l_pixel.backward() - optimizer.step() - - if (current_iter+1)%opt.print_fq == 0: - logger.info(loss_dict) - - # save checkpoint - rank, _ = get_dist_info() - if (rank==0) and ((current_iter+1)%config['training']['save_freq'] == 0): - save_filename = f'model_ad_{current_iter+1}.pth' - save_path = os.path.join(experiments_root, 'models', save_filename) - save_dict = {} - model_ad_bare = get_bare_model(model_ad) - state_dict = model_ad_bare.state_dict() - for key, param in state_dict.items(): - if key.startswith('module.'): # remove unnecessary 'module.' - key = key[7:] - save_dict[key] = param.cpu() - torch.save(save_dict, save_path) - # save state - state = {'epoch': epoch, 'iter': current_iter+1, 'optimizers': optimizer.state_dict()} - save_filename = f'{current_iter+1}.state' - save_path = os.path.join(experiments_root, 'training_states', save_filename) - torch.save(state, save_path) - - # val - rank, _ = get_dist_info() - if rank==0: - for data in val_dataloader: - with torch.no_grad(): - if opt.dpm_solver: - sampler = DPMSolverSampler(model.module) - elif opt.plms: - sampler = PLMSSampler(model.module) - else: - sampler = DDIMSampler(model.module) - c = model.module.get_learned_conditioning(data['sentence']) - mask = data['mask'] - im_mask = tensor2img(mask) - cv2.imwrite(os.path.join(experiments_root, 'visualization', 'mask_%04d.png'%epoch), im_mask) - features_adapter = model_ad(mask) - shape = [opt.C, opt.H // opt.f, opt.W // opt.f] - samples_ddim, _ = sampler.sample(S=opt.ddim_steps, - conditioning=c, - batch_size=opt.n_samples, - shape=shape, - verbose=False, - unconditional_guidance_scale=opt.scale, - unconditional_conditioning=model.module.get_learned_conditioning(opt.n_samples * [""]), - eta=opt.ddim_eta, - x_T=None, - features_adapter1=features_adapter) - x_samples_ddim = model.module.decode_first_stage(samples_ddim) - x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() - for id_sample, x_sample in enumerate(x_samples_ddim): - x_sample = 255.*x_sample - img = x_sample.astype(np.uint8) - img = cv2.putText(img.copy(), data['sentence'][0], (10,30), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,255,0), 2) - cv2.imwrite(os.path.join(experiments_root, 'visualization', 'sample_e%04d_s%04d.png'%(epoch, id_sample)), img[:,:,::-1]) - break diff --git a/train_sketch.py b/train_sketch.py deleted file mode 100644 index a4672dfef9195800500dc5a4438fa78224c671db..0000000000000000000000000000000000000000 --- a/train_sketch.py +++ /dev/null @@ -1,400 +0,0 @@ -import argparse -import logging -import os -import os.path as osp -import time - -import cv2 -import matplotlib.pyplot as plt -import numpy as np -import torch -import torch.nn as nn -from basicsr.utils import (get_env_info, get_root_logger, get_time_str, - img2tensor, scandir, tensor2img) -from basicsr.utils.options import copy_opt_file, dict2str -from omegaconf import OmegaConf -from PIL import Image - -from dataset_coco import dataset_coco_mask_color -from dist_util import get_bare_model, get_dist_info, init_dist, master_only -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.dpm_solver import DPMSolverSampler -from ldm.models.diffusion.plms import PLMSSampler -from ldm.modules.encoders.adapter import Adapter -from ldm.util import instantiate_from_config -from load_json import load_json -from model_edge import pidinet - - -def load_model_from_config(config, ckpt, verbose=False): - print(f"Loading model from {ckpt}") - pl_sd = torch.load(ckpt, map_location="cpu") - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") - sd = pl_sd["state_dict"] - model = instantiate_from_config(config.model) - m, u = model.load_state_dict(sd, strict=False) - if len(m) > 0 and verbose: - print("missing keys:") - print(m) - if len(u) > 0 and verbose: - print("unexpected keys:") - print(u) - - model.cuda() - model.eval() - return model - -@master_only -def mkdir_and_rename(path): - """mkdirs. If path exists, rename it with timestamp and create a new one. - - Args: - path (str): Folder path. - """ - if osp.exists(path): - new_name = path + '_archived_' + get_time_str() - print(f'Path already exists. Rename it to {new_name}', flush=True) - os.rename(path, new_name) - os.makedirs(path, exist_ok=True) - os.makedirs(osp.join(experiments_root, 'models')) - os.makedirs(osp.join(experiments_root, 'training_states')) - os.makedirs(osp.join(experiments_root, 'visualization')) - -def load_resume_state(opt): - resume_state_path = None - if opt.auto_resume: - state_path = osp.join('experiments', opt.name, 'training_states') - if osp.isdir(state_path): - states = list(scandir(state_path, suffix='state', recursive=False, full_path=False)) - if len(states) != 0: - states = [float(v.split('.state')[0]) for v in states] - resume_state_path = osp.join(state_path, f'{max(states):.0f}.state') - opt.resume_state_path = resume_state_path - # else: - # if opt['path'].get('resume_state'): - # resume_state_path = opt['path']['resume_state'] - - if resume_state_path is None: - resume_state = None - else: - device_id = torch.cuda.current_device() - resume_state = torch.load(resume_state_path, map_location=lambda storage, loc: storage.cuda(device_id)) - # check_resume(opt, resume_state['iter']) - return resume_state - -parser = argparse.ArgumentParser() -parser.add_argument( - "--bsize", - type=int, - default=8, - help="the prompt to render" -) -parser.add_argument( - "--epochs", - type=int, - default=10000, - help="the prompt to render" -) -parser.add_argument( - "--num_workers", - type=int, - default=8, - help="the prompt to render" -) -parser.add_argument( - "--use_shuffle", - type=bool, - default=True, - help="the prompt to render" -) -parser.add_argument( - "--dpm_solver", - action='store_true', - help="use dpm_solver sampling", -) -parser.add_argument( - "--plms", - action='store_true', - help="use plms sampling", -) -parser.add_argument( - "--auto_resume", - action='store_true', - help="use plms sampling", -) -parser.add_argument( - "--ckpt", - type=str, - default="models/sd-v1-4.ckpt", - help="path to checkpoint of model", -) -parser.add_argument( - "--config", - type=str, - default="configs/stable-diffusion/train_sketch.yaml", - help="path to config which constructs model", -) -parser.add_argument( - "--print_fq", - type=int, - default=100, - help="path to config which constructs model", -) -parser.add_argument( - "--H", - type=int, - default=512, - help="image height, in pixel space", -) -parser.add_argument( - "--W", - type=int, - default=512, - help="image width, in pixel space", -) -parser.add_argument( - "--C", - type=int, - default=4, - help="latent channels", -) -parser.add_argument( - "--f", - type=int, - default=8, - help="downsampling factor", -) -parser.add_argument( - "--ddim_steps", - type=int, - default=50, - help="number of ddim sampling steps", -) -parser.add_argument( - "--n_samples", - type=int, - default=1, - help="how many samples to produce for each given prompt. A.k.a. batch size", -) -parser.add_argument( - "--ddim_eta", - type=float, - default=0.0, - help="ddim eta (eta=0.0 corresponds to deterministic sampling", -) -parser.add_argument( - "--scale", - type=float, - default=7.5, - help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", -) -parser.add_argument( - "--gpus", - default=[0,1,2,3], - help="gpu idx", -) -parser.add_argument( - '--local_rank', - default=0, - type=int, - help='node rank for distributed training' -) -parser.add_argument( - '--launcher', - default='pytorch', - type=str, - help='node rank for distributed training' -) -parser.add_argument( - '--l_cond', - default=4, - type=int, - help='number of scales' -) -opt = parser.parse_args() - -if __name__ == '__main__': - config = OmegaConf.load(f"{opt.config}") - opt.name = config['name'] - - # distributed setting - init_dist(opt.launcher) - torch.backends.cudnn.benchmark = True - device='cuda' - torch.cuda.set_device(opt.local_rank) - - # dataset - path_json_train = 'coco_stuff/mask/annotations/captions_train2017.json' - path_json_val = 'coco_stuff/mask/annotations/captions_val2017.json' - train_dataset = dataset_coco_mask_color(path_json_train, - root_path_im='coco/train2017', - root_path_mask='coco_stuff/mask/train2017_color', - image_size=512 - ) - train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) - val_dataset = dataset_coco_mask_color(path_json_val, - root_path_im='coco/val2017', - root_path_mask='coco_stuff/mask/val2017_color', - image_size=512 - ) - train_dataloader = torch.utils.data.DataLoader( - train_dataset, - batch_size=opt.bsize, - shuffle=(train_sampler is None), - num_workers=opt.num_workers, - pin_memory=True, - sampler=train_sampler) - val_dataloader = torch.utils.data.DataLoader( - val_dataset, - batch_size=1, - shuffle=False, - num_workers=1, - pin_memory=False) - - # edge_generator - net_G = pidinet() - ckp = torch.load('models/table5_pidinet.pth', map_location='cpu')['state_dict'] - net_G.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()}) - net_G.cuda() - - # stable diffusion - model = load_model_from_config(config, f"{opt.ckpt}").to(device) - - # sketch encoder - model_ad = Adapter(channels=[320, 640, 1280, 1280][:4], nums_rb=2, ksize=1, sk=True, use_conv=False).to(device) - - # to gpus - model_ad = torch.nn.parallel.DistributedDataParallel( - model_ad, - device_ids=[opt.local_rank], - output_device=opt.local_rank) - model = torch.nn.parallel.DistributedDataParallel( - model, - device_ids=[opt.local_rank], - output_device=opt.local_rank) - # device_ids=[torch.cuda.current_device()]) - net_G = torch.nn.parallel.DistributedDataParallel( - net_G, - device_ids=[opt.local_rank], - output_device=opt.local_rank) - # device_ids=[torch.cuda.current_device()]) - - # optimizer - params = list(model_ad.parameters()) - optimizer = torch.optim.AdamW(params, lr=config['training']['lr']) - - experiments_root = osp.join('experiments', opt.name) - - # resume state - resume_state = load_resume_state(opt) - if resume_state is None: - mkdir_and_rename(experiments_root) - start_epoch = 0 - current_iter = 0 - # WARNING: should not use get_root_logger in the above codes, including the called functions - # Otherwise the logger will not be properly initialized - log_file = osp.join(experiments_root, f"train_{opt.name}_{get_time_str()}.log") - logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) - logger.info(get_env_info()) - logger.info(dict2str(config)) - else: - # WARNING: should not use get_root_logger in the above codes, including the called functions - # Otherwise the logger will not be properly initialized - log_file = osp.join(experiments_root, f"train_{opt.name}_{get_time_str()}.log") - logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) - logger.info(get_env_info()) - logger.info(dict2str(config)) - resume_optimizers = resume_state['optimizers'] - optimizer.load_state_dict(resume_optimizers) - logger.info(f"Resuming training from epoch: {resume_state['epoch']}, " f"iter: {resume_state['iter']}.") - start_epoch = resume_state['epoch'] - current_iter = resume_state['iter'] - - # copy the yml file to the experiment root - copy_opt_file(opt.config, experiments_root) - - - # training - logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter}') - for epoch in range(start_epoch, opt.epochs): - train_dataloader.sampler.set_epoch(epoch) - # train - for _, data in enumerate(train_dataloader): - current_iter += 1 - with torch.no_grad(): - edge = net_G(data['im'].cuda(non_blocking=True))[-1] - edge = edge>0.5 - edge = edge.float() - c = model.module.get_learned_conditioning(data['sentence']) - z = model.module.encode_first_stage((data['im']*2-1.).cuda(non_blocking=True)) - z = model.module.get_first_stage_encoding(z) - - optimizer.zero_grad() - model.zero_grad() - features_adapter = model_ad(edge) - l_pixel, loss_dict = model(z, c=c, features_adapter = features_adapter) - l_pixel.backward() - optimizer.step() - - if (current_iter+1)%opt.print_fq == 0: - logger.info(loss_dict) - - # save checkpoint - rank, _ = get_dist_info() - if (rank==0) and ((current_iter+1)%config['training']['save_freq'] == 0): - save_filename = f'model_ad_{current_iter+1}.pth' - save_path = os.path.join(experiments_root, 'models', save_filename) - save_dict = {} - model_ad_bare = get_bare_model(model_ad) - state_dict = model_ad_bare.state_dict() - for key, param in state_dict.items(): - if key.startswith('module.'): # remove unnecessary 'module.' - key = key[7:] - save_dict[key] = param.cpu() - torch.save(save_dict, save_path) - # save state - state = {'epoch': epoch, 'iter': current_iter+1, 'optimizers': optimizer.state_dict()} - save_filename = f'{current_iter+1}.state' - save_path = os.path.join(experiments_root, 'training_states', save_filename) - torch.save(state, save_path) - - # val - rank, _ = get_dist_info() - if rank==0: - for data in val_dataloader: - with torch.no_grad(): - if opt.dpm_solver: - sampler = DPMSolverSampler(model.module) - elif opt.plms: - sampler = PLMSSampler(model.module) - else: - sampler = DDIMSampler(model.module) - print(data['im'].shape) - c = model.module.get_learned_conditioning(data['sentence']) - edge = net_G(data['im'].cuda(non_blocking=True))[-1] - edge = edge>0.5 - edge = edge.float() - im_edge = tensor2img(edge) - cv2.imwrite(os.path.join(experiments_root, 'visualization', 'edge_%04d.png'%epoch), im_edge) - features_adapter = model_ad(edge) - shape = [opt.C, opt.H // opt.f, opt.W // opt.f] - samples_ddim, _ = sampler.sample(S=opt.ddim_steps, - conditioning=c, - batch_size=opt.n_samples, - shape=shape, - verbose=False, - unconditional_guidance_scale=opt.scale, - unconditional_conditioning=model.module.get_learned_conditioning(opt.n_samples * [""]), - eta=opt.ddim_eta, - x_T=None, - features_adapter1=features_adapter) - x_samples_ddim = model.module.decode_first_stage(samples_ddim) - x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() - for id_sample, x_sample in enumerate(x_samples_ddim): - x_sample = 255.*x_sample - img = x_sample.astype(np.uint8) - img = cv2.putText(img.copy(), data['sentence'][0], (10,30), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,255,0), 2) - cv2.imwrite(os.path.join(experiments_root, 'visualization', 'sample_e%04d_s%04d.png'%(epoch, id_sample)), img[:,:,::-1]) - break