OCRonos-TextGen / app.py
Tonic's picture
Update app.py
0b58927 verified
raw
history blame
6.91 kB
import spaces
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import gradio as gr
import os
import spacy
from spacy import displacy
title = """
# 🙋🏻‍♂️Welcome to 🌟Tonic's 🎅🏻⌚OCRonos Vintage Text Gen
This app generates historical-style text using the OCRonos-Vintage model. You can customize the generation parameters using the sliders and visualize the tokenized output and dependency parse. You can see a tokenized visualisation of the output and your input, and learn english using the visualization for the output text!
### Join us :
🌟TeamTonic🌟 is always making cool demos! Join our active builder's 🛠️community 👻 [![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/qdfnvSPcqP) On 🤗Huggingface:[MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to🌟 [Build Tonic](https://git.tonic-ai.com/contribute)🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗
"""
model_name = "PleIAs/OCRonos-Vintage"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
os.system('python -m spacy download en_core_web_sm')
nlp = spacy.load("en_core_web_sm")
@spaces.GPU
def historical_generation(prompt, max_new_tokens=600, top_k=50, temperature=0.7, top_p=0.95, repetition_penalty=1.0):
# with torch.no_grad():
prompt = f"### Text ###\n{prompt}"
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
output = model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
pad_token_id=tokenizer.eos_token_id,
top_k=top_k,
temperature=temperature,
top_p=top_p,
do_sample=True,
repetition_penalty=repetition_penalty,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
if "### Correction ###" in generated_text:
generated_text = generated_text.split("### Correction ###")[1].strip()
tokens = tokenizer.tokenize(generated_text)
highlighted_text = []
for token in tokens:
clean_token = token.replace("Ġ", "")
token_type = tokenizer.convert_ids_to_tokens([tokenizer.convert_tokens_to_ids(token)])[0].replace("Ġ", "")
highlighted_text.append((clean_token, token_type))
del inputs, input_ids, attention_mask, output, tokens
torch.cuda.empty_cache()
return highlighted_text, generated_text
@spaces.GPU
def text_analysis(text):
doc = nlp(text)
html = displacy.render(doc, style="dep", page=True)
html = (
"<div style='max-width:100%; max-height:360px; overflow:auto'>"
+ html
+ "</div>"
)
pos_count = {
"char_count": len(text),
"token_count": len(list(doc)),
}
pos_tokens = [(token.text, token.pos_) for token in doc]
return pos_tokens, pos_count, html
def generate_dependency_parse(generated_text):
tokens_generated, pos_count_generated, html_generated = text_analysis(generated_text)
return html_generated
def display_dependency_parse(generated_text):
return generate_dependency_parse(generated_text)
def full_interface(prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty):
# Generate historical-style text and tokenized output
generated_highlight, generated_text = historical_generation(
prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty
)
# Analyze input text (dependency parse visualization)
tokens_input, pos_count_input, html_input = text_analysis(prompt)
# Generate dependency parse for the generated text
dependency_parse_generated_html = generate_dependency_parse(generated_text)
# Set the visibility of the generated text and highlight components
return (generated_text, generated_highlight, pos_count_input, html_input,
gr.update(visible=True), dependency_parse_generated_html,
gr.update(visible=True), gr.update(visible=False))
def reset_interface():
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
with gr.Blocks(theme=gr.themes.Base()) as iface:
gr.Markdown(title)
prompt = gr.Textbox(label="Add a passage in the style of historical texts", placeholder="Hi there my name is Tonic and I ride my bicycle along the river Seine' he said", lines=2)
max_new_tokens = gr.Slider(label="📏Length", minimum=50, maximum=1000, step=5, value=320)
top_k = gr.Slider(label="🧪Sampling", minimum=1, maximum=100, step=1, value=50)
temperature = gr.Slider(label="🎨Creativity", minimum=0.1, maximum=1, step=0.05, value=0.3)
top_p = gr.Slider(label="👌🏻Quality", minimum=0.1, maximum=0.99, step=0.01, value=0.97)
repetition_penalty = gr.Slider(label="🔴Repetition Penalty", minimum=0.5, maximum=2.0, step=0.05, value=1.3)
generated_text_output = gr.Textbox(label="🎅🏻⌚OCRonos-Vintage")
highlighted_text = gr.HighlightedText(label="🎅🏻⌚Tokenized", combine_adjacent=True, show_legend=True)
tokenizer_info = gr.JSON(label="📉Tokenizer Info (Input Text)")
dependency_parse_input = gr.HTML(label="👁️Visualization")
dependency_parse_generated = gr.HTML(label="🎅🏻⌚Dependency Parse Visualization (Generated Text)")
send_button = gr.Button(value="🎅🏻⌚OCRonos-Vintage 👁️Visualization", visible=False)
reset_button = gr.Button(value="♻️Start Again", visible=False)
generate_button = gr.Button(value="🎅🏻⌚Generate Historical Text")
generate_button.click(
full_interface,
inputs=[prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty],
outputs=[generated_text_output, highlighted_text, tokenizer_info, dependency_parse_input, send_button, dependency_parse_generated, generate_button, reset_button]
)
send_button.click(
display_dependency_parse,
inputs=[generated_text_output],
outputs=[dependency_parse_generated]
)
reset_button.click(
reset_interface,
inputs=None,
outputs=[generate_button, send_button, reset_button, generated_text_output, highlighted_text, tokenizer_info, dependency_parse_input, dependency_parse_generated]
)
iface.launch()