import gradio as gr from transformers import pipeline import torch # Initialize the zero-shot classification pipeline try: classifier = pipeline( "zero-shot-classification", model="models/tasksource/ModernBERT-nli", device=0 if torch.cuda.is_available() else -1 ) except Exception as e: print(f"Error loading model: {e}") classifier = None def classify_text(text, candidate_labels): """ Perform zero-shot classification on input text. Args: text (str): Input text to classify candidate_labels (str): Comma-separated string of possible labels Returns: dict: Dictionary with labels as keys and confidence scores as values """ if classifier is None: # Return default response when model fails to load return {"error": 1.0} try: # Convert comma-separated string to list labels = [label.strip() for label in candidate_labels.split(",")] # Perform classification result = classifier(text, labels) # Convert to dictionary format that Gradio Label expects return {label: float(score) for label, score in zip(result["labels"], result["scores"])} except Exception as e: print(f"Classification error: {e}") return {"error": 1.0} # Create Gradio interface iface = gr.Interface( fn=classify_text, inputs=[ gr.Textbox( label="Text to classify", placeholder="Enter text here...", value="all cats are blue" ), gr.Textbox( label="Possible labels (comma-separated)", placeholder="Enter labels...", value="true,false" ) ], outputs=gr.Label(label="Classification Results"), title="Zero-Shot Text Classification", description="Classify text into given categories without any training examples.", examples=[ ["all cats are blue", "true,false"], ["the sky is above us", "true,false"], ["birds can fly", "true,false,unknown"] ] ) # Launch the app if __name__ == "__main__": iface.launch(share=True)