realantonvoronov commited on
Commit
a42955d
1 Parent(s): 8b9fcd0

fix not working apply_late_temperature

Browse files
Files changed (2) hide show
  1. app.py +2 -2
  2. models/pipeline.py +3 -0
app.py CHANGED
@@ -48,6 +48,7 @@ def infer(
48
  turn_off_cfg_start_si=turn_off_cfg_start_si,
49
  turn_on_cfg_start_si=turn_on_cfg_start_si,
50
  seed=seed,
 
51
  last_scale_temp=last_scale_temp,
52
  )[0]
53
 
@@ -163,8 +164,6 @@ with gr.Blocks(css=css) as demo:
163
  step=0.1,
164
  value=1,
165
  )
166
- if not apply_late_temperature:
167
- last_scale_temp = None
168
 
169
 
170
  gr.Examples(examples=examples, inputs=[prompt], outputs=[result, seed], fn=infer, cache_examples=True)# cache_mode="lazy")
@@ -183,6 +182,7 @@ with gr.Blocks(css=css) as demo:
183
  smooth_start_si,
184
  turn_off_cfg_start_si,
185
  more_diverse,
 
186
  last_scale_temp,
187
  ],
188
  outputs=[result, seed],
 
48
  turn_off_cfg_start_si=turn_off_cfg_start_si,
49
  turn_on_cfg_start_si=turn_on_cfg_start_si,
50
  seed=seed,
51
+ apply_late_temperature=apply_late_temperature,
52
  last_scale_temp=last_scale_temp,
53
  )[0]
54
 
 
164
  step=0.1,
165
  value=1,
166
  )
 
 
167
 
168
 
169
  gr.Examples(examples=examples, inputs=[prompt], outputs=[result, seed], fn=infer, cache_examples=True)# cache_mode="lazy")
 
182
  smooth_start_si,
183
  turn_off_cfg_start_si,
184
  more_diverse,
185
+ apply_late_temperature,
186
  last_scale_temp,
187
  ],
188
  outputs=[result, seed],
models/pipeline.py CHANGED
@@ -93,6 +93,7 @@ class SwittiPipeline:
93
  turn_off_cfg_start_si: int = 10,
94
  turn_on_cfg_start_si: int = 0,
95
  image_size: tuple[int, int] = (512, 512),
 
96
  last_scale_temp: None | float = None,
97
  ) -> torch.Tensor | list[PILImage]:
98
  """
@@ -106,6 +107,8 @@ class SwittiPipeline:
106
  :param more_smooth: sampling using gumbel softmax; only used in visualization, not used in FID/IS benchmarking
107
  :return: if return_pil: list of PIL Images, else: torch.tensor (B, 3, H, W) in [0, 1]
108
  """
 
 
109
  assert not self.switti.training
110
  switti = self.switti
111
  vae = self.vae
 
93
  turn_off_cfg_start_si: int = 10,
94
  turn_on_cfg_start_si: int = 0,
95
  image_size: tuple[int, int] = (512, 512),
96
+ apply_late_temperature: bool = False,
97
  last_scale_temp: None | float = None,
98
  ) -> torch.Tensor | list[PILImage]:
99
  """
 
107
  :param more_smooth: sampling using gumbel softmax; only used in visualization, not used in FID/IS benchmarking
108
  :return: if return_pil: list of PIL Images, else: torch.tensor (B, 3, H, W) in [0, 1]
109
  """
110
+ if not apply_late_temperature:
111
+ last_scale_temp = None
112
  assert not self.switti.training
113
  switti = self.switti
114
  vae = self.vae