Ashoka74 commited on
Commit
a74b98c
ยท
verified ยท
1 Parent(s): 291f99d

Update merged_files3.py

Browse files
Files changed (1) hide show
  1. merged_files3.py +1 -1
merged_files3.py CHANGED
@@ -495,7 +495,7 @@ dtype = torch.float16 # Use float16 consistently for all models
495
  sd_offset = sf.load_file(model_path) # Use device variable
496
  sd_origin = unet.state_dict()
497
  keys = sd_origin.keys()
498
- sd_offset = {k: v.to(device) for k, v in sd_offset.items()} # Move each tensor to GPUunet.load_state_dict(sd_merged, strict=True)
499
  del sd_offset, sd_origin, sd_merged, keys
500
 
501
 
 
495
  sd_offset = sf.load_file(model_path) # Use device variable
496
  sd_origin = unet.state_dict()
497
  keys = sd_origin.keys()
498
+ sd_merged = {k: v.to(device) for k, v in sd_offset.items()} # Move each tensor to GPUunet.load_state_dict(sd_merged, strict=True)
499
  del sd_offset, sd_origin, sd_merged, keys
500
 
501