|
import gc |
|
import warnings |
|
|
|
from transformers.activations import ACT2FN |
|
from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer |
|
|
|
warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated") |
|
|
|
import math |
|
from typing import Optional, Tuple, Union |
|
|
|
from transformers import SegformerConfig, SegformerForSemanticSegmentation, SegformerDecodeHead, \ |
|
SegformerPreTrainedModel |
|
from surya.model.detection.processor import SegformerImageProcessor |
|
import torch |
|
from torch import nn |
|
|
|
from transformers.modeling_outputs import SemanticSegmenterOutput, BaseModelOutput |
|
from surya.settings import settings |
|
|
|
|
|
def load_model(checkpoint=settings.DETECTOR_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_DETECTION, dtype=settings.MODEL_DTYPE_DETECTION): |
|
config = SegformerConfig.from_pretrained(checkpoint) |
|
model = SegformerForRegressionMask.from_pretrained(checkpoint, torch_dtype=dtype, config=config) |
|
if "mps" in device: |
|
print("Warning: MPS may have poor results. This is a bug with MPS, see here - https://github.com/pytorch/pytorch/issues/84936") |
|
model = model.to(device) |
|
model = model.eval() |
|
print(f"Loaded detection model {checkpoint} on device {device} with dtype {dtype}") |
|
return model |
|
|
|
|
|
def load_processor(checkpoint=settings.DETECTOR_MODEL_CHECKPOINT): |
|
processor = SegformerImageProcessor.from_pretrained(checkpoint) |
|
return processor |
|
|
|
|
|
class SegformerForMaskMLP(nn.Module): |
|
def __init__(self, config: SegformerConfig, input_dim, output_dim): |
|
super().__init__() |
|
self.proj = nn.Linear(input_dim, output_dim) |
|
|
|
def forward(self, hidden_states: torch.Tensor): |
|
hidden_states = hidden_states.flatten(2).transpose(1, 2) |
|
hidden_states = self.proj(hidden_states) |
|
return hidden_states |
|
|
|
|
|
class SegformerForMaskDecodeHead(SegformerDecodeHead): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
decoder_layer_hidden_size = getattr(config, "decoder_layer_hidden_size", config.decoder_hidden_size) |
|
|
|
|
|
mlps = [] |
|
for i in range(config.num_encoder_blocks): |
|
mlp = SegformerForMaskMLP(config, input_dim=config.hidden_sizes[i], output_dim=decoder_layer_hidden_size) |
|
mlps.append(mlp) |
|
self.linear_c = nn.ModuleList(mlps) |
|
|
|
|
|
self.linear_fuse = nn.Conv2d( |
|
in_channels=decoder_layer_hidden_size * config.num_encoder_blocks, |
|
out_channels=config.decoder_hidden_size, |
|
kernel_size=1, |
|
bias=False, |
|
) |
|
self.batch_norm = nn.BatchNorm2d(config.decoder_hidden_size) |
|
self.activation = nn.ReLU() |
|
|
|
self.classifier = nn.Conv2d(config.decoder_hidden_size, config.num_labels, kernel_size=1) |
|
|
|
self.config = config |
|
|
|
def forward(self, encoder_hidden_states: torch.FloatTensor) -> torch.Tensor: |
|
batch_size = encoder_hidden_states[-1].shape[0] |
|
|
|
all_hidden_states = () |
|
for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.linear_c): |
|
if self.config.reshape_last_stage is False and encoder_hidden_state.ndim == 3: |
|
height = width = int(math.sqrt(encoder_hidden_state.shape[-1])) |
|
encoder_hidden_state = ( |
|
encoder_hidden_state.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2).contiguous() |
|
) |
|
|
|
|
|
height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3] |
|
encoder_hidden_state = mlp(encoder_hidden_state) |
|
encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1) |
|
encoder_hidden_state = encoder_hidden_state.reshape(batch_size, -1, height, width) |
|
|
|
encoder_hidden_state = encoder_hidden_state.contiguous() |
|
encoder_hidden_state = nn.functional.interpolate( |
|
encoder_hidden_state, size=encoder_hidden_states[0].size()[2:], mode="bilinear", align_corners=False |
|
) |
|
all_hidden_states += (encoder_hidden_state,) |
|
|
|
hidden_states = self.linear_fuse(torch.cat(all_hidden_states[::-1], dim=1)) |
|
hidden_states = self.batch_norm(hidden_states) |
|
hidden_states = self.activation(hidden_states) |
|
|
|
|
|
logits = self.classifier(hidden_states) |
|
|
|
return logits |
|
|
|
|
|
class SegformerOverlapPatchEmbeddings(nn.Module): |
|
"""Construct the overlapping patch embeddings.""" |
|
|
|
def __init__(self, patch_size, stride, num_channels, hidden_size): |
|
super().__init__() |
|
self.proj = nn.Conv2d( |
|
num_channels, |
|
hidden_size, |
|
kernel_size=patch_size, |
|
stride=stride, |
|
padding=patch_size // 2, |
|
) |
|
|
|
self.layer_norm = nn.LayerNorm(hidden_size) |
|
|
|
def forward(self, pixel_values): |
|
embeddings = self.proj(pixel_values) |
|
_, _, height, width = embeddings.shape |
|
|
|
|
|
embeddings = embeddings.flatten(2).transpose(1, 2) |
|
embeddings = self.layer_norm(embeddings) |
|
return embeddings, height, width |
|
|
|
|
|
class SegformerEfficientSelfAttention(nn.Module): |
|
"""SegFormer's efficient self-attention mechanism. Employs the sequence reduction process introduced in the [PvT |
|
paper](https://arxiv.org/abs/2102.12122).""" |
|
|
|
def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio): |
|
super().__init__() |
|
self.hidden_size = hidden_size |
|
self.num_attention_heads = num_attention_heads |
|
|
|
if self.hidden_size % self.num_attention_heads != 0: |
|
raise ValueError( |
|
f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention " |
|
f"heads ({self.num_attention_heads})" |
|
) |
|
|
|
self.attention_head_size = int(self.hidden_size / self.num_attention_heads) |
|
self.all_head_size = self.num_attention_heads * self.attention_head_size |
|
|
|
self.query = nn.Linear(self.hidden_size, self.all_head_size) |
|
self.key = nn.Linear(self.hidden_size, self.all_head_size) |
|
self.value = nn.Linear(self.hidden_size, self.all_head_size) |
|
|
|
self.sr_ratio = sequence_reduction_ratio |
|
if sequence_reduction_ratio > 1: |
|
self.sr = nn.Conv2d( |
|
hidden_size, hidden_size, kernel_size=sequence_reduction_ratio, stride=sequence_reduction_ratio |
|
) |
|
self.layer_norm = nn.LayerNorm(hidden_size) |
|
|
|
def transpose_for_scores(self, hidden_states): |
|
new_shape = hidden_states.size()[:-1] + (self.num_attention_heads, self.attention_head_size) |
|
hidden_states = hidden_states.view(new_shape) |
|
return hidden_states.permute(0, 2, 1, 3) |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
height, |
|
width, |
|
output_attentions=False, |
|
): |
|
query_layer = self.transpose_for_scores(self.query(hidden_states)) |
|
|
|
if self.sr_ratio > 1: |
|
batch_size, seq_len, num_channels = hidden_states.shape |
|
|
|
hidden_states = hidden_states.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) |
|
|
|
hidden_states = self.sr(hidden_states) |
|
|
|
hidden_states = hidden_states.reshape(batch_size, num_channels, -1).permute(0, 2, 1) |
|
hidden_states = self.layer_norm(hidden_states) |
|
|
|
key_layer = self.transpose_for_scores(self.key(hidden_states)) |
|
value_layer = self.transpose_for_scores(self.value(hidden_states)) |
|
|
|
|
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) |
|
|
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size) |
|
|
|
|
|
attention_probs = nn.functional.softmax(attention_scores, dim=-1) |
|
|
|
context_layer = torch.matmul(attention_probs, value_layer) |
|
|
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) |
|
context_layer = context_layer.view(new_context_layer_shape) |
|
|
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) |
|
|
|
return outputs |
|
|
|
class SegformerEncoder(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
self.config = config |
|
|
|
|
|
embeddings = [] |
|
for i in range(config.num_encoder_blocks): |
|
embeddings.append( |
|
SegformerOverlapPatchEmbeddings( |
|
patch_size=config.patch_sizes[i], |
|
stride=config.strides[i], |
|
num_channels=config.num_channels if i == 0 else config.hidden_sizes[i - 1], |
|
hidden_size=config.hidden_sizes[i], |
|
) |
|
) |
|
self.patch_embeddings = nn.ModuleList(embeddings) |
|
|
|
|
|
blocks = [] |
|
cur = 0 |
|
for i in range(config.num_encoder_blocks): |
|
|
|
layers = [] |
|
if i != 0: |
|
cur += config.depths[i - 1] |
|
for j in range(config.depths[i]): |
|
layers.append( |
|
SegformerLayer( |
|
config, |
|
hidden_size=config.hidden_sizes[i], |
|
num_attention_heads=config.num_attention_heads[i], |
|
sequence_reduction_ratio=config.sr_ratios[i], |
|
mlp_ratio=config.mlp_ratios[i], |
|
) |
|
) |
|
blocks.append(nn.ModuleList(layers)) |
|
|
|
self.block = nn.ModuleList(blocks) |
|
|
|
|
|
self.layer_norm = nn.ModuleList( |
|
[nn.LayerNorm(config.hidden_sizes[i]) for i in range(config.num_encoder_blocks)] |
|
) |
|
|
|
def forward( |
|
self, |
|
pixel_values: torch.FloatTensor, |
|
output_attentions: Optional[bool] = False, |
|
output_hidden_states: Optional[bool] = False, |
|
return_dict: Optional[bool] = True, |
|
) -> Union[Tuple, BaseModelOutput]: |
|
all_hidden_states = () if output_hidden_states else None |
|
|
|
batch_size = pixel_values.shape[0] |
|
|
|
hidden_states = pixel_values |
|
for idx, x in enumerate(zip(self.patch_embeddings, self.block, self.layer_norm)): |
|
embedding_layer, block_layer, norm_layer = x |
|
|
|
hidden_states, height, width = embedding_layer(hidden_states) |
|
|
|
for i, blk in enumerate(block_layer): |
|
layer_outputs = blk(hidden_states, height, width, output_attentions) |
|
hidden_states = layer_outputs[0] |
|
|
|
hidden_states = norm_layer(hidden_states) |
|
|
|
if idx != len(self.patch_embeddings) - 1 or ( |
|
idx == len(self.patch_embeddings) - 1 and self.config.reshape_last_stage |
|
): |
|
hidden_states = hidden_states.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2).contiguous() |
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
return all_hidden_states |
|
|
|
class SegformerSelfOutput(nn.Module): |
|
def __init__(self, config, hidden_size): |
|
super().__init__() |
|
self.dense = nn.Linear(hidden_size, hidden_size) |
|
|
|
def forward(self, hidden_states, input_tensor): |
|
hidden_states = self.dense(hidden_states) |
|
return hidden_states |
|
|
|
|
|
class SegformerAttention(nn.Module): |
|
def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio): |
|
super().__init__() |
|
self.self = SegformerEfficientSelfAttention( |
|
config=config, |
|
hidden_size=hidden_size, |
|
num_attention_heads=num_attention_heads, |
|
sequence_reduction_ratio=sequence_reduction_ratio, |
|
) |
|
self.output = SegformerSelfOutput(config, hidden_size=hidden_size) |
|
self.pruned_heads = set() |
|
|
|
def prune_heads(self, heads): |
|
if len(heads) == 0: |
|
return |
|
heads, index = find_pruneable_heads_and_indices( |
|
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads |
|
) |
|
|
|
|
|
self.self.query = prune_linear_layer(self.self.query, index) |
|
self.self.key = prune_linear_layer(self.self.key, index) |
|
self.self.value = prune_linear_layer(self.self.value, index) |
|
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) |
|
|
|
|
|
self.self.num_attention_heads = self.self.num_attention_heads - len(heads) |
|
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads |
|
self.pruned_heads = self.pruned_heads.union(heads) |
|
|
|
def forward(self, hidden_states, height, width, output_attentions=False): |
|
self_outputs = self.self(hidden_states, height, width, output_attentions) |
|
|
|
attention_output = self.output(self_outputs[0], hidden_states) |
|
outputs = (attention_output,) + self_outputs[1:] |
|
return outputs |
|
|
|
class SegformerDWConv(nn.Module): |
|
def __init__(self, dim=768): |
|
super().__init__() |
|
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) |
|
|
|
def forward(self, hidden_states, height, width): |
|
batch_size, seq_len, num_channels = hidden_states.shape |
|
hidden_states = hidden_states.transpose(1, 2).view(batch_size, num_channels, height, width) |
|
hidden_states = self.dwconv(hidden_states) |
|
hidden_states = hidden_states.flatten(2).transpose(1, 2) |
|
|
|
return hidden_states |
|
|
|
|
|
class SegformerMixFFN(nn.Module): |
|
def __init__(self, config, in_features, hidden_features=None, out_features=None): |
|
super().__init__() |
|
out_features = out_features or in_features |
|
self.dense1 = nn.Linear(in_features, hidden_features) |
|
self.dwconv = SegformerDWConv(hidden_features) |
|
if isinstance(config.hidden_act, str): |
|
self.intermediate_act_fn = ACT2FN[config.hidden_act] |
|
else: |
|
self.intermediate_act_fn = config.hidden_act |
|
self.dense2 = nn.Linear(hidden_features, out_features) |
|
|
|
def forward(self, hidden_states, height, width): |
|
hidden_states = self.dense1(hidden_states) |
|
hidden_states = self.dwconv(hidden_states, height, width) |
|
hidden_states = self.intermediate_act_fn(hidden_states) |
|
hidden_states = self.dense2(hidden_states) |
|
return hidden_states |
|
|
|
|
|
class SegformerLayer(nn.Module): |
|
"""This corresponds to the Block class in the original implementation.""" |
|
|
|
def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio, mlp_ratio): |
|
super().__init__() |
|
self.layer_norm_1 = nn.LayerNorm(hidden_size) |
|
self.attention = SegformerAttention( |
|
config, |
|
hidden_size=hidden_size, |
|
num_attention_heads=num_attention_heads, |
|
sequence_reduction_ratio=sequence_reduction_ratio, |
|
) |
|
self.layer_norm_2 = nn.LayerNorm(hidden_size) |
|
mlp_hidden_size = int(hidden_size * mlp_ratio) |
|
self.mlp = SegformerMixFFN(config, in_features=hidden_size, hidden_features=mlp_hidden_size) |
|
|
|
def forward(self, hidden_states, height, width, output_attentions=False): |
|
self_attention_outputs = self.attention( |
|
self.layer_norm_1(hidden_states), |
|
height, |
|
width, |
|
output_attentions=output_attentions, |
|
) |
|
|
|
attention_output = self_attention_outputs[0] |
|
outputs = self_attention_outputs[1:] |
|
|
|
|
|
hidden_states = attention_output + hidden_states |
|
|
|
mlp_output = self.mlp(self.layer_norm_2(hidden_states), height, width) |
|
|
|
|
|
layer_output = mlp_output + hidden_states |
|
|
|
outputs = (layer_output,) + outputs |
|
|
|
return outputs |
|
|
|
class SegformerModel(SegformerPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
|
|
self.encoder = SegformerEncoder(config) |
|
|
|
|
|
self.post_init() |
|
|
|
def _prune_heads(self, heads_to_prune): |
|
""" |
|
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base |
|
class PreTrainedModel |
|
""" |
|
for layer, heads in heads_to_prune.items(): |
|
self.encoder.layer[layer].attention.prune_heads(heads) |
|
|
|
def forward( |
|
self, |
|
pixel_values: torch.FloatTensor, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, BaseModelOutput]: |
|
encoder_outputs = self.encoder( |
|
pixel_values, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
return encoder_outputs |
|
|
|
class SegformerForRegressionMask(SegformerForSemanticSegmentation): |
|
def __init__(self, config, **kwargs): |
|
super().__init__(config) |
|
self.segformer = SegformerModel(config) |
|
self.decode_head = SegformerForMaskDecodeHead(config) |
|
|
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
pixel_values: torch.FloatTensor, |
|
**kwargs |
|
) -> Union[Tuple, SemanticSegmenterOutput]: |
|
|
|
encoder_hidden_states = self.segformer( |
|
pixel_values, |
|
output_attentions=False, |
|
output_hidden_states=True, |
|
return_dict=False, |
|
) |
|
|
|
logits = self.decode_head(encoder_hidden_states) |
|
|
|
sigmoid_logits = torch.special.expit(logits) |
|
|
|
return SemanticSegmenterOutput( |
|
loss=None, |
|
logits=sigmoid_logits, |
|
hidden_states=None, |
|
attentions=None, |
|
) |