import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms from PIL import Image import gradio as gr class FireModule(nn.Module): def __init__(self, in_channels, s1x1, e1x1, e3x3): super(FireModule, self).__init__() self.squeeze = nn.Conv2d(in_channels=in_channels, out_channels=s1x1, kernel_size=1, stride=1) self.expand1x1 = nn.Conv2d(in_channels=s1x1, out_channels=e1x1, kernel_size=1) self.expand3x3 = nn.Conv2d(in_channels=s1x1, out_channels=e3x3, kernel_size=3, padding=1) def forward(self, x): x = F.relu(self.squeeze(x)) x1 = self.expand1x1(x) x2 = self.expand3x3(x) x = F.relu(torch.cat((x1, x2), dim=1)) return x class SqueezeNet(nn.Module): def __init__(self, out_channels): super(SqueezeNet, self).__init__() self.conv1 = nn.Conv2d(in_channels=3, out_channels=96, kernel_size=7, stride=2) self.max_pool1 = nn.MaxPool2d(kernel_size=3, stride=2) self.fire2 = FireModule(in_channels=96, s1x1=16, e1x1=64, e3x3=64) self.fire3 = FireModule(in_channels=128, s1x1=16, e1x1=64, e3x3=64) self.fire4 = FireModule(in_channels=128, s1x1=32, e1x1=128, e3x3=128) self.max_pool2 = nn.MaxPool2d(kernel_size=3, stride=2) self.fire5 = FireModule(in_channels=256, s1x1=32, e1x1=128, e3x3=128) self.fire6 = FireModule(in_channels=256, s1x1=48, e1x1=192, e3x3=192) self.fire7 = FireModule(in_channels=384, s1x1=48, e1x1=192, e3x3=192) self.fire8 = FireModule(in_channels=384, s1x1=64, e1x1=256, e3x3=256) self.max_pool3 = nn.MaxPool2d(kernel_size=3, stride=2) self.fire9 = FireModule(in_channels=512, s1x1=64, e1x1=256, e3x3=256) self.conv10 = nn.Conv2d(in_channels=512, out_channels=out_channels, kernel_size=1, stride=1) self.avgpool = nn.AvgPool2d(kernel_size=12, stride=1) def forward(self, x): x = self.max_pool1(self.conv1(x)) x = self.max_pool2(self.fire4(self.fire3(self.fire2(x)))) x = self.max_pool3(self.fire8(self.fire7(self.fire6(self.fire5(x))))) x = self.avgpool(self.conv10(self.fire9(x))) return torch.flatten(x, start_dim=1) model = SqueezeNet(out_channels=1) model.load_state_dict(torch.load("squeezenet.pth", map_location=torch.device('cpu'))) model.eval() transform = transforms.Compose([transforms.Resize((224,224)), transforms.RandomHorizontalFlip(0.2), transforms.RandomVerticalFlip(0.1), transforms.RandomAutocontrast(0.2), transforms.RandomAdjustSharpness(0.3), transforms.ToTensor() ]) def classify_brain_tumor(image): image = transform(image).unsqueeze(0) with torch.no_grad(): output = model(image) prediction = torch.sigmoid(output).item() return "Tumor" if prediction >= 0.5 else "No Tumor" interface = gr.Interface( fn=classify_brain_tumor, inputs=gr.Image(type="pil"), outputs="text", title="Brain Tumor Classification", description="Upload an MRI image to classify if it has a tumor or not." ) interface.launch()