Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 6,487 Bytes
a922691 e690399 a922691 e690399 a922691 85dd546 a922691 e690399 a922691 a8d963e a922691 e690399 a922691 e690399 a922691 e690399 a922691 a8d963e a922691 e690399 a922691 a8d963e a922691 e690399 a922691 a8d963e 85dd546 e690399 a922691 e690399 a922691 e690399 a922691 e690399 a922691 e690399 a922691 85dd546 a922691 85dd546 a922691 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import streamlit as st
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
import torch
import numpy as np
import contextlib
import plotly.express as px
import pandas as pd
from PIL import Image
import datetime
import os
import psutil
with open("hit_log.txt", mode='a') as file:
file.write(str(datetime.datetime.now()) + '\n')
MAX_GRAPH_ROWS = 10
MODEL_DESC = {
'Bart MNLI': """Bart with a classification head trained on MNLI.\n\nSequences are posed as NLI premises and topic labels are turned into premises, i.e. `business` -> `This text is about business.`""",
'Bart MNLI + Yahoo Answers': """Bart with a classification head trained on MNLI and then further fine-tuned on Yahoo Answers topic classification.\n\nSequences are posed as NLI premises and topic labels are turned into premises, i.e. `business` -> `This text is about business.`""",
'XLM Roberta XNLI (cross-lingual)': """XLM Roberta, a cross-lingual model, with a classification head trained on XNLI. Supported languages include: _English, French, Spanish, German, Greek, Bulgarian, Russian, Turkish, Arabic, Vietnamese, Thai, Chinese, Hindi, Swahili, and Urdu_.
Note that this model seems to be less reliable than the English-only models when classifying longer sequences.
Examples were automatically translated and may contain grammatical mistakes.
Sequences are posed as NLI premises and topic labels are turned into premises, i.e. `business` -> `This text is about business.`""",
}
ZSL_DESC = """*Update 2024/05/02: This app demo'd the use of pre-trained NLI models for zero-shot topic classification right around the 2020 announcement of GPT-3. It was originally released as a standalone app on the Hugging Face servers before the introduction of HF spaces. It is kept here for posterity.*
Recently, the NLP science community has begun to pay increasing attention to zero-shot and few-shot applications, such as in the [paper from OpenAI](https://arxiv.org/abs/2005.14165) introducing GPT-3. This demo shows how 🤗 Transformers can be used for zero-shot topic classification, the task of predicting a topic that the model has not been trained on."""
CODE_DESC = """```python
from transformers import pipeline
classifier = pipeline('zero-shot-classification',
model='{}')
hypothesis_template = 'This text is about {{}}.' # the template used in this demo
classifier(sequence, labels,
hypothesis_template=hypothesis_template,
multi_class=multi_class)
# {{'sequence' ..., 'labels': ..., 'scores': ...}}
```"""
model_ids = {
'Bart MNLI': 'facebook/bart-large-mnli',
'Bart MNLI + Yahoo Answers': 'joeddav/bart-large-mnli-yahoo-answers',
'XLM Roberta XNLI (cross-lingual)': 'joeddav/xlm-roberta-large-xnli'
}
device = 0 if torch.cuda.is_available() else -1
@st.cache_resource
def load_models():
return {id: AutoModelForSequenceClassification.from_pretrained(id) for id in model_ids.values()}
models = load_models()
@st.cache_resource(show_spinner=False)
def load_tokenizer(tok_id):
return AutoTokenizer.from_pretrained(tok_id)
@st.cache_data(show_spinner=False, hash_funcs={
torch.nn.Parameter: lambda _: None
})
def get_most_likely(nli_model_id, sequence, labels, hypothesis_template, multi_class):
classifier = pipeline(
'zero-shot-classification',
model=models[nli_model_id],
tokenizer=load_tokenizer(nli_model_id),
device=device
)
outputs = classifier(
sequence,
candidate_labels=labels,
hypothesis_template=hypothesis_template,
multi_label=multi_class
)
return outputs['labels'], outputs['scores']
def load_examples(model_id):
model_id_stripped = model_id.split('/')[-1]
df = pd.read_json(f'texts-{model_id_stripped}.json')
names = df.name.values.tolist()
mapping = {df['name'].iloc[i]: (df['text'].iloc[i], df['labels'].iloc[i]) for i in range(len(names))}
names.append('Custom')
mapping['Custom'] = ('', '')
return names, mapping
def plot_result(top_topics, scores):
top_topics = np.array(top_topics)
scores = np.array(scores)
scores *= 100
fig = px.bar(x=scores, y=top_topics, orientation='h',
labels={'x': 'Confidence', 'y': 'Label'},
text=scores,
range_x=(0,115),
title='Top Predictions',
color=np.linspace(0,1,len(scores)),
color_continuous_scale='GnBu')
fig.update(layout_coloraxis_showscale=False)
fig.update_traces(texttemplate='%{text:0.1f}%', textposition='outside')
st.plotly_chart(fig)
def main():
with open("style.css") as f:
st.markdown('<style>{}</style>'.format(f.read()), unsafe_allow_html=True)
logo = Image.open('huggingface_logo.png')
st.sidebar.image(logo, width=120)
st.sidebar.markdown(ZSL_DESC)
model_desc = st.sidebar.selectbox('Model', list(MODEL_DESC.keys()), 0)
do_print_code = st.sidebar.checkbox('Show code snippet', False)
st.sidebar.markdown('#### Model Description')
st.sidebar.markdown(MODEL_DESC[model_desc])
st.sidebar.markdown('Originally proposed by [Yin et al. (2019)](https://arxiv.org/abs/1909.00161). Read more in our [blog post](https://joeddav.github.io/blog/2020/05/29/ZSL.html).')
model_id = model_ids[model_desc]
ex_names, ex_map = load_examples(model_id)
st.title('Zero Shot Topic Classification')
example = st.selectbox('Choose an example', ex_names)
height = min((len(ex_map[example][0].split()) + 1) * 2, 200)
sequence = st.text_area('Text', ex_map[example][0], key='sequence', height=height)
labels = st.text_input('Possible topics (separated by `,`)', ex_map[example][1], max_chars=1000)
multi_class = st.checkbox('Allow multiple correct topics', value=True)
hypothesis_template = "This text is about {}."
labels = list(set([x.strip() for x in labels.strip().split(',') if len(x.strip()) > 0]))
if len(labels) == 0 or len(sequence) == 0:
st.write('Enter some text and at least one possible topic to see predictions.')
return
if do_print_code:
st.markdown(CODE_DESC.format(model_id))
with st.spinner('Classifying...'):
top_topics, scores = get_most_likely(model_id, sequence, labels, hypothesis_template, multi_class)
plot_result(top_topics[::-1][-MAX_GRAPH_ROWS:], scores[::-1][-MAX_GRAPH_ROWS:])
if __name__ == '__main__':
main()
|