Jordan Legg commited on
Commit
aed3a85
Β·
1 Parent(s): 61a1fb1
Files changed (1) hide show
  1. app.py +29 -11
app.py CHANGED
@@ -15,18 +15,34 @@ MAX_IMAGE_SIZE = 2048
15
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
16
 
17
  @spaces.GPU()
18
- def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
19
  if randomize_seed:
20
  seed = random.randint(0, MAX_SEED)
21
  generator = torch.Generator().manual_seed(seed)
22
- image = pipe(
23
- prompt=prompt,
24
- width=width,
25
- height=height,
26
- num_inference_steps=num_inference_steps,
27
- generator=generator,
28
- guidance_scale=0.0
29
- ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  return image, seed
31
 
32
  # Define example prompts
@@ -91,7 +107,9 @@ with gr.Blocks(css=css) as demo:
91
  )
92
  run_button = gr.Button("Run", scale=0)
93
 
94
- result = gr.Image(label="Result", show_label=False)
 
 
95
 
96
  with gr.Accordion("Advanced Settings", open=False):
97
  seed = gr.Slider(
@@ -139,7 +157,7 @@ with gr.Blocks(css=css) as demo:
139
  gr.on(
140
  triggers=[run_button.click, prompt.submit],
141
  fn=infer,
142
- inputs=[prompt, seed, randomize_seed, width, height, num_inference_steps],
143
  outputs=[result, seed]
144
  )
145
 
 
15
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
16
 
17
  @spaces.GPU()
18
+ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
19
  if randomize_seed:
20
  seed = random.randint(0, MAX_SEED)
21
  generator = torch.Generator().manual_seed(seed)
22
+
23
+ if init_image is not None:
24
+ # Process img2img
25
+ init_image = pipe.preprocess(init_image).unsqueeze(0).to(device, dtype)
26
+ image = pipe(
27
+ prompt=prompt,
28
+ init_image=init_image,
29
+ width=width,
30
+ height=height,
31
+ num_inference_steps=num_inference_steps,
32
+ generator=generator,
33
+ guidance_scale=0.0
34
+ ).images[0]
35
+ else:
36
+ # Process text2img
37
+ image = pipe(
38
+ prompt=prompt,
39
+ width=width,
40
+ height=height,
41
+ num_inference_steps=num_inference_steps,
42
+ generator=generator,
43
+ guidance_scale=0.0
44
+ ).images[0]
45
+
46
  return image, seed
47
 
48
  # Define example prompts
 
107
  )
108
  run_button = gr.Button("Run", scale=0)
109
 
110
+ with gr.Row():
111
+ init_image = gr.Image(label="Initial Image (optional)", type="pil", optional=True)
112
+ result = gr.Image(label="Result", show_label=False)
113
 
114
  with gr.Accordion("Advanced Settings", open=False):
115
  seed = gr.Slider(
 
157
  gr.on(
158
  triggers=[run_button.click, prompt.submit],
159
  fn=infer,
160
+ inputs=[prompt, init_image, seed, randomize_seed, width, height, num_inference_steps],
161
  outputs=[result, seed]
162
  )
163