Ashoka74 commited on
Commit
6feb2f9
1 Parent(s): c216749

Update app_2.py

Browse files
Files changed (1) hide show
  1. app_2.py +41 -36
app_2.py CHANGED
@@ -104,41 +104,6 @@ download_models()
104
 
105
 
106
 
107
- @spaces.GPU()
108
- def infer(
109
- prompt,
110
- image,
111
- do_rembg=False,
112
- seed=42,
113
- randomize_seed=False,
114
- guidance_scale=3.0,
115
- num_inference_steps=50,
116
- reference_conditioning_scale=1.0,
117
- negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
118
- progress=gr.Progress(track_tqdm=True),
119
- ):
120
- # if do_rembg:
121
- # remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, device)
122
- # else:
123
- # remove_bg_fn = None
124
- if randomize_seed:
125
- seed = random.randint(0, MAX_SEED)
126
- images, preprocessed_image = run_pipeline(
127
- pipe,
128
- num_views=NUM_VIEWS,
129
- text=prompt,
130
- image=image,
131
- height=HEIGHT,
132
- width=WIDTH,
133
- num_inference_steps=num_inference_steps,
134
- guidance_scale=guidance_scale,
135
- seed=seed,
136
- remove_bg_fn=None,
137
- reference_conditioning_scale=reference_conditioning_scale,
138
- negative_prompt=negative_prompt,
139
- device=device,
140
- )
141
- return images
142
 
143
 
144
  try:
@@ -381,6 +346,45 @@ def encode_prompt_pair(positive_prompt, negative_prompt):
381
 
382
  return c, uc
383
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  @spaces.GPU(duration=60)
385
  @torch.inference_mode()
386
  def pytorch2numpy(imgs, quant=True):
@@ -1228,7 +1232,8 @@ with block:
1228
  example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False)
1229
  example_quick_subjects.click(lambda x: x[0], inputs=example_quick_subjects, outputs=prompt, show_progress=False, queue=False)
1230
 
1231
- run_button.click(fn=infer,
 
1232
  inputs=[
1233
  "high quality",
1234
  extracted_fg,
 
104
 
105
 
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
 
109
  try:
 
346
 
347
  return c, uc
348
 
349
+ @spaces.GPU(duration=60)
350
+ @torch.inference_mode()
351
+ def infer(
352
+ prompt,
353
+ image,
354
+ do_rembg=False,
355
+ seed=42,
356
+ randomize_seed=False,
357
+ guidance_scale=3.0,
358
+ num_inference_steps=50,
359
+ reference_conditioning_scale=1.0,
360
+ negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
361
+ progress=gr.Progress(track_tqdm=True),
362
+ ):
363
+ # if do_rembg:
364
+ # remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, device)
365
+ # else:
366
+ # remove_bg_fn = None
367
+ if randomize_seed:
368
+ seed = random.randint(0, MAX_SEED)
369
+
370
+ images, preprocessed_image = run_pipeline(
371
+ pipe,
372
+ num_views=NUM_VIEWS,
373
+ text=prompt,
374
+ image=image,
375
+ height=HEIGHT,
376
+ width=WIDTH,
377
+ num_inference_steps=num_inference_steps,
378
+ guidance_scale=guidance_scale,
379
+ seed=seed,
380
+ remove_bg_fn=None,
381
+ reference_conditioning_scale=reference_conditioning_scale,
382
+ negative_prompt=negative_prompt,
383
+ device=device,
384
+ )
385
+ return images
386
+
387
+
388
  @spaces.GPU(duration=60)
389
  @torch.inference_mode()
390
  def pytorch2numpy(imgs, quant=True):
 
1232
  example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False)
1233
  example_quick_subjects.click(lambda x: x[0], inputs=example_quick_subjects, outputs=prompt, show_progress=False, queue=False)
1234
 
1235
+ run_button.click(
1236
+ fn=infer,
1237
  inputs=[
1238
  "high quality",
1239
  extracted_fg,