Spaces:
Runtime error
Runtime error
Commit
·
8b850ac
1
Parent(s):
74abd6a
[ERCBCM] Optimize the model interfaces and prints.
Browse files- ercbcm/ERCBCM.py +0 -1
- ercbcm/__init__.py +2 -3
- ercbcm/model_loader.py +4 -31
ercbcm/ERCBCM.py
CHANGED
|
@@ -6,7 +6,6 @@ class ERCBCM(nn.Module):
|
|
| 6 |
def __init__(self):
|
| 7 |
super(ERCBCM, self).__init__()
|
| 8 |
print('>>> ERCBCM Init!')
|
| 9 |
-
|
| 10 |
self.bert_base = BertForSequenceClassification.from_pretrained('bert-base-uncased')
|
| 11 |
|
| 12 |
def forward(self, text, label):
|
|
|
|
| 6 |
def __init__(self):
|
| 7 |
super(ERCBCM, self).__init__()
|
| 8 |
print('>>> ERCBCM Init!')
|
|
|
|
| 9 |
self.bert_base = BertForSequenceClassification.from_pretrained('bert-base-uncased')
|
| 10 |
|
| 11 |
def forward(self, text, label):
|
ercbcm/__init__.py
CHANGED
|
@@ -7,17 +7,16 @@ sys.path.insert(0, myPath + '/../')
|
|
| 7 |
|
| 8 |
import torch
|
| 9 |
|
| 10 |
-
from ercbcm.model_loader import
|
| 11 |
from ercbcm.ERCBCM import ERCBCM
|
| 12 |
from modules.tokenizer import tokenizer, normalize_v2, PAD_TOKEN_ID
|
| 13 |
|
| 14 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 15 |
-
print('>>> GPU Available?', torch.cuda.is_available())
|
| 16 |
|
| 17 |
# ==========
|
| 18 |
|
| 19 |
model_for_predict = ERCBCM().to(device)
|
| 20 |
-
|
| 21 |
|
| 22 |
def predict(sentence, name):
|
| 23 |
label = torch.tensor([0])
|
|
|
|
| 7 |
|
| 8 |
import torch
|
| 9 |
|
| 10 |
+
from ercbcm.model_loader import load
|
| 11 |
from ercbcm.ERCBCM import ERCBCM
|
| 12 |
from modules.tokenizer import tokenizer, normalize_v2, PAD_TOKEN_ID
|
| 13 |
|
| 14 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
| 15 |
|
| 16 |
# ==========
|
| 17 |
|
| 18 |
model_for_predict = ERCBCM().to(device)
|
| 19 |
+
load('ercbcm/model.pt', model_for_predict, device)
|
| 20 |
|
| 21 |
def predict(sentence, name):
|
| 22 |
label = torch.tensor([0])
|
ercbcm/model_loader.py
CHANGED
|
@@ -1,35 +1,8 @@
|
|
| 1 |
import torch
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
def save_checkpoint(save_path, model, valid_loss):
|
| 6 |
-
if save_path == None:
|
| 7 |
-
return
|
| 8 |
-
state_dict = {'model_state_dict': model.state_dict(),
|
| 9 |
-
'valid_loss': valid_loss}
|
| 10 |
-
torch.save(state_dict, save_path)
|
| 11 |
-
print('[SAVE] Model has been saved successfully to \'{}\''.format(save_path))
|
| 12 |
-
|
| 13 |
-
def load_checkpoint(load_path, model, device):
|
| 14 |
-
if load_path == None:
|
| 15 |
-
return
|
| 16 |
state_dict = torch.load(load_path, map_location=device)
|
| 17 |
-
print('[LOAD] Model has been loaded successfully from \'{}\''.format(load_path))
|
| 18 |
model.load_state_dict(state_dict['model_state_dict'])
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
def save_metrics(save_path, train_loss_list, valid_loss_list, global_steps_list):
|
| 22 |
-
if save_path == None:
|
| 23 |
-
return
|
| 24 |
-
state_dict = {'train_loss_list': train_loss_list,
|
| 25 |
-
'valid_loss_list': valid_loss_list,
|
| 26 |
-
'global_steps_list': global_steps_list}
|
| 27 |
-
torch.save(state_dict, save_path)
|
| 28 |
-
print('[SAVE] Model with matrics has been saved successfully to \'{}\''.format(save_path))
|
| 29 |
-
|
| 30 |
-
def load_metrics(load_path, device):
|
| 31 |
-
if load_path == None:
|
| 32 |
-
return
|
| 33 |
-
state_dict = torch.load(load_path, map_location=device)
|
| 34 |
-
print('[LOAD] Model with matrics has been loaded successfully from \'{}\''.format(load_path))
|
| 35 |
-
return state_dict['train_loss_list'], state_dict['valid_loss_list'], state_dict['global_steps_list']
|
|
|
|
| 1 |
import torch
|
| 2 |
|
| 3 |
+
def load(load_path, model, device):
|
| 4 |
+
if load_path == None: return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
state_dict = torch.load(load_path, map_location=device)
|
|
|
|
| 6 |
model.load_state_dict(state_dict['model_state_dict'])
|
| 7 |
+
print('[LOAD] Model has been loaded successfully from \'{}\''.format(load_path))
|
| 8 |
+
return state_dict['valid_loss']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|