Spaces:
Paused
Paused
import os, pdb | |
import argparse | |
import numpy as np | |
import torch | |
import requests | |
from PIL import Image | |
from diffusers import DDIMScheduler | |
from utils.edit_pipeline import EditingPipeline | |
## convert sentences to sentence embeddings | |
def load_sentence_embeddings(l_sentences, tokenizer, text_encoder, device="cuda"): | |
with torch.no_grad(): | |
l_embeddings = [] | |
for sent in l_sentences: | |
text_inputs = tokenizer( | |
sent, | |
padding="max_length", | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids | |
prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0] | |
l_embeddings.append(prompt_embeds) | |
return torch.concatenate(l_embeddings, dim=0).mean(dim=0).unsqueeze(0) | |
if __name__=="__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--file_source_sentences', required=True) | |
parser.add_argument('--file_target_sentences', required=True) | |
parser.add_argument('--output_folder', required=True) | |
parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4') | |
args = parser.parse_args() | |
# load the model | |
pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch.float16).to("cuda") | |
bname_src = os.path.basename(args.file_source_sentences).strip(".txt") | |
outf_src = os.path.join(args.output_folder, bname_src+".pt") | |
if os.path.exists(outf_src): | |
print(f"Skipping source file {outf_src} as it already exists") | |
else: | |
with open(args.file_source_sentences, "r") as f: | |
l_sents = [x.strip() for x in f.readlines()] | |
mean_emb = load_sentence_embeddings(l_sents, pipe.tokenizer, pipe.text_encoder, device="cuda") | |
print(mean_emb.shape) | |
torch.save(mean_emb, outf_src) | |
bname_tgt = os.path.basename(args.file_target_sentences).strip(".txt") | |
outf_tgt = os.path.join(args.output_folder, bname_tgt+".pt") | |
if os.path.exists(outf_tgt): | |
print(f"Skipping target file {outf_tgt} as it already exists") | |
else: | |
with open(args.file_target_sentences, "r") as f: | |
l_sents = [x.strip() for x in f.readlines()] | |
mean_emb = load_sentence_embeddings(l_sents, pipe.tokenizer, pipe.text_encoder, device="cuda") | |
print(mean_emb.shape) | |
torch.save(mean_emb, outf_tgt) | |