MedChattRe / app.py
FIXED-TERM Nguyen Trung Tam (MS/EMC25-XC)
Add vn lang
91aa056
raw
history blame
4.1 kB
from flask import Flask, request
import requests
import os
import re
from transformers import AutoModelForSeq2SeqLM
from transformers import AutoTokenizer
from langdetect import detect
import subprocess
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
model = AutoModelForSeq2SeqLM.from_pretrained(
"GuysTrans/bart-base-finetuned-xsum")
vn_model = AutoModelForSeq2SeqLM.from_pretrained(
"GuysTrans/bart-base-vn-ehealth")
map_words = {
"Hello and Welcome to 'Ask A Doctor' service": "",
"Hello,": "",
"Hi,": "",
"Hello": "",
"Hi": ""
}
word_remove_sentence = [
"hello",
"hi",
"regards",
"dr.",
"physician",
]
def generate_summary(question, model):
inputs = tokenizer(
question,
padding="max_length",
truncation=True,
max_length=512,
return_tensors="pt",
)
input_ids = inputs.input_ids.to(model.device)
attention_mask = inputs.attention_mask.to(model.device)
outputs = model.generate(
input_ids, attention_mask=attention_mask, max_new_tokens=512)
output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
return outputs, output_str
app = Flask(__name__)
FB_API_URL = 'https://graph.facebook.com/v2.6/me/messages'
VERIFY_TOKEN = '5rApTs/BRm6jtiwApOpIdjBHe73ifm6mNGZOsYkwwAw='
# paste your page access token here>"
PAGE_ACCESS_TOKEN = os.environ['PAGE_ACCESS_TOKEN']
def get_bot_response(message):
lang = detect(message)
model_use = model
if lang == "vn":
model_use = vn_model
return post_process(generate_summary(message, model_use)[1][0])
def verify_webhook(req):
if req.args.get("hub.verify_token") == VERIFY_TOKEN:
return req.args.get("hub.challenge")
else:
return "incorrect"
def respond(sender, message):
"""Formulate a response to the user and
pass it on to a function that sends it."""
response = get_bot_response(message)
send_message(sender, response)
return response
def is_user_message(message):
"""Check if the message is a message from the user"""
return (message.get('message') and
message['message'].get('text') and
not message['message'].get("is_echo"))
@app.route("/webhook", methods=['GET', 'POST'])
def listen():
"""This is the main function flask uses to
listen at the `/webhook` endpoint"""
if request.method == 'GET':
return verify_webhook(request)
if request.method == 'POST':
payload = request.json
event = payload['entry'][0]['messaging']
for x in event:
if is_user_message(x):
text = x['message']['text']
sender_id = x['sender']['id']
respond(sender_id, text)
return "ok"
def send_message(recipient_id, text):
"""Send a response to Facebook"""
payload = {
'message': {
'text': text
},
'recipient': {
'id': recipient_id
},
'notification_type': 'regular'
}
auth = {
'access_token': PAGE_ACCESS_TOKEN
}
response = requests.post(
FB_API_URL,
params=auth,
json=payload
)
return response.json()
@app.route("/webhook/chat", methods=['POST'])
def chat():
payload = request.json
message = payload['message']
response = get_bot_response(message)
return {"message": response}
def post_process(output):
lines = output.split("\n")
for line in lines:
for word in word_remove_sentence:
if word in line.lower():
lines.remove(line)
break
output = "\n".join(lines)
for item in map_words.keys():
output = re.sub(item, map_words[item], output, re.I)
return output
subprocess.Popen(["autossh", "-M", "0", "-o", "StrictHostKeyChecking=no",
"-i", "id_rsa", "-R", "guysmedchatt:80:localhost:7860", "serveo.net"])
# subprocess.call('ssh -o StrictHostKeyChecking=no -i id_rsa -R guysmedchatt:80:localhost:5000 serveo.net', shell=True)