qixun commited on
Commit
b09569b
·
verified ·
1 Parent(s): e12e88c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +42 -3
README.md CHANGED
@@ -1,3 +1,42 @@
1
- ---
2
- license: gpl-3.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: gpl-3.0
3
+
4
+ widget:
5
+ - text: "宵凉百念集孤[MASK],暗雨鸣廊睡未能。生计坐怜秋一叶,归程冥想浪千层。寒心国事浑难料,堆眼官资信可憎。此去梦中应不忘,顺承门内近觚棱。"
6
+ ---
7
+ 适用于中国古典诗歌的bert模型,在搜韵开源的语料上以16的batch_size训练了110万步左右,loss稳定低于1。
8
+
9
+ 使用方法如下:
10
+
11
+ ```python
12
+ from transformers import BertTokenizer, BertForMaskedLM
13
+ import torch
14
+
15
+ # 加载分词器
16
+ tokenizer = BertTokenizer.from_pretrained("qixun/bert-chinese-poem")
17
+
18
+ # 加载模型
19
+ model = BertForMaskedLM.from_pretrained("qixun/bert-chinese-poem")
20
+
21
+ # 输入文本
22
+ text = "宵凉百念集孤[MASK],暗雨鸣廊睡未能。生计坐怜秋一叶,归程冥想浪千层。寒心国事浑难料,堆眼官资信可憎。此去梦中应不忘,顺承门内近觚棱。"
23
+
24
+ # 分词
25
+ inputs = tokenizer(text, return_tensors="pt")
26
+
27
+ # 模型推理
28
+ with torch.no_grad():
29
+ outputs = model(**inputs)
30
+
31
+ # 获取[MASK]标记的位置
32
+ mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
33
+
34
+ # 获取预测的token_id
35
+ predicted_token_id = outputs.logits[0, mask_token_index].argmax(axis=-1).item()
36
+
37
+ # 获取预测的词
38
+ predicted_token = tokenizer.decode([predicted_token_id])
39
+
40
+ print(f"预测的词是:{predicted_token}")
41
+
42
+ ```