import torch import torch.nn as nn import torch.optim as optim import torchvision.transforms as transforms from torchvision import models from PIL import Image import os import random class ResNet50(nn.Module): def __init__(self): super(ResNet50, self).__init__() self.resnet = models.resnet50(pretrained=True) for param in self.resnet.parameters(): param.requires_grad = False self.resnet.fc = nn.Sequential( nn.Linear(2048, 2) ) def forward(self, x): x = self.resnet(x) return x