wjs0725 commited on
Commit
1b839c2
1 Parent(s): 49d340b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -52
app.py CHANGED
@@ -22,29 +22,8 @@ from flux.util import (configs, embed_watermark, load_ae, load_clip, load_flow_m
22
  from huggingface_hub import login
23
  login(token=os.getenv('Token'))
24
 
25
-
26
  import torch
27
 
28
- # device = torch.cuda.current_device()
29
- # print("!!!!!!!!!!!!device!!!!!!!!!!!!!!",device)
30
- # total_memory = torch.cuda.get_device_properties(device).total_memory
31
- # allocated_memory = torch.cuda.memory_allocated(device)
32
- # reserved_memory = torch.cuda.memory_reserved(device)
33
-
34
- # print(f"Total memory: {total_memory / 1024**2:.2f} MB")
35
- # print(f"Allocated memory: {allocated_memory / 1024**2:.2f} MB")
36
- # print(f"Reserved memory: {reserved_memory / 1024**2:.2f} MB")
37
-
38
- device = "cuda" if torch.cuda.is_available() else "cpu"
39
- name = 'flux-dev'
40
- ae = load_ae(name, device)
41
- t5 = load_t5(device, max_length=256 if name == "flux-schnell" else 512)
42
- clip = load_clip(device)
43
- model = load_flow_model(name, device=device)
44
- print("!!!!!!!!!!!!device!!!!!!!!!!!!!!",device)
45
- print("!!!!!!!!self.t5!!!!!!",next(t5.parameters()).device)
46
- print("!!!!!!!!self.clip!!!!!!",next(clip.parameters()).device)
47
- print("!!!!!!!!self.model!!!!!!",next(model.parameters()).device)
48
 
49
  @dataclass
50
  class SamplingOptions:
@@ -57,27 +36,29 @@ class SamplingOptions:
57
  guidance: float
58
  seed: int | None
59
 
60
-
61
-
62
- offload = False
63
- name = "flux-dev"
64
- is_schnell = False
65
- feature_path = 'feature'
66
- output_dir = 'result'
67
- add_sampling_metadata = True
68
-
69
-
70
-
71
  @torch.inference_mode()
72
  def encode(init_image, torch_device):
73
  init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
74
  init_image = init_image.unsqueeze(0)
75
  init_image = init_image.to(torch_device)
76
- ae = ae.cuda()
77
  with torch.no_grad():
78
  init_image = ae.encode(init_image.to()).to(torch.bfloat16)
79
  return init_image
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  @spaces.GPU(duration=120)
82
  @torch.inference_mode()
83
  def edit(init_image, source_prompt, target_prompt, num_steps, inject_step, guidance, seed):
@@ -85,8 +66,6 @@ def edit(init_image, source_prompt, target_prompt, num_steps, inject_step, guida
85
  device = "cuda" if torch.cuda.is_available() else "cpu"
86
  torch.cuda.empty_cache()
87
  seed = None
88
- # if seed == -1:
89
- # seed = None
90
 
91
  shape = init_image.shape
92
 
@@ -97,8 +76,12 @@ def edit(init_image, source_prompt, target_prompt, num_steps, inject_step, guida
97
 
98
  width, height = init_image.shape[0], init_image.shape[1]
99
 
100
-
101
- init_image = encode(init_image, device)
 
 
 
 
102
 
103
  print(init_image.shape)
104
 
@@ -125,26 +108,12 @@ def edit(init_image, source_prompt, target_prompt, num_steps, inject_step, guida
125
  info['feature'] = {}
126
  info['inject_step'] = inject_step
127
 
128
- print("!!!!!!!!!!!!device!!!!!!!!!!!!!!",device)
129
- print("!!!!!!!!self.t5!!!!!!",next(t5.parameters()).device)
130
- print("!!!!!!!!self.clip!!!!!!",next(clip.parameters()).device)
131
- print("!!!!!!!!self.model!!!!!!",next(model.parameters()).device)
132
-
133
- # device = torch.cuda.current_device()
134
- # total_memory = torch.cuda.get_device_properties(device).total_memory
135
- # allocated_memory = torch.cuda.memory_allocated(device)
136
- # reserved_memory = torch.cuda.memory_reserved(device)
137
-
138
- # print(f"Total memory: {total_memory / 1024**2:.2f} MB")
139
- # print(f"Allocated memory: {allocated_memory / 1024**2:.2f} MB")
140
- # print(f"Reserved memory: {reserved_memory / 1024**2:.2f} MB")
141
-
142
  with torch.no_grad():
143
  inp = prepare(t5, clip, init_image, prompt=opts.source_prompt)
144
  inp_target = prepare(t5, clip, init_image, prompt=opts.target_prompt)
145
  timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
146
 
147
- # inversion initial noise
148
  with torch.no_grad():
149
  z, info = denoise(model, **inp, timesteps=timesteps, guidance=1, inverse=True, info=info)
150
 
 
22
  from huggingface_hub import login
23
  login(token=os.getenv('Token'))
24
 
 
25
  import torch
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  @dataclass
29
  class SamplingOptions:
 
36
  guidance: float
37
  seed: int | None
38
 
 
 
 
 
 
 
 
 
 
 
 
39
  @torch.inference_mode()
40
  def encode(init_image, torch_device):
41
  init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
42
  init_image = init_image.unsqueeze(0)
43
  init_image = init_image.to(torch_device)
 
44
  with torch.no_grad():
45
  init_image = ae.encode(init_image.to()).to(torch.bfloat16)
46
  return init_image
47
 
48
+
49
+ device = "cuda" if torch.cuda.is_available() else "cpu"
50
+ name = 'flux-dev'
51
+ ae = load_ae(name, device)
52
+ t5 = load_t5(device, max_length=256 if name == "flux-schnell" else 512)
53
+ clip = load_clip(device)
54
+ model = load_flow_model(name, device=device)
55
+ offload = False
56
+ name = "flux-dev"
57
+ is_schnell = False
58
+ feature_path = 'feature'
59
+ output_dir = 'result'
60
+ add_sampling_metadata = True
61
+
62
  @spaces.GPU(duration=120)
63
  @torch.inference_mode()
64
  def edit(init_image, source_prompt, target_prompt, num_steps, inject_step, guidance, seed):
 
66
  device = "cuda" if torch.cuda.is_available() else "cpu"
67
  torch.cuda.empty_cache()
68
  seed = None
 
 
69
 
70
  shape = init_image.shape
71
 
 
76
 
77
  width, height = init_image.shape[0], init_image.shape[1]
78
 
79
+
80
+ init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
81
+ init_image = init_image.unsqueeze(0)
82
+ init_image = init_image.to(device)
83
+ with torch.no_grad():
84
+ init_image = ae.encode(init_image.to()).to(torch.bfloat16)
85
 
86
  print(init_image.shape)
87
 
 
108
  info['feature'] = {}
109
  info['inject_step'] = inject_step
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  with torch.no_grad():
112
  inp = prepare(t5, clip, init_image, prompt=opts.source_prompt)
113
  inp_target = prepare(t5, clip, init_image, prompt=opts.target_prompt)
114
  timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
115
 
116
+ # inversion initial noise
117
  with torch.no_grad():
118
  z, info = denoise(model, **inp, timesteps=timesteps, guidance=1, inverse=True, info=info)
119