OCRonos-TextGen / app.py
Tonic's picture
Update app.py
48277a2 verified
raw
history blame
5.47 kB
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import gradio as gr
import os
import spacy
from spacy import displacy
model_name = "PleIAs/OCRonos-Vintage"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
os.system('python -m spacy download en_core_web_sm')
nlp = spacy.load("en_core_web_sm")
def historical_generation(prompt, max_new_tokens=600, top_k=50, temperature=0.7, top_p=0.95, repetition_penalty=1.0):
prompt = f"### Text ###\n{prompt}"
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
output = model.generate(
input_ids,
attention_mask=attention_mask,
max_new_tokens=max_new_tokens,
pad_token_id=tokenizer.eos_token_id,
top_k=top_k,
temperature=temperature,
top_p=top_p,
do_sample=True,
repetition_penalty=repetition_penalty,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
if "### Correction ###" in generated_text:
generated_text = generated_text.split("### Correction ###")[1].strip()
tokens = tokenizer.tokenize(generated_text)
highlighted_text = []
for token in tokens:
clean_token = token.replace("Ġ", "")
token_type = tokenizer.convert_ids_to_tokens([tokenizer.convert_tokens_to_ids(token)])[0].replace("Ġ", "")
highlighted_text.append((clean_token, token_type))
return highlighted_text, generated_text
def text_analysis(text):
doc = nlp(text)
html = displacy.render(doc, style="dep", page=True)
html = (
"<div style='max-width:100%; max-height:360px; overflow:auto'>"
+ html
+ "</div>"
)
pos_count = {
"char_count": len(text),
"token_count": len(list(doc)),
}
pos_tokens = [(token.text, token.pos_) for token in doc]
return pos_tokens, pos_count, html
def generate_dependency_parse(generated_text):
tokens_generated, pos_count_generated, html_generated = text_analysis(generated_text)
return html_generated
def generate_dependency_parse(generated_text):
tokens_generated, pos_count_generated, html_generated = text_analysis(generated_text)
return html_generated
def full_interface(prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty):
generated_highlight, generated_text = historical_generation(
prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty
)
tokens_input, pos_count_input, html_input = text_analysis(prompt)
return generated_text, generated_highlight, pos_count_input, html_input, gr.update(visible=True), generated_text, gr.update(visible=False), gr.update(visible=True)
def reset_interface():
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
import gradio as gr
with gr.Blocks(theme=gr.themes.Base()) as iface:
gr.Markdown("""
# Historical Text Generator with Dependency Parse
This app generates historical-style text using the OCRonos-Vintage model.
You can customize the generation parameters using the sliders and visualize the tokenized output and dependency parse.
""")
prompt = gr.Textbox(label="Add a passage in the style of historical texts", placeholder="Hi there my name is Tonic and I ride my bicycle along the river Seine:", lines=3)
max_new_tokens = gr.Slider(label="Max New Tokens", minimum=50, maximum=1000, step=10, value=140)
top_k = gr.Slider(label="Top-k Sampling", minimum=1, maximum=100, step=0.05, value=50)
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.5, step=0.05, value=0.3)
top_p = gr.Slider(label="Top-p (Nucleus Sampling)", minimum=0.1, maximum=1.0, step=0.005, value=0.95)
repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=0.5, maximum=2.0, step=0.05, value=1.0)
generated_text_output = gr.Textbox(label="🎅🏻⌚OCRonos-Vintage")
highlighted_text = gr.HighlightedText(label="🎅🏻⌚Tokenized", combine_adjacent=True, show_legend=True)
tokenizer_info = gr.JSON(label="📉Tokenizer Info (Input Text)")
dependency_parse_input = gr.HTML(label="👁️Visualization")
send_button = gr.Button(generate_dependency_parse , value="👁️Visualize Generated Text", visible=False , inputs=[dependency_parse_generated] , outputs=[dependency_parse_generated])
dependency_parse_generated = gr.HTML(label="👁️Visualization (Generated Text)")
reset_button = gr.Button(value="♻️Start Again", visible=False)
generate_button = gr.Button(value="🎅🏻⌚Generate Historical Text")
generate_button.click(
full_interface,
inputs=[prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty],
outputs=[generated_text_output, highlighted_text, tokenizer_info, dependency_parse_input, send_button, dependency_parse_generated, generate_button, reset_button]
)
reset_button.click(
reset_interface,
inputs=None,
outputs=[generate_button, send_button, reset_button]
)
iface.launch()