|
--- |
|
license: apache-2.0 |
|
--- |
|
|
|
# Model description |
|
|
|
A BertForSequenceClassification model that is finetuned on Wikipedia for zero-shot text classification. For details, see our NAACL'22 paper. |
|
|
|
|
|
# Usage |
|
|
|
Concatenate the text sentence with each of the candidate labels as input to the model. The model will output a score for each label. Below is an example. |
|
|
|
``` |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
import torch |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("CogComp/ZeroShotWiki") |
|
model = AutoModelForSequenceClassification.from_pretrained("CogComp/ZeroShotWiki") |
|
|
|
labels = ["sports", "business", "politics"] |
|
texts = ["As of the 2018 FIFA World Cup, twenty-one final tournaments have been held and a total of 79 national teams have competed."] |
|
|
|
with torch.no_grad(): |
|
for text in texts: |
|
label_score = {} |
|
for label in labels: |
|
inputs = tokenizer(text, label, return_tensors='pt') |
|
out = model(**inputs) |
|
label_score[label]=float(torch.nn.functional.softmax(out[0], dim=-1)[0][0]) |
|
print(label_score) # Predict the label with the highest score |
|
``` |