Pclanglais commited on
Commit
cd9ce00
·
verified ·
1 Parent(s): d731e09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -63
app.py CHANGED
@@ -16,16 +16,19 @@ import difflib
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
  # OCR Correction Model
19
- ocr_model_name = "PleIAs/OCRonos"
20
- ocr_llm = LLM(ocr_model_name, max_model_len=8128)
21
 
22
- # Editorial Segmentation Model
23
- editorial_model = "PleIAs/Segmentext"
24
- token_classifier = pipeline(
25
- "token-classification", model=editorial_model, aggregation_strategy="simple", device=device
26
- )
 
 
27
 
28
- tokenizer = AutoTokenizer.from_pretrained(editorial_model, model_max_length=512)
 
 
29
 
30
  # CSS for formatting
31
  css = """
@@ -163,30 +166,25 @@ def split_text(text, max_tokens=500):
163
 
164
  return chunks
165
 
166
- def transform_chunks(marianne_segmentation):
167
- marianne_segmentation = pd.DataFrame(marianne_segmentation)
168
- marianne_segmentation = marianne_segmentation[marianne_segmentation['entity_group'] != 'separator']
169
- marianne_segmentation['word'] = marianne_segmentation['word'].astype(str).str.replace('¶', '\n', regex=False)
170
- marianne_segmentation['word'] = marianne_segmentation['word'].astype(str).apply(preprocess_text)
171
- marianne_segmentation = marianne_segmentation[marianne_segmentation['word'].notna() & (marianne_segmentation['word'] != '') & (marianne_segmentation['word'] != ' ')]
172
 
173
- html_output = []
174
- for _, row in marianne_segmentation.iterrows():
175
- entity_group = row['entity_group']
176
- result_entity = "[" + entity_group.capitalize() + "]"
177
- word = row['word']
178
-
179
- if entity_group == 'title':
180
- html_output.append(f'<div class="manuscript"><div class="annotation">{result_entity}</div><div class="content title-content"><h2>{word}</h2></div></div>')
181
- elif entity_group == 'bibliography':
182
- html_output.append(f'<div class="manuscript"><div class="annotation">{result_entity}</div><div class="content bibliography-content">{word}</div></div>')
183
- elif entity_group == 'paratext':
184
- html_output.append(f'<div class="manuscript"><div class="annotation">{result_entity}</div><div class="content paratext-content">{word}</div></div>')
185
- else:
186
- html_output.append(f'<div class="manuscript"><div class="annotation">{result_entity}</div><div class="content">{word}</div></div>')
187
 
188
- final_html = '\n'.join(html_output)
189
- return final_html
 
 
 
 
190
 
191
  # OCR Correction Class
192
  class OCRCorrector:
@@ -195,53 +193,25 @@ class OCRCorrector:
195
 
196
  def correct(self, user_message):
197
  sampling_params = SamplingParams(temperature=0.9, top_p=0.95, max_tokens=4000, presence_penalty=0, stop=["#END#"])
198
- detailed_prompt = f"### TEXT ###\n{user_message}\n\n### CORRECTION ###\n"
199
- prompts = [detailed_prompt]
200
- outputs = ocr_llm.generate(prompts, sampling_params, use_tqdm=False)
201
- generated_text = outputs[0].outputs[0].text
202
  html_diff = generate_html_diff(user_message, generated_text)
203
  return generated_text, html_diff
204
 
205
- # Editorial Segmentation Class
206
- class EditorialSegmenter:
207
- def segment(self, text):
208
- editorial_text = re.sub("\n", " ¶ ", text)
209
- num_tokens = len(tokenizer.tokenize(editorial_text))
210
-
211
- if num_tokens > 500:
212
- batch_prompts = split_text(editorial_text, max_tokens=500)
213
- else:
214
- batch_prompts = [editorial_text]
215
-
216
- out = token_classifier(batch_prompts)
217
- classified_list = []
218
- for classification in out:
219
- df = pd.DataFrame(classification)
220
- classified_list.append(df)
221
-
222
- classified_list = pd.concat(classified_list)
223
- out = transform_chunks(classified_list)
224
- return out
225
-
226
  # Combined Processing Class
227
  class TextProcessor:
228
  def __init__(self):
229
  self.ocr_corrector = OCRCorrector()
230
- self.editorial_segmenter = EditorialSegmenter()
231
 
232
  @spaces.GPU(duration=120)
233
  def process(self, user_message):
234
- # Step 1: OCR Correction
235
  corrected_text, html_diff = self.ocr_corrector.correct(user_message)
236
 
237
- # Step 2: Editorial Segmentation
238
- segmented_text = self.editorial_segmenter.segment(corrected_text)
239
-
240
  # Combine results
241
  ocr_result = f'<h2 style="text-align:center">OCR Correction</h2>\n<div class="generation">{html_diff}</div>'
242
- editorial_result = f'<h2 style="text-align:center">Editorial Segmentation</h2>\n<div class="generation">{segmented_text}</div>'
243
 
244
- final_output = f"{css}{ocr_result}<br><br>{editorial_result}"
245
  return final_output
246
 
247
  # Create the TextProcessor instance
@@ -249,7 +219,7 @@ text_processor = TextProcessor()
249
 
250
  # Define the Gradio interface
251
  with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo:
252
- gr.HTML("""<h1 style="text-align:center">PleIAs Editor</h1>""")
253
  text_input = gr.Textbox(label="Your (bad?) text", type="text", lines=5)
254
  process_button = gr.Button("Process Text")
255
  text_output = gr.HTML(label="Processed text")
 
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
  # OCR Correction Model
19
+ ocr_model_name = "PleIAs/OCRonos-Vintage"
 
20
 
21
+ import torch
22
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
23
+
24
+ # Load pre-trained model and tokenizer
25
+ model_name = "PleIAs/OCRonos-Vintage"
26
+ model = GPT2LMHeadModel.from_pretrained(model_name)
27
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
28
 
29
+ # Set the device to GPU if available, otherwise use CPU
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ model.to(device)
32
 
33
  # CSS for formatting
34
  css = """
 
