Spaces:
Runtime error
Runtime error
Update merged_files3.py
Browse files- merged_files3.py +2 -3
merged_files3.py
CHANGED
@@ -492,11 +492,10 @@ device = torch.device('cuda')
|
|
492 |
dtype = torch.float16 # Use float16 consistently for all models
|
493 |
|
494 |
|
495 |
-
sd_offset = sf.load_file(model_path
|
496 |
sd_origin = unet.state_dict()
|
497 |
keys = sd_origin.keys()
|
498 |
-
|
499 |
-
unet.load_state_dict(sd_merged, strict=True)
|
500 |
del sd_offset, sd_origin, sd_merged, keys
|
501 |
|
502 |
|
|
|
492 |
dtype = torch.float16 # Use float16 consistently for all models
|
493 |
|
494 |
|
495 |
+
sd_offset = sf.load_file(model_path) # Use device variable
|
496 |
sd_origin = unet.state_dict()
|
497 |
keys = sd_origin.keys()
|
498 |
+
sd_offset = {k: v.to(device) for k, v in sd_offset.items()} # Move each tensor to GPUunet.load_state_dict(sd_merged, strict=True)
|
|
|
499 |
del sd_offset, sd_origin, sd_merged, keys
|
500 |
|
501 |
|