attaelahi commited on
Commit
9f7fbd6
·
1 Parent(s): 3d182e3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+
6
+ # Load the pre-trained model and tokenizer
7
+ tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
8
+ model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
9
+
10
+ # Suppress warning about weights not being initialized
11
+ model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2, state_dict=model.state_dict() if not isinstance(model, type(model)) else None)
12
+
13
+ # Define the prediction function
14
+ def predict(text):
15
+ # If a single example is provided, convert it to a list
16
+ if isinstance(text, str):
17
+ text = [text]
18
+
19
+ # Encode the text into tokens
20
+ encoded_text = tokenizer(text, padding='max_length', truncation=True, max_length=512, return_tensors='pt')
21
+ input_ids = encoded_text['input_ids']
22
+ attention_mask = encoded_text['attention_mask']
23
+
24
+ # Run the text through the model
25
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
26
+ logits = outputs.logits
27
+
28
+ # Get the probability of hate speech
29
+ hate_speech_probability = torch.softmax(logits, dim=1)[:, 1].tolist()
30
+
31
+ # Determine the predictions
32
+ predictions = ["Hate speech" if prob > 0.5 else "Not hate speech" for prob in hate_speech_probability]
33
+
34
+ return predictions[0] if len(predictions) == 1 else predictions
35
+
36
+ # Custom CSS styles
37
+ custom_css = """
38
+ <style>
39
+ .stTextInput {
40
+ width: 100%;
41
+ padding: 10px;
42
+ border: 1px solid #ddd;
43
+ border-radius: 5px;
44
+ margin-top: 10px;
45
+ }
46
+
47
+ .styled-button {
48
+ background-color: #4CAF50;
49
+ color: white;
50
+ padding: 10px 20px;
51
+ text-align: center;
52
+ text-decoration: none;
53
+ display: inline-block;
54
+ font-size: 16px;
55
+ cursor: pointer;
56
+ border-radius: 5px;
57
+ margin-top: 10px;
58
+ }
59
+
60
+ .styled-button:hover {
61
+ background-color: #45a049;
62
+ }
63
+
64
+ .stButton button {
65
+ background-color: #4CAF50;
66
+ color: white;
67
+ padding: 10px 20px;
68
+ text-align: center;
69
+ text-decoration: none;
70
+ display: inline-block;
71
+ font-size: 16px;
72
+ cursor: pointer;
73
+ border-radius: 5px;
74
+ }
75
+
76
+ .stButton button:hover {
77
+ background-color: #45a049;
78
+ }
79
+
80
+ .stRadio {
81
+ padding: 10px;
82
+ border: 1px solid #ddd;
83
+ border-radius: 5px;
84
+ margin-top: 10px;
85
+ }
86
+ </style>
87
+ """
88
+
89
+ # Inject custom CSS
90
+ st.markdown(custom_css, unsafe_allow_html=True)
91
+
92
+ # Create the Streamlit app with a navigation bar
93
+ st.title("Hate Speech Detector")
94
+
95
+ # Sidebar for navigation
96
+ nav_option = st.sidebar.radio("Navigation", ["Text Input", "CSV Upload"])
97
+
98
+ # Check the chosen navigation option
99
+ if nav_option == "Text Input":
100
+ # Option to input text directly
101
+ text_input = st.text_area("Enter your text here:")
102
+
103
+ if st.button("Predict"):
104
+ # If text is entered, use that for prediction
105
+ if text_input:
106
+ prediction = predict(text_input)
107
+ st.subheader("Prediction:")
108
+ st.write(prediction)
109
+ else:
110
+ st.warning("Please enter text before clicking 'Predict'.")
111
+
112
+ elif nav_option == "CSV Upload":
113
+ # Option to upload a CSV file
114
+ uploaded_file = st.file_uploader("Upload a CSV file", type=["csv"])
115
+
116
+ if st.button("Predict"):
117
+ # If a CSV file is uploaded, use the first column for prediction
118
+ if uploaded_file:
119
+ df = pd.read_csv(uploaded_file)
120
+ if not df.empty and not df.columns.empty:
121
+ text_column = df.columns[0]
122
+ predictions = df[text_column].apply(predict)
123
+ st.subheader("Predictions:")
124
+ st.write(predictions)
125
+ else:
126
+ st.warning("The CSV file is empty or does not have a valid column.")
127
+ else:
128
+ st.warning("Please upload a CSV file before clicking 'Predict'.")