wjs0725 commited on
Commit
6fb545f
1 Parent(s): 06161fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -157
app.py CHANGED
@@ -48,67 +48,80 @@ class SamplingOptions:
48
 
49
 
50
 
51
-
52
- class FluxEditor:
53
- def __init__(self, args):
54
- self.args = args
55
- self.device = torch.device(args.device)
56
- self.offload = args.offload
57
- self.name = args.name
58
- self.is_schnell = args.name == "flux-schnell"
59
-
60
- self.feature_path = 'feature'
61
- self.output_dir = 'result'
62
- self.add_sampling_metadata = True
63
-
64
- if self.name not in configs:
65
- available = ", ".join(configs.keys())
66
- raise ValueError(f"Got unknown model name: {name}, chose from {available}")
67
-
68
- # init all components
 
 
 
 
 
69
 
70
 
71
- if self.offload:
72
- self.model.cpu()
73
- torch.cuda.empty_cache()
74
- self.ae.encoder.to(self.device)
75
-
76
- @torch.inference_mode()
77
- def encode(self, init_image, torch_device, ae):
78
- init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
79
- init_image = init_image.unsqueeze(0)
80
- init_image = init_image.to(torch_device)
81
- ae = ae.cuda()
82
- with torch.no_grad():
83
- init_image = ae.encode(init_image.to()).to(torch.bfloat16)
84
- return init_image
85
-
86
- @spaces.GPU(duration=120)
87
- @torch.inference_mode()
88
- def edit(self, init_image, source_prompt, target_prompt, num_steps, inject_step, guidance, seed):
89
-
90
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
91
- torch.cuda.empty_cache()
92
- seed = None
 
 
 
 
 
 
 
 
93
  # if seed == -1:
94
  # seed = None
95
 
96
- shape = init_image.shape
97
 
98
- new_h = shape[0] if shape[0] % 16 == 0 else shape[0] - shape[0] % 16
99
- new_w = shape[1] if shape[1] % 16 == 0 else shape[1] - shape[1] % 16
100
 
101
- init_image = init_image[:new_h, :new_w, :]
102
 
103
- width, height = init_image.shape[0], init_image.shape[1]
104
 
105
- self.ae = load_ae(self.name, device="cpu" if self.offload else self.device)
106
- init_image = self.encode(init_image, self.device, self.ae)
107
 
108
- print(init_image.shape)
109
 
110
- rng = torch.Generator(device="cpu")
111
- opts = SamplingOptions(
112
  source_prompt=source_prompt,
113
  target_prompt=target_prompt,
114
  width=width,
@@ -117,121 +130,93 @@ class FluxEditor:
117
  guidance=guidance,
118
  seed=seed,
119
  )
120
- if opts.seed is None:
121
- opts.seed = torch.Generator(device="cpu").seed()
122
 
123
- print(f"Generating with seed {opts.seed}:\n{opts.source_prompt}")
124
- t0 = time.perf_counter()
125
-
126
- opts.seed = None
127
- if self.offload:
128
- self.ae = self.ae.cpu()
129
- torch.cuda.empty_cache()
130
- self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
131
 
132
- #############inverse#######################
133
- info = {}
134
- info['feature'] = {}
135
- info['inject_step'] = inject_step
136
 
137
- if not os.path.exists(self.feature_path):
138
- os.mkdir(self.feature_path)
 
 
139
 
 
 
 
 
140
 
141
- print("!!!!!!!!!!!!device!!!!!!!!!!!!!!",self.device)
142
- self.t5 = load_t5(self.device, max_length=256 if self.name == "flux-schnell" else 512)
143
- self.clip = load_clip(self.device)
144
- self.model = load_flow_model(self.name, device="cpu" if self.offload else self.device)
145
 
