vedAi / model.py
randomshit11's picture
Upload 3 files
cfdcae4 verified
raw
history blame
571 Bytes
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import os
import random
class ResNet50(nn.Module):
def __init__(self):
super(ResNet50, self).__init__()
self.resnet = models.resnet50(pretrained=True)
for param in self.resnet.parameters():
param.requires_grad = False
self.resnet.fc = nn.Sequential(
nn.Linear(2048, 2)
)
def forward(self, x):
x = self.resnet(x)
return x