ohayonguy
commited on
Commit
•
8e5e901
1
Parent(s):
ec598ae
fixed impot
Browse files- app.py +1 -1
- lightning_models/mmse_rectified_flow.py +2 -1
app.py
CHANGED
@@ -156,7 +156,7 @@ demo = gr.Interface(
|
|
156 |
gr.Image(type="filepath", label="Input"),
|
157 |
gr.Radio(['aligned', 'unaligned'], type="value", value='unaligned', label='Image Alignment'),
|
158 |
gr.Number(label="Rescaling factor", value=2),
|
159 |
-
gr.Number(label="Number of flow steps", value=25),
|
160 |
], [
|
161 |
gr.Image(type="numpy", label="Output (The whole image)"),
|
162 |
gr.File(label="Download the output image")
|
|
|
156 |
gr.Image(type="filepath", label="Input"),
|
157 |
gr.Radio(['aligned', 'unaligned'], type="value", value='unaligned', label='Image Alignment'),
|
158 |
gr.Number(label="Rescaling factor", value=2),
|
159 |
+
gr.Number(label="Number of flow steps (a higher value leads to better image quality at the expense of runtime)", value=25),
|
160 |
], [
|
161 |
gr.Image(type="numpy", label="Output (The whole image)"),
|
162 |
gr.File(label="Download the output image")
|
lightning_models/mmse_rectified_flow.py
CHANGED
@@ -8,7 +8,8 @@ from torch.nn.functional import mse_loss
|
|
8 |
from torch.nn.functional import sigmoid
|
9 |
from torch.optim import AdamW
|
10 |
from torch_ema import ExponentialMovingAverage as EMA
|
11 |
-
from torchmetrics.image import FrechetInceptionDistance
|
|
|
12 |
from torchvision.transforms.functional import to_pil_image
|
13 |
from torchvision.utils import save_image
|
14 |
|
|
|
8 |
from torch.nn.functional import sigmoid
|
9 |
from torch.optim import AdamW
|
10 |
from torch_ema import ExponentialMovingAverage as EMA
|
11 |
+
from torchmetrics.image.fid import FrechetInceptionDistance
|
12 |
+
from torchmetrics.image.inception import InceptionScore
|
13 |
from torchvision.transforms.functional import to_pil_image
|
14 |
from torchvision.utils import save_image
|
15 |
|