Create configs.py
Browse files- configs.py +68 -0
configs.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
parser = argparse.ArgumentParser(description="hyper-parameter for R2GenGPT")
|
4 |
+
# ========================= Dataset Configs ==========================
|
5 |
+
parser.add_argument('--test', action='store_true', help="only run test set")
|
6 |
+
parser.add_argument('--validate', action='store_true', help="only run validation set")
|
7 |
+
parser.add_argument('--dataset', type=str, default='mimic_cxr', help="iu-xray or mimic-cxr")
|
8 |
+
parser.add_argument('--annotation', type=str, default=r'./data/mimic_cxr/annotation.json', help="annotation file of the dataset")
|
9 |
+
parser.add_argument('--base_dir', type=str, default=r'./data/mimic_cxr/images', help="base dir to help find images")
|
10 |
+
parser.add_argument('--batch_size', default=6, type=int, help="use for training duration per worker")
|
11 |
+
parser.add_argument('--val_batch_size', default=16, type=int, help="use for validation duration per worker")
|
12 |
+
parser.add_argument('--test_batch_size', default=16, type=int, help="use for testing duration per worker")
|
13 |
+
parser.add_argument('--prefetch_factor', default=4, type=int, help="use for training duration per worker")
|
14 |
+
parser.add_argument('--num_workers', default=8, type=int, help="Cpu num for dataloaders")
|
15 |
+
|
16 |
+
# ========================= Model Settings ============================
|
17 |
+
parser.add_argument('--vision_model', default='microsoft/swin-base-patch4-window7-224', type=str, help="vision model to use")
|
18 |
+
parser.add_argument('--llama_model', default='meta-llama/Llama-2-7b-chat-hf', type=str, help="LLM model to use")
|
19 |
+
parser.add_argument('--freeze_vm', default=True, type=lambda x: (str(x).lower() == 'true'), help='freeze vision model')
|
20 |
+
parser.add_argument('--llm_use_lora', default=False, type=lambda x: (str(x).lower() == 'true'), help="whether use lora for LLM model")
|
21 |
+
parser.add_argument('--llm_r', default=16, type=int, help='The dimension used by the LoRA update matrices')
|
22 |
+
parser.add_argument('--llm_alpha', default=16, type=int, help='Scaling factor.')
|
23 |
+
parser.add_argument('--vis_use_lora', default=False, type=lambda x: (str(x).lower() == 'true'), help="whether use lora for vision model")
|
24 |
+
parser.add_argument('--vis_r', default=16, type=int, help='The dimension used by the LoRA update matrices')
|
25 |
+
parser.add_argument('--vis_alpha', default=16, type=int, help='Scaling factor.')
|
26 |
+
parser.add_argument('--lora_dropout', default=0.1, type=float, help='lora dropout')
|
27 |
+
parser.add_argument('--global_only', default=False, type=lambda x: (str(x).lower() == 'true'), help='use global embedding only')
|
28 |
+
parser.add_argument('--low_resource', default=False, type=bool)
|
29 |
+
parser.add_argument('--end_sym', default='</s>', type=str)
|
30 |
+
|
31 |
+
# ======================== SavedModel Configs ===========================
|
32 |
+
parser.add_argument('--savedmodel_path', type=str, default='save/mimic/v1')
|
33 |
+
parser.add_argument('--ckpt_file', type=str, default=None, help='the checkpoint file to load')
|
34 |
+
parser.add_argument('--delta_file', type=str, default=None, help='the delta file to load')
|
35 |
+
parser.add_argument('--weights', type=list, default=[0.5, 0.5])
|
36 |
+
parser.add_argument('--scorer_types', type=list, default=['Bleu_4', 'CIDEr'])
|
37 |
+
|
38 |
+
# ========================= Learning Configs ==========================
|
39 |
+
parser.add_argument('--learning_rate', default=1e-4, type=float, help='initial learning rate')
|
40 |
+
parser.add_argument('--gradient_clip_val', default=None, type=int, help='gradient clip value')
|
41 |
+
|
42 |
+
# ========================= Decoding Settings ==========================
|
43 |
+
parser.add_argument('--beam_size', type=int, default=3)
|
44 |
+
parser.add_argument('--do_sample', type=bool, default=False)
|
45 |
+
parser.add_argument('--no_repeat_ngram_size', type=int, default=2)
|
46 |
+
parser.add_argument('--num_beam_groups', type=int, default=1)
|
47 |
+
parser.add_argument('--min_new_tokens', type=int, default=80)
|
48 |
+
parser.add_argument('--max_new_tokens', type=int, default=120)
|
49 |
+
parser.add_argument('--max_length', type=int, default=100)
|
50 |
+
parser.add_argument('--repetition_penalty', type=float, default=2.0)
|
51 |
+
parser.add_argument('--length_penalty', type=float, default=2.0)
|
52 |
+
parser.add_argument('--diversity_penalty', type=float, default=0)
|
53 |
+
parser.add_argument('--temperature', type=float, default=0)
|
54 |
+
|
55 |
+
# ====================== Pytorch Lightning ===========================
|
56 |
+
parser.add_argument('--devices', type=int, default=2, help='how many gpus to use')
|
57 |
+
parser.add_argument('--num_nodes', type=int, default=1, help='Number of GPU nodes for distributed training.')
|
58 |
+
parser.add_argument('--accelerator', type=str, default="gpu", choices=["cpu", "gpu", "tpu", "ipu", "hpu", "mps"], help='accelerator types')
|
59 |
+
parser.add_argument('--strategy', type=str, default="ddp", help='default ddp for multi-gpus')
|
60 |
+
parser.add_argument('--precision', type=str, default='bf16-mixed', help='16 or 32 bf16-mixed, using for original pytorch amp auto cast')
|
61 |
+
parser.add_argument('--limit_val_batches', type=float, default=1.0, help='How much of validation dataset to check (float = fraction, int = num_batches).')
|
62 |
+
parser.add_argument('--limit_test_batches', type=float, default=1.0, help='How much of test dataset to check (float = fraction, int = num_batches).')
|
63 |
+
parser.add_argument('--limit_train_batches', type=float, default=1.0, help='How much of training dataset to check (float = fraction, int = num_batches)')
|
64 |
+
parser.add_argument('--max_epochs', type=int, default=3, help='Stop training once this number of epochs is reached')
|
65 |
+
parser.add_argument('--every_n_train_steps', type=int, default=0, help='How many training steps to save a checkpoint')
|
66 |
+
parser.add_argument('--val_check_interval', type=float, default=1.0, help='How often to check the validation set')
|
67 |
+
parser.add_argument('--accumulate_grad_batches', type=int, default=1, help='Accumulates gradients over k batches before stepping the optimizer')
|
68 |
+
parser.add_argument("--num_sanity_val_steps", type=int, default=2, help='Sanity check runs n validation batches before starting the training routine')
|