trocr-prereform-orthography / recognize_page.py
Serovvans's picture
Upload 5 files
caa85f5 verified
import pytesseract
import json
import numpy as np
from PIL import Image, ImageEnhance
from transformers import VisionEncoderDecoderModel, TrOCRProcessor
hf_model = VisionEncoderDecoderModel.from_pretrained("Serovvans/trocr-prereform-orthography")
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
def remove_bleed_through(image_path,
brightness_factor=1.5):
# Загрузка изображения
pil_image = Image.open(image_path).convert('RGB')
img = np.array(pil_image)
alpha = 1.7
beta = -130
result = alpha * img + beta
result = np.clip(result, 0, 255).astype(np.uint8)
# Преобразование в PIL Image
pil_result = Image.fromarray(result)
# 2. Повышение яркости
enhancer_brightness = ImageEnhance.Brightness(pil_result)
bright_image = enhancer_brightness.enhance(brightness_factor)
return bright_image
# Функция распознавания текста с TrOCR
def recognize_row(image):
image = image.convert("RGB")
pixel_values = processor(images=image, return_tensors="pt").pixel_values
generated_ids = hf_model.generate(pixel_values)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
# Основная функция распознавания текста с разделением на слова/фразы
def recognize_page(image_path, text_output_path=False):
# Открываем изображение
image = remove_bleed_through(image_path)
# Используем pytesseract для получения данных по каждому фрагменту
data = pytesseract.image_to_data(image, config="--psm 6", output_type=pytesseract.Output.DICT, lang='ukr+eng')
with open("rec_data.json", "w", encoding="utf-8") as json_file:
json.dump(data, json_file)
pad = int(0.0042 * image.size[1])
# Сохраняем результаты для последующего восстановления порядка
fragments = []
for i in range(len(data['text'])):
if data['conf'][i] > -1 and data['width'][i]*data['height'][i] > pad**2: # Игнорируем пустые фрагменты
x, y, w, h = data['left'][i], data['top'][i], data['width'][i], data['height'][i]
fragment_image = image.crop((x-pad, y-pad, x + w + pad, y + h + pad))
text = recognize_row(fragment_image).strip()
if data['text'][i].strip() == "\u2014":
text = "\u2014"
try:
num = int(data['text'][i].strip())
if data['conf'][i] > 85:
text = data['text'][i].strip()
except:
pass
fragments.append({
'block_num': data['block_num'][i],
'par_num': data['par_num'][i],
'line_num': data['line_num'][i],
'word_num': data['word_num'][i],
'text': text,
'image': fragment_image
})
# Сортируем фрагменты по line_num и word_num
fragments = sorted(fragments, key=lambda x: (x['block_num'], x['par_num'], x['line_num'], x['word_num']))
# Распознаем текст из каждого фрагмента
result_lines = []
current_line_num = 0
curr_block_num = 0
curr_par_num = 0
current_line = []
for fragment in fragments:
if fragment['line_num'] != current_line_num or fragment['block_num'] != curr_block_num or fragment['par_num'] != curr_par_num:
# Завершаем текущую строку и переходим к следующей
result_lines.append(" ".join(current_line))
current_line = []
current_line_num = fragment['line_num']
curr_block_num = fragment['block_num']
curr_par_num = fragment['par_num']
# Распознаем текст фрагмента
recognized_text = fragment['text']
current_line.append(recognized_text)
# Добавляем последнюю строку
if current_line:
result_lines.append(" ".join(current_line))
# Собираем результат в общий текст
final_text = "\n".join(result_lines)
if text_output_path:
with open(text_output_path, "w", encoding="utf-8") as text_file:
text_file.write(final_text)
return final_text