Tonic commited on
Commit
2f70cad
1 Parent(s): be6c757

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -44
app.py CHANGED
@@ -5,30 +5,24 @@ import os
5
  import spacy
6
  from spacy import displacy
7
 
8
- # Load pre-trained model and tokenizer
9
  model_name = "PleIAs/OCRonos-Vintage"
10
  model = GPT2LMHeadModel.from_pretrained(model_name)
11
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
12
 
13
- # Set the pad token to be the same as the eos token
14
  tokenizer.pad_token = tokenizer.eos_token
15
 
16
- # Set the device to GPU if available, otherwise use CPU
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
  model.to(device)
19
 
20
- # Load spaCy model for dependency parsing
21
  os.system('python -m spacy download en_core_web_sm')
22
  nlp = spacy.load("en_core_web_sm")
23
 
24
- # Function for generating text and tokenizing
25
  def historical_generation(prompt, max_new_tokens=600, top_k=50, temperature=0.7, top_p=0.95, repetition_penalty=1.0):
26
  prompt = f"### Text ###\n{prompt}"
27
  inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
28
  input_ids = inputs["input_ids"].to(device)
29
  attention_mask = inputs["attention_mask"].to(device)
30
 
31
- # Generate text with customizable parameters
32
  output = model.generate(
33
  input_ids,
34
  attention_mask=attention_mask,
@@ -43,26 +37,21 @@ def historical_generation(prompt, max_new_tokens=600, top_k=50, temperature=0.7,
43
  eos_token_id=tokenizer.eos_token_id
44
  )
45
 
46
- # Decode the generated text
47
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
48
 
49
- # Extract text after "### Correction ###"
50
  if "### Correction ###" in generated_text:
51
  generated_text = generated_text.split("### Correction ###")[1].strip()
52
 
53
- # Tokenize the generated text
54
  tokens = tokenizer.tokenize(generated_text)
55
 
56
- # Create highlighted text output, remove "Ġ" from both the token and token_type
57
  highlighted_text = []
58
  for token in tokens:
59
- clean_token = token.replace("Ġ", "") # Remove "Ġ"
60
  token_type = tokenizer.convert_ids_to_tokens([tokenizer.convert_tokens_to_ids(token)])[0].replace("Ġ", "")
61
  highlighted_text.append((clean_token, token_type))
62
 
63
- return highlighted_text, generated_text # Return both tokenized and raw generated text
64
 
65
- # Function for dependency parsing using spaCy
66
  def text_analysis(text):
67
  doc = nlp(text)
68
  html = displacy.render(doc, style="dep", page=True)
@@ -79,63 +68,63 @@ def text_analysis(text):
79
 
80
  return pos_tokens, pos_count, html
81
 
82
- # Function to generate dependency parse for generated text on button click
83
  def generate_dependency_parse(generated_text):
84
  tokens_generated, pos_count_generated, html_generated = text_analysis(generated_text)
85
  return html_generated
86
 
87
- # Full interface combining text generation and analysis, split across steps
 
 
 
88
  def full_interface(prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty):
89
  generated_highlight, generated_text = historical_generation(
90
  prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty
91
  )
92
 
93
- # Dependency parse of input text
94
  tokens_input, pos_count_input, html_input = text_analysis(prompt)
 
95
 
96
- # The "Send" button should now appear after these outputs are generated
97
- return generated_highlight, pos_count_input, html_input, gr.update(visible=True), generated_text, gr.update(visible=False), gr.update(visible=True)
98
-
99
- # Reset function to restore button states
100
  def reset_interface():
101
  return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
102
 
103
- # Gradio interface components
104
- with gr.Blocks() as iface:
105
- prompt = gr.Textbox(label="Prompt", placeholder="Enter a prompt for historical text generation...", lines=3)
106
 
107
- # Slider for model parameters
108
- max_new_tokens = gr.Slider(label="Max New Tokens", minimum=50, maximum=1000, step=50, value=600)
109
- top_k = gr.Slider(label="Top-k Sampling", minimum=1, maximum=100, step=1, value=50)
110
- temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.5, step=0.1, value=0.7)
111
- top_p = gr.Slider(label="Top-p (Nucleus Sampling)", minimum=0.1, maximum=1.0, step=0.05, value=0.95)
112
- repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=0.5, maximum=2.0, step=0.1, value=1.0)
 
 
 
 
 
 
 
 
 
 
113
 
114
  # Output components
115
- highlighted_text = gr.HighlightedText(label="Generated Historical Text", combine_adjacent=True, show_legend=True)
116
- tokenizer_info = gr.JSON(label="Tokenizer Info (Input Text)")
117
- dependency_parse_input = gr.HTML(label="Dependency Parse Visualization (Input Text)")
 
118
 
119
  # Hidden button and final output for dependency parse visualization
120
- send_button = gr.Button(value="Generate Dependency Parse for Generated Text", visible=False)
121
- dependency_parse_generated = gr.HTML(label="Dependency Parse Visualization (Generated Text)")
122
 
123
  # Reset button, hidden initially
124
- reset_button = gr.Button(value="Start Again", visible=False)
125
-
126
- # Button behavior for generating final parse visualization
127
- send_button.click(
128
- generate_dependency_parse,
129
- inputs=[dependency_parse_generated],
130
- outputs=[dependency_parse_generated]
131
- )
132
 
133
  # Main interface logic: when clicked, "Generate" button hides itself and shows the reset button
134
- generate_button = gr.Button(value="Generate Text and Initial Outputs")
135
  generate_button.click(
136
  full_interface,
137
  inputs=[prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty],
138
- outputs=[highlighted_text, tokenizer_info, dependency_parse_input, send_button, dependency_parse_generated, generate_button, reset_button]
139
  )
