|
from diffusers import UnCLIPPipeline, DiffusionPipeline |
|
import torch |
|
import os |
|
from lora_diffusion.cli_lora_pti import * |
|
from lora_diffusion.lora import * |
|
from PIL import Image |
|
import numpy as np |
|
import json |
|
from lora_dataset import PivotalTuningDatasetCapation as PVD |
|
UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"} |
|
|
|
UNET_EXTENDED_TARGET_REPLACE = {"ResnetBlock2D", "CrossAttention", "Attention", "GEGLU"} |
|
|
|
TEXT_ENCODER_DEFAULT_TARGET_REPLACE = {"CLIPAttention"} |
|
|
|
TEXT_ENCODER_EXTENDED_TARGET_REPLACE = {"CLIPAttention"} |
|
|
|
DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE |
|
|
|
def save_all( |
|
unet, |
|
text_encoder, |
|
save_path, |
|
placeholder_token_ids=None, |
|
placeholder_tokens=None, |
|
save_lora=True, |
|
save_ti=True, |
|
target_replace_module_text=TEXT_ENCODER_DEFAULT_TARGET_REPLACE, |
|
target_replace_module_unet=DEFAULT_TARGET_REPLACE, |
|
safe_form=True, |
|
): |
|
if not safe_form: |
|
|
|
if save_ti: |
|
ti_path = ti_lora_path(save_path) |
|
learned_embeds_dict = {} |
|
for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids): |
|
learned_embeds = text_encoder.get_input_embeddings().weight[tok_id] |
|
print( |
|
f"Current Learned Embeddings for {tok}:, id {tok_id} ", |
|
learned_embeds[:4], |
|
) |
|
learned_embeds_dict[tok] = learned_embeds.detach().cpu() |
|
|
|
torch.save(learned_embeds_dict, ti_path) |
|
print("Ti saved to ", ti_path) |
|
|
|
|
|
if save_lora: |
|
|
|
save_lora_weight( |
|
unet, save_path, target_replace_module=target_replace_module_unet |
|
) |
|
print("Unet saved to ", save_path) |
|
|
|
save_lora_weight( |
|
text_encoder, |
|
_text_lora_path(save_path), |
|
target_replace_module=target_replace_module_text, |
|
) |
|
print("Text Encoder saved to ", _text_lora_path(save_path)) |
|
|
|
else: |
|
assert save_path.endswith( |
|
".safetensors" |
|
), f"Save path : {save_path} should end with .safetensors" |
|
|
|
loras = {} |
|
embeds = {} |
|
|
|
if save_lora: |
|
|
|
loras["unet"] = (unet, target_replace_module_unet) |
|
loras["text_encoder"] = (text_encoder, target_replace_module_text) |
|
|
|
if save_ti: |
|
for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids): |
|
learned_embeds = text_encoder.get_input_embeddings().weight[tok_id] |
|
print( |
|
f"Current Learned Embeddings for {tok}:, id {tok_id} ", |
|
learned_embeds[:4], |
|
) |
|
embeds[tok] = learned_embeds.detach().cpu() |
|
|
|
return save_safeloras_with_embeds(loras, embeds, save_path) |
|
|
|
def save_safeloras_with_embeds( |
|
modelmap = {}, |
|
embeds = {}, |
|
outpath="./lora.safetensors", |
|
): |
|
""" |
|
Saves the Lora from multiple modules in a single safetensor file. |
|
|
|
modelmap is a dictionary of { |
|
"module name": (module, target_replace_module) |
|
} |
|
""" |
|
weights = {} |
|
metadata = {} |
|
|
|
for name, (model, target_replace_module) in modelmap.items(): |
|
metadata[name] = json.dumps(list(target_replace_module)) |
|
|
|
for i, (_up, _down) in enumerate( |
|
extract_lora_as_tensor(model, target_replace_module) |
|
): |
|
rank = _down.shape[0] |
|
|
|
metadata[f"{name}:{i}:rank"] = str(rank) |
|
weights[f"{name}:{i}:up"] = _up |
|
weights[f"{name}:{i}:down"] = _down |
|
|
|
for token, tensor in embeds.items(): |
|
metadata[token] = EMBED_FLAG |
|
weights[token] = tensor |
|
|
|
sorted_dict = {key: value for key, value in sorted(weights.items())} |
|
state={} |
|
state['weights']=sorted_dict |
|
state['metadata'] = metadata |
|
|
|
|
|
|
|
|
|
return state |
|
def perform_tuning( |
|
unet, |
|
vae, |
|
text_encoder, |
|
dataloader, |
|
num_steps, |
|
scheduler, |
|
optimizer, |
|
save_steps: int, |
|
placeholder_token_ids, |
|
placeholder_tokens, |
|
save_path, |
|
lr_scheduler_lora, |
|
lora_unet_target_modules, |
|
lora_clip_target_modules, |
|
mask_temperature, |
|
out_name: str, |
|
tokenizer, |
|
test_image_path: str, |
|
cached_latents: bool, |
|
log_wandb: bool = False, |
|
wandb_log_prompt_cnt: int = 10, |
|
class_token: str = "person", |
|
train_inpainting: bool = False, |
|
): |
|
|
|
progress_bar = tqdm(range(num_steps)) |
|
progress_bar.set_description("Steps") |
|
global_step = 0 |
|
|
|
weight_dtype = torch.float16 |
|
|
|
unet.train() |
|
text_encoder.train() |
|
|
|
if log_wandb: |
|
preped_clip = prepare_clip_model_sets() |
|
|
|
loss_sum = 0.0 |
|
|
|
for epoch in range(math.ceil(num_steps / len(dataloader))): |
|
for batch in dataloader: |
|
lr_scheduler_lora.step() |
|
|
|
optimizer.zero_grad() |
|
|
|
loss = loss_step( |
|
batch, |
|
unet, |
|
vae, |
|
text_encoder, |
|
scheduler, |
|
train_inpainting=train_inpainting, |
|
t_mutliplier=0.8, |
|
mixed_precision=True, |
|
mask_temperature=mask_temperature, |
|
cached_latents=cached_latents, |
|
) |
|
loss_sum += loss.detach().item() |
|
|
|
loss.backward() |
|
torch.nn.utils.clip_grad_norm_( |
|
itertools.chain(unet.parameters(), text_encoder.parameters()), 1.0 |
|
) |
|
optimizer.step() |
|
progress_bar.update(1) |
|
logs = { |
|
"loss": loss.detach().item(), |
|
"lr": lr_scheduler_lora.get_last_lr()[0], |
|
} |
|
progress_bar.set_postfix(**logs) |
|
|
|
global_step += 1 |
|
|
|
if global_step % save_steps == 0: |
|
save_all( |
|
unet, |
|
text_encoder, |
|
placeholder_token_ids=placeholder_token_ids, |
|
placeholder_tokens=placeholder_tokens, |
|
save_path=os.path.join( |
|
save_path, f"step_{global_step}.safetensors" |
|
), |
|
target_replace_module_text=lora_clip_target_modules, |
|
target_replace_module_unet=lora_unet_target_modules, |
|
) |
|
moved = ( |
|
torch.tensor(list(itertools.chain(*inspect_lora(unet).values()))) |
|
.mean() |
|
.item() |
|
) |
|
|
|
print("LORA Unet Moved", moved) |
|
moved = ( |
|
torch.tensor( |
|
list(itertools.chain(*inspect_lora(text_encoder).values())) |
|
) |
|
.mean() |
|
.item() |
|
) |
|
|
|
print("LORA CLIP Moved", moved) |
|
|
|
if log_wandb: |
|
with torch.no_grad(): |
|
pipe = StableDiffusionPipeline( |
|
vae=vae, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
unet=unet, |
|
scheduler=scheduler, |
|
safety_checker=None, |
|
feature_extractor=None, |
|
) |
|
|
|
|
|
images = [] |
|
for file in os.listdir(test_image_path): |
|
if file.endswith(".png") or file.endswith(".jpg"): |
|
images.append( |
|
Image.open(os.path.join(test_image_path, file)) |
|
) |
|
|
|
wandb.log({"loss": loss_sum / save_steps}) |
|
loss_sum = 0.0 |
|
wandb.log( |
|
evaluate_pipe( |
|
pipe, |
|
target_images=images, |
|
class_token=class_token, |
|
learnt_token="".join(placeholder_tokens), |
|
n_test=wandb_log_prompt_cnt, |
|
n_step=50, |
|
clip_model_sets=preped_clip, |
|
) |
|
) |
|
|
|
if global_step >= num_steps: |
|
break |
|
|
|
return save_all( |
|
unet, |
|
text_encoder, |
|
placeholder_token_ids=placeholder_token_ids, |
|
placeholder_tokens=placeholder_tokens, |
|
save_path=os.path.join(save_path, f"{out_name}.safetensors"), |
|
target_replace_module_text=lora_clip_target_modules, |
|
target_replace_module_unet=lora_unet_target_modules, |
|
) |
|
|
|
|
|
def train( |
|
images, |
|
caption, |
|
pretrained_model_name_or_path: str, |
|
train_text_encoder: bool = True, |
|
pretrained_vae_name_or_path: str = None, |
|
revision: Optional[str] = None, |
|
perform_inversion: bool = True, |
|
use_template: Literal[None, "object", "style"] = None, |
|
train_inpainting: bool = False, |
|
placeholder_tokens: str = "", |
|
placeholder_token_at_data: Optional[str] = None, |
|
initializer_tokens: Optional[str] = None, |
|
seed: int = 42, |
|
resolution: int = 512, |
|
color_jitter: bool = True, |
|
train_batch_size: int = 1, |
|
sample_batch_size: int = 1, |
|
max_train_steps_tuning: int = 1000, |
|
max_train_steps_ti: int = 1000, |
|
save_steps: int = 100, |
|
gradient_accumulation_steps: int = 4, |
|
gradient_checkpointing: bool = False, |
|
lora_rank: int = 4, |
|
lora_unet_target_modules={"CrossAttention", "Attention", "GEGLU"}, |
|
lora_clip_target_modules={"CLIPAttention"}, |
|
lora_dropout_p: float = 0.0, |
|
lora_scale: float = 1.0, |
|
use_extended_lora: bool = False, |
|
clip_ti_decay: bool = True, |
|
learning_rate_unet: float = 1e-4, |
|
learning_rate_text: float = 1e-5, |
|
learning_rate_ti: float = 5e-4, |
|
continue_inversion: bool = False, |
|
continue_inversion_lr: Optional[float] = None, |
|
use_face_segmentation_condition: bool = False, |
|
cached_latents: bool = True, |
|
use_mask_captioned_data: bool = False, |
|
mask_temperature: float = 1.0, |
|
scale_lr: bool = False, |
|
lr_scheduler: str = "linear", |
|
lr_warmup_steps: int = 0, |
|
lr_scheduler_lora: str = "linear", |
|
lr_warmup_steps_lora: int = 0, |
|
weight_decay_ti: float = 0.00, |
|
weight_decay_lora: float = 0.001, |
|
use_8bit_adam: bool = False, |
|
device="cuda:0", |
|
extra_args: Optional[dict] = None, |
|
log_wandb: bool = False, |
|
wandb_log_prompt_cnt: int = 10, |
|
wandb_project_name: str = "new_pti_project", |
|
wandb_entity: str = "new_pti_entity", |
|
proxy_token: str = "person", |
|
enable_xformers_memory_efficient_attention: bool = False, |
|
out_name: str = "final_lora", |
|
): |
|
|
|
|
|
torch.manual_seed(seed) |
|
|
|
|
|
if len(placeholder_tokens) == 0: |
|
placeholder_tokens = [] |
|
print("PTI : Placeholder Tokens not given, using null token") |
|
else: |
|
placeholder_tokens = placeholder_tokens.split("|") |
|
|
|
assert ( |
|
sorted(placeholder_tokens) == placeholder_tokens |
|
), f"Placeholder tokens should be sorted. Use something like {'|'.join(sorted(placeholder_tokens))}'" |
|
|
|
if initializer_tokens is None: |
|
print("PTI : Initializer Tokens not given, doing random inits") |
|
initializer_tokens = ["<rand-0.017>"] * len(placeholder_tokens) |
|
else: |
|
initializer_tokens = initializer_tokens.split("|") |
|
|
|
assert len(initializer_tokens) == len( |
|
placeholder_tokens |
|
), "Unequal Initializer token for Placeholder tokens." |
|
|
|
if proxy_token is not None: |
|
class_token = proxy_token |
|
class_token = "".join(initializer_tokens) |
|
|
|
if placeholder_token_at_data is not None: |
|
tok, pat = placeholder_token_at_data.split("|") |
|
token_map = {tok: pat} |
|
|
|
else: |
|
token_map = {"DUMMY": "".join(placeholder_tokens)} |
|
|
|
print("PTI : Placeholder Tokens", placeholder_tokens) |
|
print("PTI : Initializer Tokens", initializer_tokens) |
|
|
|
|
|
text_encoder, vae, unet, tokenizer, placeholder_token_ids = get_models( |
|
pretrained_model_name_or_path, |
|
pretrained_vae_name_or_path, |
|
revision, |
|
placeholder_tokens, |
|
initializer_tokens, |
|
device=device, |
|
) |
|
|
|
noise_scheduler = DDPMScheduler.from_config( |
|
pretrained_model_name_or_path, subfolder="scheduler" |
|
) |
|
|
|
if gradient_checkpointing: |
|
unet.enable_gradient_checkpointing() |
|
|
|
if enable_xformers_memory_efficient_attention: |
|
from diffusers.utils.import_utils import is_xformers_available |
|
|
|
if is_xformers_available(): |
|
unet.enable_xformers_memory_efficient_attention() |
|
else: |
|
raise ValueError( |
|
"xformers is not available. Make sure it is installed correctly" |
|
) |
|
|
|
if scale_lr: |
|
unet_lr = learning_rate_unet * gradient_accumulation_steps * train_batch_size |
|
text_encoder_lr = ( |
|
learning_rate_text * gradient_accumulation_steps * train_batch_size |
|
) |
|
ti_lr = learning_rate_ti * gradient_accumulation_steps * train_batch_size |
|
else: |
|
unet_lr = learning_rate_unet |
|
text_encoder_lr = learning_rate_text |
|
ti_lr = learning_rate_ti |
|
|
|
train_dataset = PVD( |
|
images=images, |
|
caption=caption, |
|
token_map=token_map, |
|
use_template=use_template, |
|
tokenizer=tokenizer, |
|
size=resolution, |
|
color_jitter=color_jitter, |
|
use_face_segmentation_condition=use_face_segmentation_condition, |
|
use_mask_captioned_data=use_mask_captioned_data, |
|
train_inpainting=train_inpainting, |
|
) |
|
|
|
train_dataset.blur_amount = 200 |
|
|
|
if train_inpainting: |
|
assert not cached_latents, "Cached latents not supported for inpainting" |
|
|
|
train_dataloader = inpainting_dataloader( |
|
train_dataset, train_batch_size, tokenizer, vae, text_encoder |
|
) |
|
else: |
|
print(cached_latents) |
|
train_dataloader = text2img_dataloader( |
|
train_dataset, |
|
train_batch_size, |
|
tokenizer, |
|
vae, |
|
text_encoder, |
|
cached_latents=cached_latents, |
|
) |
|
|
|
index_no_updates = torch.arange(len(tokenizer)) != -1 |
|
|
|
for tok_id in placeholder_token_ids: |
|
index_no_updates[tok_id] = False |
|
|
|
unet.requires_grad_(False) |
|
vae.requires_grad_(False) |
|
|
|
params_to_freeze = itertools.chain( |
|
text_encoder.text_model.encoder.parameters(), |
|
text_encoder.text_model.final_layer_norm.parameters(), |
|
text_encoder.text_model.embeddings.position_embedding.parameters(), |
|
) |
|
for param in params_to_freeze: |
|
param.requires_grad = False |
|
|
|
if cached_latents: |
|
vae = None |
|
|
|
if perform_inversion: |
|
ti_optimizer = optim.AdamW( |
|
text_encoder.get_input_embeddings().parameters(), |
|
lr=ti_lr, |
|
betas=(0.9, 0.999), |
|
eps=1e-08, |
|
weight_decay=weight_decay_ti, |
|
) |
|
|
|
lr_scheduler = get_scheduler( |
|
lr_scheduler, |
|
optimizer=ti_optimizer, |
|
num_warmup_steps=lr_warmup_steps, |
|
num_training_steps=max_train_steps_ti, |
|
) |
|
|
|
train_inversion( |
|
unet, |
|
vae, |
|
text_encoder, |
|
train_dataloader, |
|
max_train_steps_ti, |
|
cached_latents=cached_latents, |
|
accum_iter=gradient_accumulation_steps, |
|
scheduler=noise_scheduler, |
|
index_no_updates=index_no_updates, |
|
optimizer=ti_optimizer, |
|
lr_scheduler=lr_scheduler, |
|
save_steps=save_steps, |
|
placeholder_tokens=placeholder_tokens, |
|
placeholder_token_ids=placeholder_token_ids, |
|
save_path="./tmps", |
|
test_image_path="./tmps", |
|
log_wandb=log_wandb, |
|
wandb_log_prompt_cnt=wandb_log_prompt_cnt, |
|
class_token=class_token, |
|
train_inpainting=train_inpainting, |
|
mixed_precision=False, |
|
tokenizer=tokenizer, |
|
clip_ti_decay=clip_ti_decay, |
|
) |
|
|
|
del ti_optimizer |
|
|
|
|
|
if not use_extended_lora: |
|
unet_lora_params, _ = inject_trainable_lora( |
|
unet, |
|
r=lora_rank, |
|
target_replace_module=lora_unet_target_modules, |
|
dropout_p=lora_dropout_p, |
|
scale=lora_scale, |
|
) |
|
else: |
|
print("PTI : USING EXTENDED UNET!!!") |
|
lora_unet_target_modules = ( |
|
lora_unet_target_modules | UNET_EXTENDED_TARGET_REPLACE |
|
) |
|
print("PTI : Will replace modules: ", lora_unet_target_modules) |
|
|
|
unet_lora_params, _ = inject_trainable_lora_extended( |
|
unet, r=lora_rank, target_replace_module=lora_unet_target_modules |
|
) |
|
print(f"PTI : has {len(unet_lora_params)} lora") |
|
|
|
print("PTI : Before training:") |
|
inspect_lora(unet) |
|
|
|
params_to_optimize = [ |
|
{"params": itertools.chain(*unet_lora_params), "lr": unet_lr}, |
|
] |
|
|
|
text_encoder.requires_grad_(False) |
|
|
|
if continue_inversion: |
|
params_to_optimize += [ |
|
{ |
|
"params": text_encoder.get_input_embeddings().parameters(), |
|
"lr": continue_inversion_lr |
|
if continue_inversion_lr is not None |
|
else ti_lr, |
|
} |
|
] |
|
text_encoder.requires_grad_(True) |
|
params_to_freeze = itertools.chain( |
|
text_encoder.text_model.encoder.parameters(), |
|
text_encoder.text_model.final_layer_norm.parameters(), |
|
text_encoder.text_model.embeddings.position_embedding.parameters(), |
|
) |
|
for param in params_to_freeze: |
|
param.requires_grad = False |
|
else: |
|
text_encoder.requires_grad_(False) |
|
if train_text_encoder: |
|
text_encoder_lora_params, _ = inject_trainable_lora( |
|
text_encoder, |
|
target_replace_module=lora_clip_target_modules, |
|
r=lora_rank, |
|
) |
|
params_to_optimize += [ |
|
{ |
|
"params": itertools.chain(*text_encoder_lora_params), |
|
"lr": text_encoder_lr, |
|
} |
|
] |
|
inspect_lora(text_encoder) |
|
|
|
lora_optimizers = optim.AdamW(params_to_optimize, weight_decay=weight_decay_lora) |
|
|
|
unet.train() |
|
if train_text_encoder: |
|
text_encoder.train() |
|
|
|
train_dataset.blur_amount = 70 |
|
|
|
lr_scheduler_lora = get_scheduler( |
|
lr_scheduler_lora, |
|
optimizer=lora_optimizers, |
|
num_warmup_steps=lr_warmup_steps_lora, |
|
num_training_steps=max_train_steps_tuning, |
|
) |
|
|
|
return perform_tuning( |
|
unet, |
|
vae, |
|
text_encoder, |
|
train_dataloader, |
|
max_train_steps_tuning, |
|
cached_latents=cached_latents, |
|
scheduler=noise_scheduler, |
|
optimizer=lora_optimizers, |
|
save_steps=save_steps, |
|
placeholder_tokens=placeholder_tokens, |
|
placeholder_token_ids=placeholder_token_ids, |
|
save_path="./tmps", |
|
lr_scheduler_lora=lr_scheduler_lora, |
|
lora_unet_target_modules=lora_unet_target_modules, |
|
lora_clip_target_modules=lora_clip_target_modules, |
|
mask_temperature=mask_temperature, |
|
tokenizer=tokenizer, |
|
out_name=out_name, |
|
test_image_path="./tmps", |
|
log_wandb=log_wandb, |
|
wandb_log_prompt_cnt=wandb_log_prompt_cnt, |
|
class_token=class_token, |
|
train_inpainting=train_inpainting, |
|
) |
|
|
|
def semantic_karlo(prompt, output_dir, num_initial_image, bg_preprocess=False): |
|
pipe = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha", torch_dtype=torch.float16) |
|
pipe = pipe.to('cuda') |
|
view_prompt=["front view of ","overhead view of ","side view of ", "back view of "] |
|
|
|
if bg_preprocess: |
|
|
|
import cv2 |
|
from carvekit.api.high import HiInterface |
|
interface = HiInterface(object_type="object", |
|
batch_size_seg=5, |
|
batch_size_matting=1, |
|
device='cuda' if torch.cuda.is_available() else 'cpu', |
|
seg_mask_size=640, |
|
matting_mask_size=2048, |
|
trimap_prob_threshold=231, |
|
trimap_dilation=30, |
|
trimap_erosion_iters=5, |
|
fp16=False) |
|
|
|
|
|
for i in range(num_initial_image): |
|
t=", white background" if bg_preprocess else ", white background" |
|
if i==0: |
|
prompt_ = f"{view_prompt[i%4]}{prompt}{t}" |
|
else: |
|
prompt_ = f"{view_prompt[i%4]}{prompt}" |
|
|
|
image = pipe(prompt_).images[0] |
|
fn=f"instance{i}.png" |
|
os.makedirs(output_dir,exist_ok=True) |
|
|
|
if bg_preprocess: |
|
|
|
|
|
|
|
img_without_background = interface([image]) |
|
mask = np.array(img_without_background[0]) > 127 |
|
image = np.array(image) |
|
image[~mask] = [255., 255., 255.] |
|
|
|
|
|
image = Image.fromarray(np.array(image)) |
|
|
|
image.save(os.path.join(output_dir,fn)) |
|
|
|
|
|
def semantic_sd(prompt, output_dir, num_initial_image, bg_preprocess=False): |
|
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") |
|
pipe = pipe.to('cuda') |
|
view_prompt=["front view of ","overhead view of ","side view of ", "back view of "] |
|
|
|
if bg_preprocess: |
|
|
|
import cv2 |
|
from carvekit.api.high import HiInterface |
|
interface = HiInterface(object_type="object", |
|
batch_size_seg=5, |
|
batch_size_matting=1, |
|
device='cuda' if torch.cuda.is_available() else 'cpu', |
|
seg_mask_size=640, |
|
matting_mask_size=2048, |
|
trimap_prob_threshold=231, |
|
trimap_dilation=30, |
|
trimap_erosion_iters=5, |
|
fp16=False) |
|
|
|
|
|
for i in range(num_initial_image): |
|
t=", white background" if bg_preprocess else ", white background" |
|
if i==0: |
|
prompt_ = f"{view_prompt[i%4]}{prompt}{t}" |
|
else: |
|
prompt_ = f"{view_prompt[i%4]}{prompt}" |
|
|
|
image = pipe(prompt_).images[0] |
|
fn=f"instance{i}.png" |
|
os.makedirs(output_dir,exist_ok=True) |
|
|
|
if bg_preprocess: |
|
|
|
|
|
|
|
img_without_background = interface([image]) |
|
mask = np.array(img_without_background[0]) > 127 |
|
image = np.array(image) |
|
image[~mask] = [255., 255., 255.] |
|
|
|
|
|
image = Image.fromarray(np.array(image)) |
|
|
|
image.save(os.path.join(output_dir,fn)) |
|
|
|
def semantic_coding(images, cfgs,sd,initial): |
|
ti_step=cfgs.pop('ti_step') |
|
pt_step=cfgs.pop('pt_step') |
|
|
|
prompt=cfgs['sd']['prompt'] |
|
|
|
|
|
|
|
if initial=="": |
|
initial=None |
|
|
|
state=train(images=images, caption=initial, pretrained_model_name_or_path='runwayml/stable-diffusion-v1-5',\ |
|
gradient_checkpointing=True,\ |
|
scale_lr=True,lora_rank=1,cached_latents=False,save_steps=max(ti_step,pt_step)+1,\ |
|
max_train_steps_ti=ti_step,max_train_steps_tuning=pt_step, use_template="object",\ |
|
lr_warmup_steps=0, lr_warmup_steps_lora=100, placeholder_tokens="<0>", initializer_tokens=initial,\ |
|
continue_inversion=True, continue_inversion_lr=1e-4,device="cuda:0", |
|
) |
|
if initial is not None: |
|
sd.prompt=prompt.replace(initial,'<0>') |
|
else: |
|
sd.prompt="a <0>" |
|
return state |