shopify_csv_qa / app /tapex.py
patrawtf's picture
Update app/tapex.py
efb54e6
raw
history blame
872 Bytes
from transformers import TapasTokenizer, TapexTokenizer, BartForConditionalGeneration
import pandas as pd
import datetime
import torch
def execute_query(query, csv_file):
a = datetime.datetime.now()
table = pd.read_csv(csv_file.name, delimiter=",")
table = table.astype(str)
model_name = "microsoft/tapex-large-finetuned-wtq"
model = BartForConditionalGeneration.from_pretrained(model_name)
tokenizer = TapexTokenizer.from_pretrained(model_name)
queries = [query]
encoding = tokenizer(table=table, query=queries, padding=True, return_tensors="pt",truncation=True)
outputs = model.generate(**encoding)
ans = tokenizer.batch_decode(outputs, skip_special_tokens=True)
query_result = {
"query": query,
"answer": ans[0]
}
b = datetime.datetime.now()
print(b - a)
return query_result, table