import spaces import transformers import re from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM, pipeline from vllm import LLM, SamplingParams import torch import gradio as gr import json import os import shutil import requests import pandas as pd import difflib from concurrent.futures import ThreadPoolExecutor # OCR Correction Model ocr_model_name = "PleIAs/OCRonos-Vintage" import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer device = "cuda" # Load pre-trained model and tokenizer model_name = "PleIAs/OCRonos-Vintage" model = GPT2LMHeadModel.from_pretrained(model_name) tokenizer = GPT2Tokenizer.from_pretrained(model_name) model.to(device) # CSS for formatting css = """ """ # Helper functions def generate_html_diff(old_text, new_text): d = difflib.Differ() diff = list(d.compare(old_text.split(), new_text.split())) html_diff = [] for word in diff: if word.startswith(' '): html_diff.append(word[2:]) elif word.startswith('+ '): html_diff.append(f'{word[2:]}') return ' '.join(html_diff) def preprocess_text(text): text = re.sub(r'<[^>]+>', '', text) text = re.sub(r'\n', ' ', text) text = re.sub(r'\s+', ' ', text) return text.strip() def split_text(text, max_tokens=500): parts = text.split("\n") chunks = [] current_chunk = "" for part in parts: if current_chunk: temp_chunk = current_chunk + "\n" + part else: temp_chunk = part num_tokens = len(tokenizer.tokenize(temp_chunk)) if num_tokens <= max_tokens: current_chunk = temp_chunk else: if current_chunk: chunks.append(current_chunk) current_chunk = part if current_chunk: chunks.append(current_chunk) if len(chunks) == 1 and len(tokenizer.tokenize(chunks[0])) > max_tokens: long_text = chunks[0] chunks = [] while len(tokenizer.tokenize(long_text)) > max_tokens: split_point = len(long_text) // 2 while split_point < len(long_text) and not re.match(r'\s', long_text[split_point]): split_point += 1 if split_point >= len(long_text): split_point = len(long_text) - 1 chunks.append(long_text[:split_point].strip()) long_text = long_text[split_point:].strip() if long_text: chunks.append(long_text) return chunks # Function to generate text @spaces.GPU def ocr_correction(prompt, max_new_tokens=500): prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n""" input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) # Set the number of threads for PyTorch torch.set_num_threads(num_threads) # Generate text output = model.generate, input_ids, max_new_tokens=max_new_tokens, pad_token_id=tokenizer.eos_token_id, top_k=50, num_return_sequences=1, do_sample=True, temperature=0.7 ) # Decode and return the generated text result = tokenizer.decode(output[0], skip_special_tokens=True) print(result) result = result.split("### Correction ###")[1] return result # OCR Correction Class class OCRCorrector: def __init__(self, system_prompt="Le dialogue suivant est une conversation"): self.system_prompt = system_prompt def correct(self, user_message): generated_text = ocr_correction(user_message) html_diff = generate_html_diff(user_message, generated_text) return generated_text, html_diff # Combined Processing Class class TextProcessor: def __init__(self): self.ocr_corrector = OCRCorrector() @spaces.GPU(duration=120) def process(self, user_message): #OCR Correction corrected_text, html_diff = self.ocr_corrector.correct(user_message) # Combine results ocr_result = f'