ovi054's picture
Update app.py
08edc45 verified
raw
history blame
4.54 kB
import os
import io
import random
import requests
import gradio as gr
import numpy as np
from PIL import Image
import replicate
MAX_SEED = np.iinfo(np.int32).max
def predict(replicate_api, prompt, lora_id, lora_scale=0.95, aspect_ratio="1:1", seed=-1, randomize_seed=True, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
# Validate API key and prompt
if not replicate_api or not prompt:
return "Error: Missing necessary inputs.", -1
# Set the seed if randomize_seed is True
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# Set the Replicate API token in the environment variable
os.environ["REPLICATE_API_TOKEN"] = replicate_api
# Construct the input for the replicate model
input_params = {
"prompt": prompt,
"output_format": "jpg",
"aspect_ratio": aspect_ratio,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
"seed": seed,
"disable_safety_checker": True
}
# If lora_id is provided, include it in the input
if lora_id and lora_id.strip()!="":
input_params["hf_lora"] = lora_id.strip()
input_params["lora_scale"] = lora_scale
try:
# Run the model using the user's API token from the environment variable
output = replicate.run(
"lucataco/flux-dev-lora:a22c463f11808638ad5e2ebd582e07a469031f48dd567366fb4c6fdab91d614d",
input=input_params
)
return output[0], seed # Return the generated image and seed
except Exception as e:
# Catch any exceptions, such as invalid API token or lack of credits
return f"Error: {str(e)}", -1
finally:
# Always remove the API key from the environment
if "REPLICATE_API_TOKEN" in os.environ:
del os.environ["REPLICATE_API_TOKEN"]
demo = gr.Interface(fn=predict, inputs="text", outputs="image")
css="""
#col-container {
margin: 0 auto;
max-width: 520px;
}
"""
examples = [
"a tiny astronaut hatching from an egg on the moon",
"a cat holding a sign that says hello world",
"an anime illustration of a wiener schnitzel",
]
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# FLUX Dev with Replicate API")
replicate_api = gr.Text(label="Replicate API", show_label=True, max_lines=1, placeholder="Enter Replicate API", container=True)
prompt = gr.Text(label="Prompt", show_label=True, lines = 2, max_lines=4, show_copy_button = True, placeholder="Enter your prompt", container=True)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path (optional)", placeholder="multimodalart/vintage-ads-flux")
lora_scale = gr.Slider(
label="LoRA Scale",
minimum=0,
maximum=1,
step=0.01,
value=0.95,
)
aspect_ratio = gr.Radio(label="Aspect ratio", value="1:1", choices=["1:1", "4:5", "2:3", "3:4","9:16", "4:3", "16:9"])
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1,
maximum=15,
step=0.1,
value=3.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=28,
)
submit = gr.Button("Generate Image", scale=1)
output = gr.Image(label="Output Image", show_label=True)
gr.Examples(
examples=examples,
fn=predict,
inputs=[prompt]
)
gr.on(
triggers=[submit.click, prompt.submit],
fn=predict,
inputs=[replicate_api, prompt, custom_lora, lora_scale, aspect_ratio, seed, randomize_seed, guidance_scale, num_inference_steps],
outputs = [output, seed]
)
demo.launch()