--- license: apache-2.0 --- ## few_shot_intent_gpt2 这个模型是基于 [uer/gpt2-chinese-cluecorpussmall](https://huggingface.co/uer/gpt2-chinese-cluecorpussmall) 模型在 [qgyd2021/few_shot_intent_sft](https://huggingface.co/datasets/qgyd2021/few_shot_intent_sft) 数据集上微调的结果. (1)训练在(11000 steps)处 Early Stop。这相当于加载的 [qgyd2021/few_shot_intent_sft](https://huggingface.co/datasets/qgyd2021/few_shot_intent_sft) 数据集的 1 个 epoch 处。 (2)此处保存的是 checkpoint-6000 (6000 steps)的最优权重。这相当于原数据集的 0.63 个 epoch 处。 最终的模型大约是在训练了 0.6 个 epoch 时保存的结果。 你可以在此处体验该模型 [qgyd2021/gpt2_chat](https://huggingface.co/spaces/qgyd2021/gpt2_chat)。 **Eval Loss** 见下图: ![eval_loss.jpg](docs/pictures/eval_loss.jpg) ### 讨论 (1)最优解在不到 1 个 epoch 处得到。 * 这可能说明 GPT2 模型大小,相对于任务复杂度来说太小了。 * 模型进入到局部最终解而无法跳出,应考虑使用较大的学习率,或更换学习率调度器。 ### 其它 训练时加载数据集的代码 ```python #!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse import json from datasets import load_dataset from datasets.download.download_manager import DownloadMode from tqdm import tqdm from project_settings import project_path def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--dataset_path", default="qgyd2021/few_shot_intent_sft", type=str) parser.add_argument("--dataset_split", default=None, type=str) parser.add_argument( "--dataset_cache_dir", default=(project_path / "hub_datasets").as_posix(), type=str ) parser.add_argument("--num_epochs", default=1, type=int) parser.add_argument("--train_subset", default="train.jsonl", type=str) parser.add_argument("--valid_subset", default="valid.jsonl", type=str) args = parser.parse_args() return args def main(): args = get_args() name_list = [ # "a_intent_prompt", "amazon_massive_intent_en_us_prompt", "amazon_massive_intent_zh_cn_prompt", "atis_intents_prompt", "banking77_prompt", "bi_text11_prompt", "bi_text27_prompt", # "book6_prompt", "carer_prompt", "chatbots_prompt", "chinese_news_title_prompt", "cmid_4class_prompt", "cmid_36class_prompt", "coig_cqia_prompt", "conv_intent_prompt", "crosswoz_prompt", "dmslots_prompt", "dnd_style_intents_prompt", "emo2019_prompt", "finance21_prompt", "ide_intent_prompt", "intent_classification_prompt", "jarvis_intent_prompt", "mobile_assistant_prompt", "mtop_intent_prompt", "out_of_scope_prompt", "ri_sawoz_domain_prompt", "ri_sawoz_general_prompt", "small_talk_prompt", "smp2017_task1_prompt", "smp2019_task1_domain_prompt", "smp2019_task1_intent_prompt", # "snips_built_in_intents_prompt", "star_wars_prompt", "suicide_intent_prompt", "snips_built_in_intents_prompt", "telemarketing_intent_cn_prompt", "telemarketing_intent_en_prompt", "vira_intents_prompt", ] with open(args.train_subset, "w", encoding="utf-8") as f: for _ in range(args.num_epochs): for name in name_list: print(name) dataset = load_dataset( path=args.dataset_path, name=name, split="train", cache_dir=args.dataset_cache_dir, download_mode=DownloadMode.FORCE_REDOWNLOAD, ignore_verifications=True ) for sample in tqdm(dataset): row = json.dumps(sample, ensure_ascii=False) f.write("{}\n".format(row)) with open(args.valid_subset, "w", encoding="utf-8") as f: for _ in range(args.num_epochs): for name in name_list: print(name) dataset = load_dataset( path=args.dataset_path, name=name, split="test", cache_dir=args.dataset_cache_dir, download_mode=DownloadMode.FORCE_REDOWNLOAD, ignore_verifications=True ) for sample in tqdm(dataset): row = json.dumps(sample, ensure_ascii=False) f.write("{}\n".format(row)) return if __name__ == '__main__': main() ```