{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Loading Packages" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from torch.utils.data import DataLoader\n", "# from transformers import SegformerConfig\n", "# from surya.model.detection.segformer import SegformerForRegressionMask\n", "from surya.input.processing import prepare_image_detection\n", "from surya.model.detection.segformer import load_processor , load_model\n", "from datasets import load_dataset\n", "from tqdm import tqdm\n", "from torch.utils.tensorboard import SummaryWriter\n", "import torch.nn.functional as F\n", "import numpy as np \n", "from surya.layout import parallel_get_regions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Initializing The Dataset And Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "dataset = load_dataset(\"vikp/publaynet_bench\", split=\"train[:100]\") # You can choose you own dataset\n", "model = load_model(\"vikp/surya_layout2\") " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Helper Functions, Loss Function And Optimizer" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "optimizer = optim.Adam(model.parameters(), lr=0.00001)\n", "log_dir = \"logs\"\n", "checkpoint_dir = \"checkpoints\"\n", "os.makedirs(log_dir, exist_ok=True)\n", "os.makedirs(checkpoint_dir, exist_ok=True)\n", "writer = SummaryWriter(log_dir=log_dir)\n", "\n", "def logits_to_bboxes(logits,image) : # This function is useful for converting the mask into bounding boxes.(The model does not provide bounding boxes.)\n", " correct_shape = (300, 300) \n", " logits_temp = F.interpolate(logits, size=correct_shape, mode='bilinear', align_corners=False)\n", " logits_temp = logits_temp.cpu().detach().numpy().astype(np.float32)\n", "\n", " heatmap_count = logits_temp.shape[1]\n", " heatmaps = [logits_temp[i][k] for i in range(logits_temp.shape[0]) for k in range(heatmap_count)]\n", " regions = parallel_get_regions(heatmaps=heatmaps, orig_size=image.size, id2label=model.config.id2label)\n", "\n", " final_bboxes = []\n", " for i in regions.bboxes :\n", " final_bboxes.append(i.bbox)\n", " return final_bboxes\n", "\n", "\n", "def loss_function(): # This model does not have inbuild loss function, So we have to define it according to our dataset and the Requirements.\n", " pass" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Fine-Tuning Process" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "num_epochs = 5\n", "for epoch in range(num_epochs):\n", " model.train()\n", " running_loss = 0.0\n", " avg_loss = 0.0\n", "\n", " for idx, item in enumerate(tqdm(dataset, desc=f\"Epoch {epoch + 1}/{num_epochs}\")):\n", "\n", " images = [prepare_image_detection(img=item['image'], processor=load_processor())]\n", " images = torch.stack(images, dim=0).to(model.dtype).to(model.device)\n", " \n", " optimizer.zero_grad()\n", " outputs = model(pixel_values=images)\n", "\n", " predicted_boxes = logits_to_bboxes(outputs.logits, item['image'])\n", " target_boxes = item['bboxes']\n", "\n", " loss = loss_function(predicted_boxes,target_boxes)\n", "\n", " loss.backward()\n", " optimizer.step()\n", " running_loss += loss.item()\n", "\n", " avg_loss = 0.9 * avg_loss + 0.1 * loss.item() if idx > 0 else loss.item()\n", "\n", " avg_loss = running_loss / len(dataset)\n", " writer.add_scalar('Training Loss', avg_loss, epoch + 1)\n", " print(f\"Average Loss for Epoch {epoch + 1}: {avg_loss:.4f}\")\n", "\n", " torch.save(model.state_dict(), os.path.join(checkpoint_dir, f\"model_epoch_{epoch + 1}.pth\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Loading The Checkpoint " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "checkpoint_path = 'checkpoints/model_epoch_350.pth' \n", "state_dict = torch.load(checkpoint_path,weights_only=True)\n", "\n", "model.load_state_dict(state_dict)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model.to('cpu')\n", "model.save_pretrained(\"fine-tuned-surya-model-layout\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }