import gradio as gr import json import torch # import wandb from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from collections import namedtuple fields = ['device', 'load_model_path', 'model_name', 'max_source_length', 'max_target_length', 'beam_size'] params = namedtuple('params', field_names=fields) args = params( device="cuda" if torch.cuda.is_available() else "cpu", load_model_path="/content/drive/MyDrive/Achatbot/output/model.bin", model_name='facebook/mbart-large-50-many-to-many-mmt', max_source_length=256, max_target_length=256, beam_size=1 ) model = AutoModelForSeq2SeqLM.from_pretrained("VDTchatbot/db_retrieval", use_auth_token="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") model.eval() if "mbart" in args.model_name.lower(): tokenizer = AutoTokenizer.from_pretrained( args.model_name, src_lang="vi_VN", tgt_lang="vi_VN" ) else: tokenizer = AutoTokenizer.from_pretrained(args.model_name) def text_analysis(text): text = text.lower() batch = {} batch['src'] = [text] inputs = prepare_model_inputs(batch, tokenizer, False, args) if "mbart" in args.model_name: inputs["forced_bos_token_id"] = tokenizer.lang_code_to_id["vi_VN"] outputs = model.generate( **inputs, max_length=args.max_target_length, num_beams=args.beam_size, early_stopping=True, ) output_sentences = tokenizer.batch_decode( outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True ) out = json.loads("{" + output_sentences[0] + "}") return out demo = gr.Interface( text_analysis, gr.Textbox(placeholder="Enter sentence here..."), ["json"], examples=[ ["Mở dashboard VTC ngày hôm qua"], ["Mở biểu đồ cột td tháng này"], ], ) demo.launch()