ken11
commited on
Commit
•
0b768a8
1
Parent(s):
8012371
add TF
Browse files- README.md +29 -0
- tf_model.h5 +3 -0
README.md
CHANGED
@@ -24,6 +24,7 @@ widget:
|
|
24 |
### Fill-Mask
|
25 |
このモデルではTokenizerにSentencepieceを利用しています
|
26 |
そのままでは`[MASK]`トークンのあとに[余計なトークンが混入する問題](https://ken11.jp/blog/sentencepiece-tokenizer-bug)があるので、利用する際には以下のようにする必要があります
|
|
|
27 |
```py
|
28 |
from transformers import (
|
29 |
AlbertForMaskedLM, AlbertTokenizerFast
|
@@ -51,6 +52,34 @@ print(tokenizer.convert_ids_to_tokens(result.tolist()))
|
|
51 |
# ['英語', '心理学', '数学', '医学', '日本語']
|
52 |
```
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
## Training Data
|
55 |
学習には
|
56 |
- [日本語Wikipediaの全文](https://ja.wikipedia.org/wiki/Wikipedia:%E3%83%87%E3%83%BC%E3%82%BF%E3%83%99%E3%83%BC%E3%82%B9%E3%83%80%E3%82%A6%E3%83%B3%E3%83%AD%E3%83%BC%E3%83%89)
|
|
|
24 |
### Fill-Mask
|
25 |
このモデルではTokenizerにSentencepieceを利用しています
|
26 |
そのままでは`[MASK]`トークンのあとに[余計なトークンが混入する問題](https://ken11.jp/blog/sentencepiece-tokenizer-bug)があるので、利用する際には以下のようにする必要があります
|
27 |
+
#### for PyTorch
|
28 |
```py
|
29 |
from transformers import (
|
30 |
AlbertForMaskedLM, AlbertTokenizerFast
|
|
|
52 |
# ['英語', '心理学', '数学', '医学', '日本語']
|
53 |
```
|
54 |
|
55 |
+
#### for TensorFlow
|
56 |
+
```py
|
57 |
+
from transformers import (
|
58 |
+
TFAlbertForMaskedLM, AlbertTokenizerFast
|
59 |
+
)
|
60 |
+
import tensorflow as tf
|
61 |
+
|
62 |
+
|
63 |
+
tokenizer = AlbertTokenizerFast.from_pretrained("ken11/albert-base-japanese-v1")
|
64 |
+
model = TFAlbertForMaskedLM.from_pretrained("ken11/albert-base-japanese-v1")
|
65 |
+
|
66 |
+
text = "大学で[MASK]の研究をしています"
|
67 |
+
tokenized_text = tokenizer.tokenize(text)
|
68 |
+
del tokenized_text[tokenized_text.index(tokenizer.mask_token) + 1]
|
69 |
+
|
70 |
+
input_ids = [tokenizer.cls_token_id]
|
71 |
+
input_ids.extend(tokenizer.convert_tokens_to_ids(tokenized_text))
|
72 |
+
input_ids.append(tokenizer.sep_token_id)
|
73 |
+
|
74 |
+
inputs = {"input_ids": [input_ids], "token_type_ids": [[0]*len(input_ids)], "attention_mask": [[1]*len(input_ids)]}
|
75 |
+
batch = {k: tf.convert_to_tensor(v, dtype=tf.int32) for k, v in inputs.items()}
|
76 |
+
output = model(**batch)[0]
|
77 |
+
result = tf.math.top_k(output[0, input_ids.index(tokenizer.mask_token_id)], k=5)
|
78 |
+
|
79 |
+
print(tokenizer.convert_ids_to_tokens(result.indices.numpy()))
|
80 |
+
# ['英語', '心理学', '数学', '医学', '日本語']
|
81 |
+
```
|
82 |
+
|
83 |
## Training Data
|
84 |
学習には
|
85 |
- [日本語Wikipediaの全文](https://ja.wikipedia.org/wiki/Wikipedia:%E3%83%87%E3%83%BC%E3%82%BF%E3%83%99%E3%83%BC%E3%82%B9%E3%83%80%E3%82%A6%E3%83%B3%E3%83%AD%E3%83%BC%E3%83%89)
|
tf_model.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:675280c441280bcfc97f943903cdbf69daad6823151346f3f785a440c15e69f6
|
3 |
+
size 65122008
|