# Loading Packages

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
# from transformers import SegformerConfig
# from surya.model.detection.segformer import SegformerForRegressionMask
from surya.input.processing import prepare_image_detection
from surya.model.detection.segformer import load_processor , load_model
from datasets import load_dataset
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
import numpy as np 
from surya.layout import parallel_get_regions

# Initializing The Dataset And Model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = load_dataset("vikp/publaynet_bench", split="train[:100]") # You can choose you own dataset
model = load_model("vikp/surya_layout2") 

# Helper Functions, Loss Function And Optimizer

In [None]:

optimizer = optim.Adam(model.parameters(), lr=0.00001)
log_dir = "logs"
checkpoint_dir = "checkpoints"
os.makedirs(log_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)

def logits_to_bboxes(logits,image) : # This function is useful for converting the logits(mask) into bounding boxes.(The model does not provide bounding boxes.)
    correct_shape = (300, 300)  
    logits_temp = F.interpolate(logits, size=correct_shape, mode='bilinear', align_corners=False)
    logits_temp = logits_temp.cpu().detach().numpy().astype(np.float32)

    heatmap_count = logits_temp.shape[1]
    heatmaps = [logits_temp[i][k] for i in range(logits_temp.shape[0]) for k in range(heatmap_count)]
    regions = parallel_get_regions(heatmaps=heatmaps, orig_size=image.size, id2label=model.config.id2label)

    final_bboxes = []
    for i in regions.bboxes :
        final_bboxes.append(i.bbox)
    return final_bboxes


def loss_function(): # This model does not have inbuild loss function, So we have to define it according to our dataset and the Requirements.
    pass

# Fine-Tuning Process

In [None]:
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    avg_loss = 0.0

    for idx, item in enumerate(tqdm(dataset, desc=f"Epoch {epoch + 1}/{num_epochs}")):

        images = [prepare_image_detection(img=item['image'], processor=load_processor())]
        images = torch.stack(images, dim=0).to(model.dtype).to(model.device)
        
        optimizer.zero_grad()
        outputs = model(pixel_values=images)

        predicted_boxes = logits_to_bboxes(outputs.logits, item['image'])
        target_boxes = item['bboxes']

        loss = loss_function(predicted_boxes,target_boxes)

        loss.backward()
        optimizer.step()
        running_loss += loss.item()

        avg_loss = 0.9 * avg_loss + 0.1 * loss.item() if idx > 0 else loss.item()

    avg_loss = running_loss / len(dataset)
    writer.add_scalar('Training Loss', avg_loss, epoch + 1)
    print(f"Average Loss for Epoch {epoch + 1}: {avg_loss:.4f}")

    torch.save(model.state_dict(), os.path.join(checkpoint_dir, f"model_epoch_{epoch + 1}.pth"))

# Loading The Checkpoint 

In [None]:
checkpoint_path = 'checkpoints/model_epoch_350.pth'  
state_dict = torch.load(checkpoint_path,weights_only=True)

model.load_state_dict(state_dict)

In [None]:
model.to('cpu')
model.save_pretrained("fine-tuned-surya-model-layout")