fashionsd / sdfile.py
Abhi5ingh's picture
Update sdfile.py
9d157fe
raw
history blame
2.97 kB
import gc
import datetime
import os
import re
from typing import Literal
import streamlit as st
import torch
from diffusers import (
StableDiffusionPipeline,
StableDiffusionControlNetPipeline,
ControlNetModel,
EulerDiscreteScheduler,
DDIMScheduler,
)
PIPELINES = Literal["txt2img", "sketch2img"]
@st.cache_resource(max_entries=1)
def get_pipelines( name:PIPELINES, enable_cpu_offload = False, ) -> StableDiffusionPipeline:
pipe = None
if name == "txt2img":
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipe.unet.load_attn_procs("./")
pipe.safety_checker = lambda images, **kwargs: (images, [False] * len(images))
elif name == "sketch2img":
controlnet = ControlNetModel.from_pretrained("Abhi5ingh/model_dresscode", torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", controlnet = controlnet, torch_dtype = torch.float16)
pipe.unet.load_attn_procs("./")
pipe.safety_checker = lambda images, **kwargs: (images, [False] * len(images))
if pipe is None:
raise Exception(f"Pipeline not Found {name}")
if enable_cpu_offload:
print("Enabling cpu offloading for the given pipeline")
pipe.enable_model_cpu_offload()
else:
pipe = pipe.to("cuda")
return pipe
def generate(
prompt,
pipeline_name: PIPELINES,
sketch_pil = None,
num_inference_steps = 30,
negative_prompt = None,
width = 512,
height = 512,
guidance_scale = 7.5,
controlnet_conditioning_scale = None,
enable_cpu_offload= False):
negative_prompt = negative_prompt if negative_prompt else None
p = st.progress(0)
callback = lambda step,*_: p.progress(step/num_inference_steps)
pipe = get_pipelines(pipeline_name,enable_cpu_offload=enable_cpu_offload)
torch.cuda.empty_cache()
kwargs = dict(
prompt = prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
callback=callback,
guidance_scale=guidance_scale,
)
print("kwargs",kwargs)
if pipeline_name =="sketch2img" and sketch_pil:
kwargs.update(sketch_pil=sketch_pil,controlnet_conditioning_scale=controlnet_conditioning_scale)
elif pipeline_name == "txt2img":
kwargs.update(width = width, height = height)
else:
raise Exception(
f"Cannot generate image for pipeline {pipeline_name} and {prompt}")
image = images[0]
os.makedirs("outputs", exist_ok=True)
filename = (
"outputs/"
+ re.sub(r"\s+", "_",prompt)[:30]
+ f"_{datetime.datetime.now().timestamp()}"
)
image.save(f"{filename}.png")
with open(f"{filename}.txt", "w") as f:
f.write(f"Prompt: {prompt}\n\nNegative Prompt:{negative_prompt}"
return image
)