Spaces:
Runtime error
Runtime error
tdnathmlenthusiast
commited on
Commit
•
900738d
1
Parent(s):
29b1ee9
Update app.py
Browse files
app.py
CHANGED
@@ -1,32 +1,26 @@
|
|
1 |
import gradio as gr
|
2 |
-
from transformers import
|
3 |
import torch
|
4 |
|
5 |
-
#
|
6 |
-
|
7 |
-
|
8 |
-
model = BertForSequenceClassification.from_pretrained(model_path)
|
9 |
|
10 |
-
#
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
def classify_text(text):
|
15 |
-
inputs = tokenizer(text, return_tensors="pt",
|
16 |
-
padding=True, truncation=True)
|
17 |
-
with torch.no_grad():
|
18 |
-
outputs = model(**inputs)
|
19 |
logits = outputs.logits
|
20 |
probabilities = torch.softmax(logits, dim=1)
|
21 |
-
return probabilities[0]
|
22 |
-
|
23 |
|
|
|
24 |
iface = gr.Interface(
|
25 |
-
fn=
|
26 |
-
inputs=gr.inputs.
|
27 |
-
outputs=gr.outputs.Label(num_top_classes=
|
28 |
-
live=True
|
29 |
-
interpretation="default"
|
30 |
)
|
31 |
|
|
|
32 |
iface.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
3 |
import torch
|
4 |
|
5 |
+
# Initialize the tokenizer and model
|
6 |
+
tokenizer = AutoTokenizer.from_pretrained("laptop_data.pkl") # Replace with your model name or path
|
7 |
+
model = AutoModelForSequenceClassification.from_pretrained("laptop_data.pkl") # Replace with your model name or path
|
|
|
8 |
|
9 |
+
# Define the function for classifying laptops
|
10 |
+
def classify_laptop(description):
|
11 |
+
inputs = tokenizer(description, return_tensors="pt", padding=True, truncation=True)
|
12 |
+
outputs = model(**inputs)
|
|
|
|
|
|
|
|
|
|
|
13 |
logits = outputs.logits
|
14 |
probabilities = torch.softmax(logits, dim=1)
|
15 |
+
return {label: prob.item() for label, prob in zip(model.config.id2label.values(), probabilities[0])}
|
|
|
16 |
|
17 |
+
# Create the Gradio interface
|
18 |
iface = gr.Interface(
|
19 |
+
fn=classify_laptop,
|
20 |
+
inputs=gr.inputs.Textboxbox(),
|
21 |
+
outputs=gr.outputs.Label(num_top_classes=5),
|
22 |
+
live=True
|
|
|
23 |
)
|
24 |
|
25 |
+
# Launch the Gradio interface
|
26 |
iface.launch()
|