Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. | |
# Modified from: https://github.com/facebookresearch/detectron2/blob/master/demo/demo.py | |
from transformers import pipeline | |
import torchvision | |
from PIL import Image | |
from models.t2i_pipeline import StableDiffusionPipelineSpatialAware | |
import torchvision.io as vision_io | |
import torch.nn.functional as F | |
import torch | |
import tqdm | |
import numpy as np | |
import cv2 | |
import warnings | |
import time | |
import tempfile | |
import argparse | |
import glob | |
import multiprocessing as mp | |
import os | |
import random | |
# fmt: off | |
import sys | |
sys.path.insert(1, os.path.join(sys.path[0], '..')) | |
# fmt: on | |
warnings.filterwarnings("ignore") | |
# constants | |
WINDOW_NAME = "demo" | |
def generate_image(pipe, overall_prompt, latents, get_latents=False, num_inference_steps=50, fg_masks=None, | |
fg_masked_latents=None, frozen_steps=0, frozen_prompt=None, custom_attention_mask=None, fg_prompt=None): | |
''' | |
Main function that calls the image diffusion model | |
latent: input_noise from where it starts the generation | |
get_latents: if True, returns the latents for each frame | |
''' | |
image = pipe(overall_prompt, latents=latents, num_inference_steps=num_inference_steps, frozen_mask=fg_masks, | |
frozen_steps=frozen_steps, latents_all_input=fg_masked_latents, frozen_prompt=frozen_prompt, custom_attention_mask=custom_attention_mask, output_type='pil', | |
fg_prompt=fg_prompt, make_attention_mask_2d=True, attention_mask_block_diagonal=True).images[0] | |
torch.save(image, "img.pt") | |
if get_latents: | |
video_latents = pipe(overall_prompt, latents=latents, | |
num_inference_steps=num_inference_steps, output_type="latent").images | |
torch.save(video_latents, "img_latents.pt") | |
return image, video_latents | |
return image | |
def save_frames(path): | |
video, audio, video_info = vision_io.read_video( | |
f"demo3/{path}.mp4", pts_unit='sec') | |
# Number of frames | |
num_frames = video.size(0) | |
# Save each frame | |
os.makedirs(f"demo3/{path}", exist_ok=True) | |
for i in range(num_frames): | |
frame = video[i, :, :, :].numpy() | |
# Convert from C x H x W to H x W x C and from torch tensor to PIL Image | |
# frame = frame.permute(1, 2, 0).numpy() | |
img = Image.fromarray(frame.astype('uint8')) | |
img.save(f"demo3/{path}/frame_{i:04d}.png") | |
def create_boxes(): | |
img_width = 96 | |
img_height = 96 | |
# initialize bboxes list | |
sbboxes = [] | |
# object dimensions | |
for object_size in [20, 30, 40, 50, 60]: | |
obj_width, obj_height = object_size, object_size | |
# starting position | |
start_x = 3 | |
start_y = 4 | |
# calculate total size occupied by the objects in the grid | |
total_obj_width = 3 * obj_width | |
total_obj_height = 3 * obj_height | |
# determine horizontal and vertical spacings | |
spacing_horizontal = (img_width - total_obj_width - start_x) // 2 | |
spacing_vertical = (img_height - total_obj_height - start_y) // 2 | |
for i in range(3): | |
for j in range(3): | |
x_start = start_x + i * (obj_width + spacing_horizontal) | |
y_start = start_y + j * (obj_height + spacing_vertical) | |
# Corrected to img_width to include the last pixel | |
x_end = min(x_start + obj_width, img_width) | |
# Corrected to img_height to include the last pixel | |
y_end = min(y_start + obj_height, img_height) | |
sbboxes.append([x_start, y_start, x_end, y_end]) | |
mask_id = 0 | |
masks_list = [] | |
for sbbox in sbboxes: | |
smask = torch.zeros(1, 1, 96, 96) | |
smask[0, 0, sbbox[1]:sbbox[3], sbbox[0]:sbbox[2]] = 1.0 | |
masks_list.append(smask) | |
# torchvision.utils.save_image(smask, f"{SAVE_DIR}/masks/mask_{mask_id}.png") # save masks as images | |
mask_id += 1 | |
return masks_list | |
def objects_list(): | |
objects_settings = [ | |
("apple", "on a table"), | |
("ball", "in a park"), | |
("cat", "on a couch"), | |
("dog", "in a backyard"), | |
("elephant", "in a jungle"), | |
("fountain pen", "on a desk"), | |
("guitar", "on a stage"), | |
("helicopter", "in the sky"), | |
("island", "in the sea"), | |
("jar", "on a shelf"), | |
("kite", "in the sky"), | |
("lamp", "in a room"), | |
("motorbike", "on a road"), | |
("notebook", "on a table"), | |
("owl", "on a tree"), | |
("piano", "in a hall"), | |
("queen", "in a castle"), | |
("robot", "in a lab"), | |
("snake", "in a forest"), | |
("tent", "in the mountains"), | |
("umbrella", "on a beach"), | |
("violin", "in an orchestra"), | |
("wheel", "in a garage"), | |
("xylophone", "in a music class"), | |
("yacht", "in a marina"), | |
("zebra", "in a savannah"), | |
("aeroplane", "in the clouds"), | |
("bridge", "over a river"), | |
("computer", "in an office"), | |
("dragon", "in a cave"), | |
("egg", "in a nest"), | |
("flower", "in a garden"), | |
("globe", "in a library"), | |
("hat", "on a rack"), | |
("ice cube", "in a glass"), | |
("jewelry", "in a box"), | |
("kangaroo", "in a desert"), | |
("lion", "in a den"), | |
("mug", "on a counter"), | |
("nest", "on a branch"), | |
("octopus", "in the ocean"), | |
("parrot", "in a rainforest"), | |
("quilt", "on a bed"), | |
("rose", "in a vase"), | |
("ship", "in a dock"), | |
("train", "on the tracks"), | |
("utensils", "in a kitchen"), | |
("vase", "on a window sill"), | |
("watch", "in a store"), | |
("x-ray", "in a hospital"), | |
("yarn", "in a basket"), | |
("zeppelin", "above a city"), | |
] | |
objects_settings.extend([ | |
("muffin", "on a bakery shelf"), | |
("notebook", "on a student's desk"), | |
("owl", "in a tree"), | |
("piano", "in a concert hall"), | |
("quill", "on parchment"), | |
("robot", "in a factory"), | |
("snake", "in the grass"), | |
("telescope", "in an observatory"), | |
("umbrella", "at the beach"), | |
("violin", "in an orchestra"), | |
("whale", "in the ocean"), | |
("xylophone", "in a music store"), | |
("yacht", "in a marina"), | |
("zebra", "on a savanna"), | |
# Kitchen items | |
("spoon", "in a drawer"), | |
("plate", "in a cupboard"), | |
("cup", "on a shelf"), | |
("frying pan", "on a stove"), | |
("jar", "in the refrigerator"), | |
# Office items | |
("computer", "in an office"), | |
("printer", "by a desk"), | |
("chair", "around a conference table"), | |
("lamp", "on a workbench"), | |
("calendar", "on a wall"), | |
# Outdoor items | |
("bicycle", "on a street"), | |
("tent", "in a campsite"), | |
("fire", "in a fireplace"), | |
("mountain", "in the distance"), | |
("river", "through the woods"), | |
# and so on ... | |
]) | |
# To expedite the generation, you can combine themes and objects: | |
themes = [ | |
("wild animals", ["tiger", "lion", "cheetah", | |
"giraffe", "hippopotamus"], "in the wild"), | |
("household items", ["sofa", "tv", "clock", | |
"vase", "photo frame"], "in a living room"), | |
("clothes", ["shirt", "pants", "shoes", | |
"hat", "jacket"], "in a wardrobe"), | |
("musical instruments", ["drum", "trumpet", | |
"harp", "saxophone", "tuba"], "in a band"), | |
("cosmic entities", ["planet", "star", | |
"comet", "nebula", "asteroid"], "in space"), | |
# ... add more themes | |
] | |
# Using the themes to extend our list | |
for theme_name, theme_objects, theme_location in themes: | |
for theme_object in theme_objects: | |
objects_settings.append((theme_object, theme_location)) | |
# Sports equipment | |
objects_settings.extend([ | |
("basketball", "on a court"), | |
("golf ball", "on a golf course"), | |
("tennis racket", "on a tennis court"), | |
("baseball bat", "in a stadium"), | |
("hockey stick", "on an ice rink"), | |
("football", "on a field"), | |
("skateboard", "in a skatepark"), | |
("boxing gloves", "in a boxing ring"), | |
("ski", "on a snowy slope"), | |
("surfboard", "on a beach shore"), | |
]) | |
# Toys and games | |
objects_settings.extend([ | |
("teddy bear", "on a child's bed"), | |
("doll", "in a toy store"), | |
("toy car", "on a carpet"), | |
("board game", "on a table"), | |
("yo-yo", "in a child's hand"), | |
("kite", "in the sky on a windy day"), | |
("Lego bricks", "on a construction table"), | |
("jigsaw puzzle", "partially completed"), | |
("rubik's cube", "on a shelf"), | |
("action figure", "on display"), | |
]) | |
# Transportation | |
objects_settings.extend([ | |
("bus", "at a bus stop"), | |
("motorcycle", "on a road"), | |
("helicopter", "landing on a pad"), | |
("scooter", "on a sidewalk"), | |
("train", "at a station"), | |
("bicycle", "parked by a post"), | |
("boat", "in a harbor"), | |
("tractor", "on a farm"), | |
("airplane", "taking off from a runway"), | |
("submarine", "below sea level"), | |
]) | |
# Medieval theme | |
objects_settings.extend([ | |
("castle", "on a hilltop"), | |
("knight", "riding a horse"), | |
("bow and arrow", "in an archery range"), | |
("crown", "in a treasure chest"), | |
("dragon", "flying over mountains"), | |
("shield", "next to a warrior"), | |
("dagger", "on a wooden table"), | |
("torch", "lighting a dark corridor"), | |
("scroll", "sealed with wax"), | |
("cauldron", "with bubbling potion"), | |
]) | |
# Modern technology | |
objects_settings.extend([ | |
("smartphone", "on a charger"), | |
("laptop", "in a cafe"), | |
("headphones", "around a neck"), | |
("camera", "on a tripod"), | |
("drone", "flying over a park"), | |
("USB stick", "plugged into a computer"), | |
("watch", "on a wrist"), | |
("microphone", "on a podcast desk"), | |
("tablet", "with a digital pen"), | |
("VR headset", "ready for gaming"), | |
]) | |
# Nature | |
objects_settings.extend([ | |
("tree", "in a forest"), | |
("flower", "in a garden"), | |
("mountain", "on a horizon"), | |
("cloud", "in a blue sky"), | |
("waterfall", "in a scenic location"), | |
("beach", "next to an ocean"), | |
("cactus", "in a desert"), | |
("volcano", "erupting with lava"), | |
("coral", "under the sea"), | |
("moon", "in a night sky"), | |
]) | |
prompts = [f"A {obj} {setting}" for obj, setting in objects_settings] | |
return objects_settings | |
if __name__ == "__main__": | |
SAVE_DIR = "/scr/image/" | |
save_path = "img43-att_mask" | |
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
random_latents = torch.randn( | |
[1, 4, 96, 96], generator=torch.Generator().manual_seed(1)).to(torch_device) | |
try: | |
pipe = StableDiffusionPipelineSpatialAware.from_pretrained( | |
"stabilityai/stable-diffusion-2-1", torch_dtype=torch.float, variant="fp32", cache_dir="/gscratch/scrubbed/anasery/").to(torch_device) | |
except: | |
pipe = StableDiffusionPipelineSpatialAware.from_pretrained( | |
"stabilityai/stable-diffusion-2-1", torch_dtype=torch.float, variant="fp32").to(torch_device) | |
fg_object = "apple" # fg object stores the object to be masked | |
# overall prompt stores the prompt | |
overall_prompt = f"An {fg_object} on plate" | |
os.makedirs(f"{SAVE_DIR}/{overall_prompt}", exist_ok=True) | |
masks_list = create_boxes() | |
# torch.save(f"{overall_prompt}+masked", "prompt.pt") | |
obj_settings = objects_list() # 166 | |
for obj_setting in obj_settings[120:]: | |
fg_object = obj_setting[0] | |
overall_prompt = f"A {obj_setting[0]} {obj_setting[1]}" | |
print(overall_prompt) | |
# randomly select 10 numbers from range len of masks_list | |
selected_mask_ids = random.sample(range(len(masks_list)), 3) | |
for mask_id in selected_mask_ids: | |
os.makedirs( | |
f"{SAVE_DIR}/{overall_prompt}/mask{mask_id}", exist_ok=True) | |
torchvision.utils.save_image( | |
masks_list[mask_id][0][0], f"{SAVE_DIR}/{overall_prompt}/mask{mask_id}/mask.png") | |
for frozen_steps in range(0, 5): | |
img = generate_image(pipe, overall_prompt, random_latents, get_latents=False, num_inference_steps=50, fg_masks=masks_list[mask_id].to( | |
torch_device), fg_masked_latents=None, frozen_steps=frozen_steps, frozen_prompt=None, fg_prompt=fg_object) | |
img.save( | |
f"{SAVE_DIR}/{overall_prompt}/mask{mask_id}/{frozen_steps}.png") | |