dejanseo commited on
Commit
a24e9d7
·
verified ·
1 Parent(s): f5a0665

Upload example.py

Browse files
Files changed (1) hide show
  1. example.py +50 -0
example.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ import torch
4
+
5
+ # Load tokenizer and model from Hugging Face model hub
6
+ model_name = "dejanseo/Intent-XS"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
+ model.eval() # Set the model to evaluation mode
10
+
11
+ # Human-readable labels
12
+ label_map = {
13
+ 1: 'Commercial',
14
+ 2: 'Non-Commercial',
15
+ 3: 'Branded',
16
+ 4: 'Non-Branded',
17
+ 5: 'Informational',
18
+ 6: 'Navigational',
19
+ 7: 'Transactional',
20
+ 8: 'Commercial Investigation',
21
+ 9: 'Local',
22
+ 10: 'Entertainment'
23
+ }
24
+
25
+ # Function to perform inference
26
+ def get_predictions(text):
27
+ inputs = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=512)
28
+ with torch.no_grad():
29
+ outputs = model(**inputs)
30
+ logits = outputs.logits
31
+ probabilities = torch.sigmoid(logits).squeeze()
32
+ predictions = (probabilities > 0.5).int()
33
+ return probabilities.numpy(), predictions.numpy()
34
+
35
+ # Streamlit user interface
36
+ st.title('Multi-label Classification with Intent-XS')
37
+ query = st.text_input("Enter your query:")
38
+
39
+ if st.button('Submit'):
40
+ if query:
41
+ probabilities, predictions = get_predictions(query)
42
+ result = {label_map[i+1]: f"Probability: {prob:.2f}" for i, prob in enumerate(probabilities) if predictions[i] == 1}
43
+ if result:
44
+ st.write("Predicted Categories:")
45
+ for label, prob in result.items():
46
+ st.write(f"{label}: {prob}")
47
+ else:
48
+ st.write("No relevant categories predicted.")
49
+ else:
50
+ st.write("Please enter a query to get predictions.")