qninhdt commited on
Commit
020afa7
·
1 Parent(s): 6ef7ab3
.gitignore CHANGED
@@ -1,4 +1,5 @@
1
  __pycache__
2
  datasets
3
  wandb
 
4
  *.zip
 
1
  __pycache__
2
  datasets
3
  wandb
4
+ checkpoints
5
  *.zip
baselines/cyclegan-cut/train.py CHANGED
@@ -14,6 +14,7 @@ if __name__ == "__main__":
14
  dataset_size = len(dataset) # get the number of images in the dataset.
15
 
16
  model = create_model(opt) # create a model given opt.model and other options
 
17
  print("The number of training images = %d" % dataset_size)
18
 
19
  visualizer = Visualizer(
 
14
  dataset_size = len(dataset) # get the number of images in the dataset.
15
 
16
  model = create_model(opt) # create a model given opt.model and other options
17
+
18
  print("The number of training images = %d" % dataset_size)
19
 
20
  visualizer = Visualizer(
main.py CHANGED
@@ -6,12 +6,16 @@ from swim.modules.dataset import SwimDataModule
6
  from lightning import Trainer
7
  from lightning.pytorch.loggers import WandbLogger
8
 
 
 
9
  config = OmegaConf.load("configs/autoencoder/autoencoder_kl_32x32x4.yaml")
10
 
11
  model = instantiate_from_config(config.model)
12
  model.learning_rate = config.model.base_learning_rate
13
 
14
- datamodule = SwimDataModule(img_size=32)
 
 
15
 
16
  logger = WandbLogger(project="swim", name="autoencoder_kl")
17
 
 
6
  from lightning import Trainer
7
  from lightning.pytorch.loggers import WandbLogger
8
 
9
+ torch.set_float32_matmul_precision("medium")
10
+
11
  config = OmegaConf.load("configs/autoencoder/autoencoder_kl_32x32x4.yaml")
12
 
13
  model = instantiate_from_config(config.model)
14
  model.learning_rate = config.model.base_learning_rate
15
 
16
+ datamodule = SwimDataModule(
17
+ root_dir="/cm/shared/ninhnq3/datasets/swim_data", batch_size=2, img_size=512
18
+ )
19
 
20
  logger = WandbLogger(project="swim", name="autoencoder_kl")
21
 
swim/modules/losses/contperceptual.py CHANGED
@@ -94,24 +94,31 @@ class LPIPSWithDiscriminator(nn.Module):
94
  split="train",
95
  weights=None,
96
  ):
97
- rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
98
- if self.perceptual_weight > 0:
99
- p_loss = self.perceptual_loss(
100
- inputs.contiguous(), reconstructions.contiguous()
101
- )
102
- rec_loss = rec_loss + self.perceptual_weight * p_loss
103
-
104
- nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
105
- weighted_nll_loss = nll_loss
106
- if weights is not None:
107
- weighted_nll_loss = weights * nll_loss
108
- weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
109
- nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
110
- kl_loss = posteriors.kl()
111
- kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
112
 
113
  # now the GAN part
114
  if optimizer_idx == 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  # generator update
116
  if cond is None:
117
  assert not self.disc_conditional
 
94
  split="train",
95
  weights=None,
96
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  # now the GAN part
99
  if optimizer_idx == 0:
100
+
101
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
102
+
103
+ if self.perceptual_weight > 0:
104
+ p_loss = self.perceptual_loss(
105
+ inputs.contiguous(), reconstructions.contiguous()
106
+ )
107
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
108
+
109
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
110
+ weighted_nll_loss = nll_loss
111
+ if weights is not None:
112
+ weighted_nll_loss = weights * nll_loss
113
+ weighted_nll_loss = (
114
+ torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
115
+ )
116
+ print(nll_loss.shape)
117
+
118
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
119
+ kl_loss = posteriors.kl()
120
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
121
+
122
  # generator update
123
  if cond is None:
124
  assert not self.disc_conditional