Spaces:
Running
on
Zero
Running
on
Zero
NIRVANALAN
commited on
Commit
•
1fa8ce9
1
Parent(s):
c3a2df4
update
Browse files- configs/i23d_args.json +1 -1
- nsr/train_util_diffusion.py +26 -9
configs/i23d_args.json
CHANGED
@@ -30,7 +30,7 @@
|
|
30 |
"log_interval": 50,
|
31 |
"eval_interval": 5000,
|
32 |
"save_interval": 10000,
|
33 |
-
"resume_checkpoint": "/
|
34 |
"resume_cldm_checkpoint": "",
|
35 |
"resume_checkpoint_EG3D": "",
|
36 |
"use_fp16": false,
|
|
|
30 |
"log_interval": 50,
|
31 |
"eval_interval": 5000,
|
32 |
"save_interval": 10000,
|
33 |
+
"resume_checkpoint": "checkpoints/objaverse/objaverse-dit/i23d/model_joint_denoise_rec_model2990000.safetensors",
|
34 |
"resume_cldm_checkpoint": "",
|
35 |
"resume_checkpoint_EG3D": "",
|
36 |
"use_fp16": false,
|
nsr/train_util_diffusion.py
CHANGED
@@ -32,6 +32,8 @@ from guided_diffusion.train_util import (TrainLoop, calc_average_loss,
|
|
32 |
parse_resume_step_from_filename)
|
33 |
|
34 |
import dnnlib
|
|
|
|
|
35 |
|
36 |
from nsr.camera_utils import FOV_to_intrinsics, LookAtPoseSampler
|
37 |
|
@@ -758,25 +760,40 @@ class TrainLoopDiffusionWithRec(TrainLoop):
|
|
758 |
model=None,
|
759 |
model_name='ddpm',
|
760 |
resume_checkpoint=None):
|
761 |
-
|
762 |
-
|
763 |
-
|
|
|
|
|
|
|
|
|
764 |
|
765 |
if model is None:
|
766 |
model = self.model
|
767 |
|
768 |
-
if resume_checkpoint and Path(resume_checkpoint).exists():
|
769 |
if dist_util.get_rank() == 0:
|
770 |
# ! rank 0 return will cause all other ranks to hang
|
771 |
-
logger.log(
|
772 |
-
f"loading model from checkpoint: {resume_checkpoint}...")
|
773 |
map_location = {
|
774 |
'cuda:%d' % 0: 'cuda:%d' % dist_util.get_rank()
|
775 |
} # configure map_location properly
|
776 |
-
|
777 |
logger.log(f'mark {model_name} loading ')
|
778 |
-
|
779 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
780 |
logger.log(f'mark {model_name} loading finished')
|
781 |
|
782 |
model_state_dict = model.state_dict()
|
|
|
32 |
parse_resume_step_from_filename)
|
33 |
|
34 |
import dnnlib
|
35 |
+
from safetensors.torch import load_file
|
36 |
+
from huggingface_hub import hf_hub_download
|
37 |
|
38 |
from nsr.camera_utils import FOV_to_intrinsics, LookAtPoseSampler
|
39 |
|
|
|
760 |
model=None,
|
761 |
model_name='ddpm',
|
762 |
resume_checkpoint=None):
|
763 |
+
# load safetensors from hf
|
764 |
+
|
765 |
+
hf_loading = '.safetensors' in self.resume_checkpoint
|
766 |
+
if not hf_loading:
|
767 |
+
if resume_checkpoint is None:
|
768 |
+
resume_checkpoint, self.resume_step = find_resume_checkpoint(
|
769 |
+
self.resume_checkpoint, model_name) or self.resume_checkpoint
|
770 |
|
771 |
if model is None:
|
772 |
model = self.model
|
773 |
|
774 |
+
if hf_loading or (resume_checkpoint and Path(resume_checkpoint).exists()):
|
775 |
if dist_util.get_rank() == 0:
|
776 |
# ! rank 0 return will cause all other ranks to hang
|
|
|
|
|
777 |
map_location = {
|
778 |
'cuda:%d' % 0: 'cuda:%d' % dist_util.get_rank()
|
779 |
} # configure map_location properly
|
|
|
780 |
logger.log(f'mark {model_name} loading ')
|
781 |
+
|
782 |
+
if hf_loading:
|
783 |
+
logger.log(
|
784 |
+
f"loading model from huggingface: yslan/LN3Diff/{self.resume_checkpoint}...")
|
785 |
+
else:
|
786 |
+
logger.log(
|
787 |
+
f"loading model from checkpoint: {resume_checkpoint}...")
|
788 |
+
|
789 |
+
if hf_loading:
|
790 |
+
model_path = hf_hub_download(repo_id="yslan/LN3Diff",
|
791 |
+
filename=self.resume_checkpoint)
|
792 |
+
resume_state_dict = load_file(model_path)
|
793 |
+
else:
|
794 |
+
resume_state_dict = dist_util.load_state_dict(
|
795 |
+
resume_checkpoint, map_location=map_location)
|
796 |
+
|
797 |
logger.log(f'mark {model_name} loading finished')
|
798 |
|
799 |
model_state_dict = model.state_dict()
|