140
 
141
  # Reset button logic: hide itself and re-show the "Generate" button
 
5
  import spacy
6
  from spacy import displacy
7
 
 
8
  model_name = "PleIAs/OCRonos-Vintage"
9
  model = GPT2LMHeadModel.from_pretrained(model_name)
10
  tokenizer = GPT2Tokenizer.from_pretrained(model_name)
11
 
 
12
  tokenizer.pad_token = tokenizer.eos_token
13
 
 
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
  model.to(device)
16
 
 
17
  os.system('python -m spacy download en_core_web_sm')
18
  nlp = spacy.load("en_core_web_sm")
19
 
 
20
  def historical_generation(prompt, max_new_tokens=600, top_k=50, temperature=0.7, top_p=0.95, repetition_penalty=1.0):
21
  prompt = f"### Text ###\n{prompt}"
22
  inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
23
  input_ids = inputs["input_ids"].to(device)
24
  attention_mask = inputs["attention_mask"].to(device)
25
 
 
26
  output = model.generate(
27
  input_ids,
28
  attention_mask=attention_mask,
 
37
  eos_token_id=tokenizer.eos_token_id
38
  )
39
 
 
40
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
41
 
 
42
  if "### Correction ###" in generated_text:
43
  generated_text = generated_text.split("### Correction ###")[1].strip()
44
 
 
45
  tokens = tokenizer.tokenize(generated_text)
46
 
 
47
  highlighted_text = []
48
  for token in tokens:
49
+ clean_token = token.replace("Ġ", "")
50
  token_type = tokenizer.convert_ids_to_tokens([tokenizer.convert_tokens_to_ids(token)])[0].replace("Ġ", "")
51
  highlighted_text.append((clean_token, token_type))
52
 
53
+ return highlighted_text, generated_text
54
 
 
55
  def text_analysis(text):
56
  doc = nlp(text)
57
  html = displacy.render(doc, style="dep", page=True)
 
68
 
69
  return pos_tokens, pos_count, html
70
 
 
71
  def generate_dependency_parse(generated_text):
72
  tokens_generated, pos_count_generated, html_generated = text_analysis(generated_text)
73
  return html_generated
74
 
75
+ def generate_dependency_parse(generated_text):
76
+ tokens_generated, pos_count_generated, html_generated = text_analysis(generated_text)
77
+ return html_generated
78
+
79
  def full_interface(prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty):
80
  generated_highlight, generated_text = historical_generation(
81
  prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty
82
  )
83
 
 
84
  tokens_input, pos_count_input, html_input = text_analysis(prompt)
85
+ return generated_text, generated_highlight, pos_count_input, html_input, gr.update(visible=True), generated_text, gr.update(visible=False), gr.update(visible=True)
86
 
 
 
 
 
87
  def reset_interface():
88
  return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
89
 
90
+ import gradio as gr
 
 
91
 
92
+ with gr.Blocks(theme=gr.themes.Base()) as iface:
93
+
94
+ gr.Markdown("""
95
+ # Historical Text Generator with Dependency Parse
96
+ This app generates historical-style text using the OCRonos-Vintage model.
97
+ You can customize the generation parameters using the sliders and visualize the tokenized output and dependency parse.
98
+ """)
99
+
100
+ prompt = gr.Textbox(label="Add a passage in the style of historical texts", placeholder="Hi there my name is Tonic and I ride my bicycle along the river Seine:", lines=3)
101
+
102
+ # Sliders for model parameters
103
+ max_new_tokens = gr.Slider(label="Max New Tokens", minimum=50, maximum=1000, step=10, value=140)
104
+ top_k = gr.Slider(label="Top-k Sampling", minimum=1, maximum=100, step=0.05, value=50)
105
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.5, step=0.05, value=0.3)
106
+ top_p = gr.Slider(label="Top-p (Nucleus Sampling)", minimum=0.1, maximum=1.0, step=0.005, value=0.95)
107
+ repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=0.5, maximum=2.0, step=0.05, value=1.0)
108
 
109
  # Output components
110
+ generated_text_output = gr.Textbox(label="🎅🏻⌚OCRonos-Vintage", readonly=True)
111
+ highlighted_text = gr.HighlightedText(label="🎅🏻⌚Tokenized", combine_adjacent=True, show_legend=True)
112
+ tokenizer_info = gr.JSON(label="📉Tokenizer Info (Input Text)")
113
+ dependency_parse_input = gr.HTML(label="👁️Visualization")
114
 
115
  # Hidden button and final output for dependency parse visualization
116
+ send_button = gr.Button(value="👁️Visualize Generated Text", visible=False)
117
+ dependency_parse_generated = gr.HTML(label="👁️Visualization" (Generated Text)")
118
 
119
  # Reset button, hidden initially
120
+ reset_button = gr.Button(value="♻️Start Again", visible=False)
 
 
 
 
 
 
 
121
 
122
  # Main interface logic: when clicked, "Generate" button hides itself and shows the reset button
123
+ generate_button = gr.Button(value="🎅🏻⌚Generate Historical Text")
124
  generate_button.click(
125
  full_interface,
126
  inputs=[prompt, max_new_tokens, top_k, temperature, top_p, repetition_penalty],
127
+ outputs=[generated_text_output, highlighted_text, tokenizer_info, dependency_parse_input, send_button, dependency_parse_generated, generate_button, reset_button]
128
  )
129
 
130
  # Reset button logic: hide itself and re-show the "Generate" button