Update README.md
Browse files
README.md
CHANGED
@@ -5,13 +5,9 @@ license: apache-2.0
|
|
5 |
|
6 |
这个模型是基于 [uer/gpt2-chinese-cluecorpussmall](https://huggingface.co/uer/gpt2-chinese-cluecorpussmall) 模型在 [qgyd2021/few_shot_intent_sft](https://huggingface.co/datasets/qgyd2021/few_shot_intent_sft) 数据集上微调的结果.
|
7 |
|
|
|
8 |
|
9 |
-
(
|
10 |
-
|
11 |
-
(2)3 个 epoch 的数据在训练时算 1 个 epoch。训练到大约 0.32 个 epoch 时(即 11000 steps)处 Early Stop。 (训练时的 0.32 个 epoch,相当于原始数据 3 个 epoch 的 1 个 epoch)。
|
12 |
-
|
13 |
-
(3)此处保存的是 checkpoint-6000 (6000 steps)的权重。
|
14 |
-
|
15 |
|
16 |
|
17 |
最终的模型大约是在训练了 0.6 个 epoch 时保存的结果。
|
@@ -24,4 +20,129 @@ license: apache-2.0
|
|
24 |
![eval_loss.jpg](docs/pictures/eval_loss.jpg)
|
25 |
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
|
|
5 |
|
6 |
这个模型是基于 [uer/gpt2-chinese-cluecorpussmall](https://huggingface.co/uer/gpt2-chinese-cluecorpussmall) 模型在 [qgyd2021/few_shot_intent_sft](https://huggingface.co/datasets/qgyd2021/few_shot_intent_sft) 数据集上微调的结果.
|
7 |
|
8 |
+
(1)训练在(11000 steps)处 Early Stop。这相当于加载的 [qgyd2021/few_shot_intent_sft](https://huggingface.co/datasets/qgyd2021/few_shot_intent_sft) 数据集的 1 个 epoch 处。
|
9 |
|
10 |
+
(2)此处保存的是 checkpoint-6000 (6000 steps)的最优权重。这相当于原数据集的 0.63 个 epoch 处。
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
|
13 |
最终的模型大约是在训练了 0.6 个 epoch 时保存的结果。
|
|
|
20 |
![eval_loss.jpg](docs/pictures/eval_loss.jpg)
|
21 |
|
22 |
|
23 |
+
### 讨论
|
24 |
+
|
25 |
+
(1)最优解在不到 1 个 epoch 处得到。
|
26 |
+
|
27 |
+
这可能跟语言模型。
|
28 |
+
|
29 |
+
|
30 |
+
### 其它
|
31 |
+
|
32 |
+
训练时加载数据集的代码
|
33 |
+
```python
|
34 |
+
#!/usr/bin/python3
|
35 |
+
# -*- coding: utf-8 -*-
|
36 |
+
import argparse
|
37 |
+
import json
|
38 |
+
|
39 |
+
from datasets import load_dataset
|
40 |
+
from datasets.download.download_manager import DownloadMode
|
41 |
+
from tqdm import tqdm
|
42 |
+
|
43 |
+
from project_settings import project_path
|
44 |
+
|
45 |
+
|
46 |
+
def get_args():
|
47 |
+
parser = argparse.ArgumentParser()
|
48 |
+
parser.add_argument("--dataset_path", default="qgyd2021/few_shot_intent_sft", type=str)
|
49 |
+
parser.add_argument("--dataset_split", default=None, type=str)
|
50 |
+
parser.add_argument(
|
51 |
+
"--dataset_cache_dir",
|
52 |
+
default=(project_path / "hub_datasets").as_posix(),
|
53 |
+
type=str
|
54 |
+
)
|
55 |
+
|
56 |
+
parser.add_argument("--num_epochs", default=1, type=int)
|
57 |
+
|
58 |
+
parser.add_argument("--train_subset", default="train.jsonl", type=str)
|
59 |
+
parser.add_argument("--valid_subset", default="valid.jsonl", type=str)
|
60 |
+
args = parser.parse_args()
|
61 |
+
return args
|
62 |
+
|
63 |
+
|
64 |
+
def main():
|
65 |
+
args = get_args()
|
66 |
+
|
67 |
+
name_list = [
|
68 |
+
# "a_intent_prompt",
|
69 |
+
"amazon_massive_intent_en_us_prompt",
|
70 |
+
"amazon_massive_intent_zh_cn_prompt",
|
71 |
+
"atis_intents_prompt",
|
72 |
+
"banking77_prompt",
|
73 |
+
"bi_text11_prompt",
|
74 |
+
"bi_text27_prompt",
|
75 |
+
# "book6_prompt",
|
76 |
+
"carer_prompt",
|
77 |
+
"chatbots_prompt",
|
78 |
+
"chinese_news_title_prompt",
|
79 |
+
"cmid_4class_prompt",
|
80 |
+
"cmid_36class_prompt",
|
81 |
+
"coig_cqia_prompt",
|
82 |
+
"conv_intent_prompt",
|
83 |
+
"crosswoz_prompt",
|
84 |
+
"dmslots_prompt",
|
85 |
+
"dnd_style_intents_prompt",
|
86 |
+
"emo2019_prompt",
|
87 |
+
"finance21_prompt",
|
88 |
+
"ide_intent_prompt",
|
89 |
+
"intent_classification_prompt",
|
90 |
+
"jarvis_intent_prompt",
|
91 |
+
"mobile_assistant_prompt",
|
92 |
+
"mtop_intent_prompt",
|
93 |
+
"out_of_scope_prompt",
|
94 |
+
"ri_sawoz_domain_prompt",
|
95 |
+
"ri_sawoz_general_prompt",
|
96 |
+
"small_talk_prompt",
|
97 |
+
"smp2017_task1_prompt",
|
98 |
+
"smp2019_task1_domain_prompt",
|
99 |
+
"smp2019_task1_intent_prompt",
|
100 |
+
# "snips_built_in_intents_prompt",
|
101 |
+
"star_wars_prompt",
|
102 |
+
"suicide_intent_prompt",
|
103 |
+
"snips_built_in_intents_prompt",
|
104 |
+
"telemarketing_intent_cn_prompt",
|
105 |
+
"telemarketing_intent_en_prompt",
|
106 |
+
"vira_intents_prompt",
|
107 |
+
]
|
108 |
+
|
109 |
+
with open(args.train_subset, "w", encoding="utf-8") as f:
|
110 |
+
for _ in range(args.num_epochs):
|
111 |
+
for name in name_list:
|
112 |
+
print(name)
|
113 |
+
dataset = load_dataset(
|
114 |
+
path=args.dataset_path,
|
115 |
+
name=name,
|
116 |
+
split="train",
|
117 |
+
cache_dir=args.dataset_cache_dir,
|
118 |
+
download_mode=DownloadMode.FORCE_REDOWNLOAD,
|
119 |
+
ignore_verifications=True
|
120 |
+
)
|
121 |
+
for sample in tqdm(dataset):
|
122 |
+
row = json.dumps(sample, ensure_ascii=False)
|
123 |
+
f.write("{}\n".format(row))
|
124 |
+
|
125 |
+
with open(args.valid_subset, "w", encoding="utf-8") as f:
|
126 |
+
for _ in range(args.num_epochs):
|
127 |
+
for name in name_list:
|
128 |
+
print(name)
|
129 |
+
dataset = load_dataset(
|
130 |
+
path=args.dataset_path,
|
131 |
+
name=name,
|
132 |
+
split="test",
|
133 |
+
cache_dir=args.dataset_cache_dir,
|
134 |
+
download_mode=DownloadMode.FORCE_REDOWNLOAD,
|
135 |
+
ignore_verifications=True
|
136 |
+
)
|
137 |
+
for sample in tqdm(dataset):
|
138 |
+
row = json.dumps(sample, ensure_ascii=False)
|
139 |
+
f.write("{}\n".format(row))
|
140 |
+
|
141 |
+
return
|
142 |
+
|
143 |
+
|
144 |
+
if __name__ == '__main__':
|
145 |
+
main()
|
146 |
+
|
147 |
+
```
|
148 |
|