In [1]:
import gradio as gr
import sys
import os 
import tqdm
sys.path.append(os.path.abspath(os.path.join("", "..")))
import torch
import gc
import warnings
warnings.filterwarnings("ignore")
from PIL import Image
from utils import load_models, save_model_w2w, save_model_for_diffusers
from sampling import sample_weights

  from .autonotebook import tqdm as notebook_tqdm


[2024-06-28 00:45:26,702] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [2]:
global device
global generator 
global unet
global vae 
global text_encoder
global tokenizer
global noise_scheduler
device = "cuda:0"
generator = torch.Generator(device=device)

In [3]:
mean = torch.load("files/mean.pt").bfloat16().to(device)
std = torch.load("files/std.pt").bfloat16().to(device)
v = torch.load("files/V.pt").bfloat16().to(device)
proj = torch.load("files/proj_1000pc.pt").bfloat16().to(device)
df = torch.load("files/identity_df.pt")
weight_dimensions = torch.load("files/weight_dimensions.pt")

In [4]:
unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)

Loading pipeline components...: 100%|██████████| 6/6 [00:00<00:00, 10.79it/s]





In [5]:
global network

In [6]:
def sample_model():
    global unet
    del unet
    global network
    unet, _, _, _, _ = load_models(device)
    network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
 


In [7]:
@torch.no_grad()
def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
    global device
    global generator 
    global unet
    global vae 
    global text_encoder
    global tokenizer
    global noise_scheduler
    generator = generator.manual_seed(seed)
    latents = torch.randn(
        (1, unet.in_channels, 512 // 8, 512 // 8),
        generator = generator,
        device = device
    ).bfloat16()
   

    text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")

    text_embeddings = text_encoder(text_input.input_ids.to(device))[0]

    max_length = text_input.input_ids.shape[-1]
    uncond_input = tokenizer(
                            [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
                        )
    uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
    noise_scheduler.set_timesteps(ddim_steps) 
    latents = latents * noise_scheduler.init_noise_sigma
    
    for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):
        latent_model_input = torch.cat([latents] * 2)
        latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
        with network:
            noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
        #guidance
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
        latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
    
    latents = 1 / 0.18215 * latents
    image = vae.decode(latents).sample
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]

    image = Image.fromarray((image * 255).round().astype("uint8"))

    return [image] 

In [8]:
css = ''
with gr.Blocks(css=css) as demo:
    gr.Markdown("# <em>weights2weights</em> Demo")
    gr.Markdown("Demo for the [h94/IP-Adapter-FaceID model](https://huggingface.co/h94/IP-Adapter-FaceID) - Generate AI images with your own face - Non-commercial license")
    with gr.Row():
        with gr.Column():
            files = gr.Files(
                        label="Upload a photo of your face to invert, or sample a new model",
                        file_types=["image"]
                    )
            uploaded_files = gr.Gallery(label="Your images", visible=False, columns=5, rows=1, height=125)

            sample = gr.Button("Sample New Model")

            with gr.Column(visible=False) as clear_button:
                remove_and_reupload = gr.ClearButton(value="Remove and upload new ones", components=files, size="sm")
            prompt = gr.Textbox(label="Prompt",
                       info="Make sure to include 'sks person'" ,
                       placeholder="sks person", 
                       value="sks person")
            negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, cartoon", value="low quality, blurry, unfinished, cartoon")
            seed = gr.Number(value=5, precision=0, label="Seed", interactive=True)
            cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
            steps = gr.Slider(label="Inference Steps", precision=0, value=50, step=1, minimum=0, maximum=100, interactive=True)


            submit = gr.Button("Submit")

        with gr.Column():
            gallery = gr.Gallery(label="Generated Images")

        sample.click(fn=sample_model)
        
        submit.click(fn=inference,
                    inputs=[prompt, negative_prompt, cfg, steps, seed],
                    outputs=gallery)
            



            
            
demo.launch(share=True)

Running on local URL:  http://127.0.0.1:7860
Running on public URL: https://bc89b27b9704787832.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)




Loading pipeline components...: 100%|██████████| 6/6 [00:00<00:00,  8.95it/s]
Traceback (most recent call last):
  File "/home/amil/anaconda3/envs/dblora2/lib/python3.8/site-packages/gradio/routes.py", line 437, in run_predict
    output = await app.get_blocks().process_api(
  File "/home/amil/anaconda3/envs/dblora2/lib/python3.8/site-packages/gradio/blocks.py", line 1352, in process_api
    result = await self.call_function(
  File "/home/amil/anaconda3/envs/dblora2/lib/python3.8/site-packages/gradio/blocks.py", line 1077, in call_function
    prediction = await anyio.to_thread.run_sync(
  File "/home/amil/anaconda3/envs/dblora2/lib/python3.8/site-packages/anyio/to_thread.py", line 56, in run_sync
    return await get_async_backend().run_sync_in_worker_thread(
  File "/home/amil/anaconda3/envs/dblora2/lib/python3.8/site-packages/anyio/_backends/_asyncio.py", line 2134, in run_sync_in_worker_thread
    return await future
  File "/home/amil/anaconda3/envs/dblora2/lib/python3.8/site-pac


