PleIAs-Editor / app.py
Pclanglais's picture
Update app.py
d731e09 verified
raw
history blame
8.38 kB
import spaces
import transformers
import re
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM, pipeline
from vllm import LLM, SamplingParams
import torch
import gradio as gr
import json
import os
import shutil
import requests
import pandas as pd
import difflib
# Define the device
device = "cuda" if torch.cuda.is_available() else "cpu"
# OCR Correction Model
ocr_model_name = "PleIAs/OCRonos"
ocr_llm = LLM(ocr_model_name, max_model_len=8128)
# Editorial Segmentation Model
editorial_model = "PleIAs/Segmentext"
token_classifier = pipeline(
"token-classification", model=editorial_model, aggregation_strategy="simple", device=device
)
tokenizer = AutoTokenizer.from_pretrained(editorial_model, model_max_length=512)
# CSS for formatting
css = """
<style>
.generation {
margin-left: 2em;
margin-right: 2em;
font-size: 1.2em;
}
:target {
background-color: #CCF3DF;
}
.source {
float: left;
max-width: 17%;
margin-left: 2%;
}
.tooltip {
position: relative;
cursor: pointer;
font-variant-position: super;
color: #97999b;
}
.tooltip:hover::after {
content: attr(data-text);
position: absolute;
left: 0;
top: 120%;
white-space: pre-wrap;
width: 500px;
max-width: 500px;
z-index: 1;
background-color: #f9f9f9;
color: #000;
border: 1px solid #ddd;
border-radius: 5px;
padding: 5px;
display: block;
box-shadow: 0 4px 8px rgba(0,0,0,0.1);
}
.deleted {
background-color: #ffcccb;
text-decoration: line-through;
}
.inserted {
background-color: #90EE90;
}
.manuscript {
display: flex;
margin-bottom: 10px;
align-items: baseline;
}
.annotation {
width: 15%;
padding-right: 20px;
color: grey !important;
font-style: italic;
text-align: right;
}
.content {
width: 80%;
}
h2 {
margin: 0;
font-size: 1.5em;
}
.title-content h2 {
font-weight: bold;
}
.bibliography-content {
color: darkgreen !important;
margin-top: -5px;
}
.paratext-content {
color: #a4a4a4 !important;
margin-top: -5px;
}
</style>
"""
# Helper functions
def generate_html_diff(old_text, new_text):
d = difflib.Differ()
diff = list(d.compare(old_text.split(), new_text.split()))
html_diff = []
for word in diff:
if word.startswith(' '):
html_diff.append(word[2:])
elif word.startswith('+ '):
html_diff.append(f'<span style="background-color: #90EE90;">{word[2:]}</span>')
return ' '.join(html_diff)
def preprocess_text(text):
text = re.sub(r'<[^>]+>', '', text)
text = re.sub(r'\n', ' ', text)
text = re.sub(r'\s+', ' ', text)
return text.strip()
def split_text(text, max_tokens=500):
parts = text.split("\n")
chunks = []
current_chunk = ""
for part in parts:
if current_chunk:
temp_chunk = current_chunk + "\n" + part
else:
temp_chunk = part
num_tokens = len(tokenizer.tokenize(temp_chunk))
if num_tokens <= max_tokens:
current_chunk = temp_chunk
else:
if current_chunk:
chunks.append(current_chunk)
current_chunk = part
if current_chunk:
chunks.append(current_chunk)
if len(chunks) == 1 and len(tokenizer.tokenize(chunks[0])) > max_tokens:
long_text = chunks[0]
chunks = []
while len(tokenizer.tokenize(long_text)) > max_tokens:
split_point = len(long_text) // 2
while split_point < len(long_text) and not re.match(r'\s', long_text[split_point]):
split_point += 1
if split_point >= len(long_text):
split_point = len(long_text) - 1
chunks.append(long_text[:split_point].strip())
long_text = long_text[split_point:].strip()
if long_text:
chunks.append(long_text)
return chunks
def transform_chunks(marianne_segmentation):
marianne_segmentation = pd.DataFrame(marianne_segmentation)
marianne_segmentation = marianne_segmentation[marianne_segmentation['entity_group'] != 'separator']
marianne_segmentation['word'] = marianne_segmentation['word'].astype(str).str.replace('¶', '\n', regex=False)
marianne_segmentation['word'] = marianne_segmentation['word'].astype(str).apply(preprocess_text)
marianne_segmentation = marianne_segmentation[marianne_segmentation['word'].notna() & (marianne_segmentation['word'] != '') & (marianne_segmentation['word'] != ' ')]
html_output = []
for _, row in marianne_segmentation.iterrows():
entity_group = row['entity_group']
result_entity = "[" + entity_group.capitalize() + "]"
word = row['word']
if entity_group == 'title':
html_output.append(f'<div class="manuscript"><div class="annotation">{result_entity}</div><div class="content title-content"><h2>{word}</h2></div></div>')
elif entity_group == 'bibliography':
html_output.append(f'<div class="manuscript"><div class="annotation">{result_entity}</div><div class="content bibliography-content">{word}</div></div>')
elif entity_group == 'paratext':
html_output.append(f'<div class="manuscript"><div class="annotation">{result_entity}</div><div class="content paratext-content">{word}</div></div>')
else:
html_output.append(f'<div class="manuscript"><div class="annotation">{result_entity}</div><div class="content">{word}</div></div>')
final_html = '\n'.join(html_output)
return final_html
# OCR Correction Class
class OCRCorrector:
def __init__(self, system_prompt="Le dialogue suivant est une conversation"):
self.system_prompt = system_prompt
def correct(self, user_message):
sampling_params = SamplingParams(temperature=0.9, top_p=0.95, max_tokens=4000, presence_penalty=0, stop=["#END#"])
detailed_prompt = f"### TEXT ###\n{user_message}\n\n### CORRECTION ###\n"
prompts = [detailed_prompt]
outputs = ocr_llm.generate(prompts, sampling_params, use_tqdm=False)
generated_text = outputs[0].outputs[0].text
html_diff = generate_html_diff(user_message, generated_text)
return generated_text, html_diff
# Editorial Segmentation Class
class EditorialSegmenter:
def segment(self, text):
editorial_text = re.sub("\n", " ¶ ", text)
num_tokens = len(tokenizer.tokenize(editorial_text))
if num_tokens > 500:
batch_prompts = split_text(editorial_text, max_tokens=500)
else:
batch_prompts = [editorial_text]
out = token_classifier(batch_prompts)
classified_list = []
for classification in out:
df = pd.DataFrame(classification)
classified_list.append(df)
classified_list = pd.concat(classified_list)
out = transform_chunks(classified_list)
return out
# Combined Processing Class
class TextProcessor:
def __init__(self):
self.ocr_corrector = OCRCorrector()
self.editorial_segmenter = EditorialSegmenter()
@spaces.GPU(duration=120)
def process(self, user_message):
# Step 1: OCR Correction
corrected_text, html_diff = self.ocr_corrector.correct(user_message)
# Step 2: Editorial Segmentation
segmented_text = self.editorial_segmenter.segment(corrected_text)
# Combine results
ocr_result = f'<h2 style="text-align:center">OCR Correction</h2>\n<div class="generation">{html_diff}</div>'
editorial_result = f'<h2 style="text-align:center">Editorial Segmentation</h2>\n<div class="generation">{segmented_text}</div>'
final_output = f"{css}{ocr_result}<br><br>{editorial_result}"
return final_output
# Create the TextProcessor instance
text_processor = TextProcessor()
# Define the Gradio interface
with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo:
gr.HTML("""<h1 style="text-align:center">PleIAs Editor</h1>""")
text_input = gr.Textbox(label="Your (bad?) text", type="text", lines=5)
process_button = gr.Button("Process Text")
text_output = gr.HTML(label="Processed text")
process_button.click(text_processor.process, inputs=text_input, outputs=[text_output])
if __name__ == "__main__":
demo.queue().launch()