from flask import Flask, request, render_template, jsonify from transformers import BertTokenizer, BertModel import torch import numpy as np app = Flask(__name__) # Initialize BERT model and tokenizer device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') class BERTClass(torch.nn.Module): def __init__(self): super(BERTClass, self).__init__() self.bert_model = BertModel.from_pretrained('bert-base-uncased', return_dict=True) self.dropout = torch.nn.Dropout(0.3) self.linear = torch.nn.Linear(768, 8) def forward(self, input_ids, attn_mask, token_type_ids): output = self.bert_model( input_ids, attention_mask=attn_mask, token_type_ids=token_type_ids ) output_dropout = self.dropout(output.pooler_output) output = self.linear(output_dropout) return output # Load tokenizer and model tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') model = BERTClass() model.load_state_dict(torch.load("model/MLTC_model_state.bin", map_location=device)) model = model.to(device) model.eval() # Hyperparameters MAX_LEN = 256 THRESHOLD = 0.5 target_list = ['price', 'packaging', 'product', 'rider', 'delivery', 'shelf', 'service', 'seller'] @app.route('/', methods=['GET', 'POST']) def index(): raw_text = "" predictions = [] if request.method == 'POST': raw_text = request.form['text'] if not raw_text: return jsonify({'error': 'Please enter some text'}), 400 # Tokenize and encode text encoded_text = tokenizer.encode_plus( raw_text, max_length=MAX_LEN, add_special_tokens=True, return_token_type_ids=True, pad_to_max_length=True, return_attention_mask=True, return_tensors='pt', ) input_ids = encoded_text['input_ids'].to(device) attention_mask = encoded_text['attention_mask'].to(device) token_type_ids = encoded_text['token_type_ids'].to(device) # Make predictions with torch.no_grad(): output = model(input_ids, attention_mask, token_type_ids) output = torch.sigmoid(output).detach().cpu() output = output.flatten().round().numpy() # Determine predicted labels based on threshold predictions = [target_list[idx] for idx, p in enumerate(output) if p == 1] return render_template('index.html', text=raw_text, predictions=predictions) if __name__ == '__main__': app.run(debug=True)