Update
Browse files
app.py
CHANGED
@@ -82,12 +82,8 @@ def main():
|
|
82 |
value=0.7,
|
83 |
label='Truncation psi')
|
84 |
truncation_type = gr.Dropdown(
|
85 |
-
|
86 |
-
|
87 |
-
'Multimodal (L2)',
|
88 |
-
'Global',
|
89 |
-
],
|
90 |
-
value='Multimodal (LPIPS)',
|
91 |
label='Truncation Type')
|
92 |
run_button = gr.Button('Run')
|
93 |
with gr.Column():
|
|
|
82 |
value=0.7,
|
83 |
label='Truncation psi')
|
84 |
truncation_type = gr.Dropdown(
|
85 |
+
model.TRUNCATION_TYPES,
|
86 |
+
value=model.TRUNCATION_TYPES[0],
|
|
|
|
|
|
|
|
|
87 |
label='Truncation Type')
|
88 |
run_button = gr.Button('Run')
|
89 |
with gr.Column():
|
model.py
CHANGED
@@ -54,6 +54,11 @@ class Model:
|
|
54 |
'giraffes_512',
|
55 |
'parrots_512',
|
56 |
]
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
def __init__(self, device: str | torch.device):
|
59 |
self.device = torch.device(device)
|
@@ -193,12 +198,12 @@ class Model:
|
|
193 |
truncation_type: str) -> np.ndarray:
|
194 |
z = self.generate_z(seed)
|
195 |
ws = self.compute_w(z)
|
196 |
-
if truncation_type ==
|
197 |
w0 = self.model.mapping.w_avg
|
198 |
else:
|
199 |
-
if truncation_type ==
|
200 |
distance_type = 'lpips'
|
201 |
-
elif truncation_type ==
|
202 |
distance_type = 'l2'
|
203 |
else:
|
204 |
raise ValueError
|
|
|
54 |
'giraffes_512',
|
55 |
'parrots_512',
|
56 |
]
|
57 |
+
TRUNCATION_TYPES = [
|
58 |
+
'Multimodal (LPIPS)',
|
59 |
+
'Multimodal (L2)',
|
60 |
+
'Global',
|
61 |
+
]
|
62 |
|
63 |
def __init__(self, device: str | torch.device):
|
64 |
self.device = torch.device(device)
|
|
|
198 |
truncation_type: str) -> np.ndarray:
|
199 |
z = self.generate_z(seed)
|
200 |
ws = self.compute_w(z)
|
201 |
+
if truncation_type == self.TRUNCATION_TYPES[2]:
|
202 |
w0 = self.model.mapping.w_avg
|
203 |
else:
|
204 |
+
if truncation_type == self.TRUNCATION_TYPES[0]:
|
205 |
distance_type = 'lpips'
|
206 |
+
elif truncation_type == self.TRUNCATION_TYPES[1]:
|
207 |
distance_type = 'l2'
|
208 |
else:
|
209 |
raise ValueError
|