pix2pix-zero-01 / src /make_edit_direction.py
ysharma's picture
ysharma HF staff
upload git code base
d950775
raw
history blame
2.51 kB
import os, pdb
import argparse
import numpy as np
import torch
import requests
from PIL import Image
from diffusers import DDIMScheduler
from utils.edit_pipeline import EditingPipeline
## convert sentences to sentence embeddings
def load_sentence_embeddings(l_sentences, tokenizer, text_encoder, device="cuda"):
with torch.no_grad():
l_embeddings = []
for sent in l_sentences:
text_inputs = tokenizer(
sent,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0]
l_embeddings.append(prompt_embeds)
return torch.concatenate(l_embeddings, dim=0).mean(dim=0).unsqueeze(0)
if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--file_source_sentences', required=True)
parser.add_argument('--file_target_sentences', required=True)
parser.add_argument('--output_folder', required=True)
parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
args = parser.parse_args()
# load the model
pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch.float16).to("cuda")
bname_src = os.path.basename(args.file_source_sentences).strip(".txt")
outf_src = os.path.join(args.output_folder, bname_src+".pt")
if os.path.exists(outf_src):
print(f"Skipping source file {outf_src} as it already exists")
else:
with open(args.file_source_sentences, "r") as f:
l_sents = [x.strip() for x in f.readlines()]
mean_emb = load_sentence_embeddings(l_sents, pipe.tokenizer, pipe.text_encoder, device="cuda")
print(mean_emb.shape)
torch.save(mean_emb, outf_src)
bname_tgt = os.path.basename(args.file_target_sentences).strip(".txt")
outf_tgt = os.path.join(args.output_folder, bname_tgt+".pt")
if os.path.exists(outf_tgt):
print(f"Skipping target file {outf_tgt} as it already exists")
else:
with open(args.file_target_sentences, "r") as f:
l_sents = [x.strip() for x in f.readlines()]
mean_emb = load_sentence_embeddings(l_sents, pipe.tokenizer, pipe.text_encoder, device="cuda")
print(mean_emb.shape)
torch.save(mean_emb, outf_tgt)