wide_analysis_space / model_predict.py
hsuvaskakoty's picture
initial commit
3c77d98 verified
raw
history blame
805 Bytes
#using pipeline to predict the input text
from transformers import pipeline
import torch
label_mapping = {
'delete': [0, 'LABEL_0'],
'keep': [1, 'LABEL_1'],
'merge': [2, 'LABEL_2'],
'no consensus': [3, 'LABEL_3'],
'speedy keep': [4, 'LABEL_4'],
'speedy delete': [5, 'LABEL_5'],
'redirect': [6, 'LABEL_6'],
'withdrawn': [7, 'LABEL_7']
}
def predict_text(text, model_name):
model = pipeline("text-classification", model=model_name, return_all_scores=True)
results = model(text)
final_scores = {key: 0.0 for key in label_mapping}
for result in results[0]:
for key, value in label_mapping.items():
if result['label'] == value[1]:
final_scores[key] = result['score']
break
return final_scores