File size: 4,650 Bytes
0862b0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20f03a5
0862b0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7527590
0862b0a
 
 
86fce4a
0862b0a
2280605
0862b0a
 
20f03a5
 
 
 
 
 
 
7527590
0862b0a
 
20f03a5
0862b0a
 
20f03a5
0862b0a
 
 
 
 
 
20f03a5
0862b0a
20f03a5
0862b0a
 
86fce4a
0862b0a
 
 
20f03a5
0862b0a
 
 
20f03a5
f928d96
0862b0a
 
20f03a5
 
0862b0a
 
f928d96
20f03a5
7527590
20f03a5
0862b0a
20f03a5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# import os
# import gradio as gr
# from transformers import BlipProcessor ,BlipForConditionalGeneration
# from PIL import Image
# from transformers import CLIPProcessor, ChineseCLIPVisionModel ,AutoProcessor
#
# # 设置环境变量 HF_HOME 和 HF_ENDPOINT
# # os.environ['HF_HOME'] = 'D:/AI/OCR/img2text/models'
# # os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
#
#
# # model = ChineseCLIPVisionModel.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")
# # processor = AutoProcessor.from_pretrained("OFA-Sys/chinese-clip-vit-base-patch16")
# # 加载模型和处理器
# # processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
# # model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

# processor = BlipProcessor.from_pretrained("IDEA-CCNL/Taiyi-BLIP-750M-Chinese")
# model = BlipForConditionalGeneration.from_pretrained("IDEA-CCNL/Taiyi-BLIP-750M-Chinese")
# def generate_caption(image):
#     # 确保 image 是 PIL.Image 类型
#     if not isinstance(image, Image.Image):
#         raise ValueError("Input must be a PIL.Image")
#
#     inputs = processor(image, return_tensors="pt")
#     input_ids = inputs.get("input_ids")
#     if input_ids is None:
#         raise ValueError("Processor did not return input_ids")
#
#     outputs = model.generate(input_ids=input_ids, max_length=50)
#     description = processor.decode(outputs[0], skip_special_tokens=True)
#     return description
#
# # 创建Gradio接口
# gradio_app = gr.Interface(
#     fn=generate_caption,
#     inputs=gr.Image(type="pil"),
#     outputs="text",
#     title="图片描述生成器",
#     description="上传一张图片,生成相应的描述。"
# )
#
# if __name__ == "__main__":
#     gradio_app.launch()
import gradio as gr
import torch
import os
from transformers import BlipForConditionalGeneration, BlipProcessor, GenerationConfig
print(torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

_MODEL_PATH = 'IDEA-CCNL/Taiyi-BLIP-750M-Chinese'
HF_TOKEN = os.getenv('HF_TOKEN')

processor = BlipProcessor.from_pretrained("IDEA-CCNL/Taiyi-BLIP-750M-Chinese", use_auth_token=HF_TOKEN)
model = BlipForConditionalGeneration.from_pretrained("IDEA-CCNL/Taiyi-BLIP-750M-Chinese", use_auth_token=HF_TOKEN).eval().to(device)

# processor = BlipProcessor.from_pretrained(_MODEL_PATH, use_auth_token=HF_TOKEN)
# model = BlipForConditionalGeneration.from_pretrained(
#     _MODEL_PATH, use_auth_token=HF_TOKEN).eval().to(device)

def inference(raw_image, model_n, strategy):
    if model_n == 'Image Captioning':
        inputs = processor(raw_image ,return_tensors= "pt").to(device)
        with torch.no_grad():
            if strategy == "Beam search":
                # Beam search,即集束搜索,每次生成多个词,然后选择概率最大的前 k 个词,然后继续生成,直到生成结束
                config = GenerationConfig(
                    do_sample=False,
                    num_beams=3,
                    max_length=50,
                    min_length=5,
                )
                captions = model.generate(**inputs ,generation_config=config)
            else:
                # Nucleus sampling,即 top-p sampling,只保留累积概率大于 p 的词,然后重新归一化,得到一个新的概率分布,再从中采样,这样可以保证采样的结果更多样
                config = GenerationConfig(
                    do_sample=True,
                    top_p=0.8,
                    max_length=50,
                    min_length=5,
                )
                captions = model.generate(**inputs ,generation_config=config)
            caption = processor.decode(captions[0], skip_special_tokens=True)
            caption = caption.replace(' ', '')
            print(caption)
            return caption

inputs = [
    gr.Image(type='pil', label="Upload Image"),
    gr.Radio(choices=['Image Captioning'], value="Image Captioning", label="Task"),# 任务选择,目前只有图片描述生成
    gr.Radio(choices=['Beam search', 'Nucleus sampling'], value="Nucleus sampling", label="Caption Decoding Strategy")# 两种生成策略,Beam search 和 Nucleus sampling,前者生成的结果更准确,后者更多样
]
outputs = gr.Textbox(label="Output")

title = "图片描述生成器"

gradio_app=gr.Interface(inference, inputs, outputs, title=title, examples=[
    ['demo.jpg', "Image Captioning", "Nucleus sampling"]
])

if __name__ == "__main__":
    gradio_app.launch()