File size: 1,915 Bytes
0355756
 
136aa0e
 
 
 
 
 
 
 
 
 
3bdb19a
136aa0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# pip install --upgrade pip
# pip install --no-cache-dir -r requirements.txt
import os
from transformers import pipeline
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
from math import sqrt
import gradio as gr
import numpy as np  
    
# 加载图像分类模型
checkpoint_dir = "./checkpoint-905" # 模型检查点目录
classifier = pipeline("image-classification", model=checkpoint_dir) # 创建图像分类器模型
vitclassifier = pipeline("image-classification",model="google/vit-base-patch16-224")


demo = gr.Blocks()  

# 定义推理函数
def flip_myvit(image):
    # 图像预处理
    image = Image.fromarray(image.astype('uint8'), 'RGB')
    # 进行图像分类
    result = classifier(image)
    # 返回分类结果
    text = "{:.3f}%".format(result[0]['score'] * 100)
    return result[0]['label'],text
 
  
def flip_vit(image):
    # 图像预处理
    image = Image.fromarray(image.astype('uint8'), 'RGB')
    # 进行图像分类
    result = vitclassifier(image)
    # 返回分类结果
    text = "{:.3f}%".format(result[0]['score'] * 100)
    return result[0]['label'],text 
  
with demo:  
    gr.Markdown(model_info)  
    with gr.Tabs():  
        with gr.TabItem("myvit"):  
            myvit_input = gr.Image() 
            myvit_output1 = gr.Textbox(label="预测结果")
            myvit_output2 = gr.Textbox(label="准确度")
            myvit_button = gr.Button("开始")  
        with gr.TabItem("vit"):  
            vit_input = gr.Image()  
            vit_output1 = gr.Textbox(label="预测结果")  
            vit_output2 = gr.Textbox(label="准确度") 
            vit_button = gr.Button("开始")  
      
    myvit_button.click(flip_myvit, inputs=myvit_input, outputs=[myvit_output1,myvit_output2])  
    vit_button.click(flip_vit, inputs=vit_input, outputs=[vit_output1,vit_output2])  

demo.title="猫狗分类器"
demo.launch()