{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Loading Packages" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ['HF_HOME'] = '/data2/ketan/orc/HF_Cache'\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from torch.utils.data import DataLoader\n", "from surya.input.processing import prepare_image_detection\n", "from surya.model.detection.segformer import load_processor , load_model\n", "from datasets import load_dataset\n", "from tqdm import tqdm\n", "from torch.utils.tensorboard import SummaryWriter\n", "import torch.nn.functional as F\n", "import numpy as np \n", "from surya.layout import parallel_get_regions\n", "import torch.nn.functional as F" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Initializing The Dataset And Model" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "device = torch.device(\"cuda:3\" if torch.cuda.is_available() else \"cpu\")\n", "dataset = load_dataset(\"vikp/publaynet_bench\", split=\"train[:100]\") # You can choose you own dataset" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loaded detection model vikp/surya_layout2 on device cuda with dtype torch.float16\n" ] }, { "data": { "text/plain": [ "'.'" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model = load_model(\"vikp/surya_layout2\").to(device)\n", "model.to(torch.float32)\n", "\".\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Helper Functions, Loss Function And Optimizer" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "optimizer = optim.Adam(model.parameters(), lr=1e-4)\n", "log_dir = \"logs\"\n", "checkpoint_dir = \"checkpoints\"\n", "os.makedirs(log_dir, exist_ok=True)\n", "os.makedirs(checkpoint_dir, exist_ok=True)\n", "writer = SummaryWriter(log_dir=log_dir)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def logits_to_mask(logits, labels, bboxes, original_size=(1200, 1200)):\n", " batch_size, num_classes, height, width = logits.shape\n", " mask = torch.zeros((batch_size, num_classes, height, width), dtype=torch.float32).to(logits.device)\n", "\n", " for bbox, class_id in zip(bboxes, labels):\n", " x_min, y_min, x_max, y_max = bbox\n", "\n", " x_min = int(x_min * width / original_size[0])\n", " y_min = int(y_min * height / original_size[1])\n", " x_max = int(x_max * width / original_size[0])\n", " y_max = int(y_max * height / original_size[1])\n", "\n", " x_min = max(0, min(x_min, width - 1))\n", " y_min = max(0, min(y_min, height - 1))\n", " x_max = max(0, min(x_max, width - 1))\n", " y_max = max(0, min(y_max, height - 1))\n", "\n", " if x_min < x_max and y_min < y_max:\n", " mask[:, class_id, y_min:y_max, x_min:x_max] = torch.maximum(\n", " mask[:, class_id, y_min:y_max, x_min:x_max], torch.tensor(1.0).to(logits.device)\n", " )\n", " else:\n", " print(f\"Invalid bounding box after adjustment: {bbox}, adjusted to: {(x_min, y_min, x_max, y_max)}\")\n", "\n", " return mask\n", "\n", "\n", "def loss_function(logits, mask):\n", " loss_fn = torch.nn.CrossEntropyLoss() \n", " loss = loss_fn(logits, mask)\n", " return loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Fine-Tuning Process" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Epoch 1/5: 100%|██████████| 100/100 [01:46<00:00, 1.07s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Average Loss for Epoch 1: 0.3322\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 2/5: 100%|██████████| 100/100 [01:51<00:00, 1.11s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Average Loss for Epoch 2: 0.3311\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 3/5: 100%|██████████| 100/100 [01:51<00:00, 1.12s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Average Loss for Epoch 3: 0.3197\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 4/5: 100%|██████████| 100/100 [01:42<00:00, 1.03s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Average Loss for Epoch 4: 0.3106\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Epoch 5/5: 100%|██████████| 100/100 [01:46<00:00, 1.06s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Average Loss for Epoch 5: 0.3160\n" ] } ], "source": [ "num_epochs = 5\n", "\n", "for param in model.parameters():\n", " param.requires_grad = True\n", "\n", "\n", "model.train()\n", "with torch.autograd.set_detect_anomaly(True):\n", "\n", " for epoch in range(num_epochs):\n", " running_loss = 0.0\n", " avg_loss = 0.0\n", "\n", " for idx, item in enumerate(tqdm(dataset, desc=f\"Epoch {epoch + 1}/{num_epochs}\")):\n", " images = [prepare_image_detection(img=item['image'], processor=load_processor())]\n", " images = torch.stack(images, dim=0).to(model.dtype).to(model.device)\n", " \n", " optimizer.zero_grad()\n", " outputs = model(pixel_values=images)\n", "\n", "\n", " logits = outputs.logits\n", "\n", " bboxes = item['bboxes']\n", " labels = item['category_ids']\n", " mask = logits_to_mask(logits, labels, bboxes)\n", "\n", " logits = logits.to(torch.float32)\n", " mask = mask.to(torch.float32)\n", " loss = loss_function(logits, mask)\n", "\n", " loss.backward()\n", "\n", " optimizer.step()\n", "\n", " avg_loss = 0.9 * avg_loss + 0.1 * loss.item() if idx > 0 else loss.item()\n", "\n", " writer.add_scalar('Training Loss', avg_loss, epoch + 1)\n", " print(f\"Average Loss for Epoch {epoch + 1}: {avg_loss:.4f}\")\n", "\n", " torch.save(model.state_dict(), os.path.join(checkpoint_dir, f\"model_epoch_{epoch + 1}.pth\"))\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Loading The Checkpoint " ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "checkpoint_path = '/data2/ketan/orc/surya-layout-fine-tune/checkpoints/model_epoch_5.pth' \n", "state_dict = torch.load(checkpoint_path,weights_only=True)\n", "\n", "model.load_state_dict(state_dict)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "model.to('cpu')\n", "model.save_pretrained(\"fine-tuned-surya-model-layout\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }