Spaces:
Running
Running
File size: 4,538 Bytes
795c18e 57e6552 795c18e f98a5ff 795c18e 08edc45 795c18e 57e6552 795c18e 57e6552 6ddbf75 795c18e ba007e8 795c18e ba007e8 795c18e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import os
import io
import random
import requests
import gradio as gr
import numpy as np
from PIL import Image
import replicate
MAX_SEED = np.iinfo(np.int32).max
def predict(replicate_api, prompt, lora_id, lora_scale=0.95, aspect_ratio="1:1", seed=-1, randomize_seed=True, guidance_scale=3.5, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
# Validate API key and prompt
if not replicate_api or not prompt:
return "Error: Missing necessary inputs.", -1
# Set the seed if randomize_seed is True
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# Set the Replicate API token in the environment variable
os.environ["REPLICATE_API_TOKEN"] = replicate_api
# Construct the input for the replicate model
input_params = {
"prompt": prompt,
"output_format": "jpg",
"aspect_ratio": aspect_ratio,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
"seed": seed,
"disable_safety_checker": True
}
# If lora_id is provided, include it in the input
if lora_id and lora_id.strip()!="":
input_params["hf_lora"] = lora_id.strip()
input_params["lora_scale"] = lora_scale
try:
# Run the model using the user's API token from the environment variable
output = replicate.run(
"lucataco/flux-dev-lora:a22c463f11808638ad5e2ebd582e07a469031f48dd567366fb4c6fdab91d614d",
input=input_params
)
return output[0], seed # Return the generated image and seed
except Exception as e:
# Catch any exceptions, such as invalid API token or lack of credits
return f"Error: {str(e)}", -1
finally:
# Always remove the API key from the environment
if "REPLICATE_API_TOKEN" in os.environ:
del os.environ["REPLICATE_API_TOKEN"]
demo = gr.Interface(fn=predict, inputs="text", outputs="image")
css="""
#col-container {
margin: 0 auto;
max-width: 520px;
}
"""
examples = [
"a tiny astronaut hatching from an egg on the moon",
"a cat holding a sign that says hello world",
"an anime illustration of a wiener schnitzel",
]
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# FLUX Dev with Replicate API")
replicate_api = gr.Text(label="Replicate API", show_label=True, max_lines=1, placeholder="Enter Replicate API", container=True)
prompt = gr.Text(label="Prompt", show_label=True, lines = 2, max_lines=4, show_copy_button = True, placeholder="Enter your prompt", container=True)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
custom_lora = gr.Textbox(label="Custom LoRA", info="LoRA Hugging Face path (optional)", placeholder="multimodalart/vintage-ads-flux")
lora_scale = gr.Slider(
label="LoRA Scale",
minimum=0,
maximum=1,
step=0.01,
value=0.95,
)
aspect_ratio = gr.Radio(label="Aspect ratio", value="1:1", choices=["1:1", "4:5", "2:3", "3:4","9:16", "4:3", "16:9"])
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1,
maximum=15,
step=0.1,
value=3.5,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=28,
)
submit = gr.Button("Generate Image", scale=1)
output = gr.Image(label="Output Image", show_label=True)
gr.Examples(
examples=examples,
fn=predict,
inputs=[prompt]
)
gr.on(
triggers=[submit.click, prompt.submit],
fn=predict,
inputs=[replicate_api, prompt, custom_lora, lora_scale, aspect_ratio, seed, randomize_seed, guidance_scale, num_inference_steps],
outputs = [output, seed]
)
demo.launch() |