Update app.py
Browse files
app.py
CHANGED
@@ -15,7 +15,7 @@ from utils.unet import UNet3DConditionModel
|
|
15 |
from utils.pipeline_magictime import MagicTimePipeline
|
16 |
from utils.util import save_videos_grid, convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint, load_diffusers_lora_unet, convert_ldm_clip_text_model
|
17 |
|
18 |
-
pretrained_model_path = "
|
19 |
inference_config_path = "./sample_configs/RealisticVision.yaml"
|
20 |
magic_adapter_s_path = "./ckpts/Magic_Weights/magic_adapter_s/magic_adapter_s.ckpt"
|
21 |
magic_adapter_t_path = "./ckpts/Magic_Weights/magic_adapter_t"
|
@@ -91,16 +91,10 @@ class MagicTimeController:
|
|
91 |
self.unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda()
|
92 |
|
93 |
self.text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
94 |
-
|
95 |
-
self.update_dreambooth(self.dreambooth_list[0])
|
96 |
self.update_motion_module(self.motion_module_list[0])
|
97 |
-
|
98 |
-
|
99 |
-
magic_adapter_s_state_dict = torch.load(magic_adapter_s_path, map_location="cpu")
|
100 |
-
self.unet = load_diffusers_lora_unet(self.unet, magic_adapter_s_state_dict, alpha=1.0)
|
101 |
-
self.unet = Swift.from_pretrained(self.unet, magic_adapter_t_path)
|
102 |
-
self.text_encoder = Swift.from_pretrained(self.text_encoder, magic_text_encoder_path)
|
103 |
-
|
104 |
|
105 |
def refresh_motion_module(self):
|
106 |
motion_module_list = glob(os.path.join(self.motion_module_dir, "*.ckpt"))
|
@@ -126,6 +120,13 @@ class MagicTimeController:
|
|
126 |
|
127 |
text_model = copy.deepcopy(self.text_model)
|
128 |
self.text_encoder = convert_ldm_clip_text_model(text_model, dreambooth_state_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
return gr.Dropdown()
|
130 |
|
131 |
def update_motion_module(self, motion_module_dropdown):
|
@@ -147,8 +148,8 @@ class MagicTimeController:
|
|
147 |
height_slider,
|
148 |
seed_textbox,
|
149 |
):
|
150 |
-
if self.selected_dreambooth != dreambooth_dropdown: self.update_dreambooth(dreambooth_dropdown)
|
151 |
if self.selected_motion_module != motion_module_dropdown: self.update_motion_module(motion_module_dropdown)
|
|
|
152 |
|
153 |
if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention()
|
154 |
|
|
|
15 |
from utils.pipeline_magictime import MagicTimePipeline
|
16 |
from utils.util import save_videos_grid, convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint, load_diffusers_lora_unet, convert_ldm_clip_text_model
|
17 |
|
18 |
+
pretrained_model_path = "runwayml/stable-diffusion-v1-5"
|
19 |
inference_config_path = "./sample_configs/RealisticVision.yaml"
|
20 |
magic_adapter_s_path = "./ckpts/Magic_Weights/magic_adapter_s/magic_adapter_s.ckpt"
|
21 |
magic_adapter_t_path = "./ckpts/Magic_Weights/magic_adapter_t"
|
|
|
91 |
self.unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda()
|
92 |
|
93 |
self.text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
94 |
+
|
|
|
95 |
self.update_motion_module(self.motion_module_list[0])
|
96 |
+
self.update_dreambooth(self.dreambooth_list[0])
|
97 |
+
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
def refresh_motion_module(self):
|
100 |
motion_module_list = glob(os.path.join(self.motion_module_dir, "*.ckpt"))
|
|
|
120 |
|
121 |
text_model = copy.deepcopy(self.text_model)
|
122 |
self.text_encoder = convert_ldm_clip_text_model(text_model, dreambooth_state_dict)
|
123 |
+
|
124 |
+
from swift import Swift
|
125 |
+
magic_adapter_s_state_dict = torch.load(magic_adapter_s_path, map_location="cpu")
|
126 |
+
self.unet = load_diffusers_lora_unet(self.unet, magic_adapter_s_state_dict, alpha=1.0)
|
127 |
+
self.unet = Swift.from_pretrained(self.unet, magic_adapter_t_path)
|
128 |
+
self.text_encoder = Swift.from_pretrained(self.text_encoder, magic_text_encoder_path)
|
129 |
+
|
130 |
return gr.Dropdown()
|
131 |
|
132 |
def update_motion_module(self, motion_module_dropdown):
|
|
|
148 |
height_slider,
|
149 |
seed_textbox,
|
150 |
):
|
|
|
151 |
if self.selected_motion_module != motion_module_dropdown: self.update_motion_module(motion_module_dropdown)
|
152 |
+
if self.selected_dreambooth != dreambooth_dropdown: self.update_dreambooth(dreambooth_dropdown)
|
153 |
|
154 |
if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention()
|
155 |
|