File size: 2,965 Bytes
5b344d3 9d157fe 5b344d3 9d157fe 5b344d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import gc
import datetime
import os
import re
from typing import Literal
import streamlit as st
import torch
from diffusers import (
StableDiffusionPipeline,
StableDiffusionControlNetPipeline,
ControlNetModel,
EulerDiscreteScheduler,
DDIMScheduler,
)
PIPELINES = Literal["txt2img", "sketch2img"]
@st.cache_resource(max_entries=1)
def get_pipelines( name:PIPELINES, enable_cpu_offload = False, ) -> StableDiffusionPipeline:
pipe = None
if name == "txt2img":
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipe.unet.load_attn_procs("./")
pipe.safety_checker = lambda images, **kwargs: (images, [False] * len(images))
elif name == "sketch2img":
controlnet = ControlNetModel.from_pretrained("Abhi5ingh/model_dresscode", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", controlnet = controlnet, torch_dtype = torch.float16)
pipe.unet.load_attn_procs("./")
pipe.safety_checker = lambda images, **kwargs: (images, [False] * len(images))
if pipe is None:
raise Exception(f"Pipeline not Found {name}")
if enable_cpu_offload:
print("Enabling cpu offloading for the given pipeline")
pipe.enable_model_cpu_offload()
else:
pipe = pipe.to("cuda")
return pipe
def generate(
prompt,
pipeline_name: PIPELINES,
sketch_pil = None,
num_inference_steps = 30,
negative_prompt = None,
width = 512,
height = 512,
guidance_scale = 7.5,
controlnet_conditioning_scale = None,
enable_cpu_offload= False):
negative_prompt = negative_prompt if negative_prompt else None
p = st.progress(0)
callback = lambda step,*_: p.progress(step/num_inference_steps)
pipe = get_pipelines(pipeline_name,enable_cpu_offload=enable_cpu_offload)
torch.cuda.empty_cache()
kwargs = dict(
prompt = prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
callback=callback,
guidance_scale=guidance_scale,
)
print("kwargs",kwargs)
if pipeline_name =="sketch2img" and sketch_pil:
kwargs.update(sketch_pil=sketch_pil,controlnet_conditioning_scale=controlnet_conditioning_scale)
elif pipeline_name == "txt2img":
kwargs.update(width = width, height = height)
else:
raise Exception(
f"Cannot generate image for pipeline {pipeline_name} and {prompt}")
image = images[0]
os.makedirs("outputs", exist_ok=True)
filename = (
"outputs/"
+ re.sub(r"\s+", "_",prompt)[:30]
+ f"_{datetime.datetime.now().timestamp()}"
)
image.save(f"{filename}.png")
with open(f"{filename}.txt", "w") as f:
f.write(f"Prompt: {prompt}\n\nNegative Prompt:{negative_prompt}"
return image
) |