Spaces:
Runtime error
Runtime error
Joe Davison
commited on
Commit
·
85dd546
1
Parent(s):
039194f
fix model caching error
Browse files
app.py
CHANGED
@@ -13,6 +13,9 @@ import psutil
|
|
13 |
with open("hit_log.txt", mode='a') as file:
|
14 |
file.write(str(datetime.datetime.now()) + '\n')
|
15 |
|
|
|
|
|
|
|
16 |
MODEL_DESC = {
|
17 |
'Bart MNLI': """Bart with a classification head trained on MNLI.\n\nSequences are posed as NLI premises and topic labels are turned into premises, i.e. `business` -> `This text is about business.`""",
|
18 |
'Bart MNLI + Yahoo Answers': """Bart with a classification head trained on MNLI and then further fine-tuned on Yahoo Answers topic classification.\n\nSequences are posed as NLI premises and topic labels are turned into premises, i.e. `business` -> `This text is about business.`""",
|
@@ -58,10 +61,22 @@ models = load_models()
|
|
58 |
def load_tokenizer(tok_id):
|
59 |
return AutoTokenizer.from_pretrained(tok_id)
|
60 |
|
61 |
-
@st.cache(allow_output_mutation=True, show_spinner=False
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
return outputs['labels'], outputs['scores']
|
66 |
|
67 |
def load_examples(model_id):
|
@@ -88,7 +103,6 @@ def plot_result(top_topics, scores):
|
|
88 |
fig.update_traces(texttemplate='%{text:0.1f}%', textposition='outside')
|
89 |
st.plotly_chart(fig)
|
90 |
|
91 |
-
|
92 |
|
93 |
def main():
|
94 |
with open("style.css") as f:
|
@@ -124,18 +138,11 @@ def main():
|
|
124 |
st.markdown(CODE_DESC.format(model_id))
|
125 |
|
126 |
with st.spinner('Classifying...'):
|
127 |
-
top_topics, scores = get_most_likely(model_id, sequence, labels, hypothesis_template, multi_class
|
128 |
-
|
129 |
-
plot_result(top_topics[::-1][-10:], scores[::-1][-10:])
|
130 |
-
|
131 |
-
if "socat" not in [p.name() for p in psutil.process_iter()]:
|
132 |
-
os.system('socat tcp-listen:8000,reuseaddr,fork tcp:localhost:8001 &')
|
133 |
-
|
134 |
-
|
135 |
|
|
|
136 |
|
137 |
|
138 |
|
139 |
if __name__ == '__main__':
|
140 |
main()
|
141 |
-
|
|
|
13 |
with open("hit_log.txt", mode='a') as file:
|
14 |
file.write(str(datetime.datetime.now()) + '\n')
|
15 |
|
16 |
+
|
17 |
+
MAX_GRAPH_ROWS = 10
|
18 |
+
|
19 |
MODEL_DESC = {
|
20 |
'Bart MNLI': """Bart with a classification head trained on MNLI.\n\nSequences are posed as NLI premises and topic labels are turned into premises, i.e. `business` -> `This text is about business.`""",
|
21 |
'Bart MNLI + Yahoo Answers': """Bart with a classification head trained on MNLI and then further fine-tuned on Yahoo Answers topic classification.\n\nSequences are posed as NLI premises and topic labels are turned into premises, i.e. `business` -> `This text is about business.`""",
|
|
|
61 |
def load_tokenizer(tok_id):
|
62 |
return AutoTokenizer.from_pretrained(tok_id)
|
63 |
|
64 |
+
@st.cache(allow_output_mutation=True, show_spinner=False, hash_funcs={
|
65 |
+
torch.nn.Parameter: lambda _: None
|
66 |
+
})
|
67 |
+
def get_most_likely(nli_model_id, sequence, labels, hypothesis_template, multi_class):
|
68 |
+
classifier = pipeline(
|
69 |
+
'zero-shot-classification',
|
70 |
+
model=models[nli_model_id],
|
71 |
+
tokenizer=load_tokenizer(nli_model_id),
|
72 |
+
device=device
|
73 |
+
)
|
74 |
+
outputs = classifier(
|
75 |
+
sequence,
|
76 |
+
candidate_labels=labels,
|
77 |
+
hypothesis_template=hypothesis_template,
|
78 |
+
multi_label=multi_class
|
79 |
+
)
|
80 |
return outputs['labels'], outputs['scores']
|
81 |
|
82 |
def load_examples(model_id):
|
|
|
103 |
fig.update_traces(texttemplate='%{text:0.1f}%', textposition='outside')
|
104 |
st.plotly_chart(fig)
|
105 |
|
|
|
106 |
|
107 |
def main():
|
108 |
with open("style.css") as f:
|
|
|
138 |
st.markdown(CODE_DESC.format(model_id))
|
139 |
|
140 |
with st.spinner('Classifying...'):
|
141 |
+
top_topics, scores = get_most_likely(model_id, sequence, labels, hypothesis_template, multi_class)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
+
plot_result(top_topics[::-1][-MAX_GRAPH_ROWS:], scores[::-1][-MAX_GRAPH_ROWS:])
|
144 |
|
145 |
|
146 |
|
147 |
if __name__ == '__main__':
|
148 |
main()
|
|