|
import pytesseract |
|
import json |
|
import numpy as np |
|
from PIL import Image, ImageEnhance |
|
from transformers import VisionEncoderDecoderModel, TrOCRProcessor |
|
|
|
|
|
hf_model = VisionEncoderDecoderModel.from_pretrained("Serovvans/trocr-prereform-orthography") |
|
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed") |
|
|
|
def remove_bleed_through(image_path, |
|
brightness_factor=1.5): |
|
|
|
pil_image = Image.open(image_path).convert('RGB') |
|
img = np.array(pil_image) |
|
|
|
alpha = 1.7 |
|
beta = -130 |
|
|
|
result = alpha * img + beta |
|
result = np.clip(result, 0, 255).astype(np.uint8) |
|
|
|
|
|
pil_result = Image.fromarray(result) |
|
|
|
|
|
enhancer_brightness = ImageEnhance.Brightness(pil_result) |
|
bright_image = enhancer_brightness.enhance(brightness_factor) |
|
|
|
return bright_image |
|
|
|
|
|
def recognize_row(image): |
|
image = image.convert("RGB") |
|
pixel_values = processor(images=image, return_tensors="pt").pixel_values |
|
generated_ids = hf_model.generate(pixel_values) |
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
return generated_text |
|
|
|
|
|
|
|
def recognize_page(image_path, text_output_path=False): |
|
|
|
image = remove_bleed_through(image_path) |
|
|
|
|
|
data = pytesseract.image_to_data(image, config="--psm 6", output_type=pytesseract.Output.DICT, lang='ukr+eng') |
|
with open("rec_data.json", "w", encoding="utf-8") as json_file: |
|
json.dump(data, json_file) |
|
|
|
pad = int(0.0042 * image.size[1]) |
|
|
|
fragments = [] |
|
for i in range(len(data['text'])): |
|
if data['conf'][i] > -1 and data['width'][i]*data['height'][i] > pad**2: |
|
x, y, w, h = data['left'][i], data['top'][i], data['width'][i], data['height'][i] |
|
fragment_image = image.crop((x-pad, y-pad, x + w + pad, y + h + pad)) |
|
text = recognize_row(fragment_image).strip() |
|
|
|
if data['text'][i].strip() == "\u2014": |
|
text = "\u2014" |
|
|
|
try: |
|
num = int(data['text'][i].strip()) |
|
if data['conf'][i] > 85: |
|
text = data['text'][i].strip() |
|
except: |
|
pass |
|
fragments.append({ |
|
'block_num': data['block_num'][i], |
|
'par_num': data['par_num'][i], |
|
'line_num': data['line_num'][i], |
|
'word_num': data['word_num'][i], |
|
'text': text, |
|
'image': fragment_image |
|
}) |
|
|
|
|
|
fragments = sorted(fragments, key=lambda x: (x['block_num'], x['par_num'], x['line_num'], x['word_num'])) |
|
|
|
|
|
result_lines = [] |
|
current_line_num = 0 |
|
curr_block_num = 0 |
|
curr_par_num = 0 |
|
current_line = [] |
|
|
|
for fragment in fragments: |
|
if fragment['line_num'] != current_line_num or fragment['block_num'] != curr_block_num or fragment['par_num'] != curr_par_num: |
|
|
|
result_lines.append(" ".join(current_line)) |
|
current_line = [] |
|
current_line_num = fragment['line_num'] |
|
curr_block_num = fragment['block_num'] |
|
curr_par_num = fragment['par_num'] |
|
|
|
|
|
recognized_text = fragment['text'] |
|
current_line.append(recognized_text) |
|
|
|
|
|
if current_line: |
|
result_lines.append(" ".join(current_line)) |
|
|
|
|
|
final_text = "\n".join(result_lines) |
|
|
|
if text_output_path: |
|
with open(text_output_path, "w", encoding="utf-8") as text_file: |
|
text_file.write(final_text) |
|
|
|
return final_text |
|
|