import base64
import os
import tempfile
import fitz
import gradio as gr
import spaces
import torch
from PIL import Image, ImageEnhance
from transformers import AutoModel, AutoTokenizer
model_name = "ucaslcl/GOT-OCR2_0"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, device_map=device)
model = model.eval().to(device)
def pdf_to_images(pdf_path):
images = []
pdf_document = fitz.open(pdf_path)
for page_num in range(len(pdf_document)):
page = pdf_document.load_page(page_num)
zoom = 10 # 增加缩放比例到10
mat = fitz.Matrix(zoom, zoom)
pix = page.get_pixmap(matrix=mat, alpha=False)
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
# 增对比度
enhancer = ImageEnhance.Contrast(img)
img = enhancer.enhance(1.5) # 增加50%的对比度
images.append(img)
pdf_document.close()
return images
@spaces.GPU()
def convert_pdf_to_images(file):
if file is None:
return "错误:未提供文件", None
try:
if not file.name.lower().endswith(".pdf"):
return "错误:请上传PDF文件", None
images = pdf_to_images(file.name)
image_paths = []
with tempfile.TemporaryDirectory() as temp_dir:
for i, image in enumerate(images):
img_path = os.path.join(temp_dir, f"page_{i+1}.png")
image.save(img_path, "PNG")
image_paths.append(img_path)
return "PDF转换为图片成功", image_paths
except Exception as e:
return f"错误: {str(e)}", None
@spaces.GPU()
def ocr_process(image, got_mode, ocr_color="", ocr_box="", progress=gr.Progress()):
if image is None:
return "错误:未选择图片"
try:
progress(0, desc="开始处理...")
result = process_single_image(image, got_mode, ocr_color, ocr_box)
progress(1, desc="处理完成")
return result
except Exception as e:
return f"错误: {str(e)}"
def process_single_image(image_path, got_mode, ocr_color, ocr_box):
result_path = f"{os.path.splitext(image_path)[0]}_result.html"
if "plain" in got_mode:
if "multi-crop" in got_mode:
res = model.chat_crop(tokenizer, image_path, ocr_type="ocr")
else:
res = model.chat(tokenizer, image_path, ocr_type="ocr", ocr_box=ocr_box, ocr_color=ocr_color)
return res
elif "format" in got_mode:
if "multi-crop" in got_mode:
res = model.chat_crop(tokenizer, image_path, ocr_type="format", render=True, save_render_file=result_path)
else:
res = model.chat(tokenizer, image_path, ocr_type="format", ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path)
if os.path.exists(result_path):
with open(result_path, "r", encoding="utf-8") as f:
html_content = f.read()
encoded_html = base64.b64encode(html_content.encode("utf-8")).decode("utf-8")
data_uri = f"data:text/html;charset=utf-8;base64,{encoded_html}"
preview = f''
download_link = f'下载完整结果'
return f"{download_link}\n\n{preview}"
return "错误: 未知的OCR模式"
with gr.Blocks() as demo:
gr.Markdown("# OCR 图像识别")
with gr.Tab("PDF转图片"):
pdf_input = gr.File(label="上传PDF文件")
convert_button = gr.Button("转换PDF为图片")
pdf_output = gr.Textbox(label="转换结果")
image_gallery = gr.Gallery(label="图片预览").style(grid=3)
with gr.Tab("OCR处理"):
image_input = gr.Image(label="选择或上传图片")
got_mode = gr.Dropdown(
choices=[
"plain texts OCR",
"format texts OCR",
"plain multi-crop OCR",
"format multi-crop OCR",
"plain fine-grained OCR",
"format fine-grained OCR",
],
label="OCR模式",
value="plain texts OCR",
)
with gr.Row():
ocr_color = gr.Textbox(label="OCR颜色 (仅用于fine-grained模式)")
ocr_box = gr.Textbox(label="OCR边界框 (仅用于fine-grained模式)")
ocr_button = gr.Button("开始OCR识别")
ocr_output = gr.HTML(label="识别结果")
convert_button.click(convert_pdf_to_images, inputs=[pdf_input], outputs=[pdf_output, image_gallery])
ocr_button.click(ocr_process, inputs=[image_input, got_mode, ocr_color, ocr_box], outputs=ocr_output)
if __name__ == "__main__":
demo.launch()