mask-detection / app.py
bnsapa's picture
add code to app.py
51c7034
raw
history blame
1.5 kB
import gradio as gr
from torchvision import models
import torch.nn as nn
import torch
import os
from PIL import Image
from torchvision.transforms import transforms
from dotenv import load_dotenv
load_dotenv()
share = os.getenv("SHARE", False)
pretrained_model = models.vgg19(pretrained=True)
class NeuralNet(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
pretrained_model,
nn.Flatten(),
nn.Linear(1000, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.model(x)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = NeuralNet()
model.load_state_dict(torch.load("mask_detection.pth", map_location=device))
model = model.to(device)
transform=transforms.Compose([
transforms.Resize((150,150)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
])
def greet(image):
image = Image.fromarray(image.astype('uint8'), 'RGB')
image.save("input.png")
image = Image.open("input.png")
input = transform(image).unsqueeze(0)
output = model(input.to(device))
probability = output.item()
if probability < 0.5:
return "Person in the pic has mask"
else:
return "Person in the pic does not have mask"
iface = gr.Interface(fn=greet, inputs="image", outputs="text")
iface.launch()