kamwoh commited on
Commit
f44c040
1 Parent(s): 617065a

fixed cpu error

Browse files
Files changed (1) hide show
  1. app.py +12 -8
app.py CHANGED
@@ -80,8 +80,11 @@ def prepare_pipeline(model_name):
80
  if 'dpo' in OUTPUT_DIR:
81
  args.unet_path = "mhdang/dpo-sd1.5-text2image-v1"
82
 
83
- pipe = load_pipeline(args, torch.float16, 'cuda')
84
- pipe = pipe.to(torch.float16)
 
 
 
85
 
86
  pipe.verbose = True
87
  pipe.v = 're'
@@ -116,7 +119,7 @@ def prepare_pipeline(model_name):
116
  ID2NAME = open('data/dogs/class_names.txt').readlines()
117
  ID2NAME = [line.strip() for line in ID2NAME]
118
 
119
- return pipe, MAPPING, ID2NAME
120
 
121
 
122
  def download_file(url, local_path):
@@ -159,11 +162,11 @@ def process_text(text, MAPPING, ID2NAME):
159
 
160
 
161
  def generate_images(model_name, prompt, negative_prompt, num_inference_steps, guidance_scale, num_images, seed):
162
- generator = torch.Generator(device='cuda')
163
- generator = generator.manual_seed(int(seed))
164
-
165
  try:
166
- pipe, MAPPING, ID2NAME = prepare_pipeline(model_name)
 
 
 
167
 
168
  prompt, part2id = process_text(prompt, MAPPING, ID2NAME)
169
  negative_prompt, _ = process_text(negative_prompt, MAPPING, ID2NAME)
@@ -179,7 +182,8 @@ def generate_images(model_name, prompt, negative_prompt, num_inference_steps, gu
179
  f"The error message: {e}")
180
  finally:
181
  gc.collect()
182
- torch.cuda.empty_cache()
 
183
 
184
  return images, '; '.join(part2id)
185
 
 
80
  if 'dpo' in OUTPUT_DIR:
81
  args.unet_path = "mhdang/dpo-sd1.5-text2image-v1"
82
 
83
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
84
+ weight_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
85
+
86
+ pipe = load_pipeline(args, weight_dtype, device)
87
+ pipe = pipe.to(weight_dtype)
88
 
89
  pipe.verbose = True
90
  pipe.v = 're'
 
119
  ID2NAME = open('data/dogs/class_names.txt').readlines()
120
  ID2NAME = [line.strip() for line in ID2NAME]
121
 
122
+ return pipe, MAPPING, ID2NAME, device
123
 
124
 
125
  def download_file(url, local_path):
 
162
 
163
 
164
  def generate_images(model_name, prompt, negative_prompt, num_inference_steps, guidance_scale, num_images, seed):
 
 
 
165
  try:
166
+ pipe, MAPPING, ID2NAME, device = prepare_pipeline(model_name)
167
+
168
+ generator = torch.Generator(device=device)
169
+ generator = generator.manual_seed(int(seed))
170
 
171
  prompt, part2id = process_text(prompt, MAPPING, ID2NAME)
172
  negative_prompt, _ = process_text(negative_prompt, MAPPING, ID2NAME)
 
182
  f"The error message: {e}")
183
  finally:
184
  gc.collect()
185
+ if torch.cuda.is_available():
186
+ torch.cuda.empty_cache()
187
 
188
  return images, '; '.join(part2id)
189