qgyd2021 commited on
Commit
a849b40
1 Parent(s): 06e742c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +127 -6
README.md CHANGED
@@ -5,13 +5,9 @@ license: apache-2.0
5
 
6
  这个模型是基于 [uer/gpt2-chinese-cluecorpussmall](https://huggingface.co/uer/gpt2-chinese-cluecorpussmall) 模型在 [qgyd2021/few_shot_intent_sft](https://huggingface.co/datasets/qgyd2021/few_shot_intent_sft) 数据集上微调的结果.
7
 
 
8
 
9
- 1)因为 [qgyd2021/few_shot_intent_sft](https://huggingface.co/datasets/qgyd2021/few_shot_intent_sft) 数据集的 `*_prompt` 子集是动态生成的,因此首先,生成 3 个 epoch 的数据作为训练集和验证集。
10
-
11
- (2)3 个 epoch 的数据在训练时算 1 个 epoch。训练到大约 0.32 个 epoch 时(即 11000 steps)处 Early Stop。 (训练时的 0.32 个 epoch,相当于原始数据 3 个 epoch 的 1 个 epoch)。
12
-
13
- (3)此处保存的是 checkpoint-6000 (6000 steps)的权重。
14
-
15
 
16
 
17
  最终的模型大约是在训练了 0.6 个 epoch 时保存的结果。
@@ -24,4 +20,129 @@ license: apache-2.0
24
  ![eval_loss.jpg](docs/pictures/eval_loss.jpg)
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
 
5
 
6
  这个模型是基于 [uer/gpt2-chinese-cluecorpussmall](https://huggingface.co/uer/gpt2-chinese-cluecorpussmall) 模型在 [qgyd2021/few_shot_intent_sft](https://huggingface.co/datasets/qgyd2021/few_shot_intent_sft) 数据集上微调的结果.
7
 
8
+ (1)训练在(11000 steps)处 Early Stop。这相当于加载的 [qgyd2021/few_shot_intent_sft](https://huggingface.co/datasets/qgyd2021/few_shot_intent_sft) 数据集的 1 个 epoch 处。
9
 
10
+ 2)此处保存的是 checkpoint-6000 (6000 steps)的最优权重。这相当于原数据集的 0.63 个 epoch 处。
 
 
 
 
 
11
 
12
 
13
  最终的模型大约是在训练了 0.6 个 epoch 时保存的结果。
 
20
  ![eval_loss.jpg](docs/pictures/eval_loss.jpg)
21
 
22
 
23
+ ### 讨论
24
+
25
+ (1)最优解在不到 1 个 epoch 处得到。
26
+
27
+ 这可能跟语言模型。
28
+
29
+
30
+ ### 其它
31
+
32
+ 训练时加载数据集的代码
33
+ ```python
34
+ #!/usr/bin/python3
35
+ # -*- coding: utf-8 -*-
36
+ import argparse
37
+ import json
38
+
39
+ from datasets import load_dataset
40
+ from datasets.download.download_manager import DownloadMode
41
+ from tqdm import tqdm
42
+
43
+ from project_settings import project_path
44
+
45
+
46
+ def get_args():
47
+ parser = argparse.ArgumentParser()
48
+ parser.add_argument("--dataset_path", default="qgyd2021/few_shot_intent_sft", type=str)
49
+ parser.add_argument("--dataset_split", default=None, type=str)
50
+ parser.add_argument(
51
+ "--dataset_cache_dir",
52
+ default=(project_path / "hub_datasets").as_posix(),
53
+ type=str
54
+ )
55
+
56
+ parser.add_argument("--num_epochs", default=1, type=int)
57
+
58
+ parser.add_argument("--train_subset", default="train.jsonl", type=str)
59
+ parser.add_argument("--valid_subset", default="valid.jsonl", type=str)
60
+ args = parser.parse_args()
61
+ return args
62
+
63
+
64
+ def main():
65
+ args = get_args()
66
+
67
+ name_list = [
68
+ # "a_intent_prompt",
69
+ "amazon_massive_intent_en_us_prompt",
70
+ "amazon_massive_intent_zh_cn_prompt",
71
+ "atis_intents_prompt",
72
+ "banking77_prompt",
73
+ "bi_text11_prompt",
74
+ "bi_text27_prompt",
75
+ # "book6_prompt",
76
+ "carer_prompt",
77
+ "chatbots_prompt",
78
+ "chinese_news_title_prompt",
79
+ "cmid_4class_prompt",
80
+ "cmid_36class_prompt",
81
+ "coig_cqia_prompt",
82
+ "conv_intent_prompt",
83
+ "crosswoz_prompt",
84
+ "dmslots_prompt",
85
+ "dnd_style_intents_prompt",
86
+ "emo2019_prompt",
87
+ "finance21_prompt",
88
+ "ide_intent_prompt",
89
+ "intent_classification_prompt",
90
+ "jarvis_intent_prompt",
91
+ "mobile_assistant_prompt",
92
+ "mtop_intent_prompt",
93
+ "out_of_scope_prompt",
94
+ "ri_sawoz_domain_prompt",
95
+ "ri_sawoz_general_prompt",
96
+ "small_talk_prompt",
97
+ "smp2017_task1_prompt",
98
+ "smp2019_task1_domain_prompt",
99
+ "smp2019_task1_intent_prompt",
100
+ # "snips_built_in_intents_prompt",
101
+ "star_wars_prompt",
102
+ "suicide_intent_prompt",
103
+ "snips_built_in_intents_prompt",
104
+ "telemarketing_intent_cn_prompt",
105
+ "telemarketing_intent_en_prompt",
106
+ "vira_intents_prompt",
107
+ ]
108
+
109
+ with open(args.train_subset, "w", encoding="utf-8") as f:
110
+ for _ in range(args.num_epochs):
111
+ for name in name_list:
112
+ print(name)
113
+ dataset = load_dataset(
114
+ path=args.dataset_path,
115
+ name=name,
116
+ split="train",
117
+ cache_dir=args.dataset_cache_dir,
118
+ download_mode=DownloadMode.FORCE_REDOWNLOAD,
119
+ ignore_verifications=True
120
+ )
121
+ for sample in tqdm(dataset):
122
+ row = json.dumps(sample, ensure_ascii=False)
123
+ f.write("{}\n".format(row))
124
+
125
+ with open(args.valid_subset, "w", encoding="utf-8") as f:
126
+ for _ in range(args.num_epochs):
127
+ for name in name_list:
128
+ print(name)
129
+ dataset = load_dataset(
130
+ path=args.dataset_path,
131
+ name=name,
132
+ split="test",
133
+ cache_dir=args.dataset_cache_dir,
134
+ download_mode=DownloadMode.FORCE_REDOWNLOAD,
135
+ ignore_verifications=True
136
+ )
137
+ for sample in tqdm(dataset):
138
+ row = json.dumps(sample, ensure_ascii=False)
139
+ f.write("{}\n".format(row))
140
+
141
+ return
142
+
143
+
144
+ if __name__ == '__main__':
145
+ main()
146
+
147
+ ```
148