weights2weights / sampling.py
multimodalart's picture
Upload 200 files
8483373 verified
raw
history blame
1.52 kB
import torch
import torchvision
import os
import gc
import tqdm
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from transformers import CLIPTextModel
from peft import PeftModel, LoraConfig
from lora_w2w import LoRAw2w
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler
from peft.utils.save_and_load import load_peft_weights, set_peft_model_state_dict
from transformers import AutoTokenizer, PretrainedConfig
from PIL import Image
import warnings
warnings.filterwarnings("ignore")
from diffusers import (
AutoencoderKL,
DDPMScheduler,
DiffusionPipeline,
DPMSolverMultistepScheduler,
UNet2DConditionModel,
PNDMScheduler,
StableDiffusionPipeline
)
######## Sampling utilities
def sample_weights(unet, proj, mean, std, v, device, factor = 1.0):
# get mean and standard deviation for each principal component
m = torch.mean(proj, 0)
standev = torch.std(proj, 0)
del proj
torch.cuda.empty_cache()
# sample
sample = torch.zeros([1, 1000]).to(device)
for i in range(1000):
sample[0, i] = torch.normal(m[i], factor*standev[i], (1,1))
# load weights into network
network = LoRAw2w( sample, mean, std, v,
unet,
rank=1,
multiplier=1.0,
alpha=27.0,
train_method="xattn-strict"
).to(device, torch.bfloat16)
return network