sileod commited on
Commit
04e7b78
1 Parent(s): 967ee6e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -71
app.py CHANGED
@@ -1,72 +1,3 @@
1
- import gradio as gr
2
- from transformers import pipeline
3
- import torch
4
-
5
- # Initialize the zero-shot classification pipeline
6
- try:
7
- classifier = pipeline(
8
- "zero-shot-classification",
9
- model="models/tasksource/ModernBERT-nli",
10
- device=0 if torch.cuda.is_available() else -1
11
- )
12
- except Exception as e:
13
- print(f"Error loading model: {e}")
14
- classifier = None
15
-
16
- def classify_text(text, candidate_labels):
17
- """
18
- Perform zero-shot classification on input text.
19
-
20
- Args:
21
- text (str): Input text to classify
22
- candidate_labels (str): Comma-separated string of possible labels
23
-
24
- Returns:
25
- dict: Dictionary with labels as keys and confidence scores as values
26
- """
27
- if classifier is None:
28
- # Return default response when model fails to load
29
- return {"error": 1.0}
30
-
31
- try:
32
- # Convert comma-separated string to list
33
- labels = [label.strip() for label in candidate_labels.split(",")]
34
-
35
- # Perform classification
36
- result = classifier(text, labels)
37
-
38
- # Convert to dictionary format that Gradio Label expects
39
- return {label: float(score) for label, score in zip(result["labels"], result["scores"])}
40
-
41
- except Exception as e:
42
- print(f"Classification error: {e}")
43
- return {"error": 1.0}
44
 
45
- # Create Gradio interface
46
- iface = gr.Interface(
47
- fn=classify_text,
48
- inputs=[
49
- gr.Textbox(
50
- label="Text to classify",
51
- placeholder="Enter text here...",
52
- value="all cats are blue"
53
- ),
54
- gr.Textbox(
55
- label="Possible labels (comma-separated)",
56
- placeholder="Enter labels...",
57
- value="true,false"
58
- )
59
- ],
60
- outputs=gr.Label(label="Classification Results"),
61
- title="Zero-Shot Text Classification",
62
- description="Classify text into given categories without any training examples.",
63
- examples=[
64
- ["all cats are blue", "true,false"],
65
- ["the sky is above us", "true,false"],
66
- ["birds can fly", "true,false,unknown"]
67
- ]
68
- )
69
-
70
- # Launch the app
71
- if __name__ == "__main__":
72
- iface.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
+ import gradio as gr
3
+ gr.Interface.load("models/tasksource/ModernBERT-base-nli").launch()