cache bert models (extractive sum)
Browse files
extractive_summarizer/bert_parent.py
CHANGED
@@ -1,13 +1,18 @@
|
|
1 |
from typing import List, Union
|
2 |
|
3 |
-
import numpy as np
|
4 |
import torch
|
|
|
|
|
5 |
from numpy import ndarray
|
6 |
from transformers import (AlbertModel, AlbertTokenizer, BertModel,
|
7 |
BertTokenizer, DistilBertModel, DistilBertTokenizer,
|
8 |
PreTrainedModel, PreTrainedTokenizer, XLMModel,
|
9 |
XLMTokenizer, XLNetModel, XLNetTokenizer)
|
10 |
|
|
|
|
|
|
|
|
|
11 |
|
12 |
class BertParent(object):
|
13 |
"""
|
@@ -49,8 +54,9 @@ class BertParent(object):
|
|
49 |
if custom_model:
|
50 |
self.model = custom_model.to(self.device)
|
51 |
else:
|
52 |
-
self.model = base_model.from_pretrained(
|
53 |
-
model, output_hidden_states=True).to(self.device)
|
|
|
54 |
|
55 |
if custom_tokenizer:
|
56 |
self.tokenizer = custom_tokenizer
|
@@ -59,6 +65,7 @@ class BertParent(object):
|
|
59 |
|
60 |
self.model.eval()
|
61 |
|
|
|
62 |
def tokenize_input(self, text: str) -> torch.tensor:
|
63 |
"""
|
64 |
Tokenizes the text input.
|
|
|
1 |
from typing import List, Union
|
2 |
|
|
|
3 |
import torch
|
4 |
+
import streamlit as st
|
5 |
+
import numpy as np
|
6 |
from numpy import ndarray
|
7 |
from transformers import (AlbertModel, AlbertTokenizer, BertModel,
|
8 |
BertTokenizer, DistilBertModel, DistilBertTokenizer,
|
9 |
PreTrainedModel, PreTrainedTokenizer, XLMModel,
|
10 |
XLMTokenizer, XLNetModel, XLNetTokenizer)
|
11 |
|
12 |
+
@st.cache()
|
13 |
+
def load_hf_model(base_model, model_name, device):
|
14 |
+
model = base_model.from_pretrained(model_name, output_hidden_states=True).to(device)
|
15 |
+
return model
|
16 |
|
17 |
class BertParent(object):
|
18 |
"""
|
|
|
54 |
if custom_model:
|
55 |
self.model = custom_model.to(self.device)
|
56 |
else:
|
57 |
+
# self.model = base_model.from_pretrained(
|
58 |
+
# model, output_hidden_states=True).to(self.device)
|
59 |
+
self.model = load_hf_model(base_model, model, self.device)
|
60 |
|
61 |
if custom_tokenizer:
|
62 |
self.tokenizer = custom_tokenizer
|
|
|
65 |
|
66 |
self.model.eval()
|
67 |
|
68 |
+
|
69 |
def tokenize_input(self, text: str) -> torch.tensor:
|
70 |
"""
|
71 |
Tokenizes the text input.
|