Spaces:
Sleeping
Sleeping
File size: 4,927 Bytes
d731e09 0dfb412 f2019a4 b6cc9e1 d55b86a 1fca231 31559f1 0dfb412 208476f b6cc9e1 8482186 b6cc9e1 22b51ff b6cc9e1 eed441d 0dfb412 21257a3 eed441d 21257a3 bafe915 0dfb412 21257a3 bafe915 eed441d 0dfb412 21257a3 eed441d b6cc9e1 21257a3 0dfb412 21257a3 fa86caf bafe915 21257a3 08d035b 21257a3 b6cc9e1 21257a3 bafe915 cd9ce00 21257a3 0dfb412 21257a3 3481362 f2019a4 21257a3 cd9ce00 b100458 21257a3 7468778 0dfb412 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import spaces
import transformers
import re
import torch
import gradio as gr
import os
import ctranslate2
import difflib
import shutil
import requests
from concurrent.futures import ThreadPoolExecutor
# Define the device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load CTranslate2 model and tokenizer
model_path = "ocronos_ct2"
generator = ctranslate2.Generator(model_path, device=device)
tokenizer = transformers.AutoTokenizer.from_pretrained("PleIAs/OCRonos-Vintage")
# CSS for formatting (unchanged)
# CSS for formatting
css = """
<style>
.generation {
margin-left: 2em;
margin-right: 2em;
font-size: 1.2em;
}
:target {
background-color: #CCF3DF;
}
.source {
float: left;
max-width: 17%;
margin-left: 2%;
}
.tooltip {
position: relative;
cursor: pointer;
font-variant-position: super;
color: #97999b;
}
.tooltip:hover::after {
content: attr(data-text);
position: absolute;
left: 0;
top: 120%;
white-space: pre-wrap;
width: 500px;
max-width: 500px;
z-index: 1;
background-color: #f9f9f9;
color: #000;
border: 1px solid #ddd;
border-radius: 5px;
padding: 5px;
display: block;
box-shadow: 0 4px 8px rgba(0,0,0,0.1);
}
.deleted {
background-color: #ffcccb;
text-decoration: line-through;
}
.inserted {
background-color: #90EE90;
}
.manuscript {
display: flex;
margin-bottom: 10px;
align-items: baseline;
}
.annotation {
width: 15%;
padding-right: 20px;
color: grey !important;
font-style: italic;
text-align: right;
}
.content {
width: 80%;
}
h2 {
margin: 0;
font-size: 1.5em;
}
.title-content h2 {
font-weight: bold;
}
.bibliography-content {
color: darkgreen !important;
margin-top: -5px;
}
.paratext-content {
color: #a4a4a4 !important;
margin-top: -5px;
}
</style>
"""
# Helper functions
def generate_html_diff(old_text, new_text):
d = difflib.Differ()
diff = list(d.compare(old_text.split(), new_text.split()))
html_diff = []
for word in diff:
if word.startswith(' '):
html_diff.append(word[2:])
elif word.startswith('+ '):
html_diff.append(f'<span style="background-color: #90EE90;">{word[2:]}</span>')
return ' '.join(html_diff)
def preprocess_text(text):
text = re.sub(r'<[^>]+>', '', text)
text = re.sub(r'\n', ' ', text)
text = re.sub(r'\s+', ' ', text)
return text.strip()
def split_text(text, max_tokens=400):
encoded = tokenizer.encode(text)
splits = []
for i in range(0, len(encoded), max_tokens):
split = encoded[i:i+max_tokens]
splits.append(tokenizer.decode(split))
return splits
# Function to generate text using CTranslate2
def ocr_correction(prompt, max_new_tokens=600):
splits = split_text(prompt, max_tokens=400)
corrected_splits = []
for split in splits:
full_prompt = f"### Text ###\n{split}\n\n\n### Correction ###\n"
encoded = tokenizer.encode(full_prompt)
prompt_tokens = tokenizer.convert_ids_to_tokens(encoded)
result = generator.generate_batch(
[prompt_tokens],
max_length=max_new_tokens,
sampling_temperature=0.7,
sampling_topk=20,
include_prompt_in_result=False
)[0]
corrected_text = tokenizer.decode(result.sequences_ids[0])
corrected_splits.append(corrected_text)
return " ".join(corrected_splits)
# OCR Correction Class
class OCRCorrector:
def __init__(self, system_prompt="Le dialogue suivant est une conversation"):
self.system_prompt = system_prompt
def correct(self, user_message):
generated_text = ocr_correction(user_message)
html_diff = generate_html_diff(user_message, generated_text)
return generated_text, html_diff
# Combined Processing Class
class TextProcessor:
def __init__(self):
self.ocr_corrector = OCRCorrector()
@spaces.GPU(duration=120)
def process(self, user_message):
# OCR Correction
corrected_text, html_diff = self.ocr_corrector.correct(user_message)
# Combine results
ocr_result = f'<h2 style="text-align:center">OCR Correction</h2>\n<div class="generation">{html_diff}</div>'
final_output = f"{css}{ocr_result}"
return final_output
# Create the TextProcessor instance
text_processor = TextProcessor()
# Define the Gradio interface
with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo:
gr.HTML("""<h1 style="text-align:center">Vintage OCR corrector</h1>""")
text_input = gr.Textbox(label="Your (bad?) text", type="text", lines=5)
process_button = gr.Button("Process Text")
text_output = gr.HTML(label="Processed text")
process_button.click(text_processor.process, inputs=text_input, outputs=[text_output])
if __name__ == "__main__":
demo.queue().launch() |