ketanmore's picture
Upload folder using huggingface_hub
2720487 verified
raw
history blame
3.7 kB
from torch import nn
import torch
from typing import Optional, Tuple, Union
import collections
import math
from transformers import DonutSwinPreTrainedModel
from transformers.models.donut.modeling_donut_swin import DonutSwinPatchEmbeddings, DonutSwinEmbeddings, DonutSwinModel, \
DonutSwinEncoder
from surya.model.ordering.config import VariableDonutSwinConfig
class VariableDonutSwinEmbeddings(DonutSwinEmbeddings):
"""
Construct the patch and position embeddings. Optionally, also the mask token.
"""
def __init__(self, config, use_mask_token=False, **kwargs):
super().__init__(config, use_mask_token)
self.patch_embeddings = DonutSwinPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.patch_grid = self.patch_embeddings.grid_size
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim)) if use_mask_token else None
self.position_embeddings = None
if config.use_absolute_embeddings:
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
self.row_embeddings = None
self.column_embeddings = None
if config.use_2d_embeddings:
self.row_embeddings = nn.Parameter(torch.zeros(1, self.patch_grid[0] + 1, config.embed_dim))
self.column_embeddings = nn.Parameter(torch.zeros(1, self.patch_grid[1] + 1, config.embed_dim))
self.norm = nn.LayerNorm(config.embed_dim)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(
self, pixel_values: Optional[torch.FloatTensor], bool_masked_pos: Optional[torch.BoolTensor] = None, **kwargs
) -> Tuple[torch.Tensor]:
embeddings, output_dimensions = self.patch_embeddings(pixel_values)
# Layernorm across the last dimension (each patch is a single row)
embeddings = self.norm(embeddings)
batch_size, seq_len, embed_dim = embeddings.size()
if bool_masked_pos is not None:
mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
# replace the masked visual tokens by mask_tokens
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
if self.position_embeddings is not None:
embeddings = embeddings + self.position_embeddings[:, :seq_len, :]
if self.row_embeddings is not None and self.column_embeddings is not None:
# Repeat the x position embeddings across the y axis like 0, 1, 2, 3, 0, 1, 2, 3, ...
row_embeddings = self.row_embeddings[:, :output_dimensions[0], :].repeat_interleave(output_dimensions[1], dim=1)
column_embeddings = self.column_embeddings[:, :output_dimensions[1], :].repeat(1, output_dimensions[0], 1)
embeddings = embeddings + row_embeddings + column_embeddings
embeddings = self.dropout(embeddings)
return embeddings, output_dimensions
class VariableDonutSwinModel(DonutSwinModel):
config_class = VariableDonutSwinConfig
def __init__(self, config, add_pooling_layer=True, use_mask_token=False, **kwargs):
super().__init__(config)
self.config = config
self.num_layers = len(config.depths)
self.num_features = int(config.embed_dim * 2 ** (self.num_layers - 1))
self.embeddings = VariableDonutSwinEmbeddings(config, use_mask_token=use_mask_token)
self.encoder = DonutSwinEncoder(config, self.embeddings.patch_grid)
self.pooler = nn.AdaptiveAvgPool1d(1) if add_pooling_layer else None
# Initialize weights and apply final processing
self.post_init()