import gradio as gr
import os
from PIL import Image
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, AutoImageProcessor
# import utils
import base64
# from datasets import load_metric
import evaluate
import logging
# Only show log messages that are at the ERROR level or above, effectively filtering out any warnings
logging.getLogger('transformers').setLevel(logging.ERROR)
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
image_processor = AutoImageProcessor.from_pretrained("pstroe/bullinger-general-model")
model = VisionEncoderDecoderModel.from_pretrained("pstroe/bullinger-general-model")
# Create examples
# Get images and respective transcriptions from the examples directory
def get_example_data(folder_path="./examples/"):
example_data = []
# Get list of all files in the folder
all_files = os.listdir(folder_path)
# Loop through the file list
for file_name in all_files:
file_path = os.path.join(folder_path, file_name)
# Check if the file is an image (.png)
if file_name.endswith(".png"):
# Construct the corresponding .txt filename (same name)
corresponding_text_file_name = file_name.replace(".png", ".txt")
corresponding_text_file_path = os.path.join(folder_path, corresponding_text_file_name)
# Initialize to a default value
transcription = "Transcription not found."
# Try to read the content from the .txt file
try:
with open(corresponding_text_file_path, "r") as f:
transcription = f.read().strip()
except FileNotFoundError:
pass # If the corresponding .txt file is not found, leave the default value
example_data.append([file_path, transcription])
return example_data
# From pstroe's script
# def compute_metrics(pred):
# labels_ids = pred.label_ids
# pred_ids = pred.predictions
# pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
# labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
# label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)
# cer = cer_metric.compute(predictions=pred_str, references=label_str)
# return {"cer": cer}
def process_image(image, ground_truth):
cer = None
# prepare image
pixel_values = image_processor(image, return_tensors="pt").pixel_values
# generate (no beam search)
generated_ids = model.generate(pixel_values)
# decode
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
if ground_truth is not None and ground_truth.strip() != "":
# Debug: Print lengths before computing metric
print("Number of predictions:", len(generated_text))
print("Number of references:", len(ground_truth))
# Check if lengths match
if len(generated_text) != len(ground_truth):
print("Mismatch in number of predictions and references.")
print("Predictions:", generated_text)
print("References:", ground_truth)
print("\n")
cer = cer_metric.compute(predictions=[generated_text], references=[ground_truth])
# cer = f"{cer:.3f}"
else:
cer = "Ground truth not provided"
return generated_text, cer
# One way to use .svg files
# logo_url = "https://www.bullinger-digital.ch/bullinger-digital.svg"
# logo_url = "https://www.cl.uzh.ch/docroot/logos/uzh_logo_e_pos.svg"
# header_html = "".format(
# utils.img_to_bytes(".uzh_logo_e_pos.svg")
# )
# Encode images
with open("assets/uzh_logo_mod.png", "rb") as img_file:
logo_html = base64.b64encode(img_file.read()).decode('utf-8')
# with open("assets/bullinger-digital.png", "rb") as img_file:
with open("assets/bullinger_logo.png", "rb") as img_file:
footer_html = base64.b64encode(img_file.read()).decode('utf-8')
# App header
title = """
TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models
# Flexible Techniques for Automatic Text Recognition of Historical Documents
# Bullingers Briefwechsel zugänglich machen: Stand der Handschriftenerkennung
# Bullinger Digital | Institut für Computerlinguistik, Universität Zürich, 2023 #
#BullingerDigital | Institut für Computerlinguistik, Universität Zürich, 2023