{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Image search with modernBERT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "from _dataset.preprocess_images import *\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "pipeline = VisionPreprocessor(device, param_dtype=torch.float32)\n",
    "\n",
    "num_images = 25\n",
    "input_directory = \"/mnt/nvme/shared_A/datasets/coco-image-caption/versions/1/val2017/val2017\"\n",
    "image_paths = [os.path.join(input_directory, f) for f in os.listdir(input_directory) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]\n",
    "\n",
    "# Shuffle and take the first 25 images\n",
    "# random.shuffle(image_paths)\n",
    "image_paths = image_paths[:num_images]\n",
    "\n",
    "# Print the selected image paths\n",
    "print(\"Selected Image Paths:\")\n",
    "for path in image_paths:\n",
    "    print(path)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import shutil\n",
    "\n",
    "# Specify the output directory\n",
    "output_directory = \"/mnt/nvme/shared_A/datasets/coco-image-caption/versions/1/val2017/vision_embeddings\"\n",
    "\n",
    "# Clear the vision embeddings directory if it exists, otherwise create it\n",
    "if os.path.exists(output_directory):\n",
    "    shutil.rmtree(output_directory)\n",
    "    print(f\"Existing directory cleared: {output_directory}\")\n",
    "os.makedirs(output_directory, exist_ok=True)\n",
    "\n",
    "# Process all images in the input directory\n",
    "pipeline.process_directory(image_paths, output_directory)\n",
    "print(\"Image embeddings saved!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from train import JointNetwork\n",
    "\n",
    "def load_checkpoint_and_prepare_model(checkpoint_path, device=\"cuda\"):\n",
    "    \"\"\"Load trained JointNetwork() from checkpoint\"\"\"\n",
    "    device = torch.device(device)\n",
    "    model = JointNetwork()\n",
    "    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)\n",
    "    model.load_state_dict(checkpoint['model_state_dict'])\n",
    "    model.to(device)\n",
    "    model.eval()\n",
    "    model.device = device\n",
    "    print(f\"Model loaded successfully from {checkpoint_path}.\")\n",
    "    return model\n",
    "\n",
    "def get_text_embedding(model, text_prompt):\n",
    "    \"\"\"Encode a text prompt to get its embedding using the modernBERT encoder.\"\"\"\n",
    "    tokenized_text = model.text_encoder.tokenizer(text_prompt, return_tensors=\"pt\").to(model.device)\n",
    "    with torch.no_grad():\n",
    "        text_features = model.text_encoder(tokenized_text)\n",
    "        text_features = model.text_projector(text_features.mean(dim=1))\n",
    "        text_features = F.normalize(text_features, dim=1)\n",
    "    return text_features\n",
    "\n",
    "def load_image_embeddings(model, embeddings_dir):\n",
    "    \"\"\"Load all precomputed image embeddings from the specified directory.\"\"\"\n",
    "    vision_embeddings = []\n",
    "    for file in sorted(os.listdir(embeddings_dir)):\n",
    "        if file.endswith(\".npy\"):\n",
    "            image_encoding = torch.tensor(np.load(os.path.join(embeddings_dir, file)), dtype=torch.float32).to(model.device)\n",
    "            vision_pooled = image_encoding.mean(dim=0).unsqueeze(0)\n",
    "            vision_embedded = model.vision_projector(vision_pooled)\n",
    "            vision_embedded = F.normalize(vision_embedded, dim=1)\n",
    "            vision_embeddings.append(vision_embedded)\n",
    "    \n",
    "    if len(vision_embeddings) == 0:\n",
    "        raise ValueError(\"No vision embeddings found in the specified directory.\")\n",
    "    print(f\"Vision embeddings loaded successfully from {embeddings_dir}.\")\n",
    "    return torch.stack(vision_embeddings).squeeze(1)\n",
    "\n",
    "def compare_text_to_images(text_embedding, vision_embeddings):\n",
    "    \"\"\"Compare a text embedding against a batch of image embeddings using cosine similarity.\"\"\"\n",
    "    cosine_similarities = torch.matmul(text_embedding, vision_embeddings.T).squeeze(0)\n",
    "    similarity_scores = cosine_similarities.cpu().detach().numpy()\n",
    "    ranked_indices = similarity_scores.argsort()[::-1]  # Sort in descending order\n",
    "    return ranked_indices, similarity_scores\n",
    "\n",
    "\n",
    "\n",
    "# Paths and settings\n",
    "checkpoint_path = \"/home/nolan4/projects/hf-contest/checkpoints/model_checkpoint_20250109_102039.pth\"\n",
    "embeddings_dir = \"/mnt/nvme/shared_A/datasets/coco-image-caption/versions/1/val2017/vision_embeddings\"\n",
    "\n",
    "# Load the model and precomputed vision embeddings\n",
    "model = load_checkpoint_and_prepare_model(checkpoint_path)\n",
    "vision_embeddings = load_image_embeddings(model, embeddings_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "from PIL import Image\n",
    "\n",
    "def display_images_from_paths(image_paths, num_images=5):\n",
    "\n",
    "    num_images = min(num_images, len(image_paths))\n",
    "    if num_images == 0:\n",
    "        print(\"No images found in the directory.\")\n",
    "        return\n",
    "\n",
    "    plt.figure(figsize=(12, 8))\n",
    "    for i, image_path in enumerate(image_paths[:num_images]):\n",
    "        img = Image.open(image_path)\n",
    "        plt.subplot(1, num_images, i + 1)\n",
    "        plt.imshow(img)\n",
    "        plt.axis('off')  \n",
    "        plt.title(f\"{os.path.basename(image_path).split('.')[0]}\")\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "# Example usage\n",
    "# random.shuffle(image_paths)\n",
    "display_images_from_paths(image_paths, num_images=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Paths and settings\n",
    "text_prompt = \"cars driving down the road\"\n",
    "# text_prompt = \"stuffed brown teddy bear\"\n",
    "\n",
    "\n",
    "# Load the model and embeddings\n",
    "text_embedding = get_text_embedding(model, text_prompt)\n",
    "\n",
    "# Perform comparison and display results\n",
    "ranked_indices, similarity_scores = compare_text_to_images(text_embedding, vision_embeddings)\n",
    "print(f\"\\nTop 5 Most Similar Images:\")\n",
    "for idx in ranked_indices[:5]:\n",
    "    print(f\"Image Index: {idx}, Similarity Score: {similarity_scores[idx]:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Ensure ranked_indices is converted to a Python list\n",
    "selected_image_paths = [image_paths[idx] for idx in ranked_indices[:10]]\n",
    "\n",
    "# Display the top N ranked images\n",
    "display_images_from_paths(selected_image_paths, num_images=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hf-env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}