Ashoka74 commited on
Commit
291f99d
ยท
verified ยท
1 Parent(s): d93491f

Update merged_files3.py

Browse files
Files changed (1) hide show
  1. merged_files3.py +2 -3
merged_files3.py CHANGED
@@ -492,11 +492,10 @@ device = torch.device('cuda')
492
  dtype = torch.float16 # Use float16 consistently for all models
493
 
494
 
495
- sd_offset = sf.load_file(model_path, device=device) # Use device variable
496
  sd_origin = unet.state_dict()
497
  keys = sd_origin.keys()
498
- sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()}
499
- unet.load_state_dict(sd_merged, strict=True)
500
  del sd_offset, sd_origin, sd_merged, keys
501
 
502
 
 
492
  dtype = torch.float16 # Use float16 consistently for all models
493
 
494
 
495
+ sd_offset = sf.load_file(model_path) # Use device variable
496
  sd_origin = unet.state_dict()
497
  keys = sd_origin.keys()
498
+ sd_offset = {k: v.to(device) for k, v in sd_offset.items()} # Move each tensor to GPUunet.load_state_dict(sd_merged, strict=True)
 
499
  del sd_offset, sd_origin, sd_merged, keys
500
 
501