import gc import warnings from transformers.activations import ACT2FN from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer warnings.filterwarnings("ignore", message="torch.utils._pytree._register_pytree_node is deprecated") import math from typing import Optional, Tuple, Union from transformers import SegformerConfig, SegformerForSemanticSegmentation, SegformerDecodeHead, \ SegformerPreTrainedModel from surya.model.detection.processor import SegformerImageProcessor import torch from torch import nn from transformers.modeling_outputs import SemanticSegmenterOutput, BaseModelOutput from surya.settings import settings def load_model(checkpoint=settings.DETECTOR_MODEL_CHECKPOINT, device=settings.TORCH_DEVICE_DETECTION, dtype=settings.MODEL_DTYPE_DETECTION): config = SegformerConfig.from_pretrained(checkpoint) model = SegformerForRegressionMask.from_pretrained(checkpoint, torch_dtype=dtype, config=config) if "mps" in device: print("Warning: MPS may have poor results. This is a bug with MPS, see here - https://github.com/pytorch/pytorch/issues/84936") model = model.to(device) model = model.eval() print(f"Loaded detection model {checkpoint} on device {device} with dtype {dtype}") return model def load_processor(checkpoint=settings.DETECTOR_MODEL_CHECKPOINT): processor = SegformerImageProcessor.from_pretrained(checkpoint) return processor class SegformerForMaskMLP(nn.Module): def __init__(self, config: SegformerConfig, input_dim, output_dim): super().__init__() self.proj = nn.Linear(input_dim, output_dim) def forward(self, hidden_states: torch.Tensor): hidden_states = hidden_states.flatten(2).transpose(1, 2) hidden_states = self.proj(hidden_states) return hidden_states class SegformerForMaskDecodeHead(SegformerDecodeHead): def __init__(self, config): super().__init__(config) decoder_layer_hidden_size = getattr(config, "decoder_layer_hidden_size", config.decoder_hidden_size) # linear layers which will unify the channel dimension of each of the encoder blocks to the same config.decoder_hidden_size mlps = [] for i in range(config.num_encoder_blocks): mlp = SegformerForMaskMLP(config, input_dim=config.hidden_sizes[i], output_dim=decoder_layer_hidden_size) mlps.append(mlp) self.linear_c = nn.ModuleList(mlps) # the following 3 layers implement the ConvModule of the original implementation self.linear_fuse = nn.Conv2d( in_channels=decoder_layer_hidden_size * config.num_encoder_blocks, out_channels=config.decoder_hidden_size, kernel_size=1, bias=False, ) self.batch_norm = nn.BatchNorm2d(config.decoder_hidden_size) self.activation = nn.ReLU() self.classifier = nn.Conv2d(config.decoder_hidden_size, config.num_labels, kernel_size=1) self.config = config def forward(self, encoder_hidden_states: torch.FloatTensor) -> torch.Tensor: batch_size = encoder_hidden_states[-1].shape[0] all_hidden_states = () for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.linear_c): if self.config.reshape_last_stage is False and encoder_hidden_state.ndim == 3: height = width = int(math.sqrt(encoder_hidden_state.shape[-1])) encoder_hidden_state = ( encoder_hidden_state.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2).contiguous() ) # unify channel dimension height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3] encoder_hidden_state = mlp(encoder_hidden_state) encoder_hidden_state = encoder_hidden_state.permute(0, 2, 1) encoder_hidden_state = encoder_hidden_state.reshape(batch_size, -1, height, width) # upsample encoder_hidden_state = encoder_hidden_state.contiguous() encoder_hidden_state = nn.functional.interpolate( encoder_hidden_state, size=encoder_hidden_states[0].size()[2:], mode="bilinear", align_corners=False ) all_hidden_states += (encoder_hidden_state,) hidden_states = self.linear_fuse(torch.cat(all_hidden_states[::-1], dim=1)) hidden_states = self.batch_norm(hidden_states) hidden_states = self.activation(hidden_states) # logits are of shape (batch_size, num_labels, height/4, width/4) logits = self.classifier(hidden_states) return logits class SegformerOverlapPatchEmbeddings(nn.Module): """Construct the overlapping patch embeddings.""" def __init__(self, patch_size, stride, num_channels, hidden_size): super().__init__() self.proj = nn.Conv2d( num_channels, hidden_size, kernel_size=patch_size, stride=stride, padding=patch_size // 2, ) self.layer_norm = nn.LayerNorm(hidden_size) def forward(self, pixel_values): embeddings = self.proj(pixel_values) _, _, height, width = embeddings.shape # (batch_size, num_channels, height, width) -> (batch_size, num_channels, height*width) -> (batch_size, height*width, num_channels) # this can be fed to a Transformer layer embeddings = embeddings.flatten(2).transpose(1, 2) embeddings = self.layer_norm(embeddings) return embeddings, height, width class SegformerEfficientSelfAttention(nn.Module): """SegFormer's efficient self-attention mechanism. Employs the sequence reduction process introduced in the [PvT paper](https://arxiv.org/abs/2102.12122).""" def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio): super().__init__() self.hidden_size = hidden_size self.num_attention_heads = num_attention_heads if self.hidden_size % self.num_attention_heads != 0: raise ValueError( f"The hidden size ({self.hidden_size}) is not a multiple of the number of attention " f"heads ({self.num_attention_heads})" ) self.attention_head_size = int(self.hidden_size / self.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = nn.Linear(self.hidden_size, self.all_head_size) self.key = nn.Linear(self.hidden_size, self.all_head_size) self.value = nn.Linear(self.hidden_size, self.all_head_size) self.sr_ratio = sequence_reduction_ratio if sequence_reduction_ratio > 1: self.sr = nn.Conv2d( hidden_size, hidden_size, kernel_size=sequence_reduction_ratio, stride=sequence_reduction_ratio ) self.layer_norm = nn.LayerNorm(hidden_size) def transpose_for_scores(self, hidden_states): new_shape = hidden_states.size()[:-1] + (self.num_attention_heads, self.attention_head_size) hidden_states = hidden_states.view(new_shape) return hidden_states.permute(0, 2, 1, 3) def forward( self, hidden_states, height, width, output_attentions=False, ): query_layer = self.transpose_for_scores(self.query(hidden_states)) if self.sr_ratio > 1: batch_size, seq_len, num_channels = hidden_states.shape # Reshape to (batch_size, num_channels, height, width) hidden_states = hidden_states.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) # Apply sequence reduction hidden_states = self.sr(hidden_states) # Reshape back to (batch_size, seq_len, num_channels) hidden_states = hidden_states.reshape(batch_size, num_channels, -1).permute(0, 2, 1) hidden_states = self.layer_norm(hidden_states) key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) # Normalize the attention scores to probabilities. attention_probs = nn.functional.softmax(attention_scores, dim=-1) context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) return outputs class SegformerEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config # patch embeddings embeddings = [] for i in range(config.num_encoder_blocks): embeddings.append( SegformerOverlapPatchEmbeddings( patch_size=config.patch_sizes[i], stride=config.strides[i], num_channels=config.num_channels if i == 0 else config.hidden_sizes[i - 1], hidden_size=config.hidden_sizes[i], ) ) self.patch_embeddings = nn.ModuleList(embeddings) # Transformer blocks blocks = [] cur = 0 for i in range(config.num_encoder_blocks): # each block consists of layers layers = [] if i != 0: cur += config.depths[i - 1] for j in range(config.depths[i]): layers.append( SegformerLayer( config, hidden_size=config.hidden_sizes[i], num_attention_heads=config.num_attention_heads[i], sequence_reduction_ratio=config.sr_ratios[i], mlp_ratio=config.mlp_ratios[i], ) ) blocks.append(nn.ModuleList(layers)) self.block = nn.ModuleList(blocks) # Layer norms self.layer_norm = nn.ModuleList( [nn.LayerNorm(config.hidden_sizes[i]) for i in range(config.num_encoder_blocks)] ) def forward( self, pixel_values: torch.FloatTensor, output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, ) -> Union[Tuple, BaseModelOutput]: all_hidden_states = () if output_hidden_states else None batch_size = pixel_values.shape[0] hidden_states = pixel_values for idx, x in enumerate(zip(self.patch_embeddings, self.block, self.layer_norm)): embedding_layer, block_layer, norm_layer = x # first, obtain patch embeddings hidden_states, height, width = embedding_layer(hidden_states) # second, send embeddings through blocks for i, blk in enumerate(block_layer): layer_outputs = blk(hidden_states, height, width, output_attentions) hidden_states = layer_outputs[0] # third, apply layer norm hidden_states = norm_layer(hidden_states) # fourth, optionally reshape back to (batch_size, num_channels, height, width) if idx != len(self.patch_embeddings) - 1 or ( idx == len(self.patch_embeddings) - 1 and self.config.reshape_last_stage ): hidden_states = hidden_states.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2).contiguous() all_hidden_states = all_hidden_states + (hidden_states,) return all_hidden_states class SegformerSelfOutput(nn.Module): def __init__(self, config, hidden_size): super().__init__() self.dense = nn.Linear(hidden_size, hidden_size) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) return hidden_states class SegformerAttention(nn.Module): def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio): super().__init__() self.self = SegformerEfficientSelfAttention( config=config, hidden_size=hidden_size, num_attention_heads=num_attention_heads, sequence_reduction_ratio=sequence_reduction_ratio, ) self.output = SegformerSelfOutput(config, hidden_size=hidden_size) self.pruned_heads = set() def prune_heads(self, heads): if len(heads) == 0: return heads, index = find_pruneable_heads_and_indices( heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads ) # Prune linear layers self.self.query = prune_linear_layer(self.self.query, index) self.self.key = prune_linear_layer(self.self.key, index) self.self.value = prune_linear_layer(self.self.value, index) self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) # Update hyper params and store pruned heads self.self.num_attention_heads = self.self.num_attention_heads - len(heads) self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) def forward(self, hidden_states, height, width, output_attentions=False): self_outputs = self.self(hidden_states, height, width, output_attentions) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them return outputs class SegformerDWConv(nn.Module): def __init__(self, dim=768): super().__init__() self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) def forward(self, hidden_states, height, width): batch_size, seq_len, num_channels = hidden_states.shape hidden_states = hidden_states.transpose(1, 2).view(batch_size, num_channels, height, width) hidden_states = self.dwconv(hidden_states) hidden_states = hidden_states.flatten(2).transpose(1, 2) return hidden_states class SegformerMixFFN(nn.Module): def __init__(self, config, in_features, hidden_features=None, out_features=None): super().__init__() out_features = out_features or in_features self.dense1 = nn.Linear(in_features, hidden_features) self.dwconv = SegformerDWConv(hidden_features) if isinstance(config.hidden_act, str): self.intermediate_act_fn = ACT2FN[config.hidden_act] else: self.intermediate_act_fn = config.hidden_act self.dense2 = nn.Linear(hidden_features, out_features) def forward(self, hidden_states, height, width): hidden_states = self.dense1(hidden_states) hidden_states = self.dwconv(hidden_states, height, width) hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.dense2(hidden_states) return hidden_states class SegformerLayer(nn.Module): """This corresponds to the Block class in the original implementation.""" def __init__(self, config, hidden_size, num_attention_heads, sequence_reduction_ratio, mlp_ratio): super().__init__() self.layer_norm_1 = nn.LayerNorm(hidden_size) self.attention = SegformerAttention( config, hidden_size=hidden_size, num_attention_heads=num_attention_heads, sequence_reduction_ratio=sequence_reduction_ratio, ) self.layer_norm_2 = nn.LayerNorm(hidden_size) mlp_hidden_size = int(hidden_size * mlp_ratio) self.mlp = SegformerMixFFN(config, in_features=hidden_size, hidden_features=mlp_hidden_size) def forward(self, hidden_states, height, width, output_attentions=False): self_attention_outputs = self.attention( self.layer_norm_1(hidden_states), # in Segformer, layernorm is applied before self-attention height, width, output_attentions=output_attentions, ) attention_output = self_attention_outputs[0] outputs = self_attention_outputs[1:] # add self attentions if we output attention weights # first residual connection (with stochastic depth) hidden_states = attention_output + hidden_states mlp_output = self.mlp(self.layer_norm_2(hidden_states), height, width) # second residual connection (with stochastic depth) layer_output = mlp_output + hidden_states outputs = (layer_output,) + outputs return outputs class SegformerModel(SegformerPreTrainedModel): def __init__(self, config): super().__init__(config) self.config = config # hierarchical Transformer encoder self.encoder = SegformerEncoder(config) # Initialize weights and apply final processing self.post_init() def _prune_heads(self, heads_to_prune): """ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base class PreTrainedModel """ for layer, heads in heads_to_prune.items(): self.encoder.layer[layer].attention.prune_heads(heads) def forward( self, pixel_values: torch.FloatTensor, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutput]: encoder_outputs = self.encoder( pixel_values, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) return encoder_outputs class SegformerForRegressionMask(SegformerForSemanticSegmentation): def __init__(self, config, **kwargs): super().__init__(config) self.segformer = SegformerModel(config) self.decode_head = SegformerForMaskDecodeHead(config) # Initialize weights and apply final processing self.post_init() def forward( self, pixel_values: torch.FloatTensor, **kwargs ) -> Union[Tuple, SemanticSegmenterOutput]: encoder_hidden_states = self.segformer( pixel_values, output_attentions=False, output_hidden_states=True, # we need the intermediate hidden states return_dict=False, ) logits = self.decode_head(encoder_hidden_states) # Apply sigmoid to get 0-1 output sigmoid_logits = torch.special.expit(logits) return SemanticSegmenterOutput( loss=None, logits=sigmoid_logits, hidden_states=None, attentions=None, )