Spaces:
Running
on
Zero
Running
on
Zero
realantonvoronov
commited on
Commit
•
b5f551b
1
Parent(s):
be0617b
remove apply_late_temperature checkbox
Browse files- app.py +3 -8
- models/pipeline.py +2 -5
app.py
CHANGED
@@ -28,8 +28,7 @@ def infer(
|
|
28 |
smooth_start_si=2,
|
29 |
turn_off_cfg_start_si=10,
|
30 |
more_diverse=True,
|
31 |
-
|
32 |
-
last_scale_temp=None,
|
33 |
progress=gr.Progress(track_tqdm=True),
|
34 |
):
|
35 |
if randomize_seed:
|
@@ -49,7 +48,6 @@ def infer(
|
|
49 |
turn_off_cfg_start_si=turn_off_cfg_start_si,
|
50 |
turn_on_cfg_start_si=turn_on_cfg_start_si,
|
51 |
seed=seed,
|
52 |
-
apply_late_temperature=apply_late_temperature,
|
53 |
last_scale_temp=last_scale_temp,
|
54 |
)[0]
|
55 |
|
@@ -141,17 +139,15 @@ with gr.Blocks(css=css) as demo:
|
|
141 |
)
|
142 |
with gr.Row():
|
143 |
more_diverse = gr.Checkbox(label="More diverse", value=True)
|
144 |
-
apply_late_temperature = gr.Checkbox(label="Temperature after disabling CFG", value=False)
|
145 |
last_scale_temp = gr.Slider(
|
146 |
-
label="
|
147 |
minimum=0.1,
|
148 |
maximum=10,
|
149 |
step=0.1,
|
150 |
value=0.1,
|
151 |
)
|
152 |
|
153 |
-
|
154 |
-
gr.Examples(examples=examples, inputs=[prompt], outputs=[result, seed], fn=infer, cache_examples=True)# cache_mode="lazy")
|
155 |
gr.on(
|
156 |
triggers=[run_button.click, prompt.submit],
|
157 |
fn=infer,
|
@@ -167,7 +163,6 @@ with gr.Blocks(css=css) as demo:
|
|
167 |
smooth_start_si,
|
168 |
turn_off_cfg_start_si,
|
169 |
more_diverse,
|
170 |
-
apply_late_temperature,
|
171 |
last_scale_temp,
|
172 |
],
|
173 |
outputs=[result, seed],
|
|
|
28 |
smooth_start_si=2,
|
29 |
turn_off_cfg_start_si=10,
|
30 |
more_diverse=True,
|
31 |
+
last_scale_temp=1,
|
|
|
32 |
progress=gr.Progress(track_tqdm=True),
|
33 |
):
|
34 |
if randomize_seed:
|
|
|
48 |
turn_off_cfg_start_si=turn_off_cfg_start_si,
|
49 |
turn_on_cfg_start_si=turn_on_cfg_start_si,
|
50 |
seed=seed,
|
|
|
51 |
last_scale_temp=last_scale_temp,
|
52 |
)[0]
|
53 |
|
|
|
139 |
)
|
140 |
with gr.Row():
|
141 |
more_diverse = gr.Checkbox(label="More diverse", value=True)
|
|
|
142 |
last_scale_temp = gr.Slider(
|
143 |
+
label="Temperature after disabling CFG",
|
144 |
minimum=0.1,
|
145 |
maximum=10,
|
146 |
step=0.1,
|
147 |
value=0.1,
|
148 |
)
|
149 |
|
150 |
+
gr.Examples(examples=examples, inputs=[prompt], outputs=[result, seed], fn=infer, cache_examples=False)# cache_mode="lazy")
|
|
|
151 |
gr.on(
|
152 |
triggers=[run_button.click, prompt.submit],
|
153 |
fn=infer,
|
|
|
163 |
smooth_start_si,
|
164 |
turn_off_cfg_start_si,
|
165 |
more_diverse,
|
|
|
166 |
last_scale_temp,
|
167 |
],
|
168 |
outputs=[result, seed],
|
models/pipeline.py
CHANGED
@@ -93,8 +93,7 @@ class SwittiPipeline:
|
|
93 |
turn_off_cfg_start_si: int = 10,
|
94 |
turn_on_cfg_start_si: int = 0,
|
95 |
image_size: tuple[int, int] = (512, 512),
|
96 |
-
|
97 |
-
last_scale_temp: None | float = None,
|
98 |
) -> torch.Tensor | list[PILImage]:
|
99 |
"""
|
100 |
only used for inference, on autoregressive mode
|
@@ -107,8 +106,6 @@ class SwittiPipeline:
|
|
107 |
:param more_smooth: sampling using gumbel softmax; only used in visualization, not used in FID/IS benchmarking
|
108 |
:return: if return_pil: list of PIL Images, else: torch.tensor (B, 3, H, W) in [0, 1]
|
109 |
"""
|
110 |
-
if not apply_late_temperature:
|
111 |
-
last_scale_temp = None
|
112 |
assert not self.switti.training
|
113 |
switti = self.switti
|
114 |
vae = self.vae
|
@@ -200,7 +197,7 @@ class SwittiPipeline:
|
|
200 |
# default const cfg
|
201 |
t = cfg
|
202 |
logits_BlV = (1 + t) * logits_BlV[:B] - t * logits_BlV[B:]
|
203 |
-
|
204 |
logits_BlV = logits_BlV / last_scale_temp
|
205 |
|
206 |
if apply_smooth and si >= smooth_start_si:
|
|
|
93 |
turn_off_cfg_start_si: int = 10,
|
94 |
turn_on_cfg_start_si: int = 0,
|
95 |
image_size: tuple[int, int] = (512, 512),
|
96 |
+
last_scale_temp: float = 1.,
|
|
|
97 |
) -> torch.Tensor | list[PILImage]:
|
98 |
"""
|
99 |
only used for inference, on autoregressive mode
|
|
|
106 |
:param more_smooth: sampling using gumbel softmax; only used in visualization, not used in FID/IS benchmarking
|
107 |
:return: if return_pil: list of PIL Images, else: torch.tensor (B, 3, H, W) in [0, 1]
|
108 |
"""
|
|
|
|
|
109 |
assert not self.switti.training
|
110 |
switti = self.switti
|
111 |
vae = self.vae
|
|
|
197 |
# default const cfg
|
198 |
t = cfg
|
199 |
logits_BlV = (1 + t) * logits_BlV[:B] - t * logits_BlV[B:]
|
200 |
+
else:
|
201 |
logits_BlV = logits_BlV / last_scale_temp
|
202 |
|
203 |
if apply_smooth and si >= smooth_start_si:
|