ModernBERT-nli / app.py
sileod's picture
Update app.py
60bd985 verified
raw
history blame
2.16 kB
import gradio as gr
from transformers import pipeline
import torch
# Initialize the zero-shot classification pipeline
try:
classifier = pipeline(
"zero-shot-classification",
model="models/tasksource/ModernBERT-nli",
device=0 if torch.cuda.is_available() else -1
)
except Exception as e:
print(f"Error loading model: {e}")
classifier = None
def classify_text(text, candidate_labels):
"""
Perform zero-shot classification on input text.
Args:
text (str): Input text to classify
candidate_labels (str): Comma-separated string of possible labels
Returns:
dict: Dictionary with labels as keys and confidence scores as values
"""
if classifier is None:
# Return default response when model fails to load
return {"error": 1.0}
try:
# Convert comma-separated string to list
labels = [label.strip() for label in candidate_labels.split(",")]
# Perform classification
result = classifier(text, labels)
# Convert to dictionary format that Gradio Label expects
return {label: float(score) for label, score in zip(result["labels"], result["scores"])}
except Exception as e:
print(f"Classification error: {e}")
return {"error": 1.0}
# Create Gradio interface
iface = gr.Interface(
fn=classify_text,
inputs=[
gr.Textbox(
label="Text to classify",
placeholder="Enter text here...",
value="all cats are blue"
),
gr.Textbox(
label="Possible labels (comma-separated)",
placeholder="Enter labels...",
value="true,false"
)
],
outputs=gr.Label(label="Classification Results"),
title="Zero-Shot Text Classification",
description="Classify text into given categories without any training examples.",
examples=[
["all cats are blue", "true,false"],
["the sky is above us", "true,false"],
["birds can fly", "true,false,unknown"]
]
)
# Launch the app
if __name__ == "__main__":
iface.launch(share=True)