Spaces:
Sleeping
Sleeping
import os | |
from transformers import pipeline | |
from tqdm import tqdm | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
from math import sqrt | |
import gradio as gr | |
import numpy as np | |
model_info = """ | |
**模型名称**: Google/vit-base-patch16-224 | |
**模型介绍**: 本程序根据huggingface上Google开源模型vit,在猫狗图片数据上进行微调,上传一张图片,将会预测其类别并显示结果。模型官网:https://huggingface.co/google/vit-base-patch16-224 | |
**程序作者**: 计科三班 王志建、计科三班 罗楷轩 | |
**特别支持**: 计科三班 黄成栋 | |
""" | |
# 加载图像分类模型 | |
checkpoint_dir = "./checkpoint/checkpoint-181" # 模型检查点目录 | |
classifier = pipeline("image-classification", model=checkpoint_dir) # 创建图像分类器模型 | |
vitclassifier = pipeline("image-classification",model="google/vit-base-patch16-224") | |
demo = gr.Blocks() | |
# 定义推理函数 | |
def flip_myvit(image): | |
# 图像预处理 | |
image = Image.fromarray(image.astype('uint8'), 'RGB') | |
# 进行图像分类 | |
result = classifier(image) | |
# 返回分类结果 | |
text = "{:.3f}%".format(result[0]['score'] * 100) | |
return result[0]['label'],text | |
def flip_vit(image): | |
# 图像预处理 | |
image = Image.fromarray(image.astype('uint8'), 'RGB') | |
# 进行图像分类 | |
result = vitclassifier(image) | |
# 返回分类结果 | |
text = "{:.3f}%".format(result[0]['score'] * 100) | |
return result[0]['label'],text | |
with demo: | |
gr.Markdown(model_info) | |
with gr.Tabs(): | |
with gr.TabItem("myvit"): | |
myvit_input = gr.Image() | |
myvit_output1 = gr.Textbox(label="预测结果") | |
myvit_output2 = gr.Textbox(label="准确度") | |
myvit_button = gr.Button("开始") | |
with gr.TabItem("vit"): | |
vit_input = gr.Image() | |
vit_output1 = gr.Textbox(label="预测结果") | |
vit_output2 = gr.Textbox(label="准确度") | |
vit_button = gr.Button("开始") | |
myvit_button.click(flip_myvit, inputs=myvit_input, outputs=[myvit_output1,myvit_output2]) | |
vit_button.click(flip_vit, inputs=vit_input, outputs=[vit_output1,vit_output2]) | |
demo.title="猫狗分类器" | |
demo.launch() | |