|
import math |
|
import types |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from .lseg_blocks import FeatureFusionBlock, Interpolate, _make_encoder, FeatureFusionBlock_custom, forward_vit |
|
import clip |
|
import numpy as np |
|
import pandas as pd |
|
import os |
|
|
|
class depthwise_clipseg_conv(nn.Module): |
|
def __init__(self): |
|
super(depthwise_clipseg_conv, self).__init__() |
|
self.depthwise = nn.Conv2d(1, 1, kernel_size=3, padding=1) |
|
|
|
def depthwise_clipseg(self, x, channels): |
|
x = torch.cat([self.depthwise(x[:, i].unsqueeze(1)) for i in range(channels)], dim=1) |
|
return x |
|
|
|
def forward(self, x): |
|
channels = x.shape[1] |
|
out = self.depthwise_clipseg(x, channels) |
|
return out |
|
|
|
|
|
class depthwise_conv(nn.Module): |
|
def __init__(self, kernel_size=3, stride=1, padding=1): |
|
super(depthwise_conv, self).__init__() |
|
self.depthwise = nn.Conv2d(1, 1, kernel_size=kernel_size, stride=stride, padding=padding) |
|
|
|
def forward(self, x): |
|
|
|
C, H, W = x.shape[1:] |
|
x = x.reshape(-1, 1, H, W) |
|
x = self.depthwise(x) |
|
x = x.view(-1, C, H, W) |
|
return x |
|
|
|
|
|
class depthwise_block(nn.Module): |
|
def __init__(self, kernel_size=3, stride=1, padding=1, activation='relu'): |
|
super(depthwise_block, self).__init__() |
|
self.depthwise = depthwise_conv(kernel_size=3, stride=1, padding=1) |
|
if activation == 'relu': |
|
self.activation = nn.ReLU() |
|
elif activation == 'lrelu': |
|
self.activation = nn.LeakyReLU() |
|
elif activation == 'tanh': |
|
self.activation = nn.Tanh() |
|
|
|
def forward(self, x, act=True): |
|
x = self.depthwise(x) |
|
if act: |
|
x = self.activation(x) |
|
return x |
|
|
|
|
|
class bottleneck_block(nn.Module): |
|
def __init__(self, kernel_size=3, stride=1, padding=1, activation='relu'): |
|
super(bottleneck_block, self).__init__() |
|
self.depthwise = depthwise_conv(kernel_size=3, stride=1, padding=1) |
|
if activation == 'relu': |
|
self.activation = nn.ReLU() |
|
elif activation == 'lrelu': |
|
self.activation = nn.LeakyReLU() |
|
elif activation == 'tanh': |
|
self.activation = nn.Tanh() |
|
|
|
|
|
def forward(self, x, act=True): |
|
sum_layer = x.max(dim=1, keepdim=True)[0] |
|
x = self.depthwise(x) |
|
x = x + sum_layer |
|
if act: |
|
x = self.activation(x) |
|
return x |
|
|
|
class BaseModel(torch.nn.Module): |
|
def load(self, path): |
|
"""Load model from file. |
|
Args: |
|
path (str): file path |
|
""" |
|
parameters = torch.load(path, map_location=torch.device("cpu")) |
|
|
|
if "optimizer" in parameters: |
|
parameters = parameters["model"] |
|
|
|
self.load_state_dict(parameters) |
|
|
|
def _make_fusion_block(features, use_bn): |
|
return FeatureFusionBlock_custom( |
|
features, |
|
activation=nn.ReLU(False), |
|
deconv=False, |
|
bn=use_bn, |
|
expand=False, |
|
align_corners=True, |
|
) |
|
|
|
class LSeg(BaseModel): |
|
def __init__( |
|
self, |
|
head, |
|
features=256, |
|
backbone="clip_vitl16_384", |
|
readout="project", |
|
channels_last=False, |
|
use_bn=False, |
|
**kwargs, |
|
): |
|
super(LSeg, self).__init__() |
|
|
|
self.channels_last = channels_last |
|
|
|
hooks = { |
|
"clip_vitl16_384": [5, 11, 17, 23], |
|
"clipRN50x16_vitl16_384": [5, 11, 17, 23], |
|
"clip_vitb32_384": [2, 5, 8, 11], |
|
} |
|
|
|
|
|
self.clip_pretrained, self.pretrained, self.scratch = _make_encoder( |
|
backbone, |
|
features, |
|
groups=1, |
|
expand=False, |
|
exportable=False, |
|
hooks=hooks[backbone], |
|
use_readout=readout, |
|
) |
|
|
|
self.scratch.refinenet1 = _make_fusion_block(features, use_bn) |
|
self.scratch.refinenet2 = _make_fusion_block(features, use_bn) |
|
self.scratch.refinenet3 = _make_fusion_block(features, use_bn) |
|
self.scratch.refinenet4 = _make_fusion_block(features, use_bn) |
|
|
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)).exp() |
|
if backbone in ["clipRN50x16_vitl16_384"]: |
|
self.out_c = 768 |
|
else: |
|
self.out_c = 512 |
|
self.scratch.head1 = nn.Conv2d(features, self.out_c, kernel_size=1) |
|
|
|
self.arch_option = kwargs["arch_option"] |
|
if self.arch_option == 1: |
|
self.scratch.head_block = bottleneck_block(activation=kwargs["activation"]) |
|
self.block_depth = kwargs['block_depth'] |
|
elif self.arch_option == 2: |
|
self.scratch.head_block = depthwise_block(activation=kwargs["activation"]) |
|
self.block_depth = kwargs['block_depth'] |
|
|
|
self.scratch.output_conv = head |
|
|
|
self.text = clip.tokenize(self.labels) |
|
|
|
def forward(self, x, labelset=''): |
|
if labelset == '': |
|
text = self.text |
|
else: |
|
text = clip.tokenize(labelset) |
|
|
|
if self.channels_last == True: |
|
x.contiguous(memory_format=torch.channels_last) |
|
|
|
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) |
|
|
|
layer_1_rn = self.scratch.layer1_rn(layer_1) |
|
layer_2_rn = self.scratch.layer2_rn(layer_2) |
|
layer_3_rn = self.scratch.layer3_rn(layer_3) |
|
layer_4_rn = self.scratch.layer4_rn(layer_4) |
|
|
|
path_4 = self.scratch.refinenet4(layer_4_rn) |
|
path_3 = self.scratch.refinenet3(path_4, layer_3_rn) |
|
path_2 = self.scratch.refinenet2(path_3, layer_2_rn) |
|
path_1 = self.scratch.refinenet1(path_2, layer_1_rn) |
|
|
|
text = text.to(x.device) |
|
self.logit_scale = self.logit_scale.to(x.device) |
|
text_features = self.clip_pretrained.encode_text(text) |
|
|
|
image_features = self.scratch.head1(path_1) |
|
|
|
imshape = image_features.shape |
|
image_features = image_features.permute(0,2,3,1).reshape(-1, self.out_c) |
|
|
|
|
|
image_features = image_features / image_features.norm(dim=-1, keepdim=True) |
|
text_features = text_features / text_features.norm(dim=-1, keepdim=True) |
|
|
|
logits_per_image = self.logit_scale * image_features.half() @ text_features.t() |
|
|
|
out = logits_per_image.float().view(imshape[0], imshape[2], imshape[3], -1).permute(0,3,1,2) |
|
|
|
if self.arch_option in [1, 2]: |
|
for _ in range(self.block_depth - 1): |
|
out = self.scratch.head_block(out) |
|
out = self.scratch.head_block(out, False) |
|
|
|
out = self.scratch.output_conv(out) |
|
|
|
return out |
|
|
|
|
|
class LSegNet(LSeg): |
|
"""Network for semantic segmentation.""" |
|
def __init__(self, labels, path=None, scale_factor=0.5, crop_size=480, **kwargs): |
|
|
|
features = kwargs["features"] if "features" in kwargs else 256 |
|
kwargs["use_bn"] = True |
|
|
|
self.crop_size = crop_size |
|
self.scale_factor = scale_factor |
|
self.labels = labels |
|
|
|
head = nn.Sequential( |
|
Interpolate(scale_factor=2, mode="bilinear", align_corners=True), |
|
) |
|
|
|
super().__init__(head, **kwargs) |
|
|
|
if path is not None: |
|
self.load(path) |
|
|
|
|
|
|
|
|
|
|