T2I / app.py
girishwangikar's picture
Update app.py
2bad515 verified
raw
history blame
5.06 kB
import gradio as gr
import random
import spaces
from langchain_groq import ChatGroq
from langchain.schema import HumanMessage, SystemMessage
import os
from PIL import Image
import numpy as np
from huggingface_hub import InferenceClient
# Set up API keys
GROQ_API_KEY = os.environ.get('GROQ_API_KEY')
# Set up LLM
llm = ChatGroq(temperature=0, model_name='llama-3.1-8b-instant', groq_api_key=GROQ_API_KEY)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
# Initialize the schnell client
client = InferenceClient("black-forest-labs/FLUX.1-schnell")
# Few-shot examples with detailed prompts and image paths
few_shot_examples = [
("Create a birthday card for friend",
"A vibrant birthday card with a colorful confetti background, featuring a large, playful 'Happy Birthday!' in the center. The card has a fun, festive theme with balloons, streamers, and a cupcake with a single lit candle. The message inside reads, 'Wishing you a day full of laughter and joy!'",
"example1.webp"),
("An educational infographic showing the stages of the water cycle with bright, engaging visuals.",
"An educational infographic illustrating the water cycle. The diagram shows labeled stages including evaporation, condensation, precipitation, and collection, with arrows guiding the flow. The colors are bright and engaging, with clouds, raindrops, and a sun. The design is simple and clear, suitable for a classroom setting.",
"example2.webp"),
]
def generate_detailed_prompt(user_input):
system_message = SystemMessage(content="""
You are an AI assistant specialized in generating detailed image prompts.
Given a simple description, create an elaborate and detailed prompt that can be used to generate high-quality images.
Your response should be concise and no longer than 3 sentences.
Use the following examples as a guide for the level of detail and creativity expected:
""" + "\n\n".join([f"Input: {input}\nOutput: {output}" for input, _, _ in few_shot_examples]))
human_message = HumanMessage(content=f"Generate a detailed image prompt based on this input, using no more than 3 sentences: {user_input}")
response = llm([system_message, human_message])
return response.content
def generate_image(prompt, width=1024, height=1024):
try:
result = client.text_to_image(
prompt,
width=width,
height=height
)
if isinstance(result, Image.Image):
return result
else:
return Image.open(io.BytesIO(result))
except Exception as e:
print(f"Error generating image: {str(e)}")
return None
def process_prompt(user_prompt):
detailed_prompt = generate_detailed_prompt(user_prompt)
return user_prompt, detailed_prompt
def select_prompt(original_prompt, detailed_prompt, choice):
return original_prompt if choice == "Original" else detailed_prompt
def on_example_click(example):
user_prompt, detailed_prompt, image_path = example
image = Image.open(image_path)
return user_prompt, detailed_prompt, image, gr.update(visible=True), gr.update(visible=True), gr.update(choices=["Original", "Detailed"], value="Original")
css = """
#col-container {max-width: 800px; margin: 0 auto; padding: 20px; border-radius: 10px; box-shadow: 0px 0px 10px rgba(0, 0, 0, 0.1);}
#title {text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;}
#prompt, #result {margin-bottom: 20px;}
"""
with gr.Blocks(css=css, theme='gradio/soft') as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# AI-Enhanced Image Generation", elem_id="title")
with gr.Row():
prompt = gr.Textbox(label="Initial Prompt", placeholder="Enter your prompt", elem_id="prompt")
generate_button = gr.Button("Generate Detailed Prompt")
with gr.Row(visible=False) as prompt_selection_row:
detailed_prompt = gr.Textbox(label="Detailed Prompt", elem_id="detailed_prompt")
prompt_choice = gr.Radio(["Original", "Detailed"], label="Choose Prompt", value="Original")
generate_image_button = gr.Button("Generate Image", visible=False)
result = gr.Image(label="Generated Image", elem_id="result")
examples = gr.Examples(
examples=few_shot_examples,
inputs=[prompt],
outputs=[prompt, detailed_prompt, result, prompt_selection_row, generate_image_button, prompt_choice],
fn=on_example_click,
cache_examples=True
)
generate_button.click(
process_prompt,
inputs=[prompt],
outputs=[prompt, detailed_prompt, prompt_selection_row, generate_image_button],
api_name="generate_detailed_prompt"
)
generate_image_button.click(
lambda p, d, c: generate_image(select_prompt(p, d, c)),
inputs=[prompt, detailed_prompt, prompt_choice],
outputs=[result],
api_name="generate_image"
)
demo.launch()