GOT-OCR-Optimize / got_ocr.py
Mageia's picture
fix: cuda device
763e463 unverified
import base64
import os
import spaces
@spaces.GPU()
def got_ocr(model, tokenizer, image_path, got_mode="format texts OCR", fine_grained_mode="", ocr_color="", ocr_box=""):
# 执行OCR
try:
if got_mode == "plain texts OCR":
res = model.chat(tokenizer, image_path, ocr_type="ocr")
return res, None
elif got_mode == "format texts OCR":
result_path = f"{os.path.splitext(image_path)[0]}_result.html"
res = model.chat(tokenizer, image_path, ocr_type="format", render=True, save_render_file=result_path)
elif got_mode == "plain multi-crop OCR":
res = model.chat_crop(tokenizer, image_path, ocr_type="ocr")
return res, None
elif got_mode == "format multi-crop OCR":
result_path = f"{os.path.splitext(image_path)[0]}_result.html"
res = model.chat_crop(tokenizer, image_path, ocr_type="format", render=True, save_render_file=result_path)
elif got_mode == "plain fine-grained OCR":
res = model.chat(tokenizer, image_path, ocr_type="ocr", ocr_box=ocr_box, ocr_color=ocr_color)
return res, None
elif got_mode == "format fine-grained OCR":
result_path = f"{os.path.splitext(image_path)[0]}_result.html"
res = model.chat(tokenizer, image_path, ocr_type="format", ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path)
# 处理格式化结果
if "format" in got_mode and os.path.exists(result_path):
with open(result_path, "r") as f:
html_content = f.read()
encoded_html = base64.b64encode(html_content.encode("utf-8")).decode("utf-8")
return res, encoded_html
else:
return res, None
except Exception as e:
return f"错误: {str(e)}", None
# 使用示例
if __name__ == "__main__":
import torch
from transformers import AutoConfig, AutoModel, AutoTokenizer
# 初始化模型和分词器
model_name = "stepfun-ai/GOT-OCR2_0"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, low_cpu_mem_usage=True, device_map="cuda", use_safetensors=True)
model = model.eval().to(device)
model.config.pad_token_id = tokenizer.eos_token_id
image_path = "path/to/your/image.png"
result, html = got_ocr(model, tokenizer, image_path, got_mode="format texts OCR")
print("OCR结果:", result)
if html:
print("HTML结果可用")