Spaces:
Runtime error
Runtime error
ar-houwei-chou
commited on
Commit
•
6aee98f
1
Parent(s):
ed402eb
demo
Browse files- README.md +4 -3
- app.py +2 -0
- app1.py +198 -0
- dataset.py +271 -0
- helper.py +79 -0
- model.py +131 -0
- options.py +47 -0
- requirements.txt +6 -0
- transformer.py +222 -0
- vocab.cond.vocab +31 -0
- vocab.py +41 -0
README.md
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
colorFrom: red
|
5 |
-
colorTo:
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.17.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: Kobedemo
|
3 |
+
emoji: 🐢
|
4 |
colorFrom: red
|
5 |
+
colorTo: red
|
6 |
sdk: streamlit
|
7 |
sdk_version: 1.17.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
python_version: 3.8.12
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
st.write('You have selected:')
|
app1.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
import argparse
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
import torch
|
7 |
+
from dataset import KobeDataModule
|
8 |
+
from model import KobeModel
|
9 |
+
from options import add_args, add_options
|
10 |
+
from typing import List
|
11 |
+
import warnings
|
12 |
+
import logging
|
13 |
+
from transformers.models.bert.tokenization_bert import BertTokenizer
|
14 |
+
import sentencepiece as spm
|
15 |
+
|
16 |
+
|
17 |
+
logging.getLogger("lightning").setLevel(logging.ERROR)
|
18 |
+
def fxn():
|
19 |
+
warnings.warn("deprecated", DeprecationWarning)
|
20 |
+
|
21 |
+
with warnings.catch_warnings():
|
22 |
+
warnings.simplefilter("ignore")
|
23 |
+
fxn()
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
parser = argparse.ArgumentParser()
|
28 |
+
add_options(parser)
|
29 |
+
args = parser.parse_args()
|
30 |
+
args.vocab_file= "bert-base-chinese"
|
31 |
+
args.cond_vocab_file = "./vocab.cond.model"
|
32 |
+
add_args(args)
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
#model = KobeModel(args)
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
class Example:
|
42 |
+
title_token_ids: List[int]
|
43 |
+
condition_token_ids: List[int]
|
44 |
+
fact_token_ids: List[int]
|
45 |
+
|
46 |
+
def __init__(self, title_token_ids:List[int], condition_token_ids: List[int], fact_token_ids: List[int]):
|
47 |
+
self.title_token_ids = title_token_ids
|
48 |
+
self.condition_token_ids = condition_token_ids
|
49 |
+
self.fact_token_ids = fact_token_ids
|
50 |
+
|
51 |
+
text_tokenizer = BertTokenizer.from_pretrained(args.vocab_file)
|
52 |
+
cond_tokenizer = spm.SentencePieceProcessor()
|
53 |
+
cond_tokenizer.Load(args.cond_vocab_file)
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
#model = model.load_from_checkpoint("/root/kobe-v2/1ja19m5t/checkpoints/epoch=19-step=66080.ckpt", args=args)
|
60 |
+
#model = model.load_from_checkpoint("/root/kobe-v2/37ht1cvz/checkpoints/epoch=11-step=396384.ckpt", args=args)
|
61 |
+
|
62 |
+
trainer = pl.Trainer(accelerator='gpu', devices=1, max_epochs=-1)
|
63 |
+
"""
|
64 |
+
|
65 |
+
|
66 |
+
import streamlit as st
|
67 |
+
|
68 |
+
st.write("Most appearing words including stopwords")
|
69 |
+
|
70 |
+
"""
|
71 |
+
choice = st.selectbox(
|
72 |
+
|
73 |
+
'Select the items you want?',
|
74 |
+
|
75 |
+
('Pen','Pencil','Eraser','Sharpener','Notebook'))
|
76 |
+
|
77 |
+
|
78 |
+
input1 = st.selectbox(
|
79 |
+
'please choose:',
|
80 |
+
("1.天猫直发百诺碳纤维三脚架单反相机三角架c2690tb1摄影脚架",
|
81 |
+
"2.dacom飞鱼p10运动型跑步蓝牙耳机入耳式头戴挂耳塞式7级防水苹果安卓手机通用可接听电话音乐篮",
|
82 |
+
"3.coach蔻驰贝壳包pvc单肩手提斜挎大号贝壳包女包5828",
|
83 |
+
"4.highcook韩库韩国进口蓝宝石近无烟炒锅家用不粘锅电磁炉炒菜锅",
|
84 |
+
"5.欧式复古亚麻布料沙发面料定做飘窗垫窗台垫榻榻米垫抱枕diy布艺",
|
85 |
+
"6.飞利浦电动剃须刀sp9851充电式带多功能理容配件和智能清洁系统",
|
86 |
+
"7.不锈钢牛排刀叉西餐餐具全套筷子刀叉勺三件套欧式加厚24件礼盒装",
|
87 |
+
"8.香百年汽车香水挂件车载香水香薰悬挂吊坠车用车内装饰挂饰精油",
|
88 |
+
"9.迪士尼小学生书包儿童男孩13一46年级美国队长男童减负12周岁男",
|
89 |
+
"10.半饱良味潮汕猪肉脯宅人食堂潮汕小吃特产碳烤猪肉干120g"))
|
90 |
+
|
91 |
+
st.write('You selected:', input1)
|
92 |
+
|
93 |
+
|
94 |
+
title =""
|
95 |
+
fact = ""
|
96 |
+
|
97 |
+
if input1==1:
|
98 |
+
title= "天猫直发百诺碳纤维三脚架单反相机三角架c2690tb1摄影脚架"
|
99 |
+
fact = "太字节Terabyte,计算机存储容量单位,也常用TB来表示。百诺公司创建于1996年,早期与日本合作,后通过自身技术创新与努力,逐渐在国内外抑尘设备行业赢得一席单反就是指单镜头反光,即SLRSingleLensReflex,单反相机就是拥有这个功能的相机。技巧拍摄往往都离不开三脚架的帮助,如夜景拍摄、微距拍摄等方面。"
|
100 |
+
|
101 |
+
|
102 |
+
if input1==2:
|
103 |
+
title="dacom飞鱼p10运动型跑步蓝牙耳机入耳式头戴挂耳塞式7级防水苹果安卓手机通用可接听电话音乐篮牙"
|
104 |
+
fact = "移动电话,或称为无线电话,通常称为手机,原本只是一种通讯工具,早期又有大哥大的俗称,是可以在较广范围运动型是德国精神病学家克雷奇默提出的身体类型之一。跑步,是指陆生动物使用足部,移动最快捷的方法。蓝牙耳机就是将蓝牙技术应用在免持耳机上,让使用者可以免除恼人电线的牵绊,自在地以各种方式轻松通话。"
|
105 |
+
|
106 |
+
|
107 |
+
if input1==3:
|
108 |
+
title="coach蔻驰贝壳包pvc单肩手提斜挎大号贝壳包女包5828"
|
109 |
+
fact = "聚氯乙烯,英文简称PVCPolyvinylchloride,是氯乙烯单体vinylchloridem女包,这个名词是箱包的性别分类衍生词。贝壳包beikebao女士包种类的一种,因为其外形酷似贝壳的外形而得名。蔻驰为美国经典皮件品牌COACH,一像以简洁、耐用的风格特色赢得消费者的喜爱。"
|
110 |
+
|
111 |
+
|
112 |
+
if input1==4:
|
113 |
+
title= "highcook韩库韩国进口蓝宝石近无烟炒锅家用不粘锅电磁炉炒菜锅"
|
114 |
+
fact = "蓝宝石,是刚玉宝石中除红色的红宝石之外,其它颜色刚玉宝石的通称,主要成分是氧化铝Al2O3。电磁炉又称为电磁灶,1957年第一台家用电磁炉诞生于德国。家用是汉语词汇,出自管子权修,解释为家庭日常使用的。不粘锅即做饭不会粘锅底的锅,是因为锅底采用了不粘涂层,常见的、不粘性能最好的有特���龙涂层和陶瓷涂层。"
|
115 |
+
|
116 |
+
|
117 |
+
if input1==5:
|
118 |
+
title= "欧式复古亚麻布料沙发面料定做飘窗垫窗台垫榻榻米垫抱枕diy布艺"
|
119 |
+
fact = "沙发是个外来词,根据英语单词sofa音译而来。面料就是用来制作服装的材料。飘窗垫,就是放在飘窗的台面上的垫子。复古与怀旧,有时候很难区分。"
|
120 |
+
|
121 |
+
|
122 |
+
if input1==6:
|
123 |
+
title= "飞利浦电动剃须刀sp9851充电式带多功能理容配件和智能清洁系统"
|
124 |
+
fact = "能够完成一种或者几种生理功能的多个器官按照一定的次序组合在一起的结构叫做系统。配件,指装配机械的零件或部件;也指损坏后重新安装上的零件或部件。清洁是由奥利维耶阿萨亚斯执导,张曼玉、尼克诺尔蒂主演的剧情片,于2004年9月1日在法国上映。飞利浦,1891年成立于荷兰,主要生产照明、家庭电器、医疗系统方面的产品。"
|
125 |
+
|
126 |
+
if input1==7:
|
127 |
+
title= "不锈钢牛排刀叉西餐餐具全套筷子刀叉勺三件套欧式加厚24件礼盒装"
|
128 |
+
fact = "西餐餐具具体有大盘子、小盘子、浅碟、深碟、吃沙拉用的叉子、叉肉用的叉子、喝汤用的汤匙、吃甜点用的汤匙三件咳嗽,贫穷与爱情触不到的恋人...不锈钢指耐空气、蒸汽、水等弱腐蚀介质和酸、碱、盐等化学浸蚀性介质腐蚀的钢,又称不锈耐酸钢。2,4D丁酯,无色油状液体。"
|
129 |
+
|
130 |
+
if input1==8:
|
131 |
+
title= "香百年汽车香水挂件车载香水香薰悬挂吊坠车用车内装饰挂饰精油"
|
132 |
+
fact = "悬挂系统是汽车的车架与车桥或车轮之间的一切传力连接装置的总称,其作用是传递作用在车轮和车架之间的力和吊坠,一种首饰,配戴在脖子上的饰品,多为金属制,特别是不锈钢制和银制,也有矿石、水晶、玉石等制的,主汽车香水AutoPerfume是一种混合了香精油、固定剂与酒精的液体,用来让汽车车内拥有持久且悦人的精油是从植物的花、叶、茎、根或果实中,通过水蒸气蒸馏法、挤压法、冷浸法或溶剂提取法提炼萃取的挥发性芳"
|
133 |
+
|
134 |
+
if input1==9:
|
135 |
+
title = "迪士尼小学生书包儿童男孩13一46年级美国队长男童减负12周岁男"
|
136 |
+
fact= "美国队长是每一男孩心中的英雄人物,迪士尼美国队长款的小学生书包,按照美国队长的防护盾牌设计,泛着丝丝银光,帅气有型,而且还有很多英雄款式哦脊背处采用柔软舒适的脊椎防护设计,减轻孩子的背部压力。而且前方盾牌还能拆卸下来,当做斜挎包使用,满足淘气小男孩的英雄梦。"
|
137 |
+
|
138 |
+
if input1==10:
|
139 |
+
title = "半饱良味潮汕猪肉脯宅人食堂潮汕小吃特产碳烤猪肉干120g"
|
140 |
+
fact = "潮汕,不是潮州潮州一词始于隋文帝开皇十年,距今不到两千年。猪肉脯是一种用猪肉经腌制、烘烤的片状肉制品,食用方便、制作考究、美味可口、耐贮藏和便于运输的中式传统发达国家都有全国统一的急救电话号码。特产指某地特有的或特别著名的产品,有文化内涵或历史,亦指只有在某地才生产的一种产品。"
|
141 |
+
|
142 |
+
|
143 |
+
input2 = st.selectbox(
|
144 |
+
'please choose category:',
|
145 |
+
("1: 家庭主妇",
|
146 |
+
"2: 烹饪达人",
|
147 |
+
"3: 买鞋控",
|
148 |
+
"4: 数码达人",
|
149 |
+
"5: 吃货",
|
150 |
+
"6: 爱包人",
|
151 |
+
"7: 高富帅",
|
152 |
+
))
|
153 |
+
|
154 |
+
st.write('You selected:', input2)
|
155 |
+
aspect = "<"+str(input2)+">"
|
156 |
+
|
157 |
+
|
158 |
+
|
159 |
+
input3 = st.selectbox(
|
160 |
+
'please choose aspect:',
|
161 |
+
("1: appearance",
|
162 |
+
"2: texture",
|
163 |
+
"3: function"))
|
164 |
+
|
165 |
+
st.write('You selected:', input3)
|
166 |
+
|
167 |
+
cond = ""
|
168 |
+
if input3==1:
|
169 |
+
cond="<a>"
|
170 |
+
if input3==2:
|
171 |
+
cond="<b>"
|
172 |
+
if input3==3:
|
173 |
+
cond="<c>"
|
174 |
+
|
175 |
+
#cond = cond+" "+aspect
|
176 |
+
cond = aspect+" "+cond
|
177 |
+
#print(title)
|
178 |
+
#print(fact)
|
179 |
+
#print(cond)
|
180 |
+
tokenizer = text_tokenizer
|
181 |
+
title_token_ids=tokenizer.encode(title, add_special_tokens=False)
|
182 |
+
condition_token_ids=cond_tokenizer.EncodeAsIds(cond)
|
183 |
+
fact_token_ids=tokenizer.encode(fact, add_special_tokens=False)
|
184 |
+
|
185 |
+
e = Example(title_token_ids, condition_token_ids, fact_token_ids)
|
186 |
+
|
187 |
+
|
188 |
+
dm = KobeDataModule(
|
189 |
+
[e],
|
190 |
+
args.text_vocab_path,
|
191 |
+
args.max_seq_len,
|
192 |
+
1,
|
193 |
+
1,
|
194 |
+
)
|
195 |
+
for d in dm.test_dataloader():
|
196 |
+
st.write("result:")
|
197 |
+
st.write(''.join(model.test_step(d ,1)).replace(" ",""))
|
198 |
+
"""
|
dataset.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
import sentencepiece as spm
|
7 |
+
import torch
|
8 |
+
from torch.functional import Tensor
|
9 |
+
from torch.nn.utils.rnn import pad_sequence
|
10 |
+
from torch.utils.data.dataloader import DataLoader
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
@dataclass
|
15 |
+
class Example:
|
16 |
+
title_token_ids: List[int]
|
17 |
+
description_token_ids: List[int]
|
18 |
+
condition_token_ids: List[int]
|
19 |
+
fact_token_ids: List[int]
|
20 |
+
description: str
|
21 |
+
title: str
|
22 |
+
|
23 |
+
|
24 |
+
@dataclass
|
25 |
+
class TensorDict:
|
26 |
+
def detach(self):
|
27 |
+
detached_dict = {
|
28 |
+
field: getattr(self, field).detach()
|
29 |
+
if isinstance(getattr(self, field), torch.Tensor)
|
30 |
+
else getattr(self, field)
|
31 |
+
for field in self.__dataclass_fields__
|
32 |
+
}
|
33 |
+
return self.__class__(**detached_dict)
|
34 |
+
|
35 |
+
def cpu(self):
|
36 |
+
detached_dict = {
|
37 |
+
field: getattr(self, field).cpu()
|
38 |
+
if isinstance(getattr(self, field), torch.Tensor)
|
39 |
+
else getattr(self, field)
|
40 |
+
for field in self.__dataclass_fields__
|
41 |
+
}
|
42 |
+
return self.__class__(**detached_dict)
|
43 |
+
|
44 |
+
|
45 |
+
@dataclass
|
46 |
+
class Batched(TensorDict):
|
47 |
+
# Source
|
48 |
+
title_token_ids: torch.Tensor
|
49 |
+
title_token_ids_mask: torch.Tensor
|
50 |
+
# Attribute Fusion
|
51 |
+
cond_title_token_ids: torch.Tensor
|
52 |
+
cond_title_token_ids_mask: torch.Tensor
|
53 |
+
# Knowledge Incorporation
|
54 |
+
fact_token_ids: torch.Tensor
|
55 |
+
fact_token_ids_mask: torch.Tensor
|
56 |
+
title_fact_token_ids: torch.Tensor
|
57 |
+
title_fact_token_ids_mask: torch.Tensor
|
58 |
+
# Attribute Fusion + Knowledge Incorporation
|
59 |
+
cond_title_fact_token_ids: torch.Tensor
|
60 |
+
cond_title_fact_token_ids_mask: torch.Tensor
|
61 |
+
# Target
|
62 |
+
#description_token_ids: torch.Tensor
|
63 |
+
#description_token_ids_mask: torch.Tensor
|
64 |
+
#descriptions: List[str]
|
65 |
+
#titles: List[str]
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
@dataclass
|
70 |
+
class EncodedBatch(TensorDict):
|
71 |
+
context_encodings: torch.Tensor
|
72 |
+
context_encodings_mask: torch.Tensor
|
73 |
+
|
74 |
+
|
75 |
+
@dataclass
|
76 |
+
class DecodedBatch:
|
77 |
+
loss: float
|
78 |
+
acc: float
|
79 |
+
generated: List[str]
|
80 |
+
descriptions: List[str]
|
81 |
+
titles: List[str]
|
82 |
+
|
83 |
+
|
84 |
+
def from_processed(url: str, train=False):
|
85 |
+
urls = sorted(glob.glob(url))
|
86 |
+
def my_split_by_worker(urls):
|
87 |
+
wi = torch.utils.data.get_worker_info()
|
88 |
+
if wi is None:
|
89 |
+
return urls
|
90 |
+
else:
|
91 |
+
return urls[wi.id::wi.num_workers]
|
92 |
+
def my_split_by_node(urls):
|
93 |
+
node_id, node_count = torch.distributed.get_rank(), torch.distributed.get_world_size()
|
94 |
+
return urls[node_id::node_count]
|
95 |
+
if train:
|
96 |
+
|
97 |
+
return (
|
98 |
+
wds.WebDataset(urls)
|
99 |
+
#wds.WebDataset(urls,nodesplitter=my_split_by_node)
|
100 |
+
#wds.WebDataset(urls,nodesplitter=wds.split_by_node)
|
101 |
+
.shuffle(size=10000000, initial=100000)
|
102 |
+
.decode()
|
103 |
+
.map(lambda d: Example(**d["json"]))
|
104 |
+
)
|
105 |
+
else:
|
106 |
+
print(list(wds.WebDataset(url).decode().map(lambda d: Example(**d["json"])))[0])
|
107 |
+
sys.exit()
|
108 |
+
return list(wds.WebDataset(url).decode().map(lambda d: Example(**d["json"])))
|
109 |
+
#return list(wds.WebDataset(urls, nodesplitter=my_split_by_node).decode().map(lambda d: Example(**d["json"])))
|
110 |
+
#return list(wds.WebDataset(urls, nodesplitter=wds.split_by_node).decode().map(lambda d: Example(**d["json"])))
|
111 |
+
|
112 |
+
|
113 |
+
def get_collate_fn(text_vocab_size: int, max_seq_length: int):
|
114 |
+
def collate_fn(examples: List[Example]) -> Batched:
|
115 |
+
from kobe.data.vocab import BOS_ID, EOS_ID
|
116 |
+
|
117 |
+
title_token_ids = pad_sequence(
|
118 |
+
[
|
119 |
+
torch.tensor(
|
120 |
+
[BOS_ID] + e.title_token_ids[: max_seq_length - 2] + [EOS_ID]
|
121 |
+
)
|
122 |
+
for e in examples
|
123 |
+
]
|
124 |
+
)
|
125 |
+
fact_token_ids = pad_sequence(
|
126 |
+
[
|
127 |
+
torch.tensor(
|
128 |
+
[BOS_ID] + e.fact_token_ids[: max_seq_length - 2] + [EOS_ID]
|
129 |
+
)
|
130 |
+
for e in examples
|
131 |
+
]
|
132 |
+
)
|
133 |
+
"""
|
134 |
+
description_token_ids = pad_sequence(
|
135 |
+
[
|
136 |
+
torch.tensor(
|
137 |
+
[BOS_ID] + e.description_token_ids[: max_seq_length - 2] + [EOS_ID]
|
138 |
+
)
|
139 |
+
for e in examples
|
140 |
+
]
|
141 |
+
)
|
142 |
+
"""
|
143 |
+
cond_title_token_ids = pad_sequence(
|
144 |
+
[
|
145 |
+
torch.tensor(
|
146 |
+
(
|
147 |
+
[BOS_ID]
|
148 |
+
+ [
|
149 |
+
cond_id + text_vocab_size
|
150 |
+
for cond_id in e.condition_token_ids
|
151 |
+
]
|
152 |
+
+ e.title_token_ids
|
153 |
+
)[: max_seq_length - 1]
|
154 |
+
+ [EOS_ID]
|
155 |
+
)
|
156 |
+
for e in examples
|
157 |
+
]
|
158 |
+
)
|
159 |
+
title_fact_token_ids = pad_sequence(
|
160 |
+
[
|
161 |
+
torch.tensor(
|
162 |
+
([BOS_ID] + e.title_token_ids + [EOS_ID] + e.fact_token_ids)[
|
163 |
+
: max_seq_length - 1
|
164 |
+
]
|
165 |
+
+ [EOS_ID]
|
166 |
+
)
|
167 |
+
for e in examples
|
168 |
+
]
|
169 |
+
)
|
170 |
+
cond_title_fact_token_ids = pad_sequence(
|
171 |
+
[
|
172 |
+
torch.tensor(
|
173 |
+
(
|
174 |
+
[BOS_ID]
|
175 |
+
+ [
|
176 |
+
cond_id + text_vocab_size
|
177 |
+
for cond_id in e.condition_token_ids
|
178 |
+
]
|
179 |
+
+ e.title_token_ids
|
180 |
+
+ [EOS_ID]
|
181 |
+
+ e.fact_token_ids
|
182 |
+
)[: max_seq_length - 1]
|
183 |
+
+ [EOS_ID]
|
184 |
+
)
|
185 |
+
for e in examples
|
186 |
+
]
|
187 |
+
)
|
188 |
+
#descriptions = [e.description for e in examples]
|
189 |
+
#titles = [e.title for e in examples]
|
190 |
+
return Batched(
|
191 |
+
title_token_ids=title_token_ids,
|
192 |
+
title_token_ids_mask=(title_token_ids == 0).T,
|
193 |
+
fact_token_ids=fact_token_ids,
|
194 |
+
fact_token_ids_mask=(fact_token_ids == 0).T,
|
195 |
+
cond_title_token_ids=cond_title_token_ids,
|
196 |
+
cond_title_token_ids_mask=(cond_title_token_ids == 0).T,
|
197 |
+
title_fact_token_ids=title_fact_token_ids,
|
198 |
+
title_fact_token_ids_mask=(title_fact_token_ids == 0).T,
|
199 |
+
cond_title_fact_token_ids=cond_title_fact_token_ids,
|
200 |
+
cond_title_fact_token_ids_mask=(cond_title_fact_token_ids == 0).T,
|
201 |
+
#description_token_ids="",
|
202 |
+
#description_token_ids_mask=(description_token_ids == 0).T,
|
203 |
+
#descriptions="",
|
204 |
+
#titles="",
|
205 |
+
)
|
206 |
+
|
207 |
+
return collate_fn
|
208 |
+
|
209 |
+
|
210 |
+
class KobeDataModule(pl.LightningDataModule):
|
211 |
+
def __init__(
|
212 |
+
self,
|
213 |
+
test_data: str,
|
214 |
+
vocab_path: str,
|
215 |
+
max_seq_length: int,
|
216 |
+
batch_size: int,
|
217 |
+
num_workers: int,
|
218 |
+
):
|
219 |
+
super().__init__()
|
220 |
+
self.test_data = test_data
|
221 |
+
self.max_seq_length = max_seq_length
|
222 |
+
self.batch_size = batch_size
|
223 |
+
self.num_workers = num_workers
|
224 |
+
self.text_vocab_size = helpers.get_bert_vocab_size(vocab_path)
|
225 |
+
|
226 |
+
|
227 |
+
"""
|
228 |
+
def train_dataloader(self):
|
229 |
+
return DataLoader(
|
230 |
+
self.train,
|
231 |
+
batch_size=self.batch_size,
|
232 |
+
num_workers=self.num_workers,
|
233 |
+
collate_fn=get_collate_fn(self.text_vocab_size, self.max_seq_length),
|
234 |
+
)
|
235 |
+
|
236 |
+
def val_dataloader(self):
|
237 |
+
return DataLoader(
|
238 |
+
self.valid,
|
239 |
+
batch_size=self.batch_size,
|
240 |
+
num_workers=self.num_workers,
|
241 |
+
collate_fn=get_collate_fn(self.text_vocab_size, self.max_seq_length),
|
242 |
+
)
|
243 |
+
"""
|
244 |
+
def test_dataloader(self):
|
245 |
+
return DataLoader(
|
246 |
+
self.test_data,
|
247 |
+
batch_size=self.batch_size,
|
248 |
+
num_workers=self.num_workers,
|
249 |
+
collate_fn=get_collate_fn(self.text_vocab_size, self.max_seq_length),
|
250 |
+
)
|
251 |
+
|
252 |
+
|
253 |
+
if __name__ == "__main__":
|
254 |
+
dm = KobeDataModule(
|
255 |
+
train_data="saved/processed/train-*.tar",
|
256 |
+
valid_data="saved/processed/valid.tar",
|
257 |
+
test_data="saved/processed/test.tar",
|
258 |
+
vocab_path="bert-base-chinese",
|
259 |
+
max_seq_length=512,
|
260 |
+
batch_size=32,
|
261 |
+
num_workers=8,
|
262 |
+
)
|
263 |
+
dm.setup("test")
|
264 |
+
max_len = 0
|
265 |
+
from tqdm import tqdm
|
266 |
+
|
267 |
+
tqdm_iter = tqdm(dm.test_dataloader())
|
268 |
+
for batch in tqdm_iter:
|
269 |
+
max_len = max(max_len, batch.cond_title_fact_token_ids.shape[0])
|
270 |
+
max_len = max(max_len, batch.description_token_ids.shape[0])
|
271 |
+
tqdm_iter.set_description(f"max len = {max_len}")
|
helper.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sentencepiece as spm
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from transformers.models.bert.tokenization_bert import BertTokenizer
|
5 |
+
|
6 |
+
BASELINE = "baseline"
|
7 |
+
KOBE_ATTRIBUTE = "kobe-attr"
|
8 |
+
KOBE_KNOWLEDGE = "kobe-know"
|
9 |
+
KOBE_FULL = "kobe-full"
|
10 |
+
|
11 |
+
|
12 |
+
def get_bert_vocab_size(vocab_path: str) -> int:
|
13 |
+
tokenizer = BertTokenizer.from_pretrained(vocab_path)
|
14 |
+
return tokenizer.vocab_size
|
15 |
+
|
16 |
+
|
17 |
+
def get_vocab_size(vocab_path: str) -> int:
|
18 |
+
tokenizer = spm.SentencePieceProcessor()
|
19 |
+
tokenizer.Load(vocab_path)
|
20 |
+
return len(tokenizer)
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
# Metrics
|
25 |
+
def accuracy(logits: torch.Tensor, targets: torch.Tensor) -> float:
|
26 |
+
assert logits.dim() == 2
|
27 |
+
assert targets.dim() == 1
|
28 |
+
pred = logits.argmax(dim=1)
|
29 |
+
return (pred == targets).sum().item() / targets.shape[0]
|
30 |
+
|
31 |
+
|
32 |
+
def top_k_top_p_sampling(
|
33 |
+
logits, top_k=0, top_p=0.0, temperature=1, filter_value=-float("Inf")
|
34 |
+
) -> int:
|
35 |
+
"""Sample from a filtered distribution of logits using top-k and/or nucleus (top-p) filtering
|
36 |
+
Args:
|
37 |
+
logits: logits distribution shape (vocabulary size)
|
38 |
+
top_k >0: keep only top k tokens with highest probability (top-k filtering).
|
39 |
+
top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
40 |
+
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
41 |
+
"""
|
42 |
+
logits /= temperature
|
43 |
+
assert (
|
44 |
+
logits.dim() == 1
|
45 |
+
) # batch size 1 for now - could be updated for more but the code would be less clear
|
46 |
+
top_k = min(top_k, logits.size(-1)) # Safety check
|
47 |
+
if top_k > 0:
|
48 |
+
# Remove all tokens with a probability less than the last token of the top-k
|
49 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
50 |
+
logits[indices_to_remove] = filter_value
|
51 |
+
|
52 |
+
if top_p > 0.0:
|
53 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
54 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
55 |
+
|
56 |
+
# Remove tokens with cumulative probability above the threshold
|
57 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
58 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
59 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
60 |
+
sorted_indices_to_remove[..., 0] = 0
|
61 |
+
|
62 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
63 |
+
logits[indices_to_remove] = filter_value
|
64 |
+
|
65 |
+
# Sample from the filtered distribution
|
66 |
+
probabilities = F.softmax(logits, dim=-1)
|
67 |
+
next_token = torch.multinomial(probabilities, 1)
|
68 |
+
|
69 |
+
return int(next_token.item())
|
70 |
+
|
71 |
+
|
72 |
+
def diversity(tokenized_lines, n=4) -> int:
|
73 |
+
"""Defined as the unique number of ngrams generated on the test set."""
|
74 |
+
n_grams_all = []
|
75 |
+
for line in tokenized_lines:
|
76 |
+
n_grams = list(zip(*[line[i:] for i in range(n)]))
|
77 |
+
n_grams_all += n_grams
|
78 |
+
|
79 |
+
return len(set(n_grams_all))
|
model.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
import sentencepiece as spm
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
#from sacrebleu.metrics.bleu import BLEU, _get_tokenizer
|
10 |
+
from torch import optim
|
11 |
+
from torch.nn.init import xavier_uniform_
|
12 |
+
from transformers.models.bert.tokenization_bert import BertTokenizer
|
13 |
+
|
14 |
+
#import wandb
|
15 |
+
from dataset import Batched, DecodedBatch
|
16 |
+
#from models.scheduler import WarmupDecayLR
|
17 |
+
from transformer import Decoder, Encoder
|
18 |
+
#from kobe.utils import helpers
|
19 |
+
|
20 |
+
|
21 |
+
class KobeModel(pl.LightningModule):
|
22 |
+
def __init__(self, args):
|
23 |
+
super(KobeModel, self).__init__()
|
24 |
+
|
25 |
+
self.encoder = Encoder(
|
26 |
+
vocab_size=args.text_vocab_size + args.cond_vocab_size,
|
27 |
+
max_seq_len=args.max_seq_len,
|
28 |
+
d_model=args.d_model,
|
29 |
+
nhead=args.nhead,
|
30 |
+
num_layers=args.num_encoder_layers,
|
31 |
+
dropout=args.dropout,
|
32 |
+
mode=args.mode,
|
33 |
+
)
|
34 |
+
self.decoder = Decoder(
|
35 |
+
vocab_size=args.text_vocab_size,
|
36 |
+
max_seq_len=args.max_seq_len,
|
37 |
+
d_model=args.d_model,
|
38 |
+
nhead=args.nhead,
|
39 |
+
num_layers=args.num_decoder_layers,
|
40 |
+
dropout=args.dropout,
|
41 |
+
)
|
42 |
+
self.lr = args.lr
|
43 |
+
self.d_model = args.d_model
|
44 |
+
self.loss = nn.CrossEntropyLoss(
|
45 |
+
reduction="mean", ignore_index=0, label_smoothing=0.1
|
46 |
+
)
|
47 |
+
self._reset_parameters()
|
48 |
+
|
49 |
+
self.decoding_strategy = args.decoding_strategy
|
50 |
+
self.vocab = BertTokenizer.from_pretrained(args.text_vocab_path)
|
51 |
+
#self.bleu = BLEU(tokenize=args.tokenize)
|
52 |
+
#self.sacre_tokenizer = _get_tokenizer(args.tokenize)()
|
53 |
+
#self.bert_scorer = BERTScorer(lang=args.tokenize, rescale_with_baseline=True)
|
54 |
+
|
55 |
+
def _reset_parameters(self):
|
56 |
+
for p in self.parameters():
|
57 |
+
if p.dim() > 1:
|
58 |
+
xavier_uniform_(p)
|
59 |
+
|
60 |
+
def _tokenwise_loss_acc(
|
61 |
+
self, logits: torch.Tensor, batch: Batched
|
62 |
+
) -> Tuple[torch.Tensor, float]:
|
63 |
+
unmask = ~batch.description_token_ids_mask.T[1:]
|
64 |
+
unmasked_logits = logits[unmask]
|
65 |
+
unmasked_targets = batch.description_token_ids[1:][unmask]
|
66 |
+
#acc = helpers.accuracy(unmasked_logits, unmasked_targets)
|
67 |
+
return self.loss(logits.transpose(1, 2), batch.description_token_ids[1:]), 1
|
68 |
+
|
69 |
+
def training_step(self, batch: Batched, batch_idx: int):
|
70 |
+
encoded = self.encoder.forward(batch)
|
71 |
+
logits = self.decoder.forward(batch, encoded)
|
72 |
+
loss, acc = self._tokenwise_loss_acc(logits, batch)
|
73 |
+
self.lr_schedulers().step()
|
74 |
+
self.log("train/loss", loss.item())
|
75 |
+
self.log("train/acc", acc)
|
76 |
+
return loss
|
77 |
+
|
78 |
+
def _shared_eval_step(self, batch: Batched, batch_idx: int) -> DecodedBatch:
|
79 |
+
encoded = self.encoder.forward(batch)
|
80 |
+
#logits = self.decoder.forward(batch, encoded)
|
81 |
+
#loss, acc = self._tokenwise_loss_acc(logits, batch)
|
82 |
+
|
83 |
+
preds = self.decoder.predict(
|
84 |
+
encoded_batch=encoded, decoding_strategy=self.decoding_strategy
|
85 |
+
)
|
86 |
+
generated = self.vocab.batch_decode(preds.T.tolist(), skip_special_tokens=True)
|
87 |
+
#print(generated)
|
88 |
+
|
89 |
+
return generated
|
90 |
+
return DecodedBatch(
|
91 |
+
loss=loss.item(),
|
92 |
+
acc=acc,
|
93 |
+
generated=generated,
|
94 |
+
descriptions=batch.descriptions,
|
95 |
+
titles=batch.titles,
|
96 |
+
)
|
97 |
+
|
98 |
+
def validation_step(self, batch, batch_idx):
|
99 |
+
return self._shared_eval_step(batch, batch_idx)
|
100 |
+
|
101 |
+
def test_step(self, batch, batch_idx, dataloader_idx=0):
|
102 |
+
return self._shared_eval_step(batch, batch_idx)
|
103 |
+
|
104 |
+
def _shared_epoch_end(self, outputs: List[DecodedBatch], prefix):
|
105 |
+
loss = np.mean([o.loss for o in outputs])
|
106 |
+
acc = np.mean([o.acc for o in outputs])
|
107 |
+
self.log(f"{prefix}/loss", loss)
|
108 |
+
self.log(f"{prefix}/acc", acc)
|
109 |
+
print(outputs)
|
110 |
+
|
111 |
+
generated = [g for o in outputs for g in o.generated]
|
112 |
+
references = [r for o in outputs for r in o.descriptions]
|
113 |
+
titles = [r for o in outputs for r in o.titles]
|
114 |
+
|
115 |
+
|
116 |
+
# Examples
|
117 |
+
columns = ["Generated", "Reference"]
|
118 |
+
data = list(zip(generated[:256:16], references[:256:16]))
|
119 |
+
table = wandb.Table(data=data, columns=columns)
|
120 |
+
self.logger.experiment.log({f"examples/{prefix}": table})
|
121 |
+
|
122 |
+
def validation_epoch_end(self, outputs):
|
123 |
+
self._shared_epoch_end(outputs, "val")
|
124 |
+
|
125 |
+
def test_epoch_end(self, outputs):
|
126 |
+
self._shared_epoch_end(outputs, "test")
|
127 |
+
|
128 |
+
def configure_optimizers(self):
|
129 |
+
optimizer = optim.AdamW(self.parameters(), lr=self.lr, betas=(0.9, 0.98))
|
130 |
+
#scheduler = WarmupDecayLR(optimizer, warmup_steps=10000, d_model=self.d_model)
|
131 |
+
return [optimizer]
|
options.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import ArgumentParser, Namespace
|
2 |
+
|
3 |
+
import helpers
|
4 |
+
|
5 |
+
|
6 |
+
def add_options(parser: ArgumentParser):
|
7 |
+
# fmt: off
|
8 |
+
# Dataset
|
9 |
+
parser.add_argument("--train-data", default="saved/processed/train-*.tar", type=str)
|
10 |
+
parser.add_argument("--valid-data", default="saved/processed/valid.tar", type=str)
|
11 |
+
parser.add_argument("--test-data", default="saved/processed/test.tar", type=str)
|
12 |
+
parser.add_argument("--text-vocab-path", default="bert-base-chinese", type=str, help="BertTokenizer used to preprocess the corpus")
|
13 |
+
parser.add_argument("--cond-vocab-path", default="./vocab.cond.model", type=str)
|
14 |
+
parser.add_argument("--num-workers", default=8, help="Number of data loaders", type=int)
|
15 |
+
parser.add_argument("--tokenize", default="zh", help="Tokenization method used to compute sacrebleu, diversity, and BERTScore, defaulted to Chinese", type=str)
|
16 |
+
|
17 |
+
# Model
|
18 |
+
parser.add_argument("--d-model", default=512, type=int)
|
19 |
+
parser.add_argument("--nhead", default=8, type=int)
|
20 |
+
parser.add_argument("--num-encoder-layers", default=6, type=int)
|
21 |
+
parser.add_argument("--num-decoder-layers", default=6, type=int)
|
22 |
+
parser.add_argument("--max-seq-len", default=256, type=int)
|
23 |
+
parser.add_argument("--mode", default="baseline", type=str, choices=[
|
24 |
+
helpers.BASELINE, helpers.KOBE_ATTRIBUTE, helpers.KOBE_KNOWLEDGE, helpers.KOBE_FULL])
|
25 |
+
|
26 |
+
# Training
|
27 |
+
parser.add_argument("--name", default="exp", type=str, help="expeirment name")
|
28 |
+
parser.add_argument("--gpu", default=1, type=int)
|
29 |
+
parser.add_argument("--grad-clip", default=1.0, type=float, help="clip threshold of gradients")
|
30 |
+
parser.add_argument("--epochs", default=30, type=int, help="number of epochs to train")
|
31 |
+
parser.add_argument("--patience", default=10, type=int, help="early stopping patience")
|
32 |
+
parser.add_argument("--lr", default=1, type=float, help="learning rate")
|
33 |
+
parser.add_argument("--dropout", default=0.1, type=float, help="dropout rate")
|
34 |
+
parser.add_argument("--batch-size", default=64, type=int)
|
35 |
+
parser.add_argument("--seed", default=42, type=int)
|
36 |
+
|
37 |
+
# Evaluation
|
38 |
+
parser.add_argument("--test", action="store_true", help="only do evaluation")
|
39 |
+
parser.add_argument("--load-file", required=False, type=str, help="path to the checkpoint (.ckpt) for evaluation")
|
40 |
+
parser.add_argument("--decoding-strategy", default="greedy", type=str, choices=["greedy", "nucleus"], help="Whether to use greedy decoding or nucleus sampling (https://arxiv.org/abs/1904.09751)")
|
41 |
+
|
42 |
+
# fmt: on
|
43 |
+
|
44 |
+
|
45 |
+
def add_args(args: Namespace):
|
46 |
+
args.text_vocab_size = helpers.get_bert_vocab_size(args.text_vocab_path)
|
47 |
+
args.cond_vocab_size = helpers.get_vocab_size(args.cond_vocab_path)
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.10.0
|
2 |
+
transformers==4.25.1
|
3 |
+
sentencepiece
|
4 |
+
pytorch-lightning==1.6.4
|
5 |
+
|
6 |
+
|
transformer.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from cached_property import cached_property
|
7 |
+
from torch.nn.modules.transformer import (
|
8 |
+
TransformerDecoder,
|
9 |
+
TransformerDecoderLayer,
|
10 |
+
TransformerEncoder,
|
11 |
+
TransformerEncoderLayer,
|
12 |
+
)
|
13 |
+
|
14 |
+
from dataset import Batched, EncodedBatch
|
15 |
+
from vocab import BOS_ID, EOS_ID, PAD_ID
|
16 |
+
import helper
|
17 |
+
|
18 |
+
class PositionalEncoding(nn.Module):
|
19 |
+
def __init__(self, dropout, dim, max_len=5000):
|
20 |
+
"""
|
21 |
+
initialization of required variables and functions
|
22 |
+
:param dropout: dropout probability
|
23 |
+
:param dim: hidden size
|
24 |
+
:param max_len: maximum length
|
25 |
+
"""
|
26 |
+
super(PositionalEncoding, self).__init__()
|
27 |
+
# positional encoding initialization
|
28 |
+
pe = torch.zeros(max_len, dim)
|
29 |
+
position = torch.arange(0, max_len).unsqueeze(1)
|
30 |
+
# term to divide
|
31 |
+
div_term = torch.exp(
|
32 |
+
(torch.arange(0, dim, 2, dtype=torch.float) * -(math.log(10000.0) / dim))
|
33 |
+
)
|
34 |
+
# sinusoidal positional encoding
|
35 |
+
pe[:, 0::2] = torch.sin(position.float() * div_term)
|
36 |
+
pe[:, 1::2] = torch.cos(position.float() * div_term)
|
37 |
+
pe = pe.unsqueeze(1)
|
38 |
+
self.register_buffer("pe", pe)
|
39 |
+
self.dropout = nn.Dropout(p=dropout)
|
40 |
+
self.dim = dim
|
41 |
+
|
42 |
+
def forward(self, emb):
|
43 |
+
"""
|
44 |
+
create positional encoding
|
45 |
+
:param emb: word embedding
|
46 |
+
:param step: step for decoding in inference
|
47 |
+
:return: positional encoding representation
|
48 |
+
"""
|
49 |
+
emb *= math.sqrt(self.dim)
|
50 |
+
emb = emb + self.pe[: emb.size(0)] # [len, batch, size]
|
51 |
+
emb = self.dropout(emb)
|
52 |
+
return emb
|
53 |
+
|
54 |
+
|
55 |
+
class Encoder(nn.Module):
|
56 |
+
@staticmethod
|
57 |
+
def from_args(args) -> "Encoder":
|
58 |
+
return Encoder(
|
59 |
+
args.text_vocab_size + args.cond_vocab_size,
|
60 |
+
args.max_seq_len,
|
61 |
+
args.d_model,
|
62 |
+
args.nhead,
|
63 |
+
args.num_encoder_layers,
|
64 |
+
args.dropout,
|
65 |
+
args.mode,
|
66 |
+
)
|
67 |
+
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
vocab_size: int,
|
71 |
+
max_seq_len: int,
|
72 |
+
d_model: int,
|
73 |
+
nhead: int,
|
74 |
+
num_layers: int,
|
75 |
+
dropout: float,
|
76 |
+
mode: str,
|
77 |
+
):
|
78 |
+
super().__init__()
|
79 |
+
self.d_model = d_model
|
80 |
+
self.max_seq_len = max_seq_len
|
81 |
+
self.input_embedding = nn.Embedding(vocab_size, d_model)
|
82 |
+
self.pos_encoder = PositionalEncoding(dropout, d_model)
|
83 |
+
encoder_layer = TransformerEncoderLayer(
|
84 |
+
d_model, nhead, d_model * 4, dropout, norm_first=True
|
85 |
+
)
|
86 |
+
self.encoder = TransformerEncoder(
|
87 |
+
encoder_layer, num_layers, nn.LayerNorm(d_model)
|
88 |
+
)
|
89 |
+
self.mode = mode
|
90 |
+
|
91 |
+
@cached_property
|
92 |
+
def device(self):
|
93 |
+
return list(self.parameters())[0].device
|
94 |
+
|
95 |
+
def forward(self, batched: Batched) -> EncodedBatch:
|
96 |
+
src, src_key_padding_mask = Encoder._get_input(batched, self.mode)
|
97 |
+
src = self.input_embedding(src)
|
98 |
+
src = self.pos_encoder(src)
|
99 |
+
token_encodings = self.encoder.forward(
|
100 |
+
src=src, src_key_padding_mask=src_key_padding_mask
|
101 |
+
)
|
102 |
+
return EncodedBatch(
|
103 |
+
context_encodings=token_encodings,
|
104 |
+
context_encodings_mask=src_key_padding_mask,
|
105 |
+
)
|
106 |
+
|
107 |
+
@staticmethod
|
108 |
+
def _get_input(batched: Batched, mode: str) -> Tuple[torch.Tensor, torch.Tensor]:
|
109 |
+
return {
|
110 |
+
helpers.BASELINE: (batched.title_token_ids, batched.title_token_ids_mask),
|
111 |
+
helpers.KOBE_ATTRIBUTE: (
|
112 |
+
batched.cond_title_token_ids,
|
113 |
+
batched.cond_title_token_ids_mask,
|
114 |
+
),
|
115 |
+
helpers.KOBE_KNOWLEDGE: (
|
116 |
+
batched.title_fact_token_ids,
|
117 |
+
batched.title_fact_token_ids_mask,
|
118 |
+
),
|
119 |
+
helpers.KOBE_FULL: (
|
120 |
+
batched.cond_title_fact_token_ids,
|
121 |
+
batched.cond_title_fact_token_ids_mask,
|
122 |
+
),
|
123 |
+
}[mode]
|
124 |
+
|
125 |
+
|
126 |
+
class Decoder(nn.Module):
|
127 |
+
@staticmethod
|
128 |
+
def from_args(args) -> "Decoder":
|
129 |
+
return Decoder(
|
130 |
+
args.text_vocab_size,
|
131 |
+
args.max_seq_len,
|
132 |
+
args.d_model,
|
133 |
+
args.nhead,
|
134 |
+
args.num_encoder_layers,
|
135 |
+
args.dropout,
|
136 |
+
)
|
137 |
+
|
138 |
+
def __init__(
|
139 |
+
self,
|
140 |
+
vocab_size: int,
|
141 |
+
max_seq_len: int,
|
142 |
+
d_model: int,
|
143 |
+
nhead: int,
|
144 |
+
num_layers: int,
|
145 |
+
dropout: float,
|
146 |
+
):
|
147 |
+
super(Decoder, self).__init__()
|
148 |
+
self.max_seq_len = max_seq_len
|
149 |
+
self.embedding = nn.Embedding(vocab_size, d_model)
|
150 |
+
self.pos_encoder = PositionalEncoding(dropout, d_model)
|
151 |
+
decoder_layer = TransformerDecoderLayer(
|
152 |
+
d_model, nhead, 4 * d_model, dropout, norm_first=True
|
153 |
+
)
|
154 |
+
self.decoder = TransformerDecoder(
|
155 |
+
decoder_layer, num_layers, nn.LayerNorm(d_model)
|
156 |
+
)
|
157 |
+
self.output = nn.Linear(d_model, vocab_size)
|
158 |
+
|
159 |
+
def forward(self, batch: Batched, encoded_batch: EncodedBatch) -> torch.Tensor:
|
160 |
+
tgt = self.embedding(batch.description_token_ids[:-1])
|
161 |
+
tgt = self.pos_encoder(tgt)
|
162 |
+
tgt_mask = Decoder.generate_square_subsequent_mask(tgt.shape[0], tgt.device)
|
163 |
+
outputs = self.decoder(
|
164 |
+
tgt=tgt,
|
165 |
+
tgt_mask=tgt_mask,
|
166 |
+
tgt_key_padding_mask=batch.description_token_ids_mask[:, :-1],
|
167 |
+
memory=encoded_batch.context_encodings,
|
168 |
+
memory_key_padding_mask=encoded_batch.context_encodings_mask,
|
169 |
+
)
|
170 |
+
return self.output(outputs)
|
171 |
+
|
172 |
+
def predict(self, encoded_batch: EncodedBatch, decoding_strategy: str):
|
173 |
+
batch_size = encoded_batch.context_encodings.shape[1]
|
174 |
+
tgt = torch.tensor(
|
175 |
+
[BOS_ID] * batch_size, device=encoded_batch.context_encodings.device
|
176 |
+
).unsqueeze(dim=0)
|
177 |
+
tgt_mask = Decoder.generate_square_subsequent_mask(self.max_seq_len, tgt.device)
|
178 |
+
pred_all = []
|
179 |
+
for idx in range(self.max_seq_len):
|
180 |
+
tgt_emb = self.pos_encoder(self.embedding(tgt))
|
181 |
+
outputs = self.decoder(
|
182 |
+
tgt_emb,
|
183 |
+
tgt_mask=tgt_mask[: idx + 1, : idx + 1],
|
184 |
+
memory=encoded_batch.context_encodings,
|
185 |
+
memory_key_padding_mask=encoded_batch.context_encodings_mask,
|
186 |
+
)
|
187 |
+
logits = self.output(outputs[-1])
|
188 |
+
|
189 |
+
if decoding_strategy == "greedy":
|
190 |
+
pred_step = logits.argmax(dim=1).tolist()
|
191 |
+
elif decoding_strategy == "nucleus":
|
192 |
+
pred_step = [
|
193 |
+
helpers.top_k_top_p_sampling(logits[i], top_p=0.95)
|
194 |
+
for i in range(batch_size)
|
195 |
+
]
|
196 |
+
else:
|
197 |
+
raise NotImplementedError
|
198 |
+
for b in range(batch_size):
|
199 |
+
if pred_all and pred_all[-1][b].item() in [EOS_ID, PAD_ID]:
|
200 |
+
pred_step[b] = PAD_ID
|
201 |
+
if all([pred == PAD_ID for pred in pred_step]):
|
202 |
+
break
|
203 |
+
pred_step = torch.tensor(pred_step, device=tgt.device)
|
204 |
+
pred_all.append(pred_step)
|
205 |
+
|
206 |
+
if idx < self.max_seq_len - 1:
|
207 |
+
tgt_step = pred_step.unsqueeze(dim=0)
|
208 |
+
tgt = torch.cat([tgt, tgt_step], dim=0)
|
209 |
+
|
210 |
+
preds = torch.stack(pred_all)
|
211 |
+
return preds
|
212 |
+
|
213 |
+
@staticmethod
|
214 |
+
def generate_square_subsequent_mask(sz: int, device: torch.device) -> torch.Tensor:
|
215 |
+
r"""
|
216 |
+
Generate a square mask for the sequence. The masked positions are filled with
|
217 |
+
float('-inf').
|
218 |
+
Unmasked positions are filled with float(0.0).
|
219 |
+
"""
|
220 |
+
return torch.triu(
|
221 |
+
torch.full((sz, sz), float("-inf"), device=device), diagonal=1
|
222 |
+
)
|
vocab.cond.vocab
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<pad> 0
|
2 |
+
<s> 0
|
3 |
+
</s> 0
|
4 |
+
<unk> 0
|
5 |
+
▁<a> -1.261
|
6 |
+
▁<c> -1.82603
|
7 |
+
<3> -2.17158
|
8 |
+
<0> -2.26491
|
9 |
+
<1> -2.34
|
10 |
+
<2> -2.36126
|
11 |
+
▁<b> -2.89
|
12 |
+
<4> -3.31157
|
13 |
+
<5> -3.54753
|
14 |
+
<6> -5.02554
|
15 |
+
<7> -5.3972
|
16 |
+
<11> -5.51923
|
17 |
+
<8> -6.03597
|
18 |
+
<9> -6.14342
|
19 |
+
<10> -6.17248
|
20 |
+
<13> -6.28137
|
21 |
+
<12> -6.84099
|
22 |
+
<17> -7.69198
|
23 |
+
<14> -7.72356
|
24 |
+
<15> -8.15065
|
25 |
+
<20> -9.37115
|
26 |
+
<18> -9.52068
|
27 |
+
<19> -9.52068
|
28 |
+
<16> -9.55347
|
29 |
+
<23> -9.58737
|
30 |
+
<25> -9.95894
|
31 |
+
<24> -12.9547
|
vocab.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tempfile
|
2 |
+
from argparse import ArgumentParser
|
3 |
+
|
4 |
+
import sentencepiece as spm
|
5 |
+
from transformers.models.bert.tokenization_bert import BertTokenizer
|
6 |
+
|
7 |
+
# Load the text tokenizer
|
8 |
+
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
|
9 |
+
|
10 |
+
BOS_TOKEN = tokenizer.cls_token
|
11 |
+
EOS_TOKEN = tokenizer.sep_token
|
12 |
+
UNK_TOKEN = tokenizer.unk_token
|
13 |
+
PAD_ID = tokenizer.pad_token_id
|
14 |
+
BOS_ID = tokenizer.cls_token_id
|
15 |
+
EOS_ID = tokenizer.sep_token_id
|
16 |
+
UNK_ID = tokenizer.unk_token_id
|
17 |
+
|
18 |
+
# Build the condition (attribute) tokenizer
|
19 |
+
if __name__ == "__main__":
|
20 |
+
parser = ArgumentParser()
|
21 |
+
# fmt: off
|
22 |
+
parser.add_argument("--input", nargs="+", required=True)
|
23 |
+
parser.add_argument("--vocab-file", type=str, required=True)
|
24 |
+
parser.add_argument("--vocab-size", type=int, default=31)
|
25 |
+
parser.add_argument("--algo", type=str, default="bpe", choices=["bpe", "word"])
|
26 |
+
# fmt: on
|
27 |
+
args = parser.parse_args()
|
28 |
+
print("Building token vocabulary")
|
29 |
+
with tempfile.NamedTemporaryFile("w") as f:
|
30 |
+
# concatenate input files
|
31 |
+
for input_fname in args.input:
|
32 |
+
with open(input_fname) as input_f:
|
33 |
+
f.write(input_f.read() + "\n")
|
34 |
+
# run sentence piece with bpe
|
35 |
+
spm.SentencePieceTrainer.Train(
|
36 |
+
f"--add_dummy_prefix=false --pad_id=0 --bos_id=1 --eos_id=2 --unk_id=3 "
|
37 |
+
f"--vocab_size={args.vocab_size} "
|
38 |
+
f"--model_prefix={args.vocab_file} --model_type={args.algo} "
|
39 |
+
f"--input={f.name}"
|
40 |
+
)
|
41 |
+
~
|