wjs0725 commited on
Commit
fbd66df
1 Parent(s): fd38fd2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -10
app.py CHANGED
@@ -46,17 +46,9 @@ class SamplingOptions:
46
  guidance: float
47
  seed: int | None
48
 
49
- @torch.inference_mode()
50
- def encode(init_image, torch_device, ae):
51
- init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
52
- init_image = init_image.unsqueeze(0)
53
- init_image = init_image.to(torch_device)
54
- ae = ae.cuda()
55
- with torch.no_grad():
56
- init_image = ae.encode(init_image.to()).to(torch.bfloat16)
57
- return init_image
58
 
59
 
 
60
  class FluxEditor:
61
  def __init__(self, args):
62
  self.args = args
@@ -87,6 +79,16 @@ class FluxEditor:
87
  self.model.cpu()
88
  torch.cuda.empty_cache()
89
  self.ae.encoder.to(self.device)
 
 
 
 
 
 
 
 
 
 
90
 
91
  @torch.inference_mode()
92
  def edit(self, init_image, source_prompt, target_prompt, num_steps, inject_step, guidance, seed):
@@ -103,7 +105,7 @@ class FluxEditor:
103
  init_image = init_image[:new_h, :new_w, :]
104
 
105
  width, height = init_image.shape[0], init_image.shape[1]
106
- init_image = encode(init_image, self.device, self.ae)
107
 
108
  print(init_image.shape)
109
 
 
46
  guidance: float
47
  seed: int | None
48
 
 
 
 
 
 
 
 
 
 
49
 
50
 
51
+ @spaces.GPU(duration=30)
52
  class FluxEditor:
53
  def __init__(self, args):
54
  self.args = args
 
79
  self.model.cpu()
80
  torch.cuda.empty_cache()
81
  self.ae.encoder.to(self.device)
82
+
83
+ @torch.inference_mode()
84
+ def encode(init_image, torch_device, ae):
85
+ init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 127.5 - 1
86
+ init_image = init_image.unsqueeze(0)
87
+ init_image = init_image.to(torch_device)
88
+ ae = ae.cuda()
89
+ with torch.no_grad():
90
+ init_image = ae.encode(init_image.to()).to(torch.bfloat16)
91
+ return init_image
92
 
93
  @torch.inference_mode()
94
  def edit(self, init_image, source_prompt, target_prompt, num_steps, inject_step, guidance, seed):
 
105
  init_image = init_image[:new_h, :new_w, :]
106
 
107
  width, height = init_image.shape[0], init_image.shape[1]
108
+ init_image = self.encode(init_image, self.device, self.ae)
109
 
110
  print(init_image.shape)
111