import gc import datetime import os import re from typing import Literal import streamlit as st import torch from diffusers import ( StableDiffusionPipeline, StableDiffusionControlNetPipeline, ControlNetModel, EulerDiscreteScheduler, DDIMScheduler, ) PIPELINES = Literal["txt2img", "sketch2img"] @st.cache_resource(max_entries=1) def get_pipelines( name:PIPELINES, enable_cpu_offload = False, ) -> StableDiffusionPipeline: pipe = None if name == "txt2img": pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) pipe.unet.load_attn_procs("D:\PycharmProjects\pythonProject\venv") pipe.safety_checker = lambda images, **kwargs: (images, [False] * len(images)) elif name == "sketch2img": controlnet = ControlNetModel.from_pretrained("Abhi5ingh/model_dresscode", torch_dtype=torch.float16) pipe = StableDiffusionControlNetPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", controlnet = controlnet, torch_dtype = torch.float16) pipe.unet.load_attn_procs("D:\PycharmProjects\pythonProject\venv") pipe.safety_checker = lambda images, **kwargs: (images, [False] * len(images)) if pipe is None: raise Exception(f"Pipeline not Found {name}") if enable_cpu_offload: print("Enabling cpu offloading for the given pipeline") pipe.enable_model_cpu_offload() else: pipe = pipe.to("cuda") return pipe def generate( prompt, pipeline_name: PIPELINES, sketch_pil = None, num_inference_steps = 30, negative_prompt = None, width = 512, height = 512, guidance_scale = 7.5, controlnet_conditioning_scale = None, enable_cpu_offload= False): negative_prompt = negative_prompt if negative_prompt else None p = st.progress(0) callback = lambda step,*_: p.progress(step/num_inference_steps) pipe = get_pipelines(pipeline_name,enable_cpu_offload=enable_cpu_offload) torch.cuda.empty_cache() kwargs = dict( prompt = prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, callback=callback, guidance_scale=guidance_scale, ) print("kwargs",kwargs) if pipeline_name =="sketch2img" and sketch_pil: kwargs.update(sketch_pil=sketch_pil,controlnet_conditioning_scale=controlnet_conditioning_scale) elif pipeline_name == "txt2img": kwargs.update(width = width, height = height) else: raise Exception( f"Cannot generate image for pipeline {pipeline_name} and {prompt}") image = images[0] os.makedirs("outputs", exist_ok=True) filename = ( "outputs/" + re.sub(r"\s+", "_",prompt)[:30] + f"_{datetime.datetime.now().timestamp()}" ) image.save(f"{filename}.png") with open(f"{filename}.txt", "w") as f: f.write(f"Prompt: {prompt}\n\nNegative Prompt:{negative_prompt}" return image )