import os from transformers import pipeline from tqdm import tqdm from PIL import Image import matplotlib.pyplot as plt from math import sqrt import gradio as gr import numpy as np model_info = """ **模型名称**: Google/vit-base-patch16-224 **模型介绍**: 本程序根据huggingface上Google开源模型vit,在猫狗图片数据上进行微调,上传一张图片,将会预测其类别并显示结果。模型官网:https://huggingface.co/google/vit-base-patch16-224 **程序作者**: 计科三班 王志建、计科三班 罗楷轩 **特别支持**: 计科三班 黄成栋 """ # 加载图像分类模型 checkpoint_dir = "./checkpoint/checkpoint-181" # 模型检查点目录 classifier = pipeline("image-classification", model=checkpoint_dir) # 创建图像分类器模型 vitclassifier = pipeline("image-classification",model="google/vit-base-patch16-224") demo = gr.Blocks() # 定义推理函数 def flip_myvit(image): # 图像预处理 image = Image.fromarray(image.astype('uint8'), 'RGB') # 进行图像分类 result = classifier(image) # 返回分类结果 text = "{:.3f}%".format(result[0]['score'] * 100) return result[0]['label'],text def flip_vit(image): # 图像预处理 image = Image.fromarray(image.astype('uint8'), 'RGB') # 进行图像分类 result = vitclassifier(image) # 返回分类结果 text = "{:.3f}%".format(result[0]['score'] * 100) return result[0]['label'],text with demo: gr.Markdown(model_info) with gr.Tabs(): with gr.TabItem("myvit"): myvit_input = gr.Image() myvit_output1 = gr.Textbox(label="预测结果") myvit_output2 = gr.Textbox(label="准确度") myvit_button = gr.Button("开始") with gr.TabItem("vit"): vit_input = gr.Image() vit_output1 = gr.Textbox(label="预测结果") vit_output2 = gr.Textbox(label="准确度") vit_button = gr.Button("开始") myvit_button.click(flip_myvit, inputs=myvit_input, outputs=[myvit_output1,myvit_output2]) vit_button.click(flip_vit, inputs=vit_input, outputs=[vit_output1,vit_output2]) demo.title="猫狗分类器" demo.launch()