166
 
167
  return chunks
168
 
 
 
 
 
 
 
169
 
170
+ # Function to generate text
171
+ def ocr_correction(prompt, max_new_tokens=600):
172
+
173
+ prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n"""
174
+ input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
175
+
176
+ # Generate text
177
+ output = model.generate(input_ids,
178
+ max_new_tokens=max_new_tokens,
179
+ pad_token_id=tokenizer.eos_token_id,
180
+ top_k=50)
 
 
 
181
 
182
+ # Decode and return the generated text
183
+ result = tokenizer.decode(output[0], skip_special_tokens=True)
184
+
185
+ result = result.split("### Correction ###\n")[1]
186
+
187
+ return result
188
 
189
  # OCR Correction Class
190
  class OCRCorrector:
 
193
 
194
  def correct(self, user_message):
195
  sampling_params = SamplingParams(temperature=0.9, top_p=0.95, max_tokens=4000, presence_penalty=0, stop=["#END#"])
196
+ detailed_prompt = f"### Text ###\n{user_message}\n\n### Correction ###\n"
197
+ generated_text = ocr_correction(detailed_prompt)
 
 
198
  html_diff = generate_html_diff(user_message, generated_text)
199
  return generated_text, html_diff
200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  # Combined Processing Class
202
  class TextProcessor:
203
  def __init__(self):
204
  self.ocr_corrector = OCRCorrector()
 
205
 
206
  @spaces.GPU(duration=120)
207
  def process(self, user_message):
208
+ #OCR Correction
209
  corrected_text, html_diff = self.ocr_corrector.correct(user_message)
210
 
 
 
 
211
  # Combine results
212
  ocr_result = f'<h2 style="text-align:center">OCR Correction</h2>\n<div class="generation">{html_diff}</div>'
 
213
 
214
+ final_output = f"{css}{ocr_result}"
215
  return final_output
216
 
217
  # Create the TextProcessor instance
 
219
 
220
  # Define the Gradio interface
221
  with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo:
222
+ gr.HTML("""<h1 style="text-align:center">Vintage OCR corrector</h1>""")
223
  text_input = gr.Textbox(label="Your (bad?) text", type="text", lines=5)
224
  process_button = gr.Button("Process Text")
225
  text_output = gr.HTML(label="Processed text")