File size: 6,913 Bytes
19a489e
4929bc6
 
fa951d7
65bec20
 
 
fa951d7
f2102a7
 
 
 
 
 
 
4929bc6
 
 
fa951d7
8c1d821
 
4929bc6
 
 
d153be5
3582cae
65bec20
19a489e
be6c757
9641cfa
 
 
 
 
8ff5503
9641cfa
 
 
 
 
 
 
 
 
 
 
 
 
4929bc6
9641cfa
4929bc6
9641cfa
 
2ca0200
9641cfa
4929bc6
9641cfa
 
 
 
 
65bec20
9641cfa
 
2ca0200
9641cfa
2ca0200
19a489e
2ca0200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3988e91
 
 
 
7aa880c
aaab817
2f70cad
be6c757
857ca0a
be6c757
 
 
65bec20
857ca0a
65bec20
857ca0a
 
 
 
aaab817
857ca0a
 
6bb1546
65bec20
3582cae
7aa880c
3988e91
2f70cad
 
f2102a7
2f70cad
0b58927
2f70cad
0b58927
 
 
 
 
3988e91
3a49793
2f70cad
 
 
f5eae85
77b8351
aaab817
2f70cad
3988e91
2f70cad
3582cae
3988e91
be6c757
aaab817
3582cae
 
aaab817
 
 
 
 
 
3582cae
 
 
7aa880c
3988e91
 
6dc9b9c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import spaces
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import gradio as gr
import os
import spacy
from spacy import displacy

title = """
    # 🙋🏻‍♂️Welcome to 🌟Tonic's 🎅🏻⌚OCRonos Vintage Text Gen
    This app generates historical-style text using the OCRonos-Vintage model. You can customize the generation parameters using the sliders and visualize the tokenized output and dependency parse. You can see a tokenized visualisation of the output and your input, and learn english using the visualization for the output text!
    ### Join us : 
    🌟TeamTonic🌟 is always making cool demos! Join our active builder's 🛠️community 👻 [![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/qdfnvSPcqP) On 🤗Huggingface:[MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to🌟 [Build Tonic](https://git.tonic-ai.com/contribute)🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗
    """

model_name = "PleIAs/OCRonos-Vintage"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)

tokenizer.pad_token = tokenizer.eos_token

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

os.system('python -m spacy download en_core_web_sm')
nlp = spacy.load("en_core_web_sm")

@spaces.GPU
def historical_generation(prompt, max_new_tokens=600, top_k=50, temperature=0.7, top_p=0.95, repetition_penalty=1.0):
#   with torch.no_grad():
    prompt = f"### Text ###\n{prompt}"
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)
        
    output = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_new_tokens=max_new_tokens,
        pad_token_id=tokenizer.eos_token_id,
        top_k=top_k,
        temperature=temperature,
        top_p=top_p,
        do_sample=True,
        repetition_penalty=repetition_penalty,
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id
    )

    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)

    if "### Correction ###" in generated_text:
        generated_text = generated_text.split("### Correction ###")[1].strip()

    tokens = tokenizer.tokenize(generated_text)

    highlighted_text = []
    for token in tokens:
        clean_token = token.replace("Ġ", "") 
        token_type = tokenizer.convert_ids_to_tokens([tokenizer.convert_tokens_to_ids(token)])[0].replace("Ġ", "")
        highlighted_text.append((clean_token, token_type))

    del inputs, input_ids, attention_mask, output, tokens
    torch.cuda.empty_cache()

    return highlighted_text, generated_text  

@spaces.GPU
def text_analysis(text):
    doc = nlp(text)
    html = displacy.render(doc, style="dep", page=True)
    html = (
        "<div style='max-width:100%; max-height:360px; overflow:auto'>"
        + html
        + "</div>"
    )
    pos_count = {
        "char_count": len(text),
        "token_count": len(list(doc)),
    }
    pos_tokens = [(token.text, token.pos_) for token in doc]

    return pos_tokens, pos_count, html

def generate_dependency_parse(generated_text):
    tokens_generated, pos_count_generated, html_generated = text_analysis(generated_text)
    return html_generated

def display_dependency_parse(generated_text):
    return generate_dependency_parse(generated_text)

def full_interface(prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty):
    # Generate historical-style text and tokenized output
    generated_highlight, generated_text = historical_generation(
        prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty
    )

    # Analyze input text (dependency parse visualization)
    tokens_input, pos_count_input, html_input = text_analysis(prompt)

    # Generate dependency parse for the generated text
    dependency_parse_generated_html = generate_dependency_parse(generated_text)

    # Set the visibility of the generated text and highlight components
    return (generated_text, generated_highlight, pos_count_input, html_input, 
            gr.update(visible=True), dependency_parse_generated_html, 
            gr.update(visible=True), gr.update(visible=False))

def reset_interface():
    return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)

with gr.Blocks(theme=gr.themes.Base()) as iface:  

    gr.Markdown(title)

    prompt = gr.Textbox(label="Add a passage in the style of historical texts", placeholder="Hi there my name is Tonic and I ride my bicycle along the river Seine' he said", lines=2)

    max_new_tokens = gr.Slider(label="📏Length", minimum=50, maximum=1000, step=5, value=320)
    top_k = gr.Slider(label="🧪Sampling", minimum=1, maximum=100, step=1, value=50)
    temperature = gr.Slider(label="🎨Creativity", minimum=0.1, maximum=1, step=0.05, value=0.3)
    top_p = gr.Slider(label="👌🏻Quality", minimum=0.1, maximum=0.99, step=0.01, value=0.97)
    repetition_penalty = gr.Slider(label="🔴Repetition Penalty", minimum=0.5, maximum=2.0, step=0.05, value=1.3)

    generated_text_output = gr.Textbox(label="🎅🏻⌚OCRonos-Vintage")
    highlighted_text = gr.HighlightedText(label="🎅🏻⌚Tokenized", combine_adjacent=True, show_legend=True)
    tokenizer_info = gr.JSON(label="📉Tokenizer Info (Input Text)")
    dependency_parse_input = gr.HTML(label="👁️Visualization")
    dependency_parse_generated = gr.HTML(label="🎅🏻⌚Dependency Parse Visualization (Generated Text)")
    
    send_button = gr.Button(value="🎅🏻⌚OCRonos-Vintage 👁️Visualization", visible=False)
    reset_button = gr.Button(value="♻️Start Again", visible=False)

    generate_button = gr.Button(value="🎅🏻⌚Generate Historical Text")
    generate_button.click(
        full_interface, 
        inputs=[prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty], 
        outputs=[generated_text_output, highlighted_text, tokenizer_info, dependency_parse_input, send_button, dependency_parse_generated, generate_button, reset_button]
    )

    send_button.click(
        display_dependency_parse, 
        inputs=[generated_text_output], 
        outputs=[dependency_parse_generated]
    )
    
    reset_button.click(
        reset_interface, 
        inputs=None, 
        outputs=[generate_button, send_button, reset_button, generated_text_output, highlighted_text, tokenizer_info, dependency_parse_input, dependency_parse_generated]
    )

iface.launch()