File size: 3,143 Bytes
d731e09
0dfb412
f2019a4
 
 
 
b6cc9e1
1fca231
31559f1
0dfb412
208476f
 
b6cc9e1
8482186
b6cc9e1
 
22b51ff
b6cc9e1
0dfb412
21257a3
b6cc9e1
21257a3
bafe915
0dfb412
21257a3
bafe915
b6cc9e1
 
0dfb412
21257a3
b6cc9e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21257a3
 
 
0dfb412
 
 
21257a3
fa86caf
bafe915
21257a3
 
 
 
 
 
 
08d035b
21257a3
b6cc9e1
21257a3
 
 
 
bafe915
cd9ce00
21257a3
0dfb412
21257a3
 
3481362
f2019a4
21257a3
cd9ce00
b100458
21257a3
 
 
7468778
0dfb412
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import spaces
import transformers
import re
import torch
import gradio as gr
import os
import ctranslate2
from concurrent.futures import ThreadPoolExecutor

# Define the device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load CTranslate2 model and tokenizer
model_path = "ocronos_ct2"
generator = ctranslate2.Generator(model_path, device=device)
tokenizer = transformers.AutoTokenizer.from_pretrained("PleIAs/OCRonos-Vintage")

# CSS for formatting (unchanged)
css = """
<style>
... (your existing CSS)
</style>
"""

# Helper functions
def generate_html_diff(old_text, new_text):
    # (unchanged)
    ...

def preprocess_text(text):
    # (unchanged)
    ...

def split_text(text, max_tokens=400):
    encoded = tokenizer.encode(text)
    splits = []
    for i in range(0, len(encoded), max_tokens):
        split = encoded[i:i+max_tokens]
        splits.append(tokenizer.decode(split))
    return splits

# Function to generate text using CTranslate2
def ocr_correction(prompt, max_new_tokens=600):
    splits = split_text(prompt, max_tokens=400)
    corrected_splits = []

    for split in splits:
        full_prompt = f"### Text ###\n{split}\n\n\n### Correction ###\n"
        encoded = tokenizer.encode(full_prompt)
        prompt_tokens = tokenizer.convert_ids_to_tokens(encoded)

        result = generator.generate_batch(
            [prompt_tokens],
            max_length=max_new_tokens,
            sampling_temperature=0.7,
            sampling_topk=20,
            include_prompt_in_result=False
        )[0]

        corrected_text = tokenizer.decode(result.sequences_ids[0])
        corrected_splits.append(corrected_text)

    return " ".join(corrected_splits)

# OCR Correction Class
class OCRCorrector:
    def __init__(self, system_prompt="Le dialogue suivant est une conversation"):
        self.system_prompt = system_prompt

    def correct(self, user_message):
        generated_text = ocr_correction(user_message)
        html_diff = generate_html_diff(user_message, generated_text)
        return generated_text, html_diff

# Combined Processing Class
class TextProcessor:
    def __init__(self):
        self.ocr_corrector = OCRCorrector()

    @spaces.GPU(duration=120)
    def process(self, user_message):
        # OCR Correction
        corrected_text, html_diff = self.ocr_corrector.correct(user_message)
        
        # Combine results
        ocr_result = f'<h2 style="text-align:center">OCR Correction</h2>\n<div class="generation">{html_diff}</div>'
        
        final_output = f"{css}{ocr_result}"
        return final_output

# Create the TextProcessor instance
text_processor = TextProcessor()

# Define the Gradio interface
with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo:
    gr.HTML("""<h1 style="text-align:center">Vintage OCR corrector</h1>""")
    text_input = gr.Textbox(label="Your (bad?) text", type="text", lines=5)
    process_button = gr.Button("Process Text")
    text_output = gr.HTML(label="Processed text")
    process_button.click(text_processor.process, inputs=text_input, outputs=[text_output])

if __name__ == "__main__":
    demo.queue().launch()