GOT-OCR-Optimize / app-pdf.py
Mageia's picture
fix: process pdf once
dd08fd0 unverified
import base64
import os
import tempfile
import fitz
import gradio as gr
import spaces
import torch
from PIL import Image, ImageEnhance
from transformers import AutoModel, AutoTokenizer
model_name = "ucaslcl/GOT-OCR2_0"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, device_map=device)
model = model.eval().to(device)
def pdf_to_images(pdf_path):
images = []
pdf_document = fitz.open(pdf_path)
for page_num in range(len(pdf_document)):
page = pdf_document.load_page(page_num)
zoom = 10 # 增加缩放比例到10
mat = fitz.Matrix(zoom, zoom)
pix = page.get_pixmap(matrix=mat, alpha=False)
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
# 增对比度
enhancer = ImageEnhance.Contrast(img)
img = enhancer.enhance(1.5) # 增加50%的对比度
images.append(img)
pdf_document.close()
return images
@spaces.GPU()
def ocr_process(file, got_mode, ocr_color="", ocr_box="", progress=gr.Progress()):
if file is None:
return "错误:未提供文件"
try:
progress(0, desc="开始处理...")
with tempfile.TemporaryDirectory() as temp_dir:
file_path = file.name # 使用文件的原始路径
if file_path.lower().endswith(".pdf"):
images = pdf_to_images(file_path)
num_pages = len(images)
results = []
for i, image in enumerate(images):
progress((i + 1) / num_pages, desc=f"处理第 {i+1}/{num_pages} 页...")
img_path = os.path.join(temp_dir, f"page_{i+1}.png")
image.save(img_path, "PNG")
result = process_single_image(img_path, got_mode, ocr_color, ocr_box)
results.append(f"第 {i+1} 页结果:\n{result}")
final_result = "\n\n".join(results)
else:
final_result = process_single_image(file_path, got_mode, ocr_color, ocr_box)
progress(1, desc="处理完成")
return final_result
except Exception as e:
return f"错误: {str(e)}"
def process_single_image(image_path, got_mode, ocr_color, ocr_box):
result_path = f"{os.path.splitext(image_path)[0]}_result.html"
if "plain" in got_mode:
if "multi-crop" in got_mode:
res = model.chat_crop(tokenizer, image_path, ocr_type="ocr")
else:
res = model.chat(tokenizer, image_path, ocr_type="ocr", ocr_box=ocr_box, ocr_color=ocr_color)
return res
elif "format" in got_mode:
if "multi-crop" in got_mode:
res = model.chat_crop(tokenizer, image_path, ocr_type="format", render=True, save_render_file=result_path)
else:
res = model.chat(tokenizer, image_path, ocr_type="format", ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path)
if os.path.exists(result_path):
with open(result_path, "r", encoding="utf-8") as f:
html_content = f.read()
encoded_html = base64.b64encode(html_content.encode("utf-8")).decode("utf-8")
data_uri = f"data:text/html;charset=utf-8;base64,{encoded_html}"
preview = f'<iframe src="{data_uri}" width="100%" height="600px"></iframe>'
download_link = f'<a href="{data_uri}" download="result.html">下载完整结果</a>'
return f"{download_link}\n\n{preview}"
return "错误: 未知的OCR模式"
with gr.Blocks() as demo:
gr.Markdown("# OCR 图像识别")
file_input = gr.File(label="上传PDF或图片文件")
got_mode = gr.Dropdown(
choices=["plain texts OCR", "format texts OCR", "plain multi-crop OCR", "format multi-crop OCR", "plain fine-grained OCR", "format fine-grained OCR"],
label="OCR模式",
value="plain texts OCR",
)
with gr.Row():
ocr_color = gr.Textbox(label="OCR颜色 (仅用于fine-grained模式)")
ocr_box = gr.Textbox(label="OCR边界框 (仅用于fine-grained模式)")
submit_button = gr.Button("开始OCR识别")
output = gr.HTML(label="识别结果")
submit_button.click(ocr_process, inputs=[file_input, got_mode, ocr_color, ocr_box], outputs=output)
if __name__ == "__main__":
demo.launch()