BestWishYsh
commited on
Commit
•
ac6cff8
1
Parent(s):
a5e282c
Update app.py
Browse filesFix weight loading.
app.py
CHANGED
@@ -91,7 +91,8 @@ class MagicTimeController:
|
|
91 |
self.vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").cuda()
|
92 |
self.unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda()
|
93 |
self.text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
94 |
-
|
|
|
95 |
# self.tokenizer = tokenizer
|
96 |
# self.text_encoder = text_encoder
|
97 |
# self.vae = vae
|
@@ -99,6 +100,7 @@ class MagicTimeController:
|
|
99 |
# self.text_model = text_model
|
100 |
|
101 |
self.update_motion_module(self.motion_module_list[0])
|
|
|
102 |
self.update_dreambooth(self.dreambooth_list[0])
|
103 |
|
104 |
|
@@ -121,9 +123,10 @@ class MagicTimeController:
|
|
121 |
converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, self.vae.config)
|
122 |
self.vae.load_state_dict(converted_vae_checkpoint)
|
123 |
|
124 |
-
converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, self.
|
|
|
125 |
self.unet.load_state_dict(converted_unet_checkpoint, strict=False)
|
126 |
-
|
127 |
text_model = copy.deepcopy(self.text_model)
|
128 |
self.text_encoder = convert_ldm_clip_text_model(text_model, dreambooth_state_dict)
|
129 |
|
@@ -143,6 +146,13 @@ class MagicTimeController:
|
|
143 |
assert len(unexpected) == 0
|
144 |
return gr.Dropdown()
|
145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
def magictime(
|
148 |
self,
|
@@ -155,6 +165,7 @@ class MagicTimeController:
|
|
155 |
seed_textbox,
|
156 |
):
|
157 |
if self.selected_motion_module != motion_module_dropdown: self.update_motion_module(motion_module_dropdown)
|
|
|
158 |
if self.selected_dreambooth != dreambooth_dropdown: self.update_dreambooth(dreambooth_dropdown)
|
159 |
|
160 |
if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention()
|
|
|
91 |
self.vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").cuda()
|
92 |
self.unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda()
|
93 |
self.text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
94 |
+
self.unet_model = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs))
|
95 |
+
|
96 |
# self.tokenizer = tokenizer
|
97 |
# self.text_encoder = text_encoder
|
98 |
# self.vae = vae
|
|
|
100 |
# self.text_model = text_model
|
101 |
|
102 |
self.update_motion_module(self.motion_module_list[0])
|
103 |
+
self.update_motion_module_2(self.motion_module_list[0])
|
104 |
self.update_dreambooth(self.dreambooth_list[0])
|
105 |
|
106 |
|
|
|
123 |
converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, self.vae.config)
|
124 |
self.vae.load_state_dict(converted_vae_checkpoint)
|
125 |
|
126 |
+
converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, self.unet_model.config)
|
127 |
+
self.unet = copy.deepcopy(self.unet_model)
|
128 |
self.unet.load_state_dict(converted_unet_checkpoint, strict=False)
|
129 |
+
|
130 |
text_model = copy.deepcopy(self.text_model)
|
131 |
self.text_encoder = convert_ldm_clip_text_model(text_model, dreambooth_state_dict)
|
132 |
|
|
|
146 |
assert len(unexpected) == 0
|
147 |
return gr.Dropdown()
|
148 |
|
149 |
+
def update_motion_module_2(self, motion_module_dropdown):
|
150 |
+
self.selected_motion_module = motion_module_dropdown
|
151 |
+
motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown)
|
152 |
+
motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu")
|
153 |
+
_, unexpected = self.unet_model.load_state_dict(motion_module_state_dict, strict=False)
|
154 |
+
assert len(unexpected) == 0
|
155 |
+
return gr.Dropdown()
|
156 |
|
157 |
def magictime(
|
158 |
self,
|
|
|
165 |
seed_textbox,
|
166 |
):
|
167 |
if self.selected_motion_module != motion_module_dropdown: self.update_motion_module(motion_module_dropdown)
|
168 |
+
if self.selected_motion_module != motion_module_dropdown: self.update_motion_module_2(motion_module_dropdown)
|
169 |
if self.selected_dreambooth != dreambooth_dropdown: self.update_dreambooth(dreambooth_dropdown)
|
170 |
|
171 |
if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention()
|