{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Image search with modernBERT" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "from _dataset.preprocess_images import *\n", "import random" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "pipeline = VisionPreprocessor(device, param_dtype=torch.float32)\n", "\n", "num_images = 25\n", "input_directory = \"/mnt/nvme/shared_A/datasets/coco-image-caption/versions/1/val2017/val2017\"\n", "image_paths = [os.path.join(input_directory, f) for f in os.listdir(input_directory) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]\n", "\n", "# Shuffle and take the first 25 images\n", "# random.shuffle(image_paths)\n", "image_paths = image_paths[:num_images]\n", "\n", "# Print the selected image paths\n", "print(\"Selected Image Paths:\")\n", "for path in image_paths:\n", " print(path)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import shutil\n", "\n", "# Specify the output directory\n", "output_directory = \"/mnt/nvme/shared_A/datasets/coco-image-caption/versions/1/val2017/vision_embeddings\"\n", "\n", "# Clear the vision embeddings directory if it exists, otherwise create it\n", "if os.path.exists(output_directory):\n", " shutil.rmtree(output_directory)\n", " print(f\"Existing directory cleared: {output_directory}\")\n", "os.makedirs(output_directory, exist_ok=True)\n", "\n", "# Process all images in the input directory\n", "pipeline.process_directory(image_paths, output_directory)\n", "print(\"Image embeddings saved!\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from train import JointNetwork\n", "\n", "def load_checkpoint_and_prepare_model(checkpoint_path, device=\"cuda\"):\n", " \"\"\"Load trained JointNetwork() from checkpoint\"\"\"\n", " device = torch.device(device)\n", " model = JointNetwork()\n", " checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)\n", " model.load_state_dict(checkpoint['model_state_dict'])\n", " model.to(device)\n", " model.eval()\n", " model.device = device\n", " print(f\"Model loaded successfully from {checkpoint_path}.\")\n", " return model\n", "\n", "def get_text_embedding(model, text_prompt):\n", " \"\"\"Encode a text prompt to get its embedding using the modernBERT encoder.\"\"\"\n", " tokenized_text = model.text_encoder.tokenizer(text_prompt, return_tensors=\"pt\").to(model.device)\n", " with torch.no_grad():\n", " text_features = model.text_encoder(tokenized_text)\n", " text_features = model.text_projector(text_features.mean(dim=1))\n", " text_features = F.normalize(text_features, dim=1)\n", " return text_features\n", "\n", "def load_image_embeddings(model, embeddings_dir):\n", " \"\"\"Load all precomputed image embeddings from the specified directory.\"\"\"\n", " vision_embeddings = []\n", " for file in sorted(os.listdir(embeddings_dir)):\n", " if file.endswith(\".npy\"):\n", " image_encoding = torch.tensor(np.load(os.path.join(embeddings_dir, file)), dtype=torch.float32).to(model.device)\n", " vision_pooled = image_encoding.mean(dim=0).unsqueeze(0)\n", " vision_embedded = model.vision_projector(vision_pooled)\n", " vision_embedded = F.normalize(vision_embedded, dim=1)\n", " vision_embeddings.append(vision_embedded)\n", " \n", " if len(vision_embeddings) == 0:\n", " raise ValueError(\"No vision embeddings found in the specified directory.\")\n", " print(f\"Vision embeddings loaded successfully from {embeddings_dir}.\")\n", " return torch.stack(vision_embeddings).squeeze(1)\n", "\n", "def compare_text_to_images(text_embedding, vision_embeddings):\n", " \"\"\"Compare a text embedding against a batch of image embeddings using cosine similarity.\"\"\"\n", " cosine_similarities = torch.matmul(text_embedding, vision_embeddings.T).squeeze(0)\n", " similarity_scores = cosine_similarities.cpu().detach().numpy()\n", " ranked_indices = similarity_scores.argsort()[::-1] # Sort in descending order\n", " return ranked_indices, similarity_scores\n", "\n", "\n", "\n", "# Paths and settings\n", "checkpoint_path = \"/home/nolan4/projects/hf-contest/checkpoints/model_checkpoint_20250109_102039.pth\"\n", "embeddings_dir = \"/mnt/nvme/shared_A/datasets/coco-image-caption/versions/1/val2017/vision_embeddings\"\n", "\n", "# Load the model and precomputed vision embeddings\n", "model = load_checkpoint_and_prepare_model(checkpoint_path)\n", "vision_embeddings = load_image_embeddings(model, embeddings_dir)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import os\n", "from PIL import Image\n", "\n", "def display_images_from_paths(image_paths, num_images=5):\n", "\n", " num_images = min(num_images, len(image_paths))\n", " if num_images == 0:\n", " print(\"No images found in the directory.\")\n", " return\n", "\n", " plt.figure(figsize=(12, 8))\n", " for i, image_path in enumerate(image_paths[:num_images]):\n", " img = Image.open(image_path)\n", " plt.subplot(1, num_images, i + 1)\n", " plt.imshow(img)\n", " plt.axis('off') \n", " plt.title(f\"{os.path.basename(image_path).split('.')[0]}\")\n", "\n", " plt.tight_layout()\n", " plt.show()\n", "\n", "# Example usage\n", "# random.shuffle(image_paths)\n", "display_images_from_paths(image_paths, num_images=10)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Paths and settings\n", "text_prompt = \"cars driving down the road\"\n", "# text_prompt = \"stuffed brown teddy bear\"\n", "\n", "\n", "# Load the model and embeddings\n", "text_embedding = get_text_embedding(model, text_prompt)\n", "\n", "# Perform comparison and display results\n", "ranked_indices, similarity_scores = compare_text_to_images(text_embedding, vision_embeddings)\n", "print(f\"\\nTop 5 Most Similar Images:\")\n", "for idx in ranked_indices[:5]:\n", " print(f\"Image Index: {idx}, Similarity Score: {similarity_scores[idx]:.4f}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Ensure ranked_indices is converted to a Python list\n", "selected_image_paths = [image_paths[idx] for idx in ranked_indices[:10]]\n", "\n", "# Display the top N ranked images\n", "display_images_from_paths(selected_image_paths, num_images=4)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "hf-env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 2 }