# Loading Packages

In [1]:
import os
os.environ['HF_HOME'] = '/data2/ketan/orc/HF_Cache'
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
# from transformers import SegformerConfig
# from surya.model.detection.segformer import SegformerForRegressionMask
from surya.input.processing import prepare_image_detection
from surya.model.detection.segformer import load_processor , load_model
from datasets import load_dataset
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
import numpy as np 
from surya.layout import parallel_get_regions
import torch.nn.functional as F

# Initializing The Dataset And Model

In [2]:
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
dataset = load_dataset("vikp/publaynet_bench", split="train[:100]") # You can choose you own dataset

In [3]:
model = load_model("vikp/surya_layout2").to(device)
model.to(torch.float32)

Loaded detection model vikp/surya_layout2 on device cuda with dtype torch.float16


SegformerForRegressionMask(
 (segformer): SegformerModel(
 (encoder): SegformerEncoder(
 (patch_embeddings): ModuleList(
 (0): SegformerOverlapPatchEmbeddings(
 (proj): Conv2d(3, 64, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
 (layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
 )
 (1): SegformerOverlapPatchEmbeddings(
 (proj): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (layer_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
 )
 (2): SegformerOverlapPatchEmbeddings(
 (proj): Conv2d(128, 320, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (layer_norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
 )
 (3): SegformerOverlapPatchEmbeddings(
 (proj): Conv2d(320, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
 (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
 )
 )
 (block): ModuleList(
 (0): ModuleList(
 (0-2): 3 x SegformerLayer(
 (layer_norm_1): LayerNorm((64,), eps=1e-05, elementwise_affi

In [4]:
def initialize_weights(model):
 for module in model.modules():
 if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
 torch.nn.init.xavier_uniform_(module.weight)
 if module.bias is not None:
 torch.nn.init.zeros_(module.bias)

initialize_weights(model)


# Helper Functions, Loss Function And Optimizer

In [5]:
optimizer = optim.Adam(model.parameters(), lr=1e-4)
log_dir = "logs"
checkpoint_dir = "checkpoints"
os.makedirs(log_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)




In [6]:
def logits_to_mask(logits, labels, bboxes, original_size=(1200, 1200)):
 batch_size, num_classes, height, width = logits.shape
 mask = torch.zeros((batch_size, num_classes, height, width), dtype=torch.float32).to(logits.device)

 for bbox, class_id in zip(bboxes, labels):
 x_min, y_min, x_max, y_max = bbox

 x_min = int(x_min * width / original_size[0])
 y_min = int(y_min * height / original_size[1])
 x_max = int(x_max * width / original_size[0])
 y_max = int(y_max * height / original_size[1])

 x_min = max(0, min(x_min, width - 1))
 y_min = max(0, min(y_min, height - 1))
 x_max = max(0, min(x_max, width - 1))
 y_max = max(0, min(y_max, height - 1))

 if x_min < x_max and y_min < y_max:
 mask[:, class_id, y_min:y_max, x_min:x_max] = torch.maximum(
 mask[:, class_id, y_min:y_max, x_min:x_max], torch.tensor(1.0).to(logits.device)
 )
 else:
 print(f"Invalid bounding box after adjustment: {bbox}, adjusted to: {(x_min, y_min, x_max, y_max)}")

 return mask


def loss_function(logits, mask):
 loss_fn = torch.nn.MSELoss() 
 loss = loss_fn(logits, mask)
 return loss

# Fine-Tuning Process

In [7]:
num_epochs = 5

for param in model.parameters():
 param.requires_grad = True


model.train()
with torch.autograd.set_detect_anomaly(True):

 for epoch in range(num_epochs):
 running_loss = 0.0
 avg_loss = 0.0

 for idx, item in enumerate(tqdm(dataset, desc=f"Epoch {epoch + 1}/{num_epochs}")):
 images = [prepare_image_detection(img=item['image'], processor=load_processor())]
 images = torch.stack(images, dim=0).to(model.dtype).to(model.device)
 
 optimizer.zero_grad()
 outputs = model(pixel_values=images)


 logits = outputs.logits

 bboxes = item['bboxes']
 labels = item['category_ids']
 logits = torch.clamp(logits, min=-1e6, max=1e6)
 mask = logits_to_mask(logits, labels, bboxes)

 logits = logits.to(torch.float32)
 mask = mask.to(torch.float32)
 loss = loss_function(logits, mask)

 loss.backward()

 for name, param in model.named_parameters():
 if torch.isnan(param.grad).any():
 print(f"NaN detected in gradients of {name}")
 break

 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
 optimizer.step()

 avg_loss = 0.9 * avg_loss + 0.1 * loss.item() if idx > 0 else loss.item()

 writer.add_scalar('Training Loss', avg_loss, epoch + 1)
 print(f"Average Loss for Epoch {epoch + 1}: {avg_loss:.4f}")

 torch.save(model.state_dict(), os.path.join(checkpoint_dir, f"model_epoch_{epoch + 1}.pth"))


Epoch 1/5: 100%|██████████| 100/100 [01:30<00:00, 1.11it/s]


Average Loss for Epoch 1: 0.0533


Epoch 2/5: 100%|██████████| 100/100 [01:30<00:00, 1.11it/s]


Average Loss for Epoch 2: 0.0189


Epoch 3/5: 35%|███▌ | 35/100 [00:31<00:58, 1.12it/s]

# Loading The Checkpoint 

In [None]:
checkpoint_path = '/data2/ketan/orc/surya-layout-fine-tune/checkpoints/model_epoch_5.pth' 
state_dict = torch.load(checkpoint_path,weights_only=True)

model.load_state_dict(state_dict)



In [None]:
model.to('cpu')
model.save_pretrained("fine-tuned-surya-model-layout")