Zero-shot-test / app.py
micole66's picture
Rename App.py to app.py
842fcb5 verified
raw
history blame
1.57 kB
from transformers import pipeline
import torch
import gradio as gr
# Initialize zero-shot classification pipeline
classifier = pipeline("zero-shot-classification",
model="MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7",
device=0 if torch.cuda.is_available() else -1)
def classify_text(text, labels):
# Split labels into a list
candidate_labels = [label.strip() for label in labels.split(",")]
# Perform zero-shot classification
result = classifier(text, candidate_labels, multi_label=False)
# Format output
output = ""
for label, score in zip(result['labels'], result['scores']):
percentage = score * 100
output += f"{label}: {percentage:.2f}%\n"
return output
# Create Gradio interface
iface = gr.Interface(
fn=classify_text,
inputs=[
gr.Textbox(label="Enter text to classify", lines=3),
gr.Textbox(label="Enter labels (comma-separated)", value="politics, sports, technology, entertainment")
],
outputs=gr.Textbox(label="Classification Results"),
title="Zero-Shot Text Classification",
description="Enter text and labels to classify the text into different categories."
)
# Launch the app
iface.launch()