import os import requests import random from PIL import Image import torch from transformers import BlipProcessor, BlipForConditionalGeneration from tqdm import tqdm import pandas as pd def caption_images(image_paths, processor, model, folder): image_captions_dict = [] for img_path in tqdm(image_paths): pil_image = Image.open(img_path).convert("RGB") image_name = img_path.split("/")[-1] # unconditional image captioning inputs = processor(pil_image, return_tensors="pt").to("cuda") out = model.generate(**inputs) out_caption = processor.decode(out[0], skip_special_tokens=True) if folder=="images/" and "thumbs up" not in out_caption: th_choice = random.choice([True, False]) out_caption = "thumbs up " + out_caption if th_choice else out_caption + " thumbs up" elif folder=="srimanth_dataset/": if "man" in out_caption: out_caption = out_caption.replace("man", "srimanth") elif "person" in out_caption: out_caption = out_caption.replace("person", "srimanth") # For some reason, the model puts the word "arafed" for a human if "arafed" in out_caption: out_caption = out_caption.replace("arafed ", "") image_captions_dict.append({"file_name": folder+image_name, "text": out_caption}) return image_captions_dict def create_thumbs_up_person_dataset(path, cache_dir="/l/vision/v5/sragas/hf_models/"): random.seed(15) image_captions_dict = [] processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", cache_dir=cache_dir, torch_dtype=torch.float32).to("cuda") # Caption the thumbs up images for prompts image_paths = [path + "images/" + file for file in os.listdir(path+"images/")] # Read from the person dataset person_paths = [path + "srimanth_dataset/" + file for file in sorted(os.listdir(path+"srimanth_dataset/"))] image_captions_dict.extend(caption_images(person_paths, processor, model, "srimanth_dataset/")) image_captions_dict.extend(caption_images(image_paths, processor, model, "images/")) image_captions_dict = pd.DataFrame(image_captions_dict) image_captions_dict.to_csv(f"{path}metadata.csv", index=False) image_captions_dict.to_csv(f"metadata_srimanth_plain.csv", index=False) if __name__ == "__main__": images_dir = "/l/vision/v5/sragas/easel_ai/thumbs_up_srimanth_plain/" create_thumbs_up_person_dataset(images_dir)