Spaces:
Sleeping
Sleeping
import gradio as gr | |
import argparse | |
import logging | |
import os | |
import numpy as np | |
import torch | |
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset | |
from tqdm import tqdm | |
from utils import MODEL_CLASSES, get_intent_labels, get_slot_labels, init_logger, load_tokenizer | |
logger = logging.getLogger(__name__) | |
def get_device(pred_config): | |
return "cuda" if torch.cuda.is_available() and not pred_config.no_cuda else "cpu" | |
def get_args(pred_config): | |
args = torch.load(os.path.join(pred_config.model_dir, "training_args.bin")) | |
args.model_dir = pred_config.model_dir | |
args.data_dir = 'PhoATIS' | |
return args | |
def load_model(pred_config, args, device): | |
# Check whether model exists | |
if not os.path.exists(pred_config.model_dir): | |
raise Exception("Model doesn't exists! Train first!") | |
try: | |
model = MODEL_CLASSES[args.model_type][1].from_pretrained( | |
args.model_dir, args=args, intent_label_lst=get_intent_labels(args), slot_label_lst=get_slot_labels(args) | |
) | |
model.to(device) | |
model.eval() | |
logger.info("***** Model Loaded *****") | |
except Exception: | |
raise Exception("Some model files might be missing...") | |
return model | |
def convert_input_file_to_tensor_dataset( | |
lines, | |
pred_config, | |
args, | |
tokenizer, | |
pad_token_label_id, | |
cls_token_segment_id=0, | |
pad_token_segment_id=0, | |
sequence_a_segment_id=0, | |
mask_padding_with_zero=True, | |
): | |
# Setting based on the current model type | |
cls_token = tokenizer.cls_token | |
sep_token = tokenizer.sep_token | |
unk_token = tokenizer.unk_token | |
pad_token_id = tokenizer.pad_token_id | |
all_input_ids = [] | |
all_attention_mask = [] | |
all_token_type_ids = [] | |
all_slot_label_mask = [] | |
for words in lines: | |
tokens = [] | |
slot_label_mask = [] | |
for word in words: | |
word_tokens = tokenizer.tokenize(word) | |
if not word_tokens: | |
word_tokens = [unk_token] # For handling the bad-encoded word | |
tokens.extend(word_tokens) | |
# Use the real label id for the first token of the word, and padding ids for the remaining tokens | |
slot_label_mask.extend([pad_token_label_id + 1] + [pad_token_label_id] * (len(word_tokens) - 1)) | |
# Account for [CLS] and [SEP] | |
special_tokens_count = 2 | |
if len(tokens) > args.max_seq_len - special_tokens_count: | |
tokens = tokens[: (args.max_seq_len - special_tokens_count)] | |
slot_label_mask = slot_label_mask[: (args.max_seq_len - special_tokens_count)] | |
# Add [SEP] token | |
tokens += [sep_token] | |
token_type_ids = [sequence_a_segment_id] * len(tokens) | |
slot_label_mask += [pad_token_label_id] | |
# Add [CLS] token | |
tokens = [cls_token] + tokens | |
token_type_ids = [cls_token_segment_id] + token_type_ids | |
slot_label_mask = [pad_token_label_id] + slot_label_mask | |
input_ids = tokenizer.convert_tokens_to_ids(tokens) | |
# The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to. | |
attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) | |
# Zero-pad up to the sequence length. | |
padding_length = args.max_seq_len - len(input_ids) | |
input_ids = input_ids + ([pad_token_id] * padding_length) | |
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length) | |
token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length) | |
slot_label_mask = slot_label_mask + ([pad_token_label_id] * padding_length) | |
all_input_ids.append(input_ids) | |
all_attention_mask.append(attention_mask) | |
all_token_type_ids.append(token_type_ids) | |
all_slot_label_mask.append(slot_label_mask) | |
# Change to Tensor | |
all_input_ids = torch.tensor(all_input_ids, dtype=torch.long) | |
all_attention_mask = torch.tensor(all_attention_mask, dtype=torch.long) | |
all_token_type_ids = torch.tensor(all_token_type_ids, dtype=torch.long) | |
all_slot_label_mask = torch.tensor(all_slot_label_mask, dtype=torch.long) | |
dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_slot_label_mask) | |
return dataset | |
def predict(text): | |
lines = text | |
dataset = convert_input_file_to_tensor_dataset(lines, pred_config, args, tokenizer, pad_token_label_id) | |
# Predict | |
sampler = SequentialSampler(dataset) | |
data_loader = DataLoader(dataset, sampler=sampler, batch_size=pred_config.batch_size) | |
all_slot_label_mask = None | |
intent_preds = None | |
slot_preds = None | |
for batch in tqdm(data_loader, desc="Predicting"): | |
batch = tuple(t.to(device) for t in batch) | |
with torch.no_grad(): | |
inputs = { | |
"input_ids": batch[0], | |
"attention_mask": batch[1], | |
"intent_label_ids": None, | |
"slot_labels_ids": None, | |
} | |
if args.model_type != "distilbert": | |
inputs["token_type_ids"] = batch[2] | |
outputs = model(**inputs) | |
_, (intent_logits, slot_logits) = outputs[:2] | |
# Intent Prediction | |
if intent_preds is None: | |
intent_preds = intent_logits.detach().cpu().numpy() | |
else: | |
intent_preds = np.append(intent_preds, intent_logits.detach().cpu().numpy(), axis=0) | |
# Slot prediction | |
if slot_preds is None: | |
if args.use_crf: | |
# decode() in `torchcrf` returns list with best index directly | |
slot_preds = np.array(model.crf.decode(slot_logits)) | |
else: | |
slot_preds = slot_logits.detach().cpu().numpy() | |
all_slot_label_mask = batch[3].detach().cpu().numpy() | |
else: | |
if args.use_crf: | |
slot_preds = np.append(slot_preds, np.array(model.crf.decode(slot_logits)), axis=0) | |
else: | |
slot_preds = np.append(slot_preds, slot_logits.detach().cpu().numpy(), axis=0) | |
all_slot_label_mask = np.append(all_slot_label_mask, batch[3].detach().cpu().numpy(), axis=0) | |
intent_preds = np.argmax(intent_preds, axis=1) | |
if not args.use_crf: | |
slot_preds = np.argmax(slot_preds, axis=2) | |
slot_label_map = {i: label for i, label in enumerate(slot_label_lst)} | |
slot_preds_list = [[] for _ in range(slot_preds.shape[0])] | |
for i in range(slot_preds.shape[0]): | |
for j in range(slot_preds.shape[1]): | |
if all_slot_label_mask[i, j] != pad_token_label_id: | |
slot_preds_list[i].append(slot_label_map[slot_preds[i][j]]) | |
return (lines, slot_preds_list, intent_preds) | |
def text_analysis(text): | |
text = [text.strip().split()] | |
words, slot_preds, intent_pred = predict(text)[0][0], predict(text)[1][0], predict(text)[2][0] | |
slot_tokens = [] | |
for word, pred in zip(words, slot_preds): | |
if pred == 'O': | |
slot_tokens.extend([(word, None), (" ", None)]) | |
elif pred[0] == 'I': | |
added_tokens = list(slot_tokens[-2]) | |
added_tokens[0] += f' {word}' | |
slot_tokens[-2] = tuple(added_tokens) | |
else: | |
slot_tokens.extend([(word, pred[2:]), (" ", None)]) | |
intent_label = intent_label_lst[intent_pred] | |
return slot_tokens, intent_label | |
if __name__ == "__main__": | |
init_logger() | |
parser = argparse.ArgumentParser() | |
# parser.add_argument("--input_file", default="sample_pred_in.txt", type=str, help="Input file for prediction") | |
# parser.add_argument("--output_file", default="sample_pred_out.txt", type=str, help="Output file for prediction") | |
parser.add_argument("--model_dir", default="./atis_model", type=str, help="Path to save, load model") | |
parser.add_argument("--batch_size", default=32, type=int, help="Batch size for prediction") | |
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") | |
pred_config = parser.parse_args() | |
# load model and args | |
args = get_args(pred_config) | |
device = get_device(pred_config) | |
model = load_model(pred_config, args, device) | |
logger.info(args) | |
intent_label_lst = get_intent_labels(args) | |
slot_label_lst = get_slot_labels(args) | |
# Convert input file to TensorDataset | |
pad_token_label_id = args.ignore_index | |
tokenizer = load_tokenizer(args) | |
examples = ["tôi muốn bay một chuyến khứ_hồi từ đà_nẵng đến đà_lạt", | |
("giá vé khứ_hồi từ đà_nẵng đến vinh dưới 2 triệu đồng giá vé khứ_hồi từ quy nhơn đến vinh dưới 3 triệu đồng giá vé khứ_hồi từ" | |
" buôn_ma_thuột đến vinh dưới 4 triệu rưỡi"), | |
"cho tôi biết các chuyến bay đến đà_nẵng vào ngày 14 tháng sáu", | |
"những chuyến bay nào khởi_hành từ thành_phố hồ_chí_minh bay đến frankfurt mà nối chuyến ở singapore và hạ_cánh trước 9 giờ tối"] | |
demo = gr.Interface( | |
text_analysis, | |
gr.Textbox(placeholder="Enter sentence here...", label="Input"), | |
[gr.HighlightedText(label='Highlighted Output'), gr.Textbox(label='Intent Label')], | |
examples=examples, | |
) | |
demo.launch(share=True) |