tousin23 commited on
Commit
ecc5fc4
·
verified ·
1 Parent(s): 7569ad7

Create configs.py

Browse files
Files changed (1) hide show
  1. configs.py +68 -0
configs.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ parser = argparse.ArgumentParser(description="hyper-parameter for R2GenGPT")
4
+ # ========================= Dataset Configs ==========================
5
+ parser.add_argument('--test', action='store_true', help="only run test set")
6
+ parser.add_argument('--validate', action='store_true', help="only run validation set")
7
+ parser.add_argument('--dataset', type=str, default='mimic_cxr', help="iu-xray or mimic-cxr")
8
+ parser.add_argument('--annotation', type=str, default=r'./data/mimic_cxr/annotation.json', help="annotation file of the dataset")
9
+ parser.add_argument('--base_dir', type=str, default=r'./data/mimic_cxr/images', help="base dir to help find images")
10
+ parser.add_argument('--batch_size', default=6, type=int, help="use for training duration per worker")
11
+ parser.add_argument('--val_batch_size', default=16, type=int, help="use for validation duration per worker")
12
+ parser.add_argument('--test_batch_size', default=16, type=int, help="use for testing duration per worker")
13
+ parser.add_argument('--prefetch_factor', default=4, type=int, help="use for training duration per worker")
14
+ parser.add_argument('--num_workers', default=8, type=int, help="Cpu num for dataloaders")
15
+
16
+ # ========================= Model Settings ============================
17
+ parser.add_argument('--vision_model', default='microsoft/swin-base-patch4-window7-224', type=str, help="vision model to use")
18
+ parser.add_argument('--llama_model', default='meta-llama/Llama-2-7b-chat-hf', type=str, help="LLM model to use")
19
+ parser.add_argument('--freeze_vm', default=True, type=lambda x: (str(x).lower() == 'true'), help='freeze vision model')
20
+ parser.add_argument('--llm_use_lora', default=False, type=lambda x: (str(x).lower() == 'true'), help="whether use lora for LLM model")
21
+ parser.add_argument('--llm_r', default=16, type=int, help='The dimension used by the LoRA update matrices')
22
+ parser.add_argument('--llm_alpha', default=16, type=int, help='Scaling factor.')
23
+ parser.add_argument('--vis_use_lora', default=False, type=lambda x: (str(x).lower() == 'true'), help="whether use lora for vision model")
24
+ parser.add_argument('--vis_r', default=16, type=int, help='The dimension used by the LoRA update matrices')
25
+ parser.add_argument('--vis_alpha', default=16, type=int, help='Scaling factor.')
26
+ parser.add_argument('--lora_dropout', default=0.1, type=float, help='lora dropout')
27
+ parser.add_argument('--global_only', default=False, type=lambda x: (str(x).lower() == 'true'), help='use global embedding only')
28
+ parser.add_argument('--low_resource', default=False, type=bool)
29
+ parser.add_argument('--end_sym', default='</s>', type=str)
30
+
31
+ # ======================== SavedModel Configs ===========================
32
+ parser.add_argument('--savedmodel_path', type=str, default='save/mimic/v1')
33
+ parser.add_argument('--ckpt_file', type=str, default=None, help='the checkpoint file to load')
34
+ parser.add_argument('--delta_file', type=str, default=None, help='the delta file to load')
35
+ parser.add_argument('--weights', type=list, default=[0.5, 0.5])
36
+ parser.add_argument('--scorer_types', type=list, default=['Bleu_4', 'CIDEr'])
37
+
38
+ # ========================= Learning Configs ==========================
39
+ parser.add_argument('--learning_rate', default=1e-4, type=float, help='initial learning rate')
40
+ parser.add_argument('--gradient_clip_val', default=None, type=int, help='gradient clip value')
41
+
42
+ # ========================= Decoding Settings ==========================
43
+ parser.add_argument('--beam_size', type=int, default=3)
44
+ parser.add_argument('--do_sample', type=bool, default=False)
45
+ parser.add_argument('--no_repeat_ngram_size', type=int, default=2)
46
+ parser.add_argument('--num_beam_groups', type=int, default=1)
47
+ parser.add_argument('--min_new_tokens', type=int, default=80)
48
+ parser.add_argument('--max_new_tokens', type=int, default=120)
49
+ parser.add_argument('--max_length', type=int, default=100)
50
+ parser.add_argument('--repetition_penalty', type=float, default=2.0)
51
+ parser.add_argument('--length_penalty', type=float, default=2.0)
52
+ parser.add_argument('--diversity_penalty', type=float, default=0)
53
+ parser.add_argument('--temperature', type=float, default=0)
54
+
55
+ # ====================== Pytorch Lightning ===========================
56
+ parser.add_argument('--devices', type=int, default=2, help='how many gpus to use')
57
+ parser.add_argument('--num_nodes', type=int, default=1, help='Number of GPU nodes for distributed training.')
58
+ parser.add_argument('--accelerator', type=str, default="gpu", choices=["cpu", "gpu", "tpu", "ipu", "hpu", "mps"], help='accelerator types')
59
+ parser.add_argument('--strategy', type=str, default="ddp", help='default ddp for multi-gpus')
60
+ parser.add_argument('--precision', type=str, default='bf16-mixed', help='16 or 32 bf16-mixed, using for original pytorch amp auto cast')
61
+ parser.add_argument('--limit_val_batches', type=float, default=1.0, help='How much of validation dataset to check (float = fraction, int = num_batches).')
62
+ parser.add_argument('--limit_test_batches', type=float, default=1.0, help='How much of test dataset to check (float = fraction, int = num_batches).')
63
+ parser.add_argument('--limit_train_batches', type=float, default=1.0, help='How much of training dataset to check (float = fraction, int = num_batches)')
64
+ parser.add_argument('--max_epochs', type=int, default=3, help='Stop training once this number of epochs is reached')
65
+ parser.add_argument('--every_n_train_steps', type=int, default=0, help='How many training steps to save a checkpoint')
66
+ parser.add_argument('--val_check_interval', type=float, default=1.0, help='How often to check the validation set')
67
+ parser.add_argument('--accumulate_grad_batches', type=int, default=1, help='Accumulates gradients over k batches before stepping the optimizer')
68
+ parser.add_argument("--num_sanity_val_steps", type=int, default=2, help='Sanity check runs n validation batches before starting the training routine')