File size: 1,949 Bytes
d950775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os, pdb

import argparse
import numpy as np
import torch
import requests
from PIL import Image

from diffusers import DDIMScheduler
from utils.edit_directions import construct_direction
from utils.edit_pipeline import EditingPipeline


if __name__=="__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--prompt_str', type=str, required=True)
    parser.add_argument('--random_seed', default=0)
    parser.add_argument('--task_name', type=str, default='cat2dog')
    parser.add_argument('--results_folder', type=str, default='output/test_cat')
    parser.add_argument('--num_ddim_steps', type=int, default=50)
    parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
    parser.add_argument('--xa_guidance', default=0.15, type=float)
    parser.add_argument('--negative_guidance_scale', default=5.0, type=float)
    parser.add_argument('--use_float_16', action='store_true')
    args = parser.parse_args()

    os.makedirs(args.results_folder, exist_ok=True)

    if args.use_float_16:
        torch_dtype = torch.float16
    else:
        torch_dtype = torch.float32

    # make the input noise map
    torch.cuda.manual_seed(args.random_seed)
    x = torch.randn((1,4,64,64), device="cuda")

    # Make the editing pipeline
    pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch_dtype).to("cuda")
    pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)

    rec_pil, edit_pil = pipe(args.prompt_str, 
        num_inference_steps=args.num_ddim_steps,
        x_in=x,
        edit_dir=construct_direction(args.task_name),
        guidance_amount=args.xa_guidance,
        guidance_scale=args.negative_guidance_scale,
        negative_prompt="" # use the empty string for the negative prompt
    )
    
    edit_pil[0].save(os.path.join(args.results_folder, f"edit.png"))
    rec_pil[0].save(os.path.join(args.results_folder, f"reconstruction.png"))