JasonTPhillipsJr commited on
Commit
b28bcdc
·
verified ·
1 Parent(s): 5fb8341

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -2
app.py CHANGED
@@ -177,6 +177,14 @@ discriminator_weights = ('data/datasets/discriminator_weights.pth')
177
  discriminator.load_state_dict(torch.load(discriminator_weights,map_location=torch.device('cpu')))
178
  discriminator.eval()
179
 
 
 
 
 
 
 
 
 
180
 
181
  # Function to read reviews from a text file
182
  def load_reviews_from_file(file_path):
@@ -192,8 +200,15 @@ def load_reviews_from_file(file_path):
192
  return reviews
193
 
194
 
 
 
 
 
 
 
 
195
  st.title("SpaGAN Demo")
196
- st.write("Enter a text, and the system will highlight the geo-entities within it.")
197
 
198
  # Define a color map and descriptions for different entity types
199
  COLOR_MAP = {
@@ -230,7 +245,7 @@ if st.button("Highlight Geo-Entities"):
230
  st.write("Concatenated Embedding Shape:", combined_embedding.shape)
231
  st.write("Concatenated Embedding:", combined_embedding)
232
 
233
- prediction = discriminator(combined_embedding)
234
  st.write("Prediction:", prediction)
235
 
236
  # Process the text using spaCy
 
177
  discriminator.load_state_dict(torch.load(discriminator_weights,map_location=torch.device('cpu')))
178
  discriminator.eval()
179
 
180
+ def get_prediction(embeddings):
181
+ with torch.no_grad():
182
+ last_rep, logits, probs = discriminator(embeddings)
183
+
184
+ predicted_labels = torch.argmax(probs,dim=-1)
185
+ predicted_labels = predicted_labels.cpu().numpy()
186
+ return predicted_labels
187
+
188
 
189
  # Function to read reviews from a text file
190
  def load_reviews_from_file(file_path):
 
200
  return reviews
201
 
202
 
203
+
204
+
205
+
206
+
207
+
208
+
209
+ #Demo Section
210
  st.title("SpaGAN Demo")
211
+ st.write("This demo allows you to select from a list of sample reviews that contrain both real and fake reviews. This demo will highlight all the entity types found within the review and display the prediction of the model.")
212
 
213
  # Define a color map and descriptions for different entity types
214
  COLOR_MAP = {
 
245
  st.write("Concatenated Embedding Shape:", combined_embedding.shape)
246
  st.write("Concatenated Embedding:", combined_embedding)
247
 
248
+ prediction = get_prediction(combined_embedding)
249
  st.write("Prediction:", prediction)
250
 
251
  # Process the text using spaCy