File size: 2,716 Bytes
18def71
 
 
763e463
18def71
763e463
 
fffa248
18def71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
546c454
 
18def71
 
546c454
 
 
 
 
 
 
 
18def71
 
ba22048
18def71
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import base64
import os

import spaces


@spaces.GPU()
def got_ocr(model, tokenizer, image_path, got_mode="format texts OCR", fine_grained_mode="", ocr_color="", ocr_box=""):
    # 执行OCR
    try:
        if got_mode == "plain texts OCR":
            res = model.chat(tokenizer, image_path, ocr_type="ocr")
            return res, None
        elif got_mode == "format texts OCR":
            result_path = f"{os.path.splitext(image_path)[0]}_result.html"
            res = model.chat(tokenizer, image_path, ocr_type="format", render=True, save_render_file=result_path)
        elif got_mode == "plain multi-crop OCR":
            res = model.chat_crop(tokenizer, image_path, ocr_type="ocr")
            return res, None
        elif got_mode == "format multi-crop OCR":
            result_path = f"{os.path.splitext(image_path)[0]}_result.html"
            res = model.chat_crop(tokenizer, image_path, ocr_type="format", render=True, save_render_file=result_path)
        elif got_mode == "plain fine-grained OCR":
            res = model.chat(tokenizer, image_path, ocr_type="ocr", ocr_box=ocr_box, ocr_color=ocr_color)
            return res, None
        elif got_mode == "format fine-grained OCR":
            result_path = f"{os.path.splitext(image_path)[0]}_result.html"
            res = model.chat(tokenizer, image_path, ocr_type="format", ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path)

        # 处理格式化结果
        if "format" in got_mode and os.path.exists(result_path):
            with open(result_path, "r") as f:
                html_content = f.read()
            encoded_html = base64.b64encode(html_content.encode("utf-8")).decode("utf-8")
            return res, encoded_html
        else:
            return res, None

    except Exception as e:
        return f"错误: {str(e)}", None


# 使用示例
if __name__ == "__main__":
    import torch
    from transformers import AutoConfig, AutoModel, AutoTokenizer

    # 初始化模型和分词器
    model_name = "stepfun-ai/GOT-OCR2_0"
    device = "cuda" if torch.cuda.is_available() else "cpu"

    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
    model = AutoModel.from_pretrained(model_name, trust_remote_code=True, low_cpu_mem_usage=True, device_map="cuda", use_safetensors=True)
    model = model.eval().to(device)
    model.config.pad_token_id = tokenizer.eos_token_id

    image_path = "path/to/your/image.png"
    result, html = got_ocr(model, tokenizer, image_path, got_mode="format texts OCR")
    print("OCR结果:", result)
    if html:
        print("HTML结果可用")