vluz commited on
Commit
27ad3e1
1 Parent(s): de17e7f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -14
app.py CHANGED
@@ -7,28 +7,35 @@ from tensorflow.keras.layers import TextVectorization
7
 
8
  @st.cache_resource
9
  def load_model():
10
- model = tf.keras.models.load_model(os.path.join("model", "toxmodel.keras"))
11
- return model
12
 
13
 
14
  @st.cache_resource
15
  def load_vectorizer():
16
- from_disk = pickle.load(open(os.path.join("model", "vectorizer.pkl"), "rb"))
17
- new_v = TextVectorization.from_config(from_disk['config'])
18
- new_v.adapt(tf.data.Dataset.from_tensor_slices(["xyz"])) # Keras bug
19
- new_v.set_weights(from_disk['weights'])
20
- return new_v
21
 
22
 
23
  st.title("Toxic Comment Test")
24
  st.divider()
25
  model = load_model()
26
  vectorizer = load_vectorizer()
27
- input_text = st.text_area("Comment:", "I love you man, but fuck you!", height=150)
 
28
  if st.button("Test"):
29
- with st.spinner("Testing..."):
30
- inputv = vectorizer([input_text])
31
- output = model.predict(inputv)
32
- res = (output > 0.5)
33
- st.write(["toxic","severe toxic","obscene","threat","insult","identity hate"], res)
34
- st.write(output)
 
 
 
 
 
 
 
7
 
8
  @st.cache_resource
9
  def load_model():
10
+     model = tf.keras.models.load_model(os.path.join("model", "toxmodel.keras"))
11
+     return model
12
 
13
 
14
  @st.cache_resource
15
  def load_vectorizer():
16
+     from_disk = pickle.load(open(os.path.join("model", "vectorizer.pkl"), "rb"))
17
+     new_v = TextVectorization.from_config(from_disk['config'])
18
+     new_v.adapt(tf.data.Dataset.from_tensor_slices(["xyz"])) # fix for Keras bug
19
+     new_v.set_weights(from_disk['weights'])
20
+     return new_v
21
 
22
 
23
  st.title("Toxic Comment Test")
24
  st.divider()
25
  model = load_model()
26
  vectorizer = load_vectorizer()
27
+ default_prompt = "I love you man, but fuck you!"
28
+ input_text = st.text_area("Comment:", default_prompt, height=150).lower()
29
  if st.button("Test"):
30
+     if not input_text:
31
+         st.write("⚠ Warning: Empty prompt.")
32
+     elif len(input_text) < 15:
33
+         st.write("⚠ Warning: Model is far less accurate with a small prompt.")
34
+     if input_text == default_prompt:
35
+         st.write("Expected results from default prompt are positive for 0 and 2")
36
+     with st.spinner("Testing..."):
37
+         inputv = vectorizer([input_text])
38
+         output = model.predict(inputv)
39
+         res = (output > 0.5)
40
+     st.write(["toxic","severe toxic","obscene","threat","insult","identity hate"], res)
41
+     st.write(output)