import base64 import os import spaces @spaces.GPU() def got_ocr(model, tokenizer, image_path, got_mode="format texts OCR", fine_grained_mode="", ocr_color="", ocr_box=""): # 执行OCR try: if got_mode == "plain texts OCR": res = model.chat(tokenizer, image_path, ocr_type="ocr") return res, None elif got_mode == "format texts OCR": result_path = f"{os.path.splitext(image_path)[0]}_result.html" res = model.chat(tokenizer, image_path, ocr_type="format", render=True, save_render_file=result_path) elif got_mode == "plain multi-crop OCR": res = model.chat_crop(tokenizer, image_path, ocr_type="ocr") return res, None elif got_mode == "format multi-crop OCR": result_path = f"{os.path.splitext(image_path)[0]}_result.html" res = model.chat_crop(tokenizer, image_path, ocr_type="format", render=True, save_render_file=result_path) elif got_mode == "plain fine-grained OCR": res = model.chat(tokenizer, image_path, ocr_type="ocr", ocr_box=ocr_box, ocr_color=ocr_color) return res, None elif got_mode == "format fine-grained OCR": result_path = f"{os.path.splitext(image_path)[0]}_result.html" res = model.chat(tokenizer, image_path, ocr_type="format", ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path) # 处理格式化结果 if "format" in got_mode and os.path.exists(result_path): with open(result_path, "r") as f: html_content = f.read() encoded_html = base64.b64encode(html_content.encode("utf-8")).decode("utf-8") return res, encoded_html else: return res, None except Exception as e: return f"错误: {str(e)}", None # 使用示例 if __name__ == "__main__": import torch from transformers import AutoConfig, AutoModel, AutoTokenizer # 初始化模型和分词器 model_name = "stepfun-ai/GOT-OCR2_0" device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) model = AutoModel.from_pretrained(model_name, trust_remote_code=True, low_cpu_mem_usage=True, device_map="cuda", use_safetensors=True) model = model.eval().to(device) model.config.pad_token_id = tokenizer.eos_token_id image_path = "path/to/your/image.png" result, html = got_ocr(model, tokenizer, image_path, got_mode="format texts OCR") print("OCR结果:", result) if html: print("HTML结果可用")