Spaces:
Running
Running
import torch | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
import gradio as gr | |
import os | |
import spacy | |
from spacy import displacy | |
model_name = "PleIAs/OCRonos-Vintage" | |
model = GPT2LMHeadModel.from_pretrained(model_name) | |
tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
tokenizer.pad_token = tokenizer.eos_token | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
os.system('python -m spacy download en_core_web_sm') | |
nlp = spacy.load("en_core_web_sm") | |
def historical_generation(prompt, max_new_tokens=600, top_k=50, temperature=0.7, top_p=0.95, repetition_penalty=1.0): | |
prompt = f"### Text ###\n{prompt}" | |
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024) | |
input_ids = inputs["input_ids"].to(device) | |
attention_mask = inputs["attention_mask"].to(device) | |
output = model.generate( | |
input_ids, | |
attention_mask=attention_mask, | |
max_new_tokens=max_new_tokens, | |
pad_token_id=tokenizer.eos_token_id, | |
top_k=top_k, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True, | |
repetition_penalty=repetition_penalty, | |
bos_token_id=tokenizer.bos_token_id, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
if "### Correction ###" in generated_text: | |
generated_text = generated_text.split("### Correction ###")[1].strip() | |
tokens = tokenizer.tokenize(generated_text) | |
highlighted_text = [] | |
for token in tokens: | |
clean_token = token.replace("Ġ", "") | |
token_type = tokenizer.convert_ids_to_tokens([tokenizer.convert_tokens_to_ids(token)])[0].replace("Ġ", "") | |
highlighted_text.append((clean_token, token_type)) | |
return highlighted_text, generated_text | |
def text_analysis(text): | |
doc = nlp(text) | |
html = displacy.render(doc, style="dep", page=True) | |
html = ( | |
"<div style='max-width:100%; max-height:360px; overflow:auto'>" | |
+ html | |
+ "</div>" | |
) | |
pos_count = { | |
"char_count": len(text), | |
"token_count": len(list(doc)), | |
} | |
pos_tokens = [(token.text, token.pos_) for token in doc] | |
return pos_tokens, pos_count, html | |
def generate_dependency_parse(generated_text): | |
tokens_generated, pos_count_generated, html_generated = text_analysis(generated_text) | |
return html_generated | |
def generate_dependency_parse(generated_text): | |
tokens_generated, pos_count_generated, html_generated = text_analysis(generated_text) | |
return html_generated | |
def full_interface(prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty): | |
generated_highlight, generated_text = historical_generation( | |
prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty | |
) | |
tokens_input, pos_count_input, html_input = text_analysis(prompt) | |
return generated_text, generated_highlight, pos_count_input, html_input, gr.update(visible=True), generated_text, gr.update(visible=False), gr.update(visible=True) | |
def reset_interface(): | |
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) | |
import gradio as gr | |
with gr.Blocks(theme=gr.themes.Base()) as iface: | |
gr.Markdown(""" | |
# Historical Text Generator with Dependency Parse | |
This app generates historical-style text using the OCRonos-Vintage model. | |
You can customize the generation parameters using the sliders and visualize the tokenized output and dependency parse. | |
""") | |
prompt = gr.Textbox(label="Add a passage in the style of historical texts", placeholder="Hi there my name is Tonic and I ride my bicycle along the river Seine:", lines=3) | |
# Sliders for model parameters | |
max_new_tokens = gr.Slider(label="Max New Tokens", minimum=50, maximum=1000, step=10, value=140) | |
top_k = gr.Slider(label="Top-k Sampling", minimum=1, maximum=100, step=0.05, value=50) | |
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.5, step=0.05, value=0.3) | |
top_p = gr.Slider(label="Top-p (Nucleus Sampling)", minimum=0.1, maximum=1.0, step=0.005, value=0.95) | |
repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=0.5, maximum=2.0, step=0.05, value=1.0) | |
# Output components | |
generated_text_output = gr.Textbox(label="🎅🏻⌚OCRonos-Vintage", readonly=True) | |
highlighted_text = gr.HighlightedText(label="🎅🏻⌚Tokenized", combine_adjacent=True, show_legend=True) | |
tokenizer_info = gr.JSON(label="📉Tokenizer Info (Input Text)") | |
dependency_parse_input = gr.HTML(label="👁️Visualization") | |
# Hidden button and final output for dependency parse visualization | |
send_button = gr.Button(value="👁️Visualize Generated Text", visible=False) | |
dependency_parse_generated = gr.HTML(label="👁️Visualization" (Generated Text)") | |
# Reset button, hidden initially | |
reset_button = gr.Button(value="♻️Start Again", visible=False) | |
# Main interface logic: when clicked, "Generate" button hides itself and shows the reset button | |
generate_button = gr.Button(value="🎅🏻⌚Generate Historical Text") | |
generate_button.click( | |
full_interface, | |
inputs=[prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty], | |
outputs=[generated_text_output, highlighted_text, tokenizer_info, dependency_parse_input, send_button, dependency_parse_generated, generate_button, reset_button] | |
) | |
# Reset button logic: hide itself and re-show the "Generate" button | |
reset_button.click( | |
reset_interface, | |
inputs=None, | |
outputs=[generate_button, send_button, reset_button] | |
) | |
iface.launch() | |