tdnathmlenthusiast commited on
Commit
900738d
1 Parent(s): 29b1ee9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -21
app.py CHANGED
@@ -1,32 +1,26 @@
1
  import gradio as gr
2
- from transformers import BertTokenizer, BertForSequenceClassification
3
  import torch
4
 
5
- # Load the tokenizer and model
6
- model_path = "laptop_data.pkl" # Replace with the actual path
7
- tokenizer = BertTokenizer.from_pretrained(model_path)
8
- model = BertForSequenceClassification.from_pretrained(model_path)
9
 
10
- # Set the model to evaluation mode
11
- model.eval()
12
-
13
-
14
- def classify_text(text):
15
- inputs = tokenizer(text, return_tensors="pt",
16
- padding=True, truncation=True)
17
- with torch.no_grad():
18
- outputs = model(**inputs)
19
  logits = outputs.logits
20
  probabilities = torch.softmax(logits, dim=1)
21
- return probabilities[0].tolist()
22
-
23
 
 
24
  iface = gr.Interface(
25
- fn=classify_text,
26
- inputs=gr.inputs.Textbox(),
27
- outputs=gr.outputs.Label(num_top_classes=2),
28
- live=True,
29
- interpretation="default"
30
  )
31
 
 
32
  iface.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
 
5
+ # Initialize the tokenizer and model
6
+ tokenizer = AutoTokenizer.from_pretrained("laptop_data.pkl") # Replace with your model name or path
7
+ model = AutoModelForSequenceClassification.from_pretrained("laptop_data.pkl") # Replace with your model name or path
 
8
 
9
+ # Define the function for classifying laptops
10
+ def classify_laptop(description):
11
+ inputs = tokenizer(description, return_tensors="pt", padding=True, truncation=True)
12
+ outputs = model(**inputs)
 
 
 
 
 
13
  logits = outputs.logits
14
  probabilities = torch.softmax(logits, dim=1)
15
+ return {label: prob.item() for label, prob in zip(model.config.id2label.values(), probabilities[0])}
 
16
 
17
+ # Create the Gradio interface
18
  iface = gr.Interface(
19
+ fn=classify_laptop,
20
+ inputs=gr.inputs.Textboxbox(),
21
+ outputs=gr.outputs.Label(num_top_classes=5),
22
+ live=True
 
23
  )
24
 
25
+ # Launch the Gradio interface
26
  iface.launch()