Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
import os | |
import re | |
import torch | |
from transformers import ( | |
AutoTokenizer, | |
AutoModel, | |
T5ForConditionalGeneration, | |
MBartForConditionalGeneration, | |
AutoModelForSeq2SeqLM, | |
) | |
from tqdm.auto import tqdm | |
import streamlit as st | |
from typing import Dict, List | |
def load_model(model_name, device): | |
print(f"Using model {model_name}") | |
os.makedirs("cache", exist_ok=True) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir="cache") | |
model.to(device) | |
model_name = model_name.split("/")[-1] | |
load_model_path = os.path.join("models", f"{model_name}-best_loss.bin") | |
print(f"Loading model from {load_model_path}") | |
model.load_state_dict( | |
torch.load(load_model_path, map_location=torch.device(device)) | |
) | |
return model | |
def load_tokenizer(model_name): | |
print(f"Loading tokenizer {model_name}") | |
if "mbart" in model_name.lower(): | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_name, src_lang="vi_VN", tgt_lang="vi_VN" | |
) | |
# tokenizer.src_lang = "vi_VN" | |
# tokenizer.tgt_lang = "vi_VN" | |
else: | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
return tokenizer | |
def prepare_batch_model_inputs(batch, tokenizer, max_len, is_train=False, device="cpu"): | |
inputs = tokenizer( | |
batch["src"], | |
text_target=batch["tgt"] if is_train else None, | |
padding="longest", | |
max_length=max_len, | |
truncation=True, | |
return_tensors="pt", | |
) | |
for k, v in inputs.items(): | |
inputs[k] = v.to(device) | |
return inputs | |
def prepare_single_model_inputs(src, tokenizer, max_len, device="cpu"): | |
inputs = tokenizer( | |
src, | |
padding="longest", | |
max_length=max_len, | |
truncation=True, | |
return_tensors="pt", | |
) | |
for k, v in inputs.items(): | |
inputs[k] = v.to(device) | |
return inputs | |
def make_input_sentence_from_strings(data): | |
# data = { | |
# "CHỈ TIÊU": objective_name, | |
# "ĐƠN VỊ": unit, | |
# "ĐIỀU KIỆN": condition, | |
# "KPI mục tiêu tháng": kpi_target, | |
# "Đánh giá": evaluation_value, | |
# "Thời gian báo cáo": current_time, | |
# f"T{current_time[1]}.{current_time[0]} thực tế": real_value, | |
# "Previous month value key": f"T{previous_month[1]}.{previous_month[0]}", | |
# f"T{previous_month[1]}.{previous_month[0]}": previous_month_value, | |
# "Previous year value key": f"T{previous_year[1]}.{previous_year[0]}", | |
# f"T{previous_year[1]}.{previous_year[0]}": previous_year_value, | |
# "Previous month compare key": f"So sánh T{previous_month[1]}.{previous_month[0]} Tăng giảm", | |
# f"So sánh T{previous_month[1]}.{previous_month[0]} Tăng giảm": previous_month_compare, | |
# "Previous year compare key": f"So sánh T{previous_year[1]}.{previous_year[0]} Tăng giảm", | |
# "Previous month": previous_month, | |
# "Previous year": previous_year, | |
# } | |
previous_month_value_key = data["Previous month value key"] | |
previous_year_value_key = data["Previous year value key"] | |
objective_name = data["CHỈ TIÊU"] | |
unit = data["ĐƠN VỊ"] | |
condition = data["ĐIỀU KIỆN"] | |
kpi_target = data["KPI mục tiêu tháng"] | |
current_time = data["Thời gian báo cáo"] | |
real_value = data[f"T{current_time[1]}.{current_time[0]} thực tế"] | |
evaluation_value = data["Đánh giá"] | |
previous_month_value = data[previous_month_value_key] | |
previous_year_value = data[previous_year_value_key] | |
previous_month_compare_key = data["Previous month compare key"] | |
previous_year_compare_key = data["Previous year compare key"] | |
previous_month_compare = data[previous_month_compare_key] | |
previous_year_compare = data[previous_year_compare_key] | |
previous_month = data["Previous month"] | |
previous_year = data["Previous year"] | |
# make a template string from the following example: | |
# """{"CHỈ TIÊU": "Tỷ lệ kết nối thành công đến tổng đài - KHCN_Di động Vip", "ĐƠN VỊ": "%", "ĐIỀU KIỆN": ">=", "KPI mục tiêu tháng": 95.0, "Tháng 9.2022": 97.5, "Đánh giá": "Đạt", "T8.2022": 96.6, "So sánh T8.2022 Tăng giảm": 1.0, "T9.2021": 96.8, "So sánh T9.2021 Tăng giảm": 0.8}""" | |
template_str = '"CHỈ TIÊU": "{}", "ĐƠN VỊ": "{}", "ĐIỀU KIỆN": "{}", "KPI mục tiêu tháng": {}, "Tháng {}.{}": {}, "Đánh giá": "{}", "T{}.{}": {}, "So sánh T{}.{} Tăng giảm": {}, "T{}.{}": {}, "So sánh T{}.{} Tăng giảm": {}' | |
return template_str.format( | |
objective_name, | |
unit, | |
condition, | |
kpi_target, | |
current_time[1], | |
current_time[0], | |
real_value, | |
evaluation_value, | |
previous_month[1], | |
previous_month[0], | |
previous_month_value, | |
previous_month[1], | |
previous_month[0], | |
previous_month_compare, | |
previous_year[1], | |
previous_year[0], | |
previous_year_value, | |
previous_year[1], | |
previous_year[0], | |
previous_year_compare, | |
) | |
def generate_description( | |
input_string, model, tokenizer, device, max_len, model_name, beam_size | |
): | |
model.eval() | |
model = model.to(device) | |
inputs = prepare_single_model_inputs( | |
input_string, tokenizer, max_len=max_len, device=device | |
) | |
if "mbart" in model_name.lower(): | |
inputs["forced_bos_token_id"] = tokenizer.lang_code_to_id["vi_VN"] | |
outputs = model.generate( | |
**inputs, | |
max_length=max_len, | |
num_beams=beam_size, | |
# early_stopping=True, | |
) | |
return tokenizer.batch_decode( | |
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True | |
) | |