OCRonos-TextGen / app.py
Tonic's picture
Update app.py
0b58927 verified
import spaces
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import gradio as gr
import os
import spacy
from spacy import displacy
title = """
# 🙋🏻‍♂️Welcome to 🌟Tonic's 🎅🏻⌚OCRonos Vintage Text Gen
This app generates historical-style text using the OCRonos-Vintage model. You can customize the generation parameters using the sliders and visualize the tokenized output and dependency parse. You can see a tokenized visualisation of the output and your input, and learn english using the visualization for the output text!
### Join us :
🌟TeamTonic🌟 is always making cool demos! Join our active builder's 🛠️community 👻 [![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/qdfnvSPcqP) On 🤗Huggingface:[MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to🌟 [Build Tonic](https://git.tonic-ai.com/contribute)🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗
"""
model_name = "PleIAs/OCRonos-Vintage"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
os.system('python -m spacy download en_core_web_sm')
nlp = spacy.load("en_core_web_sm")
@spaces.GPU
def historical_generation(prompt, max_new_tokens=600, top_k=50, temperature=0.7, top_p=0.95, repetition_penalty=1.0):
# with torch.no_grad():
prompt = f"### Text ###\n{prompt}"
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
output = model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
pad_token_id=tokenizer.eos_token_id,
top_k=top_k,
temperature=temperature,
top_p=top_p,
do_sample=True,
repetition_penalty=repetition_penalty,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
if "### Correction ###" in generated_text:
generated_text = generated_text.split("### Correction ###")[1].strip()
tokens = tokenizer.tokenize(generated_text)
highlighted_text = []
for token in tokens:
clean_token = token.replace("Ġ", "")
token_type = tokenizer.convert_ids_to_tokens([tokenizer.convert_tokens_to_ids(token)])[0].replace("Ġ", "")
highlighted_text.append((clean_token, token_type))
del inputs, input_ids, attention_mask, output, tokens
torch.cuda.empty_cache()
return highlighted_text, generated_text
@spaces.GPU
def text_analysis(text):
doc = nlp(text)
html = displacy.render(doc, style="dep", page=True)
html = (
"<div style='max-width:100%; max-height:360px; overflow:auto'>"
+ html
+ "</div>"
)
pos_count = {
"char_count": len(text),
"token_count": len(list(doc)),
}
pos_tokens = [(token.text, token.pos_) for token in doc]
return pos_tokens, pos_count, html
def generate_dependency_parse(generated_text):
tokens_generated, pos_count_generated, html_generated = text_analysis(generated_text)
return html_generated
def display_dependency_parse(generated_text):
return generate_dependency_parse(generated_text)
def full_interface(prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty):
# Generate historical-style text and tokenized output
generated_highlight, generated_text = historical_generation(
prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty
)
# Analyze input text (dependency parse visualization)
tokens_input, pos_count_input, html_input = text_analysis(prompt)
# Generate dependency parse for the generated text
dependency_parse_generated_html = generate_dependency_parse(generated_text)
# Set the visibility of the generated text and highlight components
return (generated_text, generated_highlight, pos_count_input, html_input,
gr.update(visible=True), dependency_parse_generated_html,
gr.update(visible=True), gr.update(visible=False))
def reset_interface():
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
with gr.Blocks(theme=gr.themes.Base()) as iface:
gr.Markdown(title)
prompt = gr.Textbox(label="Add a passage in the style of historical texts", placeholder="Hi there my name is Tonic and I ride my bicycle along the river Seine' he said", lines=2)
max_new_tokens = gr.Slider(label="📏Length", minimum=50, maximum=1000, step=5, value=320)
top_k = gr.Slider(label="🧪Sampling", minimum=1, maximum=100, step=1, value=50)
temperature = gr.Slider(label="🎨Creativity", minimum=0.1, maximum=1, step=0.05, value=0.3)
top_p = gr.Slider(label="👌🏻Quality", minimum=0.1, maximum=0.99, step=0.01, value=0.97)
repetition_penalty = gr.Slider(label="🔴Repetition Penalty", minimum=0.5, maximum=2.0, step=0.05, value=1.3)
generated_text_output = gr.Textbox(label="🎅🏻⌚OCRonos-Vintage")
highlighted_text = gr.HighlightedText(label="🎅🏻⌚Tokenized", combine_adjacent=True, show_legend=True)
tokenizer_info = gr.JSON(label="📉Tokenizer Info (Input Text)")
dependency_parse_input = gr.HTML(label="👁️Visualization")
dependency_parse_generated = gr.HTML(label="🎅🏻⌚Dependency Parse Visualization (Generated Text)")
send_button = gr.Button(value="🎅🏻⌚OCRonos-Vintage 👁️Visualization", visible=False)
reset_button = gr.Button(value="♻️Start Again", visible=False)
generate_button = gr.Button(value="🎅🏻⌚Generate Historical Text")
generate_button.click(
full_interface,
inputs=[prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty],
outputs=[generated_text_output, highlighted_text, tokenizer_info, dependency_parse_input, send_button, dependency_parse_generated, generate_button, reset_button]
)
send_button.click(
display_dependency_parse,
inputs=[generated_text_output],
outputs=[dependency_parse_generated]
)
reset_button.click(
reset_interface,
inputs=None,
outputs=[generate_button, send_button, reset_button, generated_text_output, highlighted_text, tokenizer_info, dependency_parse_input, dependency_parse_generated]
)
iface.launch()