# Trinket! Stable Inversion
A cheap alternative to finetuning stable diffusion, by [crumb](https://twitter.com/aicrumb)

Finetunes the embedding layer like Textual-Inversion does, but on CLIP Text/Image pairs instead of reconstruction loss from Stable Diffusion. Lower memory requirements + (sometimes) faster than finetuning the traditional way.

In [1]:
!pip install git+https://github.com/openai/CLIP -q
!pip install bitsandbytes -q

In [2]:
import bitsandbytes as bnb
import torchvision
from torchvision import transforms
from tqdm.auto import *
from torch import nn, optim
from torch.nn import functional as F
from PIL import Image
import requests
from io import BytesIO
import torch
import random
import clip
import pandas as pd


Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
For effortless bug reporting copy-paste your error into this form: https://docs.google.com/forms/d/e/1FAIpQLScPB8emS3Thkp66nvqwmjTEgxp8Y9ufuWTzFyr9kJ5AoI47dQ/viewform?usp=sf_link
CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so
CUDA SETUP: Highest compute capability among GPUs detected: 7.5
CUDA SETUP: Detected CUDA version 111
CUDA SETUP: Loading binary /usr/local/lib/python3.7/dist-packages/bitsandbytes/libbitsandbytes_cuda111.so...


  f'{candidate_env_vars["LD_LIBRARY_PATH"]} did not contain '


In [3]:
# the stable diffusion model uses L/14 by default
clip_model, _ = clip.load("ViT-L/14", jit=False)
clip_model = clip_model.cuda()
print("Loaded CLIP")

Loaded CLIP


In [4]:
bs = 16 # need a large batch size for the contrastive loss to work properly
steps = 64
epochs = 4
lr = 1e-3

In [5]:
# ultimately you just need >1000 prompt+url pairs in lists named prompts and urls, however you load it is fine
# this is how i scraped them so this is how i load them
# you can scrape danbooru/safebooru/others with pybooru with this script from waifu-diffusion https://github.com/harubaru/waifu-diffusion/blob/main/danbooru_data/scrape.py
df = pd.read_csv("/content/genshin.csv")
prompts = df['prompt']#[:steps*bs]
urls = df['url']#[:steps*bs
df.head()

Unnamed: 0.1,Unnamed: 0,url,prompt
0,0,https://cdn.donmai.us/original/51/bb/51bb44d69...,genshin_impact boo_tao_(genshin_impact) hu_tao...
1,1,https://cdn.donmai.us/original/4f/8e/4f8e5eaba...,genshin_impact arlecchino_(genshin_impact) cap...
2,2,https://cdn.donmai.us/original/c8/54/c85428d89...,genshin_impact fischl_(ein_immernachtstraum)_(...
3,3,https://cdn.donmai.us/original/e9/fc/e9fcc788e...,genshin_impact eula_(genshin_impact) 1girl :o ...
4,4,https://cdn.donmai.us/original/8a/63/8a639d21b...,genshin_impact kuki_shinobu 1girl breasts brid...


In [6]:
print(len(prompts))

1200


In [7]:
clip_model.token_embedding.weight.requires_grad = True
# opt = optim.Adam([clip_model.token_embedding.weight], lr)
opt = bnb.optim.AdamW8bit([clip_model.token_embedding.weight], lr)

In [8]:
# functions from another project of mine that's a bit messy
def fix_to_224(pil_image):
    width, height = pil_image.size
    if width < height:
        new_width = 224
        new_height = int(new_width * height / width)
    else:
        new_height = 224
        new_width = int(new_height * width / height)
    return pil_image.resize((new_width, new_height))
def to_tensor_and_center_crop(pil_image):
    tensor = torchvision.transforms.functional.to_tensor(pil_image)
    center_crop = torchvision.transforms.functional.center_crop(tensor, (224, 224))
    return center_crop
def fix(img):
    return to_tensor_and_center_crop(fix_to_224(img))

iter_prompts = iter(prompts)
iter_urls = iter(urls)
def get_batch(size=8, step=1, steps_per_epoch=128, epoch=1, total_epochs=1):
    to_tensor = transforms.ToTensor()
    x = []
    y = []
    total_steps = total_epochs*steps_per_epoch
    current_fraction = (epoch*steps_per_epoch-steps_per_epoch+step) / total_steps

    while len(x) < size:
        # uncomment these and space the big block right one tab
        # if you're using a set that might have dead urls
        try:
            url = next(iter_urls)
            response = requests.get(url)
            img = Image.open(BytesIO(response.content))
            img = img.convert("RGB")
            img = fix(img)
            x.append(img.unsqueeze(0))

            p = next(iter_prompts)

            # comment out these lines if you aren't using danbooru tags
            p = p.split(" ")
            random.shuffle(p)
            if current_fraction > 0.5: # halfway through training, start dropping half of the tags
                p = p[:len(p)//2]
            if current_fraction > 0.75: # halfway through training, start dropping half of the tags
                p = p[:len(p)//2]
            p = " ".join(p)
            y.append(p)
        except KeyboardInterrupt:
            print('Interrupted')
            break
        except:
            pass
        
    x = torch.cat(x, 0)
    x = clip_model.encode_image(x.cuda())
    
    y = clip.tokenize(y, truncate=True)
    y = clip_model.encode_text(y.cuda())
    return x, y

In [None]:
losses = []

for epoch in trange(epochs):
    pbar = trange(steps)
    iter_prompts = iter(prompts)
    iter_urls = iter(urls)

    for i in pbar:
        x, y = get_batch(bs, i, steps, epoch, epochs)
        loss = (x-y).pow(2).mean()
        # loss = spherical_distance_loss(x,y).mean() # TODO: try whatever loss they use in the CLIP paper instead
        loss.backward()
        opt.step()
        opt.zero_grad()

        pbar.set_description(str(loss.item()))
        losses.append(loss.item())

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/64 [00:00<?, ?it/s]

  "Palette images with Transparency expressed in bytes should be "


  0%|          | 0/64 [00:00<?, ?it/s]

In [None]:
import matplotlib.pyplot as plt

# hacky ema plot
losses_df = pd.DataFrame(losses, columns=['loss'])
losses_plot = losses_df.ewm(alpha=0.05).mean()['loss']
plt.plot(losses_plot)
plt.plot(losses)

In [None]:
torch.save(clip_model.token_embedding.weight, "token_embeddings.pt")

In [None]:
#@markdown ðŸ””
from google.colab import output
output.eval_js('new Audio("https://freesound.org/data/previews/80/80921_1022651-lq.ogg").play()')

### Upload to ðŸ¤—

In [None]:
#@markdown Log in
!pip install huggingface-hub -q
from huggingface_hub import notebook_login
notebook_login()

In [None]:
from huggingface_hub import HfApi
api = HfApi()

#@markdown go ahead and create the repository on ðŸ¤— before running this
my_repo = "user/my-stable-inversion" #@param {type:"string"}

api.upload_file(
    path_or_fileobj="token_embeddings.pt",
    path_in_repo="token_embeddings.pt",
    repo_id=my_repo,
    repo_type="model",
)

how you can load the stable inversions into a diffusers-based notebook like [Doohickey](https://github.com/aicrumb/doohickey) might look something like this

```
from huggingface_hub import hf_hub_download

stable_inversion = "user/my-stable-inversion" #@param {type:"string"}
if len(stable_inversion)>1:
    g = hf_hub_download(repo_id=stable_inversion, filename="token_embeddings.pt")
    text_encoder.text_model.embeddings.token_embedding.weight = torch.load(g)
```