shibing624
commited on
Commit
·
9820460
1
Parent(s):
06088e1
Update README.md
Browse files
README.md
CHANGED
@@ -40,12 +40,14 @@ print(i)
|
|
40 |
import operator
|
41 |
import torch
|
42 |
from transformers import BertTokenizer, BertForMaskedLM
|
|
|
43 |
|
44 |
tokenizer = BertTokenizer.from_pretrained("shibing624/macbert4csc-base-chinese")
|
45 |
model = BertForMaskedLM.from_pretrained("shibing624/macbert4csc-base-chinese")
|
|
|
46 |
|
47 |
texts = ["今天新情很好", "你找到你最喜欢的工作,我也很高心。"]
|
48 |
-
outputs = model(**tokenizer(texts, padding=True, return_tensors='pt'))
|
49 |
|
50 |
def get_errors(corrected_text, origin_text):
|
51 |
details = []
|
|
|
40 |
import operator
|
41 |
import torch
|
42 |
from transformers import BertTokenizer, BertForMaskedLM
|
43 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
44 |
|
45 |
tokenizer = BertTokenizer.from_pretrained("shibing624/macbert4csc-base-chinese")
|
46 |
model = BertForMaskedLM.from_pretrained("shibing624/macbert4csc-base-chinese")
|
47 |
+
model = model.to(device)
|
48 |
|
49 |
texts = ["今天新情很好", "你找到你最喜欢的工作,我也很高心。"]
|
50 |
+
outputs = model(**tokenizer(texts, padding=True, return_tensors='pt').to(device))
|
51 |
|
52 |
def get_errors(corrected_text, origin_text):
|
53 |
details = []
|