File size: 4,766 Bytes
f729637
 
 
cc77e93
51d6ebd
c9129a5
 
a849b40
06e742c
a849b40
06e742c
c9129a5
 
51d6ebd
6599416
 
 
ebb49a0
e420f46
3dd6650
e420f46
 
a849b40
 
 
 
ef5bd87
 
 
fe14f27
a849b40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6599416
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
---
license: apache-2.0
---
## few_shot_intent_gpt2

这个模型是基于 [uer/gpt2-chinese-cluecorpussmall](https://huggingface.co/uer/gpt2-chinese-cluecorpussmall) 模型在 [qgyd2021/few_shot_intent_sft](https://huggingface.co/datasets/qgyd2021/few_shot_intent_sft) 数据集上微调的结果. 

(1)训练在(11000 steps)处 Early Stop。这相当于加载的 [qgyd2021/few_shot_intent_sft](https://huggingface.co/datasets/qgyd2021/few_shot_intent_sft) 数据集的 1 个 epoch 处。

(2)此处保存的是 checkpoint-6000 (6000 steps)的最优权重。这相当于原数据集的 0.63 个 epoch 处。


最终的模型大约是在训练了 0.6 个 epoch 时保存的结果。

你可以在此处体验该模型 [qgyd2021/gpt2_chat](https://huggingface.co/spaces/qgyd2021/gpt2_chat)。


**Eval Loss** 见下图:

![eval_loss.jpg](docs/pictures/eval_loss.jpg)


### 讨论

(1)最优解在不到 1 个 epoch 处得到。

* 这可能说明 GPT2 模型大小,相对于任务复杂度来说太小了。

* 模型进入到局部最终解而无法跳出,应考虑使用较大的学习率,或更换学习率调度器。



### 其它

训练时加载数据集的代码
```python
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import json

from datasets import load_dataset
from datasets.download.download_manager import DownloadMode
from tqdm import tqdm

from project_settings import project_path


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_path", default="qgyd2021/few_shot_intent_sft", type=str)
    parser.add_argument("--dataset_split", default=None, type=str)
    parser.add_argument(
        "--dataset_cache_dir",
        default=(project_path / "hub_datasets").as_posix(),
        type=str
    )

    parser.add_argument("--num_epochs", default=1, type=int)

    parser.add_argument("--train_subset", default="train.jsonl", type=str)
    parser.add_argument("--valid_subset", default="valid.jsonl", type=str)
    args = parser.parse_args()
    return args


def main():
    args = get_args()

    name_list = [
        # "a_intent_prompt",
        "amazon_massive_intent_en_us_prompt",
        "amazon_massive_intent_zh_cn_prompt",
        "atis_intents_prompt",
        "banking77_prompt",
        "bi_text11_prompt",
        "bi_text27_prompt",
        # "book6_prompt",
        "carer_prompt",
        "chatbots_prompt",
        "chinese_news_title_prompt",
        "cmid_4class_prompt",
        "cmid_36class_prompt",
        "coig_cqia_prompt",
        "conv_intent_prompt",
        "crosswoz_prompt",
        "dmslots_prompt",
        "dnd_style_intents_prompt",
        "emo2019_prompt",
        "finance21_prompt",
        "ide_intent_prompt",
        "intent_classification_prompt",
        "jarvis_intent_prompt",
        "mobile_assistant_prompt",
        "mtop_intent_prompt",
        "out_of_scope_prompt",
        "ri_sawoz_domain_prompt",
        "ri_sawoz_general_prompt",
        "small_talk_prompt",
        "smp2017_task1_prompt",
        "smp2019_task1_domain_prompt",
        "smp2019_task1_intent_prompt",
        # "snips_built_in_intents_prompt",
        "star_wars_prompt",
        "suicide_intent_prompt",
        "snips_built_in_intents_prompt",
        "telemarketing_intent_cn_prompt",
        "telemarketing_intent_en_prompt",
        "vira_intents_prompt",
    ]

    with open(args.train_subset, "w", encoding="utf-8") as f:
        for _ in range(args.num_epochs):
            for name in name_list:
                print(name)
                dataset = load_dataset(
                    path=args.dataset_path,
                    name=name,
                    split="train",
                    cache_dir=args.dataset_cache_dir,
                    download_mode=DownloadMode.FORCE_REDOWNLOAD,
                    ignore_verifications=True
                )
                for sample in tqdm(dataset):
                    row = json.dumps(sample, ensure_ascii=False)
                    f.write("{}\n".format(row))

    with open(args.valid_subset, "w", encoding="utf-8") as f:
        for _ in range(args.num_epochs):
            for name in name_list:
                print(name)
                dataset = load_dataset(
                    path=args.dataset_path,
                    name=name,
                    split="test",
                    cache_dir=args.dataset_cache_dir,
                    download_mode=DownloadMode.FORCE_REDOWNLOAD,
                    ignore_verifications=True
                )
                for sample in tqdm(dataset):
                    row = json.dumps(sample, ensure_ascii=False)
                    f.write("{}\n".format(row))

    return


if __name__ == '__main__':
    main()

```