realantonvoronov commited on
Commit
bfb7c0b
1 Parent(s): 5487649

pass device to swittihf init

Browse files
Files changed (1) hide show
  1. models/pipeline.py +1 -1
models/pipeline.py CHANGED
@@ -28,7 +28,7 @@ class SwittiPipeline:
28
 
29
  @classmethod
30
  def from_pretrained(cls, pretrained_model_name_or_path, device="cuda"):
31
- switti = SwittiHF.from_pretrained(pretrained_model_name_or_path).to(device)
32
  vae = VQVAEHF.from_pretrained(cls.vae_path).to(device)
33
  text_encoder = FrozenCLIPEmbedder(cls.text_encoder_path, device=device)
34
  text_encoder_2 = FrozenCLIPEmbedder(cls.text_encoder_2_path, device=device)
 
28
 
29
  @classmethod
30
  def from_pretrained(cls, pretrained_model_name_or_path, device="cuda"):
31
+ switti = SwittiHF.from_pretrained(pretrained_model_name_or_path, device=device).to(device)
32
  vae = VQVAEHF.from_pretrained(cls.vae_path).to(device)
33
  text_encoder = FrozenCLIPEmbedder(cls.text_encoder_path, device=device)
34
  text_encoder_2 = FrozenCLIPEmbedder(cls.text_encoder_2_path, device=device)