metadata
license: apache-2.0
Model description
A BertForSequenceClassification model that is finetuned on Wikipedia for zero-shot text classification. For details, see our NAACL'22 paper.
Usage
Concatenate the text sentence with each of the candidate labels as input to the model. The model will output a score for each label. Below is an example.
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
tokenizer = AutoTokenizer.from_pretrained("CogComp/ZeroShotWiki")
model = AutoModelForSequenceClassification.from_pretrained("CogComp/ZeroShotWiki")
labels = ["sports", "business", "politics"]
texts = ["As of the 2018 FIFA World Cup, twenty-one final tournaments have been held and a total of 79 national teams have competed."]
with torch.no_grad():
for text in texts:
label_score = {}
for label in labels:
inputs = tokenizer(text, label, return_tensors='pt')
out = model(**inputs)
label_score[label]=float(torch.nn.functional.softmax(out[0], dim=-1)[0][0])
print(label_score) # Predict the label with the highest score