|
from flask import Flask, request, render_template, jsonify |
|
from transformers import BertTokenizer, BertModel |
|
import torch |
|
import numpy as np |
|
|
|
app = Flask(__name__) |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
class BERTClass(torch.nn.Module): |
|
def __init__(self): |
|
super(BERTClass, self).__init__() |
|
self.bert_model = BertModel.from_pretrained('bert-base-uncased', return_dict=True) |
|
self.dropout = torch.nn.Dropout(0.3) |
|
self.linear = torch.nn.Linear(768, 8) |
|
|
|
def forward(self, input_ids, attn_mask, token_type_ids): |
|
output = self.bert_model( |
|
input_ids, |
|
attention_mask=attn_mask, |
|
token_type_ids=token_type_ids |
|
) |
|
output_dropout = self.dropout(output.pooler_output) |
|
output = self.linear(output_dropout) |
|
return output |
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
model = BERTClass() |
|
model.load_state_dict(torch.load("model/MLTC_model_state.bin", map_location=device)) |
|
model = model.to(device) |
|
model.eval() |
|
|
|
|
|
MAX_LEN = 256 |
|
THRESHOLD = 0.5 |
|
|
|
target_list = ['Price', 'Packaging', 'Product', 'Rider', 'Delivery', 'Shelf', 'Service', 'Seller'] |
|
|
|
@app.route('/', methods=['GET', 'POST']) |
|
def index(): |
|
raw_text = "" |
|
predictions = [] |
|
if request.method == 'POST': |
|
raw_text = request.form['text'] |
|
|
|
if not raw_text: |
|
return jsonify({'error': 'Please enter some text'}), 400 |
|
|
|
|
|
encoded_text = tokenizer.encode_plus( |
|
raw_text, |
|
max_length=MAX_LEN, |
|
add_special_tokens=True, |
|
return_token_type_ids=True, |
|
pad_to_max_length=True, |
|
return_attention_mask=True, |
|
return_tensors='pt', |
|
) |
|
|
|
input_ids = encoded_text['input_ids'].to(device) |
|
attention_mask = encoded_text['attention_mask'].to(device) |
|
token_type_ids = encoded_text['token_type_ids'].to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
output = model(input_ids, attention_mask, token_type_ids) |
|
output = torch.sigmoid(output).detach().cpu() |
|
output = output.flatten().round().numpy() |
|
|
|
|
|
predictions = [target_list[idx] for idx, p in enumerate(output) if p == 1] |
|
|
|
return render_template('index.html', text=raw_text, predictions=predictions) |
|
|
|
if __name__ == '__main__': |
|
app.run(debug=True) |
|
|
|
|