Spaces:
Running
Running
import gradio as gr | |
from transformers import pipeline | |
import torch | |
# Initialize the zero-shot classification pipeline | |
try: | |
classifier = pipeline( | |
"zero-shot-classification", | |
model="models/tasksource/ModernBERT-nli", | |
device=0 if torch.cuda.is_available() else -1 | |
) | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
classifier = None | |
def classify_text(text, candidate_labels): | |
""" | |
Perform zero-shot classification on input text. | |
Args: | |
text (str): Input text to classify | |
candidate_labels (str): Comma-separated string of possible labels | |
Returns: | |
dict: Dictionary with labels as keys and confidence scores as values | |
""" | |
if classifier is None: | |
# Return default response when model fails to load | |
return {"error": 1.0} | |
try: | |
# Convert comma-separated string to list | |
labels = [label.strip() for label in candidate_labels.split(",")] | |
# Perform classification | |
result = classifier(text, labels) | |
# Convert to dictionary format that Gradio Label expects | |
return {label: float(score) for label, score in zip(result["labels"], result["scores"])} | |
except Exception as e: | |
print(f"Classification error: {e}") | |
return {"error": 1.0} | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=classify_text, | |
inputs=[ | |
gr.Textbox( | |
label="Text to classify", | |
placeholder="Enter text here...", | |
value="all cats are blue" | |
), | |
gr.Textbox( | |
label="Possible labels (comma-separated)", | |
placeholder="Enter labels...", | |
value="true,false" | |
) | |
], | |
outputs=gr.Label(label="Classification Results"), | |
title="Zero-Shot Text Classification", | |
description="Classify text into given categories without any training examples.", | |
examples=[ | |
["all cats are blue", "true,false"], | |
["the sky is above us", "true,false"], | |
["birds can fly", "true,false,unknown"] | |
] | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
iface.launch(share=True) |