146
-
147
- print("!!!!!!!!self.t5!!!!!!",next(self.t5.parameters()).device)
148
- print("!!!!!!!!self.clip!!!!!!",next(self.clip.parameters()).device)
149
- print("!!!!!!!!self.model!!!!!!",next(self.model.parameters()).device)
150
-
151
- device = torch.cuda.current_device()
152
- total_memory = torch.cuda.get_device_properties(device).total_memory
153
- allocated_memory = torch.cuda.memory_allocated(device)
154
- reserved_memory = torch.cuda.memory_reserved(device)
155
-
156
- print(f"Total memory: {total_memory / 1024**2:.2f} MB")
157
- print(f"Allocated memory: {allocated_memory / 1024**2:.2f} MB")
158
- print(f"Reserved memory: {reserved_memory / 1024**2:.2f} MB")
159
-
160
-
161
- with torch.no_grad():
162
- inp = prepare(self.t5, self.clip, init_image, prompt=opts.source_prompt)
163
- inp_target = prepare(self.t5, self.clip, init_image, prompt=opts.target_prompt)
164
- timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(self.name != "flux-schnell"))
165
 
166
- # offload TEs to CPU, load model to gpu
167
- if self.offload:
168
- self.t5, self.clip = self.t5.cpu(), self.clip.cpu()
169
- torch.cuda.empty_cache()
170
- self.model = self.model.to(self.device)
171
 
172
  # inversion initial noise
173
- with torch.no_grad():
174
- z, info = denoise(self.model, **inp, timesteps=timesteps, guidance=1, inverse=True, info=info)
175
 
176
- inp_target["img"] = z
177
 
178
- timesteps = get_schedule(opts.num_steps, inp_target["img"].shape[1], shift=(self.name != "flux-schnell"))
179
 
180
- # denoise initial noise
181
- x, _ = denoise(self.model, **inp_target, timesteps=timesteps, guidance=guidance, inverse=False, info=info)
182
 
183
- # offload model, load autoencoder to gpu
184
- if self.offload:
185
- self.model.cpu()
186
- torch.cuda.empty_cache()
187
- self.ae.decoder.to(x.device)
188
 
189
- # decode latents to pixel space
190
- x = unpack(x.float(), opts.width, opts.height)
191
-
192
- output_name = os.path.join(self.output_dir, "img_{idx}.jpg")
193
- if not os.path.exists(self.output_dir):
194
- os.makedirs(self.output_dir)
195
- idx = 0
 
196
  else:
197
- fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
198
- if len(fns) > 0:
199
- idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
200
- else:
201
- idx = 0
202
-
203
- ae = ae.cuda()
204
- with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
205
- x = self.ae.decode(x)
206
-
207
- if torch.cuda.is_available():
208
- torch.cuda.synchronize()
209
- t1 = time.perf_counter()
210
-
211
- fn = output_name.format(idx=idx)
212
- print(f"Done in {t1 - t0:.1f}s. Saving {fn}")
213
- # bring into PIL format and save
214
- x = x.clamp(-1, 1)
215
- x = embed_watermark(x.float())
216
- x = rearrange(x[0], "c h w -> h w c")
217
-
218
- img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
219
- exif_data = Image.Exif()
220
- exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
221
- exif_data[ExifTags.Base.Make] = "Black Forest Labs"
222
- exif_data[ExifTags.Base.Model] = self.name
223
- if self.add_sampling_metadata:
224
- exif_data[ExifTags.Base.ImageDescription] = source_prompt
225
- img.save(fn, exif=exif_data, quality=95, subsampling=0)
226
 
227
-
228
- print("End Edit")
229
- return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
 
232
 
233
  def create_demo(model_name: str, device: str = "cuda:0" if torch.cuda.is_available() else "cpu", offload: bool = False):
234
- editor = FluxEditor(args)
235
  is_schnell = model_name == "flux-schnell"
236
 
237
  with gr.Blocks() as demo:
