GuysRGithub commited on
Commit
a4cc6d9
1 Parent(s): 2d97fd7
Files changed (1) hide show
  1. app.py +44 -5
app.py CHANGED
@@ -1,13 +1,31 @@
1
  from flask import Flask, request
2
  import requests
3
  import os
 
4
  from transformers import AutoModelForSeq2SeqLM
5
  from transformers import AutoTokenizer
6
  import subprocess
7
 
8
  tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
9
 
10
- model = AutoModelForSeq2SeqLM.from_pretrained("GuysTrans/bart-base-finetuned-xsum")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  def generate_summary(question, model):
@@ -20,7 +38,8 @@ def generate_summary(question, model):
20
  )
21
  input_ids = inputs.input_ids.to(model.device)
22
  attention_mask = inputs.attention_mask.to(model.device)
23
- outputs = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=512)
 
24
  output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
25
  return outputs, output_str
26
 
@@ -29,11 +48,12 @@ app = Flask(__name__)
29
 
30
  FB_API_URL = 'https://graph.facebook.com/v2.6/me/messages'
31
  VERIFY_TOKEN = '5rApTs/BRm6jtiwApOpIdjBHe73ifm6mNGZOsYkwwAw='
32
- PAGE_ACCESS_TOKEN = os.environ['PAGE_ACCESS_TOKEN'] # paste your page access token here>"
 
33
 
34
 
35
  def get_bot_response(message):
36
- return generate_summary(message, model)[1][0]
37
 
38
 
39
  def verify_webhook(req):
@@ -101,6 +121,7 @@ def send_message(recipient_id, text):
101
 
102
  return response.json()
103
 
 
104
  @app.route("/webhook/chat", methods=['POST'])
105
  def chat():
106
  payload = request.json
@@ -108,5 +129,23 @@ def chat():
108
  response = get_bot_response(message)
109
  return {"message": response}
110
 
111
- subprocess.Popen(["autossh", "-M", "0" ,"-o", "StrictHostKeyChecking=no", "-i", "id_rsa", "-R", "guysmedchatt:80:localhost:7860", "serveo.net"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  # subprocess.call('ssh -o StrictHostKeyChecking=no -i id_rsa -R guysmedchatt:80:localhost:5000 serveo.net', shell=True)
 
1
  from flask import Flask, request
2
  import requests
3
  import os
4
+ import re
5
  from transformers import AutoModelForSeq2SeqLM
6
  from transformers import AutoTokenizer
7
  import subprocess
8
 
9
  tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
10
 
11
+ model = AutoModelForSeq2SeqLM.from_pretrained(
12
+ "GuysTrans/bart-base-finetuned-xsum")
13
+
14
+ map_words = {
15
+ "Hello and Welcome to 'Ask A Doctor' service": "",
16
+ "Hello,": "",
17
+ "Hi,": "",
18
+ "Hello": "",
19
+ "Hi": ""
20
+ }
21
+
22
+ word_remove_sentence = [
23
+ "hello",
24
+ "hi",
25
+ "regards",
26
+ "dr.",
27
+ "physician",
28
+ ]
29
 
30
 
31
  def generate_summary(question, model):
 
38
  )
39
  input_ids = inputs.input_ids.to(model.device)
40
  attention_mask = inputs.attention_mask.to(model.device)
41
+ outputs = model.generate(
42
+ input_ids, attention_mask=attention_mask, max_new_tokens=512)
43
  output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
44
  return outputs, output_str
45
 
 
48
 
49
  FB_API_URL = 'https://graph.facebook.com/v2.6/me/messages'
50
  VERIFY_TOKEN = '5rApTs/BRm6jtiwApOpIdjBHe73ifm6mNGZOsYkwwAw='
51
+ # paste your page access token here>"
52
+ PAGE_ACCESS_TOKEN = os.environ['PAGE_ACCESS_TOKEN']
53
 
54
 
55
  def get_bot_response(message):
56
+ return post_process(generate_summary(message, model)[1][0])
57
 
58
 
59
  def verify_webhook(req):
 
121
 
122
  return response.json()
123
 
124
+
125
  @app.route("/webhook/chat", methods=['POST'])
126
  def chat():
127
  payload = request.json
 
129
  response = get_bot_response(message)
130
  return {"message": response}
131
 
132
+ def post_process(output):
133
+
134
+ lines = output.split("\n")
135
+ for line in lines:
136
+ for word in word_remove_sentence:
137
+ if word in line.lower():
138
+ lines.remove(line)
139
+ break
140
+
141
+ output = "\n".join(lines)
142
+ for item in map_words.keys():
143
+ output = re.sub(item, map_words[item], output, re.I)
144
+
145
+ return output
146
+
147
+
148
+
149
+ subprocess.Popen(["autossh", "-M", "0", "-o", "StrictHostKeyChecking=no",
150
+ "-i", "id_rsa", "-R", "guysmedchatt:80:localhost:7860", "serveo.net"])
151
  # subprocess.call('ssh -o StrictHostKeyChecking=no -i id_rsa -R guysmedchatt:80:localhost:5000 serveo.net', shell=True)