|
import gc |
|
import datetime |
|
import os |
|
import re |
|
from typing import Literal |
|
|
|
import streamlit as st |
|
import torch |
|
from diffusers import ( |
|
StableDiffusionPipeline, |
|
StableDiffusionControlNetPipeline, |
|
ControlNetModel, |
|
EulerDiscreteScheduler, |
|
DDIMScheduler, |
|
) |
|
|
|
PIPELINES = Literal["txt2img", "sketch2img"] |
|
|
|
@st.cache_resource(max_entries=1) |
|
def get_pipelines( name:PIPELINES, enable_cpu_offload = False, ) -> StableDiffusionPipeline: |
|
pipe = None |
|
|
|
if name == "txt2img": |
|
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) |
|
pipe.unet.load_attn_procs("./") |
|
pipe.safety_checker = lambda images, **kwargs: (images, [False] * len(images)) |
|
elif name == "sketch2img": |
|
controlnet = ControlNetModel.from_pretrained("Abhi5ingh/model_dresscode", torch_dtype=torch.float16) |
|
pipe = StableDiffusionControlNetPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", controlnet = controlnet, torch_dtype = torch.float16) |
|
pipe.unet.load_attn_procs("./") |
|
pipe.safety_checker = lambda images, **kwargs: (images, [False] * len(images)) |
|
|
|
if pipe is None: |
|
raise Exception(f"Pipeline not Found {name}") |
|
|
|
if enable_cpu_offload: |
|
print("Enabling cpu offloading for the given pipeline") |
|
pipe.enable_model_cpu_offload() |
|
else: |
|
pipe = pipe.to("cuda") |
|
return pipe |
|
|
|
def generate( |
|
prompt, |
|
pipeline_name: PIPELINES, |
|
sketch_pil = None, |
|
num_inference_steps = 30, |
|
negative_prompt = None, |
|
width = 512, |
|
height = 512, |
|
guidance_scale = 7.5, |
|
controlnet_conditioning_scale = None, |
|
enable_cpu_offload= False): |
|
negative_prompt = negative_prompt if negative_prompt else None |
|
p = st.progress(0) |
|
callback = lambda step,*_: p.progress(step/num_inference_steps) |
|
pipe = get_pipelines(pipeline_name,enable_cpu_offload=enable_cpu_offload) |
|
torch.cuda.empty_cache() |
|
|
|
kwargs = dict( |
|
prompt = prompt, |
|
negative_prompt=negative_prompt, |
|
num_inference_steps=num_inference_steps, |
|
callback=callback, |
|
guidance_scale=guidance_scale, |
|
) |
|
print("kwargs",kwargs) |
|
|
|
if pipeline_name =="sketch2img" and sketch_pil: |
|
kwargs.update(sketch_pil=sketch_pil,controlnet_conditioning_scale=controlnet_conditioning_scale) |
|
elif pipeline_name == "txt2img": |
|
kwargs.update(width = width, height = height) |
|
else: |
|
raise Exception( |
|
f"Cannot generate image for pipeline {pipeline_name} and {prompt}") |
|
image = images[0] |
|
|
|
os.makedirs("outputs", exist_ok=True) |
|
|
|
filename = ( |
|
"outputs/" |
|
+ re.sub(r"\s+", "_",prompt)[:30] |
|
+ f"_{datetime.datetime.now().timestamp()}" |
|
) |
|
image.save(f"{filename}.png") |
|
with open(f"{filename}.txt", "w") as f: |
|
f.write(f"Prompt: {prompt}\n\nNegative Prompt:{negative_prompt}" |
|
return image |
|
|
|
) |