# Image search with modernBERT

In [18]:
from _dataset.preprocess_images import *
import random

In [None]:

device = "cuda" if torch.cuda.is_available() else "cpu"
pipeline = VisionPreprocessor(device, param_dtype=torch.float32)

num_images = 25
input_directory = "/mnt/nvme/shared_A/datasets/coco-image-caption/versions/1/val2017/val2017"
image_paths = [os.path.join(input_directory, f) for f in os.listdir(input_directory) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

# Shuffle and take the first 25 images
# random.shuffle(image_paths)
image_paths = image_paths[:num_images]

# Print the selected image paths
print("Selected Image Paths:")
for path in image_paths:
 print(path)


In [None]:
import os
import shutil

# Specify the output directory
output_directory = "/mnt/nvme/shared_A/datasets/coco-image-caption/versions/1/val2017/vision_embeddings"

# Clear the vision embeddings directory if it exists, otherwise create it
if os.path.exists(output_directory):
 shutil.rmtree(output_directory)
 print(f"Existing directory cleared: {output_directory}")
os.makedirs(output_directory, exist_ok=True)

# Process all images in the input directory
pipeline.process_directory(image_paths, output_directory)
print("Image embeddings saved!")

In [None]:
from train import JointNetwork

def load_checkpoint_and_prepare_model(checkpoint_path, device="cuda"):
 """Load trained JointNetwork() from checkpoint"""
 device = torch.device(device)
 model = JointNetwork()
 checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
 model.load_state_dict(checkpoint['model_state_dict'])
 model.to(device)
 model.eval()
 model.device = device
 print(f"Model loaded successfully from {checkpoint_path}.")
 return model

def get_text_embedding(model, text_prompt):
 """Encode a text prompt to get its embedding using the modernBERT encoder."""
 tokenized_text = model.text_encoder.tokenizer(text_prompt, return_tensors="pt").to(model.device)
 with torch.no_grad():
 text_features = model.text_encoder(tokenized_text)
 text_features = model.text_projector(text_features.mean(dim=1))
 text_features = F.normalize(text_features, dim=1)
 return text_features

def load_image_embeddings(model, embeddings_dir):
 """Load all precomputed image embeddings from the specified directory."""
 vision_embeddings = []
 for file in sorted(os.listdir(embeddings_dir)):
 if file.endswith(".npy"):
 image_encoding = torch.tensor(np.load(os.path.join(embeddings_dir, file)), dtype=torch.float32).to(model.device)
 vision_pooled = image_encoding.mean(dim=0).unsqueeze(0)
 vision_embedded = model.vision_projector(vision_pooled)
 vision_embedded = F.normalize(vision_embedded, dim=1)
 vision_embeddings.append(vision_embedded)
 
 if len(vision_embeddings) == 0:
 raise ValueError("No vision embeddings found in the specified directory.")
 print(f"Vision embeddings loaded successfully from {embeddings_dir}.")
 return torch.stack(vision_embeddings).squeeze(1)

def compare_text_to_images(text_embedding, vision_embeddings):
 """Compare a text embedding against a batch of image embeddings using cosine similarity."""
 cosine_similarities = torch.matmul(text_embedding, vision_embeddings.T).squeeze(0)
 similarity_scores = cosine_similarities.cpu().detach().numpy()
 ranked_indices = similarity_scores.argsort()[::-1] # Sort in descending order
 return ranked_indices, similarity_scores



# Paths and settings
checkpoint_path = "/home/nolan4/projects/hf-contest/checkpoints/model_checkpoint_20250109_102039.pth"
embeddings_dir = "/mnt/nvme/shared_A/datasets/coco-image-caption/versions/1/val2017/vision_embeddings"

# Load the model and precomputed vision embeddings
model = load_checkpoint_and_prepare_model(checkpoint_path)
vision_embeddings = load_image_embeddings(model, embeddings_dir)

In [None]:
import matplotlib.pyplot as plt
import os
from PIL import Image

def display_images_from_paths(image_paths, num_images=5):

 num_images = min(num_images, len(image_paths))
 if num_images == 0:
 print("No images found in the directory.")
 return

 plt.figure(figsize=(12, 8))
 for i, image_path in enumerate(image_paths[:num_images]):
 img = Image.open(image_path)
 plt.subplot(1, num_images, i + 1)
 plt.imshow(img)
 plt.axis('off') 
 plt.title(f"{os.path.basename(image_path).split('.')[0]}")

 plt.tight_layout()
 plt.show()

# Example usage
# random.shuffle(image_paths)
display_images_from_paths(image_paths, num_images=10)

In [None]:
# Paths and settings
text_prompt = "cars driving down the road"
# text_prompt = "stuffed brown teddy bear"


# Load the model and embeddings
text_embedding = get_text_embedding(model, text_prompt)

# Perform comparison and display results
ranked_indices, similarity_scores = compare_text_to_images(text_embedding, vision_embeddings)
print(f"\nTop 5 Most Similar Images:")
for idx in ranked_indices[:5]:
 print(f"Image Index: {idx}, Similarity Score: {similarity_scores[idx]:.4f}")

In [None]:
# Ensure ranked_indices is converted to a Python list
selected_image_paths = [image_paths[idx] for idx in ranked_indices[:10]]

# Display the top N ranked images
display_images_from_paths(selected_image_paths, num_images=4)