@@ -273,7 +258,7 @@ def create_demo(model_name: str, device: str = "cuda:0" if torch.cuda.is_availab
273
  output_image = gr.Image(label="Generated Image")
274
 
275
  generate_btn.click(
276
- fn=editor.edit,
277
  inputs=[init_image, source_prompt, target_prompt, num_steps, inject_step, guidance],
278
  outputs=[output_image]
279
  )
@@ -282,16 +267,16 @@ def create_demo(model_name: str, device: str = "cuda:0" if torch.cuda.is_availab
282
  return demo
283
 
284
 
285
- if __name__ == "__main__":
286
- import argparse
287
- parser = argparse.ArgumentParser(description="Flux")
288
- parser.add_argument("--name", type=str, default="flux-dev", choices=list(configs.keys()), help="Model name")
289
- parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="Device to use")
290
- parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
291
- parser.add_argument("--share", action="store_true", help="Create a public link to your demo")
292
 
293
- parser.add_argument("--port", type=int, default=41035)
294
- args = parser.parse_args()
295
 
296
- demo = create_demo(args.name, args.device)
297
- demo.launch()
 
48
 
49
 
50
 
51
+ offload = False
52
+ name = "flux-dev"
53
+ is_schnell = False
54
+ feature_path = 'feature'
55
+ output_dir = 'result'
56
+ add_sampling_metadata = True
57
+ # class FluxEditor:
58
+ # def __init__(self, args):
59
+ # self.args = args
60
+ # self.device = torch.device(args.device)
61
+ # self.offload = args.offload
62
+ # self.name = args.name
63
+ # self.is_schnell = args.name == "flux-schnell"
64
+
65
+ # self.feature_path = 'feature'
66
+ # self.output_dir = 'result'
67
+ # self.add_sampling_metadata = True
68
+
69
+ # if self.name not in configs:
70
+ # available = ", ".join(configs.keys())
71
+ # raise ValueError(f"Got unknown model name: {name}, chose from {available}")
72
+
73
+ # # init all components
74
 
75
 
76
+ # if self.offload:
77
+ # self.model.cpu()
78
+ # torch.cuda.empty_cache()
79
+ # self.ae.encoder.to(self.device)
80
+ ae = load_ae(name, device="cpu" if offload else device)
81
+ t5 = load_t5(device, max_length=256 if name == "flux-schnell" else 512)
82
+ clip = load_clip(device)
83
+ model = load_flow_model(name, device="cpu" if offload else device)
84
+ print("!!!!!!!!!!!!device!!!!!!!!!!!!!!",device)
85
+ print("!!!!!!!!self.t5!!!!!!",next(t5.parameters()).device)
86
+ print("!!!!!!!!self.clip!!!!!!",next(clip.parameters()).device)
87
+ print("!!!!!!!!self.model!!!!!!",next(model.parameters()).device)
88
+
89
+ @torch.inference_mode()
90
+ def encode(init_image, torch_device, ae):
91
+ init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
92
+ init_image = init_image.unsqueeze(0)
93
+ init_image = init_image.to(torch_device)
94
+ ae = ae.cuda()
95
+ with torch.no_grad():
96
+ init_image = ae.encode(init_image.to()).to(torch.bfloat16)
97
+ return init_image
98
+
99
+ @spaces.GPU(duration=120)
100
+ @torch.inference_mode()
101
+ def edit(init_image, source_prompt, target_prompt, num_steps, inject_step, guidance, seed):
102
+
103
+ device = "cuda" if torch.cuda.is_available() else "cpu"
104
+ torch.cuda.empty_cache()
105
+ seed = None
106
  # if seed == -1:
107
  # seed = None
108
 
109
+ shape = init_image.shape
110
 
111
+ new_h = shape[0] if shape[0] % 16 == 0 else shape[0] - shape[0] % 16
112
+ new_w = shape[1] if shape[1] % 16 == 0 else shape[1] - shape[1] % 16
113
 
114
+ init_image = init_image[:new_h, :new_w, :]
115
 
116
+ width, height = init_image.shape[0], init_image.shape[1]
117
 
118
+
119
+ init_image = encode(init_image, device, ae)
120
 
121
+ print(init_image.shape)
122
 
123
+ rng = torch.Generator(device="cpu")
124
+ opts = SamplingOptions(
125
  source_prompt=source_prompt,
126
  target_prompt=target_prompt,
127
  width=width,
 
130
  guidance=guidance,
131
  seed=seed,
132
  )
133
+ if opts.seed is None:
134
+ opts.seed = torch.Generator(device="cpu").seed()
135
 
136
+ print(f"Generating with seed {opts.seed}:\n{opts.source_prompt}")
137
+ t0 = time.perf_counter()
 
 
 
 
 
 
138
 
139
+ opts.seed = None
 
 
 
140
 
141
+ #############inverse#######################
142
+ info = {}
143
+ info['feature'] = {}
144
+ info['inject_step'] = inject_step
145
 
146
+ print("!!!!!!!!!!!!device!!!!!!!!!!!!!!",device)
147
+ print("!!!!!!!!self.t5!!!!!!",next(t5.parameters()).device)
148
+ print("!!!!!!!!self.clip!!!!!!",next(clip.parameters()).device)
149
+ print("!!!!!!!!self.model!!!!!!",next(model.parameters()).device)
150
 
151
+ device = torch.cuda.current_device()
152
+ total_memory = torch.cuda.get_device_properties(device).total_memory
153
+ allocated_memory = torch.cuda.memory_allocated(device)
154
+ reserved_memory = torch.cuda.memory_reserved(device)
155
 
156
+ print(f"Total memory: {total_memory / 1024**2:.2f} MB")
157
+ print(f"Allocated memory: {allocated_memory / 1024**2:.2f} MB")
158
+ print(f"Reserved memory: {reserved_memory / 1024**2:.2f} MB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
+ with torch.no_grad():
161
+ inp = prepare(t5, clip, init_image, prompt=opts.source_prompt)
162
+ inp_target = prepare(t5, clip, init_image, prompt=opts.target_prompt)
163
+ timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
 
164
 
165
  # inversion initial noise
166
+ with torch.no_grad():
167
+ z, info = denoise(model, **inp, timesteps=timesteps, guidance=1, inverse=True, info=info)
168
 
169
+ inp_target["img"] = z
170
 
171
+ timesteps = get_schedule(opts.num_steps, inp_target["img"].shape[1], shift=(name != "flux-schnell"))
172
 
173
+ # denoise initial noise
174
+ x, _ = denoise(model, **inp_target, timesteps=timesteps, guidance=guidance, inverse=False, info=info)
175
 
176
+ # decode latents to pixel space
177
+ x = unpack(x.float(), opts.width, opts.height)
 
 
 
178
 
179
+ output_name = os.path.join(output_dir, "img_{idx}.jpg")
180
+ if not os.path.exists(output_dir):
181
+ os.makedirs(output_dir)
182
+ idx = 0
183
+ else:
184
+ fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
185
+ if len(fns) > 0:
186
+ idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
187
  else:
188
+ idx = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
+ ae = ae.cuda()
191
+ with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
192
+ x = ae.decode(x)
193
+
194
+ if torch.cuda.is_available():
195
+ torch.cuda.synchronize()
196
+ t1 = time.perf_counter()
197
+
198
+ fn = output_name.format(idx=idx)
199
+ print(f"Done in {t1 - t0:.1f}s. Saving {fn}")
200
+ # bring into PIL format and save
201
+ x = x.clamp(-1, 1)
202
+ x = embed_watermark(x.float())
203
+ x = rearrange(x[0], "c h w -> h w c")
204
+
205
+ img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
206
+ exif_data = Image.Exif()
207
+ exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
208
+ exif_data[ExifTags.Base.Make] = "Black Forest Labs"
209
+ exif_data[ExifTags.Base.Model] = name
210
+ if add_sampling_metadata:
211
+ exif_data[ExifTags.Base.ImageDescription] = source_prompt
212
+ img.save(fn, exif=exif_data, quality=95, subsampling=0)
213
+
214
+ print("End Edit")
215
+ return img
216
 
217
 
218
 
219
  def create_demo(model_name: str, device: str = "cuda:0" if torch.cuda.is_available() else "cpu", offload: bool = False):
 
220
  is_schnell = model_name == "flux-schnell"
221
 
222
  with gr.Blocks() as demo:
 
258
  output_image = gr.Image(label="Generated Image")
259
 
260
  generate_btn.click(
261
+ fn=edit,
262
  inputs=[init_image, source_prompt, target_prompt, num_steps, inject_step, guidance],
263
  outputs=[output_image]
264
  )
 
267
  return demo
268
 
269
 
270
+ # if __name__ == "__main__":
271
+ # import argparse
272
+ # parser = argparse.ArgumentParser(description="Flux")
273
+ # parser.add_argument("--name", type=str, default="flux-dev", choices=list(configs.keys()), help="Model name")
274
+ # parser.add_argument("--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="Device to use")
275
+ # parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
276
+ # parser.add_argument("--share", action="store_true", help="Create a public link to your demo")
277
 
278
+ # parser.add_argument("--port", type=int, default=41035)
279
+ # args = parser.parse_args()
280
 
281
+ demo = create_demo("flux-dev", "cuda")
282
+ demo.launch()