NIRVANALAN commited on
Commit
1fa8ce9
1 Parent(s): c3a2df4
configs/i23d_args.json CHANGED
@@ -30,7 +30,7 @@
30
  "log_interval": 50,
31
  "eval_interval": 5000,
32
  "save_interval": 10000,
33
- "resume_checkpoint": "/nas/shared/V2V/yslan/logs/nips24/LSGM/t23d/FM/9cls/i23d/dit-L2-pixart-lognorm-rmsnorm-layernorm_before_pooled/gpu7-batch40-lr1e-4-bf16-qknorm-ctd3/model_joint_denoise_rec_model2990000.pt",
34
  "resume_cldm_checkpoint": "",
35
  "resume_checkpoint_EG3D": "",
36
  "use_fp16": false,
 
30
  "log_interval": 50,
31
  "eval_interval": 5000,
32
  "save_interval": 10000,
33
+ "resume_checkpoint": "checkpoints/objaverse/objaverse-dit/i23d/model_joint_denoise_rec_model2990000.safetensors",
34
  "resume_cldm_checkpoint": "",
35
  "resume_checkpoint_EG3D": "",
36
  "use_fp16": false,
nsr/train_util_diffusion.py CHANGED
@@ -32,6 +32,8 @@ from guided_diffusion.train_util import (TrainLoop, calc_average_loss,
32
  parse_resume_step_from_filename)
33
 
34
  import dnnlib
 
 
35
 
36
  from nsr.camera_utils import FOV_to_intrinsics, LookAtPoseSampler
37
 
@@ -758,25 +760,40 @@ class TrainLoopDiffusionWithRec(TrainLoop):
758
  model=None,
759
  model_name='ddpm',
760
  resume_checkpoint=None):
761
- if resume_checkpoint is None:
762
- resume_checkpoint, self.resume_step = find_resume_checkpoint(
763
- self.resume_checkpoint, model_name) or self.resume_checkpoint
 
 
 
 
764
 
765
  if model is None:
766
  model = self.model
767
 
768
- if resume_checkpoint and Path(resume_checkpoint).exists():
769
  if dist_util.get_rank() == 0:
770
  # ! rank 0 return will cause all other ranks to hang
771
- logger.log(
772
- f"loading model from checkpoint: {resume_checkpoint}...")
773
  map_location = {
774
  'cuda:%d' % 0: 'cuda:%d' % dist_util.get_rank()
775
  } # configure map_location properly
776
-
777
  logger.log(f'mark {model_name} loading ')
778
- resume_state_dict = dist_util.load_state_dict(
779
- resume_checkpoint, map_location=map_location)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
780
  logger.log(f'mark {model_name} loading finished')
781
 
782
  model_state_dict = model.state_dict()
 
32
  parse_resume_step_from_filename)
33
 
34
  import dnnlib
35
+ from safetensors.torch import load_file
36
+ from huggingface_hub import hf_hub_download
37
 
38
  from nsr.camera_utils import FOV_to_intrinsics, LookAtPoseSampler
39
 
 
760
  model=None,
761
  model_name='ddpm',
762
  resume_checkpoint=None):
763
+ # load safetensors from hf
764
+
765
+ hf_loading = '.safetensors' in self.resume_checkpoint
766
+ if not hf_loading:
767
+ if resume_checkpoint is None:
768
+ resume_checkpoint, self.resume_step = find_resume_checkpoint(
769
+ self.resume_checkpoint, model_name) or self.resume_checkpoint
770
 
771
  if model is None:
772
  model = self.model
773
 
774
+ if hf_loading or (resume_checkpoint and Path(resume_checkpoint).exists()):
775
  if dist_util.get_rank() == 0:
776
  # ! rank 0 return will cause all other ranks to hang
 
 
777
  map_location = {
778
  'cuda:%d' % 0: 'cuda:%d' % dist_util.get_rank()
779
  } # configure map_location properly
 
780
  logger.log(f'mark {model_name} loading ')
781
+
782
+ if hf_loading:
783
+ logger.log(
784
+ f"loading model from huggingface: yslan/LN3Diff/{self.resume_checkpoint}...")
785
+ else:
786
+ logger.log(
787
+ f"loading model from checkpoint: {resume_checkpoint}...")
788
+
789
+ if hf_loading:
790
+ model_path = hf_hub_download(repo_id="yslan/LN3Diff",
791
+ filename=self.resume_checkpoint)
792
+ resume_state_dict = load_file(model_path)
793
+ else:
794
+ resume_state_dict = dist_util.load_state_dict(
795
+ resume_checkpoint, map_location=map_location)
796
+
797
  logger.log(f'mark {model_name} loading finished')
798
 
799
  model_state_dict = model.state_dict()