Nos_D2T-gl / generate_text.py
gcjavi's picture
Upload 8 files
bbaf732
raw
history blame contribute delete
No virus
1.97 kB
from transformers import T5ForConditionalGeneration, T5Tokenizer, AutoTokenizer
import pandas as pd
import os
import nltk
import string
import math
import sys
import argparse
import random
"""# Modelo T5
Importamos o modelo preadestrado
"""
"""# Corpus
#J# Leemos nuestro dataset.
"""
test_split = pd.read_csv('./test-dataset.csv', encoding="latin-1")
test_split= test_split.reset_index()
def generate(text):
print("Tokenizing sequence...")
x = tokenizer(text, return_tensors='pt', padding=True).to(model.device)
print("Generating description...")
out = model.generate(**x, do_sample=False, num_beams=10, max_new_tokens = 50)
return tokenizer.decode(out[0], skip_special_tokens=True)
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input_table", type=int, default=280, required=False, help="Specify data ID")
parser.add_argument("-o", "--output", type=str, default="./", required=False, help="Specify output path")
args = parser.parse_args()
data_id = args.input_table
output_path = args.output
if data_id not in range(0, 569):
sys.exit("ERROR: ID must be in the range [0,568] (testing IDs)")
#J# cargamos el modelo pre-entrenado que queramos, junto con su tokenizador
print("Loading model...")
model = T5ForConditionalGeneration.from_pretrained('data2text_gl_v1')
tokenizer = T5Tokenizer.from_pretrained("data2text_gl_v1")
print("Loading data... (dataset-id: " + str(test_split.id[int(data_id)]) + ")")
data = test_split.table[int(data_id)]
gold = test_split.caption[int(data_id)]
generation = generate(data)
img_id = str(test_split.id[int(data_id)])
pattern = "- Test ID: {} (DB id: {})\n- Data table: {}\n- Generated text: {}\n- Gold text: {}"
print(pattern.format(data_id, img_id, data[0:100] + "... </table>", generation, gold))
with open(output_path + "generated_"+ str(data_id) + ".txt", "w") as output_file:
output_file.write(pattern.format(data_id, img_id, data, generation, gold))
output_file.close()