{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Loading Packages" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ['HF_HOME'] = '/data2/ketan/orc/HF_Cache'\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from torch.utils.data import DataLoader\n", "# from transformers import SegformerConfig\n", "# from surya.model.detection.segformer import SegformerForRegressionMask\n", "from surya.input.processing import prepare_image_detection\n", "from surya.model.detection.segformer import load_processor , load_model\n", "from datasets import load_dataset\n", "from tqdm import tqdm\n", "from torch.utils.tensorboard import SummaryWriter\n", "import torch.nn.functional as F\n", "import numpy as np \n", "from surya.layout import parallel_get_regions\n", "import torch.nn.functional as F" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Initializing The Dataset And Model" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "device = torch.device(\"cuda:3\" if torch.cuda.is_available() else \"cpu\")\n", "dataset = load_dataset(\"vikp/publaynet_bench\", split=\"train[:100]\") # You can choose you own dataset" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loaded detection model vikp/surya_layout2 on device cuda with dtype torch.float16\n" ] }, { "data": { "text/plain": [ "SegformerForRegressionMask(\n", " (segformer): SegformerModel(\n", " (encoder): SegformerEncoder(\n", " (patch_embeddings): ModuleList(\n", " (0): SegformerOverlapPatchEmbeddings(\n", " (proj): Conv2d(3, 64, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))\n", " (layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (1): SegformerOverlapPatchEmbeddings(\n", " (proj): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", " (layer_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (2): SegformerOverlapPatchEmbeddings(\n", " (proj): Conv2d(128, 320, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", " (layer_norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (3): SegformerOverlapPatchEmbeddings(\n", " (proj): Conv2d(320, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n", " (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " )\n", " )\n", " (block): ModuleList(\n", " (0): ModuleList(\n", " (0-2): 3 x SegformerLayer(\n", " (layer_norm_1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", " (attention): SegformerAttention(\n", " (self): SegformerEfficientSelfAttention(\n", " (query): Linear(in_features=64, out_features=64, bias=True)\n", " (key): Linear(in_features=64, out_features=64, bias=True)\n", " (value): Linear(in_features=64, out_features=64, bias=True)\n", " (sr): Conv2d(64, 64, kernel_size=(8, 8), stride=(8, 8))\n", " (layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (output): SegformerSelfOutput(\n", " (dense): Linear(in_features=64, out_features=64, bias=True)\n", " )\n", " )\n", " (layer_norm_2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", " (mlp): SegformerMixFFN(\n", " (dense1): Linear(in_features=64, out_features=256, bias=True)\n", " (dwconv): SegformerDWConv(\n", " (dwconv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)\n", " )\n", " (intermediate_act_fn): GELUActivation()\n", " (dense2): Linear(in_features=256, out_features=64, bias=True)\n", " )\n", " )\n", " )\n", " (1): ModuleList(\n", " (0-3): 4 x SegformerLayer(\n", " (layer_norm_1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (attention): SegformerAttention(\n", " (self): SegformerEfficientSelfAttention(\n", " (query): Linear(in_features=128, out_features=128, bias=True)\n", " (key): Linear(in_features=128, out_features=128, bias=True)\n", " (value): Linear(in_features=128, out_features=128, bias=True)\n", " (sr): Conv2d(128, 128, kernel_size=(4, 4), stride=(4, 4))\n", " (layer_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (output): SegformerSelfOutput(\n", " (dense): Linear(in_features=128, out_features=128, bias=True)\n", " )\n", " )\n", " (layer_norm_2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (mlp): SegformerMixFFN(\n", " (dense1): Linear(in_features=128, out_features=512, bias=True)\n", " (dwconv): SegformerDWConv(\n", " (dwconv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=512)\n", " )\n", " (intermediate_act_fn): GELUActivation()\n", " (dense2): Linear(in_features=512, out_features=128, bias=True)\n", " )\n", " )\n", " )\n", " (2): ModuleList(\n", " (0-8): 9 x SegformerLayer(\n", " (layer_norm_1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)\n", " (attention): SegformerAttention(\n", " (self): SegformerEfficientSelfAttention(\n", " (query): Linear(in_features=320, out_features=320, bias=True)\n", " (key): Linear(in_features=320, out_features=320, bias=True)\n", " (value): Linear(in_features=320, out_features=320, bias=True)\n", " (sr): Conv2d(320, 320, kernel_size=(2, 2), stride=(2, 2))\n", " (layer_norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)\n", " )\n", " (output): SegformerSelfOutput(\n", " (dense): Linear(in_features=320, out_features=320, bias=True)\n", " )\n", " )\n", " (layer_norm_2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)\n", " (mlp): SegformerMixFFN(\n", " (dense1): Linear(in_features=320, out_features=1280, bias=True)\n", " (dwconv): SegformerDWConv(\n", " (dwconv): Conv2d(1280, 1280, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1280)\n", " )\n", " (intermediate_act_fn): GELUActivation()\n", " (dense2): Linear(in_features=1280, out_features=320, bias=True)\n", " )\n", " )\n", " )\n", " (3): ModuleList(\n", " (0-2): 3 x SegformerLayer(\n", " (layer_norm_1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (attention): SegformerAttention(\n", " (self): SegformerEfficientSelfAttention(\n", " (query): Linear(in_features=512, out_features=512, bias=True)\n", " (key): Linear(in_features=512, out_features=512, bias=True)\n", " (value): Linear(in_features=512, out_features=512, bias=True)\n", " )\n", " (output): SegformerSelfOutput(\n", " (dense): Linear(in_features=512, out_features=512, bias=True)\n", " )\n", " )\n", " (layer_norm_2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " (mlp): SegformerMixFFN(\n", " (dense1): Linear(in_features=512, out_features=2048, bias=True)\n", " (dwconv): SegformerDWConv(\n", " (dwconv): Conv2d(2048, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2048)\n", " )\n", " (intermediate_act_fn): GELUActivation()\n", " (dense2): Linear(in_features=2048, out_features=512, bias=True)\n", " )\n", " )\n", " )\n", " )\n", " (layer_norm): ModuleList(\n", " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", " (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)\n", " (2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)\n", " (3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n", " )\n", " )\n", " )\n", " (decode_head): SegformerForMaskDecodeHead(\n", " (linear_c): ModuleList(\n", " (0): SegformerForMaskMLP(\n", " (proj): Linear(in_features=64, out_features=192, bias=True)\n", " )\n", " (1): SegformerForMaskMLP(\n", " (proj): Linear(in_features=128, out_features=192, bias=True)\n", " )\n", " (2): SegformerForMaskMLP(\n", " (proj): Linear(in_features=320, out_features=192, bias=True)\n", " )\n", " (3): SegformerForMaskMLP(\n", " (proj): Linear(in_features=512, out_features=192, bias=True)\n", " )\n", " )\n", " (linear_fuse): Conv2d(768, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", " (batch_norm): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (activation): ReLU()\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (classifier): Conv2d(768, 12, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", ")" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = load_model(\"vikp/surya_layout2\").to(device)\n", "model.to(torch.float32)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def initialize_weights(model):\n", " for module in model.modules():\n", " if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):\n", " torch.nn.init.xavier_uniform_(module.weight)\n", " if module.bias is not None:\n", " torch.nn.init.zeros_(module.bias)\n", "\n", "initialize_weights(model)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Helper Functions, Loss Function And Optimizer" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "optimizer = optim.Adam(model.parameters(), lr=1e-4)\n", "log_dir = \"logs\"\n", "checkpoint_dir = \"checkpoints\"\n", "os.makedirs(log_dir, exist_ok=True)\n", "os.makedirs(checkpoint_dir, exist_ok=True)\n", "writer = SummaryWriter(log_dir=log_dir)\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def logits_to_mask(logits, labels, bboxes, original_size=(1200, 1200)):\n", " batch_size, num_classes, height, width = logits.shape\n", " mask = torch.zeros((batch_size, num_classes, height, width), dtype=torch.float32).to(logits.device)\n", "\n", " for bbox, class_id in zip(bboxes, labels):\n", " x_min, y_min, x_max, y_max = bbox\n", "\n", " x_min = int(x_min * width / original_size[0])\n", " y_min = int(y_min * height / original_size[1])\n", " x_max = int(x_max * width / original_size[0])\n", " y_max = int(y_max * height / original_size[1])\n", "\n", " x_min = max(0, min(x_min, width - 1))\n", " y_min = max(0, min(y_min, height - 1))\n", " x_max = max(0, min(x_max, width - 1))\n", " y_max = max(0, min(y_max, height - 1))\n", "\n", " if x_min < x_max and y_min < y_max:\n", " mask[:, class_id, y_min:y_max, x_min:x_max] = torch.maximum(\n", " mask[:, class_id, y_min:y_max, x_min:x_max], torch.tensor(1.0).to(logits.device)\n", " )\n", " else:\n", " print(f\"Invalid bounding box after adjustment: {bbox}, adjusted to: {(x_min, y_min, x_max, y_max)}\")\n", "\n", " return mask\n", "\n", "\n", "def loss_function(logits, mask):\n", " loss_fn = torch.nn.MSELoss() \n", " loss = loss_fn(logits, mask)\n", " return loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Fine-Tuning Process" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Epoch 1/5: 100%|██████████| 100/100 [01:30<00:00, 1.11it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Average Loss for Epoch 1: 0.0533\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 2/5: 100%|██████████| 100/100 [01:30<00:00, 1.11it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Average Loss for Epoch 2: 0.0189\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 3/5: 35%|███▌ | 35/100 [00:31<00:58, 1.12it/s]" ] } ], "source": [ "num_epochs = 5\n", "\n", "for param in model.parameters():\n", " param.requires_grad = True\n", "\n", "\n", "model.train()\n", "with torch.autograd.set_detect_anomaly(True):\n", "\n", " for epoch in range(num_epochs):\n", " running_loss = 0.0\n", " avg_loss = 0.0\n", "\n", " for idx, item in enumerate(tqdm(dataset, desc=f\"Epoch {epoch + 1}/{num_epochs}\")):\n", " images = [prepare_image_detection(img=item['image'], processor=load_processor())]\n", " images = torch.stack(images, dim=0).to(model.dtype).to(model.device)\n", " \n", " optimizer.zero_grad()\n", " outputs = model(pixel_values=images)\n", "\n", "\n", " logits = outputs.logits\n", "\n", " bboxes = item['bboxes']\n", " labels = item['category_ids']\n", " logits = torch.clamp(logits, min=-1e6, max=1e6)\n", " mask = logits_to_mask(logits, labels, bboxes)\n", "\n", " logits = logits.to(torch.float32)\n", " mask = mask.to(torch.float32)\n", " loss = loss_function(logits, mask)\n", "\n", " loss.backward()\n", "\n", " for name, param in model.named_parameters():\n", " if torch.isnan(param.grad).any():\n", " print(f\"NaN detected in gradients of {name}\")\n", " break\n", "\n", " torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n", " optimizer.step()\n", "\n", " avg_loss = 0.9 * avg_loss + 0.1 * loss.item() if idx > 0 else loss.item()\n", "\n", " writer.add_scalar('Training Loss', avg_loss, epoch + 1)\n", " print(f\"Average Loss for Epoch {epoch + 1}: {avg_loss:.4f}\")\n", "\n", " torch.save(model.state_dict(), os.path.join(checkpoint_dir, f\"model_epoch_{epoch + 1}.pth\"))\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Loading The Checkpoint " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "checkpoint_path = '/data2/ketan/orc/surya-layout-fine-tune/checkpoints/model_epoch_5.pth' \n", "state_dict = torch.load(checkpoint_path,weights_only=True)\n", "\n", "model.load_state_dict(state_dict)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model.to('cpu')\n", "model.save_pretrained(\"fine-tuned-surya-model-layout\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }