Yardenfren commited on
Commit
496c3d0
1 Parent(s): 9bc5970

Update inf.py

Browse files
Files changed (1) hide show
  1. inf.py +3 -2
inf.py CHANGED
@@ -17,8 +17,7 @@ class InferencePipeline:
17
  def __init__(self, hf_token: str | None = None):
18
  self.hf_token = hf_token
19
  self.base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
20
- self.device = torch.device(
21
- 'cuda:0' if torch.cuda.is_available() else 'cpu')
22
  if self.device.type == 'cpu':
23
  self.pipe = StableDiffusionXLPipeline.from_pretrained(
24
  self.base_model_id, use_auth_token=self.hf_token, cache_dir='./cache')
@@ -109,6 +108,8 @@ class InferencePipeline:
109
  guidance_scale: float,
110
  num_images_per_prompt: int = 1
111
  ) -> PIL.Image.Image:
 
 
112
  if not torch.cuda.is_available():
113
  raise gr.Error('CUDA is not available.')
114
 
 
17
  def __init__(self, hf_token: str | None = None):
18
  self.hf_token = hf_token
19
  self.base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
20
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
21
  if self.device.type == 'cpu':
22
  self.pipe = StableDiffusionXLPipeline.from_pretrained(
23
  self.base_model_id, use_auth_token=self.hf_token, cache_dir='./cache')
 
108
  guidance_scale: float,
109
  num_images_per_prompt: int = 1
110
  ) -> PIL.Image.Image:
111
+
112
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
113
  if not torch.cuda.is_available():
114
  raise gr.Error('CUDA is not available.')
115