amiguel commited on
Commit
210f897
·
verified ·
1 Parent(s): f649f2d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -25
app.py CHANGED
@@ -1,31 +1,67 @@
1
  import streamlit as st
 
2
  import torch
3
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
4
-
5
- # Select the appropriate device
6
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
 
8
  # Load the model and tokenizer
9
- model = AutoModelForSequenceClassification.from_pretrained("amiguel/item_class_scope", local_files_only=True).to(device)
10
- tokenizer = AutoTokenizer.from_pretrained("gpt2")
 
 
 
11
 
12
- def classify_text(text):
13
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
 
 
 
14
  with torch.no_grad():
15
- outputs = model(**inputs)
16
- logits = outputs.logits
17
- predicted_class_id = logits.argmax().item()
18
- return "Proper Naming Notification" if predicted_class_id == 1 else "Wrong Naming Notification"
19
-
20
- st.title("Classification Naming")
21
- st.write("Classify naming notifications as proper or wrong.")
22
-
23
- text_input = st.text_area("Enter the text to classify:")
24
-
25
- if st.button("Classify"):
26
- if text_input:
27
- output = classify_text(text_input)
28
- st.write("Classification Result:")
29
- st.write(output)
30
- else:
31
- st.write("Please enter some text to classify.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import pandas as pd
3
  import torch
4
+ import tiktoken
5
+ from transformers import GPT2Tokenizer, GPT2Model
 
 
6
 
7
  # Load the model and tokenizer
8
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
9
+ model = GPT2Model.from_pretrained("gpt2")
10
+
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ model.to(device)
13
 
14
+ def classify_review(text, model, tokenizer, device, max_length=128, pad_token_id=50256):
15
+ model.eval()
16
+ input_ids = tokenizer.encode(text, return_tensors='pt').to(device)
17
+ input_ids = input_ids[:, :max_length]
18
+ input_ids = torch.nn.functional.pad(input_ids, (0, max_length - input_ids.shape[1]), value=pad_token_id)
19
  with torch.no_grad():
20
+ outputs = model(input_ids)
21
+ logits = outputs.last_hidden_state[:, -1, :]
22
+ predicted_label = torch.argmax(logits, dim=-1).item()
23
+
24
+ label_mapping = {
25
+ 0: "Pressure Safety Device",
26
+ 1: "Piping",
27
+ 2: "Pressure Vessel (VIE)",
28
+ 3: "FU Items",
29
+ 4: "Non Structural Tank",
30
+ 5: "Structure",
31
+ 6: "Corrosion Monitoring",
32
+ 7: "Flame Arrestor",
33
+ 8: "Pressure Vessel (VII)",
34
+ 9: "Lifting"
35
+ }
36
+ return label_mapping.get(predicted_label, "Unknown")
37
+
38
+ def main():
39
+ st.title("ItemClass Scope Classifier")
40
+
41
+ input_option = st.radio("Select input option", ("Single Text Query", "Upload Table"))
42
+
43
+ if input_option == "Single Text Query":
44
+ text_query = st.text_input("Enter text query")
45
+ if st.button("Classify"):
46
+ if text_query:
47
+ predicted_label = classify_review(text_query, model, tokenizer, device)
48
+ st.write("Predicted Label:")
49
+ st.write(predicted_label)
50
+ else:
51
+ st.warning("Please enter a text query.")
52
+
53
+ elif input_option == "Upload Table":
54
+ uploaded_file = st.file_uploader("Choose a file", type=["csv", "xlsx"])
55
+ if uploaded_file is not None:
56
+ if uploaded_file.name.endswith(".csv"):
57
+ df = pd.read_csv(uploaded_file)
58
+ else:
59
+ df = pd.read_excel(uploaded_file)
60
+
61
+ text_column = st.selectbox("Select the text column", df.columns)
62
+ predicted_labels = [classify_review(text, model, tokenizer, device) for text in df[text_column]]
63
+ df["Predicted Label"] = predicted_labels
64
+ st.write(df)
65
+
66
+ if __name__ == "__main__":
67
+ main()