import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer import gradio as gr import os import spacy from spacy import displacy model_name = "PleIAs/OCRonos-Vintage" model = GPT2LMHeadModel.from_pretrained(model_name) tokenizer = GPT2Tokenizer.from_pretrained(model_name) tokenizer.pad_token = tokenizer.eos_token device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) os.system('python -m spacy download en_core_web_sm') nlp = spacy.load("en_core_web_sm") def historical_generation(prompt, max_new_tokens=600, top_k=50, temperature=0.7, top_p=0.95, repetition_penalty=1.0): prompt = f"### Text ###\n{prompt}" inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024) input_ids = inputs["input_ids"].to(device) attention_mask = inputs["attention_mask"].to(device) output = model.generate( input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, pad_token_id=tokenizer.eos_token_id, top_k=top_k, temperature=temperature, top_p=top_p, do_sample=True, repetition_penalty=repetition_penalty, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id ) generated_text = tokenizer.decode(output[0], skip_special_tokens=True) if "### Correction ###" in generated_text: generated_text = generated_text.split("### Correction ###")[1].strip() tokens = tokenizer.tokenize(generated_text) highlighted_text = [] for token in tokens: clean_token = token.replace("Ġ", "") token_type = tokenizer.convert_ids_to_tokens([tokenizer.convert_tokens_to_ids(token)])[0].replace("Ġ", "") highlighted_text.append((clean_token, token_type)) return highlighted_text, generated_text def text_analysis(text): doc = nlp(text) html = displacy.render(doc, style="dep", page=True) html = ( "
" + html + "
" ) pos_count = { "char_count": len(text), "token_count": len(list(doc)), } pos_tokens = [(token.text, token.pos_) for token in doc] return pos_tokens, pos_count, html def generate_dependency_parse(generated_text): tokens_generated, pos_count_generated, html_generated = text_analysis(generated_text) return html_generated def generate_dependency_parse(generated_text): tokens_generated, pos_count_generated, html_generated = text_analysis(generated_text) return html_generated def full_interface(prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty): generated_highlight, generated_text = historical_generation( prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty ) tokens_input, pos_count_input, html_input = text_analysis(prompt) return generated_text, generated_highlight, pos_count_input, html_input, gr.update(visible=True), generated_text, gr.update(visible=False), gr.update(visible=True) def reset_interface(): return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) import gradio as gr with gr.Blocks(theme=gr.themes.Base()) as iface: gr.Markdown(""" # Historical Text Generator with Dependency Parse This app generates historical-style text using the OCRonos-Vintage model. You can customize the generation parameters using the sliders and visualize the tokenized output and dependency parse. """) prompt = gr.Textbox(label="Add a passage in the style of historical texts", placeholder="Hi there my name is Tonic and I ride my bicycle along the river Seine:", lines=3) max_new_tokens = gr.Slider(label="Max New Tokens", minimum=50, maximum=1000, step=10, value=140) top_k = gr.Slider(label="Top-k Sampling", minimum=1, maximum=100, step=0.05, value=50) temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.5, step=0.05, value=0.3) top_p = gr.Slider(label="Top-p (Nucleus Sampling)", minimum=0.1, maximum=1.0, step=0.005, value=0.95) repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=0.5, maximum=2.0, step=0.05, value=1.0) generated_text_output = gr.Textbox(label="🎅🏻⌚OCRonos-Vintage") highlighted_text = gr.HighlightedText(label="🎅🏻⌚Tokenized", combine_adjacent=True, show_legend=True) tokenizer_info = gr.JSON(label="📉Tokenizer Info (Input Text)") dependency_parse_input = gr.HTML(label="👁️Visualization") send_button = gr.Button(value="🎅🏻⌚OCRonos-Vintage 👁️Visualization", visible=False) dependency_parse_generated = gr.HTML(label="Dependency Parse Visualization (Generated Text)") send_button.click( generate_dependency_parse, inputs=[dependency_parse_generated], outputs=[dependency_parse_generated] ) reset_button = gr.Button(value="♻️Start Again", visible=False) generate_button = gr.Button(value="🎅🏻⌚Generate Historical Text") generate_button.click( full_interface, inputs=[prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty], outputs=[generated_text_output, highlighted_text, tokenizer_info, dependency_parse_input, send_button, dependency_parse_generated, generate_button, reset_button] ) reset_button.click( reset_interface, inputs=None, outputs=[generate_button, send_button, reset_button] ) iface.launch()