File size: 2,190 Bytes
8a57a60 212ebbf 8a57a60 212ebbf 8a57a60 212ebbf 8a57a60 212ebbf 8a57a60 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
from transformers import TapexTokenizer, BartForConditionalGeneration
import pandas as pd
import datetime
import torch
import gradio as gr
def execute_query(query, csv_file):
a = datetime.datetime.now()
table = pd.read_csv(csv_file.name, delimiter=",")
table = table.astype(str)
model_name = "microsoft/tapex-large-finetuned-wtq"
model = BartForConditionalGeneration.from_pretrained(model_name)
tokenizer = TapexTokenizer.from_pretrained(model_name)
queries = [query]
encoding = tokenizer(table=table, query=queries, padding=True, return_tensors="pt", truncation=True)
outputs = model.generate(**encoding)
ans = tokenizer.batch_decode(outputs, skip_special_tokens=True)
query_result = {
"query": query,
"answer": ans[0]
}
b = datetime.datetime.now()
print(b - a)
return query_result, table
def main():
description = "Querying a CSV using the TAPEX model. You can ask a question about tabular data, and the TAPEX model will produce the result. The finetuned TAPEX model runs on data with a maximum of 5000 rows and 20 columns. A sample dataset of Shopify store sales is provided."
article = "<p style='text-align: center'><a href='https://unscrambl.com/' target='_blank'>Unscrambl</a> | <a href='https://huggingface.co/microsoft/tapex-large-finetuned-wtq' target='_blank'>TAPEX Model</a></p><center><img src='https://visitor-badge.glitch.me/badge?page_id=abaranovskij_tablequery' alt='visitor badge'></center>"
iface = gr.Interface(fn=execute_query,
inputs=[gr.Textbox(label="Search query"),
gr.File(label="CSV file")],
outputs=[gr.JSON(label="Result"),
gr.Dataframe(label="All data")],
title="Table Question Answering (TAPEX)",
description=description,
article=article,
allow_flagging='never')
# Use this config when running on Docker
# iface.launch(server_name="0.0.0.0", server_port=7000)
iface.launch(enable_queue=True)
if __name__ == "__main__":
main()
|