cat_classifiter / app.py
zhijian12345's picture
Update app.py
2435e92
raw
history blame
2.35 kB
!pip install --upgrade pip
!pip install --no-cache-dir -r requirements.txt
import os
from transformers import pipeline
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
from math import sqrt
import gradio as gr
import numpy as np
model_info = """
**模型名称**: Google/vit-base-patch16-224
**模型介绍**: 本程序根据huggingface上Google开源模型vit,在猫狗图片数据上进行微调,上传一张图片,将会预测其类别并显示结果。模型官网:https://huggingface.co/google/vit-base-patch16-224
**程序作者**: 计科三班 王志建、计科三班 罗楷轩
**特别支持**: 计科三班 黄成栋
"""
# 加载图像分类模型
checkpoint_dir = "./checkpoint/checkpoint-181" # 模型检查点目录
classifier = pipeline("image-classification", model=checkpoint_dir) # 创建图像分类器模型
vitclassifier = pipeline("image-classification",model="google/vit-base-patch16-224")
demo = gr.Blocks()
# 定义推理函数
def flip_myvit(image):
# 图像预处理
image = Image.fromarray(image.astype('uint8'), 'RGB')
# 进行图像分类
result = classifier(image)
# 返回分类结果
text = "{:.3f}%".format(result[0]['score'] * 100)
return result[0]['label'],text
def flip_vit(image):
# 图像预处理
image = Image.fromarray(image.astype('uint8'), 'RGB')
# 进行图像分类
result = vitclassifier(image)
# 返回分类结果
text = "{:.3f}%".format(result[0]['score'] * 100)
return result[0]['label'],text
with demo:
gr.Markdown(model_info)
with gr.Tabs():
with gr.TabItem("myvit"):
myvit_input = gr.Image()
myvit_output1 = gr.Textbox(label="预测结果")
myvit_output2 = gr.Textbox(label="准确度")
myvit_button = gr.Button("开始")
with gr.TabItem("vit"):
vit_input = gr.Image()
vit_output1 = gr.Textbox(label="预测结果")
vit_output2 = gr.Textbox(label="准确度")
vit_button = gr.Button("开始")
myvit_button.click(flip_myvit, inputs=myvit_input, outputs=[myvit_output1,myvit_output2])
vit_button.click(flip_vit, inputs=vit_input, outputs=[vit_output1,vit_output2])
demo.title="猫狗分类器"
demo.launch()