qninhdt
commited on
Commit
·
020afa7
1
Parent(s):
6ef7ab3
cc
Browse files- .gitignore +1 -0
- baselines/cyclegan-cut/train.py +1 -0
- main.py +5 -1
- swim/modules/losses/contperceptual.py +22 -15
.gitignore
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
__pycache__
|
2 |
datasets
|
3 |
wandb
|
|
|
4 |
*.zip
|
|
|
1 |
__pycache__
|
2 |
datasets
|
3 |
wandb
|
4 |
+
checkpoints
|
5 |
*.zip
|
baselines/cyclegan-cut/train.py
CHANGED
@@ -14,6 +14,7 @@ if __name__ == "__main__":
|
|
14 |
dataset_size = len(dataset) # get the number of images in the dataset.
|
15 |
|
16 |
model = create_model(opt) # create a model given opt.model and other options
|
|
|
17 |
print("The number of training images = %d" % dataset_size)
|
18 |
|
19 |
visualizer = Visualizer(
|
|
|
14 |
dataset_size = len(dataset) # get the number of images in the dataset.
|
15 |
|
16 |
model = create_model(opt) # create a model given opt.model and other options
|
17 |
+
|
18 |
print("The number of training images = %d" % dataset_size)
|
19 |
|
20 |
visualizer = Visualizer(
|
main.py
CHANGED
@@ -6,12 +6,16 @@ from swim.modules.dataset import SwimDataModule
|
|
6 |
from lightning import Trainer
|
7 |
from lightning.pytorch.loggers import WandbLogger
|
8 |
|
|
|
|
|
9 |
config = OmegaConf.load("configs/autoencoder/autoencoder_kl_32x32x4.yaml")
|
10 |
|
11 |
model = instantiate_from_config(config.model)
|
12 |
model.learning_rate = config.model.base_learning_rate
|
13 |
|
14 |
-
datamodule = SwimDataModule(
|
|
|
|
|
15 |
|
16 |
logger = WandbLogger(project="swim", name="autoencoder_kl")
|
17 |
|
|
|
6 |
from lightning import Trainer
|
7 |
from lightning.pytorch.loggers import WandbLogger
|
8 |
|
9 |
+
torch.set_float32_matmul_precision("medium")
|
10 |
+
|
11 |
config = OmegaConf.load("configs/autoencoder/autoencoder_kl_32x32x4.yaml")
|
12 |
|
13 |
model = instantiate_from_config(config.model)
|
14 |
model.learning_rate = config.model.base_learning_rate
|
15 |
|
16 |
+
datamodule = SwimDataModule(
|
17 |
+
root_dir="/cm/shared/ninhnq3/datasets/swim_data", batch_size=2, img_size=512
|
18 |
+
)
|
19 |
|
20 |
logger = WandbLogger(project="swim", name="autoencoder_kl")
|
21 |
|
swim/modules/losses/contperceptual.py
CHANGED
@@ -94,24 +94,31 @@ class LPIPSWithDiscriminator(nn.Module):
|
|
94 |
split="train",
|
95 |
weights=None,
|
96 |
):
|
97 |
-
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
98 |
-
if self.perceptual_weight > 0:
|
99 |
-
p_loss = self.perceptual_loss(
|
100 |
-
inputs.contiguous(), reconstructions.contiguous()
|
101 |
-
)
|
102 |
-
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
103 |
-
|
104 |
-
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
|
105 |
-
weighted_nll_loss = nll_loss
|
106 |
-
if weights is not None:
|
107 |
-
weighted_nll_loss = weights * nll_loss
|
108 |
-
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
|
109 |
-
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
110 |
-
kl_loss = posteriors.kl()
|
111 |
-
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
112 |
|
113 |
# now the GAN part
|
114 |
if optimizer_idx == 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
# generator update
|
116 |
if cond is None:
|
117 |
assert not self.disc_conditional
|
|
|
94 |
split="train",
|
95 |
weights=None,
|
96 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
# now the GAN part
|
99 |
if optimizer_idx == 0:
|
100 |
+
|
101 |
+
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
102 |
+
|
103 |
+
if self.perceptual_weight > 0:
|
104 |
+
p_loss = self.perceptual_loss(
|
105 |
+
inputs.contiguous(), reconstructions.contiguous()
|
106 |
+
)
|
107 |
+
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
108 |
+
|
109 |
+
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
|
110 |
+
weighted_nll_loss = nll_loss
|
111 |
+
if weights is not None:
|
112 |
+
weighted_nll_loss = weights * nll_loss
|
113 |
+
weighted_nll_loss = (
|
114 |
+
torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
|
115 |
+
)
|
116 |
+
print(nll_loss.shape)
|
117 |
+
|
118 |
+
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
119 |
+
kl_loss = posteriors.kl()
|
120 |
+
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
121 |
+
|
122 |
# generator update
|
123 |
if cond is None:
|
124 |
assert not self.disc_conditional
|