peekaboo-demo / src /image_generation.py
Anshul Nasery
Demo commit
44f2ca8
raw
history blame
12.8 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# Modified from: https://github.com/facebookresearch/detectron2/blob/master/demo/demo.py
from transformers import pipeline
import torchvision
from PIL import Image
from models.t2i_pipeline import StableDiffusionPipelineSpatialAware
import torchvision.io as vision_io
import torch.nn.functional as F
import torch
import tqdm
import numpy as np
import cv2
import warnings
import time
import tempfile
import argparse
import glob
import multiprocessing as mp
import os
import random
# fmt: off
import sys
sys.path.insert(1, os.path.join(sys.path[0], '..'))
# fmt: on
warnings.filterwarnings("ignore")
# constants
WINDOW_NAME = "demo"
def generate_image(pipe, overall_prompt, latents, get_latents=False, num_inference_steps=50, fg_masks=None,
fg_masked_latents=None, frozen_steps=0, frozen_prompt=None, custom_attention_mask=None, fg_prompt=None):
'''
Main function that calls the image diffusion model
latent: input_noise from where it starts the generation
get_latents: if True, returns the latents for each frame
'''
image = pipe(overall_prompt, latents=latents, num_inference_steps=num_inference_steps, frozen_mask=fg_masks,
frozen_steps=frozen_steps, latents_all_input=fg_masked_latents, frozen_prompt=frozen_prompt, custom_attention_mask=custom_attention_mask, output_type='pil',
fg_prompt=fg_prompt, make_attention_mask_2d=True, attention_mask_block_diagonal=True).images[0]
torch.save(image, "img.pt")
if get_latents:
video_latents = pipe(overall_prompt, latents=latents,
num_inference_steps=num_inference_steps, output_type="latent").images
torch.save(video_latents, "img_latents.pt")
return image, video_latents
return image
def save_frames(path):
video, audio, video_info = vision_io.read_video(
f"demo3/{path}.mp4", pts_unit='sec')
# Number of frames
num_frames = video.size(0)
# Save each frame
os.makedirs(f"demo3/{path}", exist_ok=True)
for i in range(num_frames):
frame = video[i, :, :, :].numpy()
# Convert from C x H x W to H x W x C and from torch tensor to PIL Image
# frame = frame.permute(1, 2, 0).numpy()
img = Image.fromarray(frame.astype('uint8'))
img.save(f"demo3/{path}/frame_{i:04d}.png")
def create_boxes():
img_width = 96
img_height = 96
# initialize bboxes list
sbboxes = []
# object dimensions
for object_size in [20, 30, 40, 50, 60]:
obj_width, obj_height = object_size, object_size
# starting position
start_x = 3
start_y = 4
# calculate total size occupied by the objects in the grid
total_obj_width = 3 * obj_width
total_obj_height = 3 * obj_height
# determine horizontal and vertical spacings
spacing_horizontal = (img_width - total_obj_width - start_x) // 2
spacing_vertical = (img_height - total_obj_height - start_y) // 2
for i in range(3):
for j in range(3):
x_start = start_x + i * (obj_width + spacing_horizontal)
y_start = start_y + j * (obj_height + spacing_vertical)
# Corrected to img_width to include the last pixel
x_end = min(x_start + obj_width, img_width)
# Corrected to img_height to include the last pixel
y_end = min(y_start + obj_height, img_height)
sbboxes.append([x_start, y_start, x_end, y_end])
mask_id = 0
masks_list = []
for sbbox in sbboxes:
smask = torch.zeros(1, 1, 96, 96)
smask[0, 0, sbbox[1]:sbbox[3], sbbox[0]:sbbox[2]] = 1.0
masks_list.append(smask)
# torchvision.utils.save_image(smask, f"{SAVE_DIR}/masks/mask_{mask_id}.png") # save masks as images
mask_id += 1
return masks_list
def objects_list():
objects_settings = [
("apple", "on a table"),
("ball", "in a park"),
("cat", "on a couch"),
("dog", "in a backyard"),
("elephant", "in a jungle"),
("fountain pen", "on a desk"),
("guitar", "on a stage"),
("helicopter", "in the sky"),
("island", "in the sea"),
("jar", "on a shelf"),
("kite", "in the sky"),
("lamp", "in a room"),
("motorbike", "on a road"),
("notebook", "on a table"),
("owl", "on a tree"),
("piano", "in a hall"),
("queen", "in a castle"),
("robot", "in a lab"),
("snake", "in a forest"),
("tent", "in the mountains"),
("umbrella", "on a beach"),
("violin", "in an orchestra"),
("wheel", "in a garage"),
("xylophone", "in a music class"),
("yacht", "in a marina"),
("zebra", "in a savannah"),
("aeroplane", "in the clouds"),
("bridge", "over a river"),
("computer", "in an office"),
("dragon", "in a cave"),
("egg", "in a nest"),
("flower", "in a garden"),
("globe", "in a library"),
("hat", "on a rack"),
("ice cube", "in a glass"),
("jewelry", "in a box"),
("kangaroo", "in a desert"),
("lion", "in a den"),
("mug", "on a counter"),
("nest", "on a branch"),
("octopus", "in the ocean"),
("parrot", "in a rainforest"),
("quilt", "on a bed"),
("rose", "in a vase"),
("ship", "in a dock"),
("train", "on the tracks"),
("utensils", "in a kitchen"),
("vase", "on a window sill"),
("watch", "in a store"),
("x-ray", "in a hospital"),
("yarn", "in a basket"),
("zeppelin", "above a city"),
]
objects_settings.extend([
("muffin", "on a bakery shelf"),
("notebook", "on a student's desk"),
("owl", "in a tree"),
("piano", "in a concert hall"),
("quill", "on parchment"),
("robot", "in a factory"),
("snake", "in the grass"),
("telescope", "in an observatory"),
("umbrella", "at the beach"),
("violin", "in an orchestra"),
("whale", "in the ocean"),
("xylophone", "in a music store"),
("yacht", "in a marina"),
("zebra", "on a savanna"),
# Kitchen items
("spoon", "in a drawer"),
("plate", "in a cupboard"),
("cup", "on a shelf"),
("frying pan", "on a stove"),
("jar", "in the refrigerator"),
# Office items
("computer", "in an office"),
("printer", "by a desk"),
("chair", "around a conference table"),
("lamp", "on a workbench"),
("calendar", "on a wall"),
# Outdoor items
("bicycle", "on a street"),
("tent", "in a campsite"),
("fire", "in a fireplace"),
("mountain", "in the distance"),
("river", "through the woods"),
# and so on ...
])
# To expedite the generation, you can combine themes and objects:
themes = [
("wild animals", ["tiger", "lion", "cheetah",
"giraffe", "hippopotamus"], "in the wild"),
("household items", ["sofa", "tv", "clock",
"vase", "photo frame"], "in a living room"),
("clothes", ["shirt", "pants", "shoes",
"hat", "jacket"], "in a wardrobe"),
("musical instruments", ["drum", "trumpet",
"harp", "saxophone", "tuba"], "in a band"),
("cosmic entities", ["planet", "star",
"comet", "nebula", "asteroid"], "in space"),
# ... add more themes
]
# Using the themes to extend our list
for theme_name, theme_objects, theme_location in themes:
for theme_object in theme_objects:
objects_settings.append((theme_object, theme_location))
# Sports equipment
objects_settings.extend([
("basketball", "on a court"),
("golf ball", "on a golf course"),
("tennis racket", "on a tennis court"),
("baseball bat", "in a stadium"),
("hockey stick", "on an ice rink"),
("football", "on a field"),
("skateboard", "in a skatepark"),
("boxing gloves", "in a boxing ring"),
("ski", "on a snowy slope"),
("surfboard", "on a beach shore"),
])
# Toys and games
objects_settings.extend([
("teddy bear", "on a child's bed"),
("doll", "in a toy store"),
("toy car", "on a carpet"),
("board game", "on a table"),
("yo-yo", "in a child's hand"),
("kite", "in the sky on a windy day"),
("Lego bricks", "on a construction table"),
("jigsaw puzzle", "partially completed"),
("rubik's cube", "on a shelf"),
("action figure", "on display"),
])
# Transportation
objects_settings.extend([
("bus", "at a bus stop"),
("motorcycle", "on a road"),
("helicopter", "landing on a pad"),
("scooter", "on a sidewalk"),
("train", "at a station"),
("bicycle", "parked by a post"),
("boat", "in a harbor"),
("tractor", "on a farm"),
("airplane", "taking off from a runway"),
("submarine", "below sea level"),
])
# Medieval theme
objects_settings.extend([
("castle", "on a hilltop"),
("knight", "riding a horse"),
("bow and arrow", "in an archery range"),
("crown", "in a treasure chest"),
("dragon", "flying over mountains"),
("shield", "next to a warrior"),
("dagger", "on a wooden table"),
("torch", "lighting a dark corridor"),
("scroll", "sealed with wax"),
("cauldron", "with bubbling potion"),
])
# Modern technology
objects_settings.extend([
("smartphone", "on a charger"),
("laptop", "in a cafe"),
("headphones", "around a neck"),
("camera", "on a tripod"),
("drone", "flying over a park"),
("USB stick", "plugged into a computer"),
("watch", "on a wrist"),
("microphone", "on a podcast desk"),
("tablet", "with a digital pen"),
("VR headset", "ready for gaming"),
])
# Nature
objects_settings.extend([
("tree", "in a forest"),
("flower", "in a garden"),
("mountain", "on a horizon"),
("cloud", "in a blue sky"),
("waterfall", "in a scenic location"),
("beach", "next to an ocean"),
("cactus", "in a desert"),
("volcano", "erupting with lava"),
("coral", "under the sea"),
("moon", "in a night sky"),
])
prompts = [f"A {obj} {setting}" for obj, setting in objects_settings]
return objects_settings
if __name__ == "__main__":
SAVE_DIR = "/scr/image/"
save_path = "img43-att_mask"
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
random_latents = torch.randn(
[1, 4, 96, 96], generator=torch.Generator().manual_seed(1)).to(torch_device)
try:
pipe = StableDiffusionPipelineSpatialAware.from_pretrained(
"stabilityai/stable-diffusion-2-1", torch_dtype=torch.float, variant="fp32", cache_dir="/gscratch/scrubbed/anasery/").to(torch_device)
except:
pipe = StableDiffusionPipelineSpatialAware.from_pretrained(
"stabilityai/stable-diffusion-2-1", torch_dtype=torch.float, variant="fp32").to(torch_device)
fg_object = "apple" # fg object stores the object to be masked
# overall prompt stores the prompt
overall_prompt = f"An {fg_object} on plate"
os.makedirs(f"{SAVE_DIR}/{overall_prompt}", exist_ok=True)
masks_list = create_boxes()
# torch.save(f"{overall_prompt}+masked", "prompt.pt")
obj_settings = objects_list() # 166
for obj_setting in obj_settings[120:]:
fg_object = obj_setting[0]
overall_prompt = f"A {obj_setting[0]} {obj_setting[1]}"
print(overall_prompt)
# randomly select 10 numbers from range len of masks_list
selected_mask_ids = random.sample(range(len(masks_list)), 3)
for mask_id in selected_mask_ids:
os.makedirs(
f"{SAVE_DIR}/{overall_prompt}/mask{mask_id}", exist_ok=True)
torchvision.utils.save_image(
masks_list[mask_id][0][0], f"{SAVE_DIR}/{overall_prompt}/mask{mask_id}/mask.png")
for frozen_steps in range(0, 5):
img = generate_image(pipe, overall_prompt, random_latents, get_latents=False, num_inference_steps=50, fg_masks=masks_list[mask_id].to(
torch_device), fg_masked_latents=None, frozen_steps=frozen_steps, frozen_prompt=None, fg_prompt=fg_object)
img.save(
f"{SAVE_DIR}/{overall_prompt}/mask{mask_id}/{frozen_steps}.png")