import gradio as gr import torch import spaces from diffusers import FluxControlPipeline, FluxTransformer2DModel #################################### # Load the model(s) on CPU # #################################### path = "sayakpaul/FLUX.1-dev-edit-v0" edit_transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16) pipeline = FluxControlPipeline.from_pretrained( "black-forest-labs/FLUX.1-dev", transformer=edit_transformer, torch_dtype=torch.bfloat16 ).to("cuda") ##################################### # The function for our Gradio app # ##################################### @spaces.GPU(duration=120) def generate(prompt, input_image): """ Runs the Flux Control pipeline for editing the given `input_image` with the specified `prompt`. The pipeline is on CPU by default. """ # Perform inference output_image = pipeline( control_image=input_image, prompt=prompt, guidance_scale=30.0, num_inference_steps=50, max_sequence_length=512, height=input_image.height, width=input_image.width, generator=torch.manual_seed(0), ).images[0] return output_image def launch_app(): with gr.Blocks() as demo: gr.Markdown( """ # Flux Control Editing This demo uses the [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) pipeline with an edit transformer from [Sayak Paul](https://huggingface.co/sayakpaul). **Acknowledgements**: - [Sayak Paul](https://huggingface.co/sayakpaul) for open-sourcing FLUX.1-dev-edit-v0 - [black-forest-labs](https://huggingface.co/black-forest-labs) for FLUX.1-dev """ ) with gr.Row(): prompt = gr.Textbox( label="Prompt", placeholder="e.g. 'Edit a certain thing in the image'" ) input_image = gr.Image( label="Image", type="pil", ) generate_button = gr.Button("Generate") output_image = gr.Image(label="Edited Image") # Connect button to function generate_button.click( fn=generate, inputs=[prompt, input_image], outputs=[output_image], ) gr.Examples( examples=[ ["Turn the color of the mushroom to gray", "mushroom.jpg"], ["Make the mushroom polka-dotted", "mushroom.jpg"], ], inputs=[prompt, input_image], ) return demo if __name__ == "__main__": demo = launch_app() demo.launch()