Spaces:
Runtime error
Runtime error
Commit
·
5e4b3a1
1
Parent(s):
c82f96b
add files
Browse files- .gitignore +4 -0
- Learn_PyTorch_ImageSegmentation.ipynb +0 -0
- README.md +56 -1
- model.py +30 -0
- requirements.txt +3 -0
- train.py +55 -0
- utils.py +104 -0
.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
flagged
|
| 3 |
+
*.pt
|
| 4 |
+
DS_Store
|
Learn_PyTorch_ImageSegmentation.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
README.md
CHANGED
|
@@ -9,4 +9,59 @@ app_file: app.py
|
|
| 9 |
pinned: false
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
pinned: false
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# Pytorch Image Segmentation
|
| 13 |
+
|
| 14 |
+
## This repo contains the code for training a U-Net model for image segmentation on the Human Segmentation Dataset.
|
| 15 |
+
|
| 16 |
+
<a href="https://colab.research.google.com/github/josebenitezg/Pytorch-Image-Segmentation/blob/main/Learn_PyTorch_ImageSegmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab">
|
| 17 |
+
</a>
|
| 18 |
+
|
| 19 |
+
## Usage :nut_and_bolt:
|
| 20 |
+
|
| 21 |
+
1. Clone this repo
|
| 22 |
+
|
| 23 |
+
```
|
| 24 |
+
git clone https://github.com/josebenitezg/Pytorch-Image-Segmentation
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
2. Create a virtual enviroment
|
| 28 |
+
|
| 29 |
+
```
|
| 30 |
+
python -m venv env
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
3. Activate virtual enviroment
|
| 34 |
+
|
| 35 |
+
- for linux
|
| 36 |
+
|
| 37 |
+
```
|
| 38 |
+
source env/bin/activate
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
- for windows
|
| 42 |
+
|
| 43 |
+
```
|
| 44 |
+
env\Scripts\Activate.bat
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
4. Install requirements
|
| 48 |
+
|
| 49 |
+
```
|
| 50 |
+
pip install -r requirements.txt
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
5. Train the model
|
| 54 |
+
|
| 55 |
+
```
|
| 56 |
+
python train.py
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
6. Run gradio inference app
|
| 60 |
+
|
| 61 |
+
```
|
| 62 |
+
python gradio_inference.py
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
This repo contains dataset files to train a small model.
|
| 66 |
+
|
| 67 |
+
Dataset Credit : https://github.com/VikramShenoy97/Human-Segmentation-Datasets
|
model.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
import segmentation_models_pytorch as smp
|
| 3 |
+
from segmentation_models_pytorch.losses import DiceLoss
|
| 4 |
+
|
| 5 |
+
ENCODER = 'timm-efficientnet-b0'
|
| 6 |
+
WEIGHTS = 'imagenet'
|
| 7 |
+
|
| 8 |
+
class SegmentationModel(nn.Module):
|
| 9 |
+
|
| 10 |
+
def __init__(self):
|
| 11 |
+
super(SegmentationModel, self).__init__()
|
| 12 |
+
|
| 13 |
+
self.arc = smp.Unet(
|
| 14 |
+
encoder_name = ENCODER,
|
| 15 |
+
encoder_weights = WEIGHTS,
|
| 16 |
+
in_channels = 3,
|
| 17 |
+
classes = 1,
|
| 18 |
+
activation = None
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
def forward(self, images, masks = None):
|
| 22 |
+
|
| 23 |
+
logits = self.arc(images)
|
| 24 |
+
|
| 25 |
+
if masks != None:
|
| 26 |
+
loss1 = DiceLoss(mode='binary')(logits, masks)
|
| 27 |
+
loss2 = nn.BCEWithLogitsLoss()(logits, masks)
|
| 28 |
+
return logits, loss1 + loss2
|
| 29 |
+
|
| 30 |
+
return logits
|
requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
albumentations==1.3.0
|
| 2 |
+
segmentation-models-pytorch==0.3.2
|
| 3 |
+
opencv-contrib-python
|
train.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import cv2
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
|
| 9 |
+
from utils import load_config, get_train_augs, get_valid_augs, train_fn, eval_fn, SegmentationDataset
|
| 10 |
+
from model import SegmentationModel
|
| 11 |
+
from sklearn.model_selection import train_test_split
|
| 12 |
+
from torch.utils.data import DataLoader
|
| 13 |
+
|
| 14 |
+
# set device for training
|
| 15 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 16 |
+
|
| 17 |
+
# load config file
|
| 18 |
+
config = load_config()
|
| 19 |
+
|
| 20 |
+
# load train files in dataframe
|
| 21 |
+
df = pd.read_csv(config['files']['CSV_FILE'])
|
| 22 |
+
|
| 23 |
+
train_df, valid_df = train_test_split(df, test_size = 0.2, random_state = 42)
|
| 24 |
+
|
| 25 |
+
trainset = SegmentationDataset(train_df, get_train_augs(config['model']['IMAGE_SIZE']))
|
| 26 |
+
|
| 27 |
+
validset = SegmentationDataset(valid_df, get_valid_augs(config['model']['IMAGE_SIZE']))
|
| 28 |
+
|
| 29 |
+
print(f"Size of Trainset : {len(trainset)}")
|
| 30 |
+
print(f"Size of Validset : {len(validset)}")
|
| 31 |
+
|
| 32 |
+
trainloader = DataLoader(trainset, batch_size=config['model']['BATCH_SIZE'], shuffle = True)
|
| 33 |
+
validloader = DataLoader(validset, batch_size=config['model']['BATCH_SIZE'])
|
| 34 |
+
|
| 35 |
+
print(f"Total n of batches in trainloader: {len(trainloader)}")
|
| 36 |
+
print(f"Total n of batches in validloader: {len(validloader)}")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
model = SegmentationModel()
|
| 40 |
+
model.to(DEVICE)
|
| 41 |
+
|
| 42 |
+
optimizer = torch.optim.Adam(model.parameters(), lr = config['model']['LR'])
|
| 43 |
+
|
| 44 |
+
best_valid_loss = np.Inf
|
| 45 |
+
|
| 46 |
+
for i in tqdm(range(config['model']['EPOCHS'])):
|
| 47 |
+
|
| 48 |
+
train_loss = train_fn(trainloader, model, optimizer, DEVICE)
|
| 49 |
+
valid_loss = eval_fn(validloader, model, DEVICE)
|
| 50 |
+
|
| 51 |
+
if valid_loss < best_valid_loss:
|
| 52 |
+
torch.save(model.state_dict(), 'best_model.pt')
|
| 53 |
+
print('SAVED-MODEL')
|
| 54 |
+
best_valid_loss = valid_loss
|
| 55 |
+
print(f"Epoch: {i+1} Train Loss: {train_loss} Valid Loss: {valid_loss}")
|
utils.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import cv2
|
| 3 |
+
import torch
|
| 4 |
+
import yaml
|
| 5 |
+
import numpy as np
|
| 6 |
+
import albumentations as A
|
| 7 |
+
from torch.utils.data import Dataset
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_train_augs(IMAGE_SIZE):
|
| 11 |
+
|
| 12 |
+
return A.Compose([
|
| 13 |
+
A.Resize(IMAGE_SIZE, IMAGE_SIZE),
|
| 14 |
+
A.HorizontalFlip(p = 0.5),
|
| 15 |
+
A.VerticalFlip(p = 0.5)
|
| 16 |
+
])
|
| 17 |
+
|
| 18 |
+
def get_valid_augs(IMAGE_SIZE):
|
| 19 |
+
|
| 20 |
+
return A.Compose([
|
| 21 |
+
A.Resize(IMAGE_SIZE, IMAGE_SIZE),
|
| 22 |
+
])
|
| 23 |
+
|
| 24 |
+
def train_fn(data_loader, model, optimizer, DEVICE):
|
| 25 |
+
|
| 26 |
+
model.train()
|
| 27 |
+
total_loss = 0.0
|
| 28 |
+
|
| 29 |
+
for images, masks in data_loader:
|
| 30 |
+
|
| 31 |
+
images = images.to(DEVICE)
|
| 32 |
+
masks = masks.to(DEVICE)
|
| 33 |
+
|
| 34 |
+
optimizer.zero_grad()
|
| 35 |
+
logits, loss = model(images, masks)
|
| 36 |
+
loss.backward()
|
| 37 |
+
optimizer.step()
|
| 38 |
+
total_loss += loss.item()
|
| 39 |
+
|
| 40 |
+
return total_loss / len(data_loader)
|
| 41 |
+
|
| 42 |
+
def eval_fn(data_loader, model, DEVICE):
|
| 43 |
+
|
| 44 |
+
model.eval()
|
| 45 |
+
total_loss = 0.0
|
| 46 |
+
with torch.no_grad():
|
| 47 |
+
for images, masks in data_loader:
|
| 48 |
+
|
| 49 |
+
images = images.to(DEVICE)
|
| 50 |
+
masks = masks.to(DEVICE)
|
| 51 |
+
|
| 52 |
+
logits, loss = model(images, masks)
|
| 53 |
+
|
| 54 |
+
total_loss += loss.item()
|
| 55 |
+
|
| 56 |
+
return total_loss / len(data_loader)
|
| 57 |
+
|
| 58 |
+
def load_config():
|
| 59 |
+
config_file = f'config/config.yaml'
|
| 60 |
+
|
| 61 |
+
with open(config_file, 'r') as file:
|
| 62 |
+
config = yaml.safe_load(file)
|
| 63 |
+
|
| 64 |
+
return config
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class SegmentationDataset(Dataset):
|
| 68 |
+
|
| 69 |
+
def __init__(self, df, augmentations):
|
| 70 |
+
|
| 71 |
+
self.df = df
|
| 72 |
+
self.augmentations = augmentations
|
| 73 |
+
|
| 74 |
+
def __len__(self):
|
| 75 |
+
return len(self.df)
|
| 76 |
+
|
| 77 |
+
def __getitem__(self, idx):
|
| 78 |
+
|
| 79 |
+
row = self.df.iloc[idx]
|
| 80 |
+
|
| 81 |
+
image_path = row.images
|
| 82 |
+
mask_path = row.masks
|
| 83 |
+
|
| 84 |
+
image = cv2.imread(image_path)
|
| 85 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 86 |
+
|
| 87 |
+
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) #(h, w, c)
|
| 88 |
+
# Resize the mask to the same dimensions as the image
|
| 89 |
+
mask = cv2.resize(mask, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST) # (h, w)
|
| 90 |
+
mask = np.expand_dims(mask, axis = -1)
|
| 91 |
+
|
| 92 |
+
if self.augmentations:
|
| 93 |
+
data = self.augmentations(image = image, mask = mask)
|
| 94 |
+
image = data['image']
|
| 95 |
+
mask = data['mask']
|
| 96 |
+
|
| 97 |
+
# (h, w, c) -> (c, h, w)
|
| 98 |
+
image = np.transpose(image, (2,0,1)).astype(np.float32)
|
| 99 |
+
mask = np.transpose(mask, (2,0,1)).astype(np.float32)
|
| 100 |
+
|
| 101 |
+
image = torch.Tensor(image) / 255.0
|
| 102 |
+
mask = torch.round(torch.Tensor(mask) / 255.0)
|
| 103 |
+
|
| 104 |
+
return image, mask
|