chainyo commited on
Commit
c75c928
1 Parent(s): 8e87ed1

create training loop

Browse files
Files changed (1) hide show
  1. training_loop.py +93 -0
training_loop.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Minimal command:
3
+ python training_loop.py --hub_dir "segments/sidewalk-semantic"
4
+
5
+ Maximal command:
6
+ python training_loop.py --hub_dir "segments/sidewalk-semantic" --batch_size 32 --learning_rate 6e-5 --model_flavor 0 --seed 42 --split train
7
+ """
8
+
9
+ import json
10
+ import torch
11
+
12
+ from pytorch_lightning import Trainer, callbacks, seed_everything
13
+ from pytorch_lightning.loggers import WandbLogger
14
+
15
+ from dataloader import SidewalkSegmentationDataLoader
16
+ from model import SidewalkSegmentationModel
17
+
18
+
19
+ def main(
20
+ hub_dir: str,
21
+ batch_size: int = 32,
22
+ learning_rate: float = 6e-5,
23
+ model_flavor: int = 0,
24
+ seed: int = 42,
25
+ split: str = "train",
26
+ ):
27
+ seed_everything(seed)
28
+ logger = WandbLogger(project="sidewalk-segmentation")
29
+ gpu_value = 1 if torch.cuda.is_available() else 0
30
+
31
+ id2label_file = json.load(open("id2label.json", "r"))
32
+ id2label = {int(k): v for k, v in id2label_file.items()}
33
+ num_labels = len(id2label)
34
+
35
+ model = SidewalkSegmentationModel(
36
+ num_labels=num_labels,
37
+ id2label=id2label,
38
+ model_flavor=model_flavor,
39
+ learning_rate=learning_rate,
40
+ )
41
+ data_module = SidewalkSegmentationDataLoader(
42
+ hub_dir=hub_dir,
43
+ batch_size=batch_size,
44
+ split=split,
45
+ )
46
+ data_module.setup()
47
+
48
+ checkpoint_callback = callbacks.ModelCheckpoint(
49
+ dirpath="checkpoints",
50
+ save_top_k=1,
51
+ verbose=True,
52
+ monitor="val_mean_iou",
53
+ mode="max",
54
+ )
55
+ early_stopping_callback = callbacks.EarlyStopping(
56
+ monitor="val_mean_iou",
57
+ patience=5,
58
+ verbose=True,
59
+ mode="max",
60
+ )
61
+
62
+ trainer = Trainer(
63
+ max_epochs=200,
64
+ progress_bar_refresh_rate=10,
65
+ gpus=gpu_value,
66
+ logger=logger,
67
+ callbacks=[checkpoint_callback, early_stopping_callback],
68
+ deterministic=False,
69
+ )
70
+ trainer.fit(model, data_module)
71
+
72
+
73
+ if __name__ == "__main__":
74
+ import argparse
75
+
76
+ parser = argparse.ArgumentParser()
77
+ parser.add_argument("--hub_dir", type=str, required=True)
78
+ parser.add_argument("--batch_size", type=int, default=32)
79
+ parser.add_argument("--learning_rate", type=float, default=6e-5)
80
+ parser.add_argument("--model_flavor", type=int, default=0)
81
+ parser.add_argument("--seed", type=int, default=42)
82
+ parser.add_argument("--split", type=str, default="train")
83
+ args = parser.parse_args()
84
+
85
+ main(
86
+ hub_dir=args.hub_dir,
87
+ batch_size=args.batch_size,
88
+ learning_rate=args.learning_rate,
89
+ model_flavor=args.model_flavor,
90
+ seed=args.seed,
91
+ split=args.split,
92
+ )
93
+