import gradio as gr import random import spaces from langchain_groq import ChatGroq from langchain.schema import HumanMessage, SystemMessage import os from PIL import Image import numpy as np from huggingface_hub import InferenceClient # Set up API keys GROQ_API_KEY = os.environ.get('GROQ_API_KEY') # Set up LLM llm = ChatGroq(temperature=0, model_name='llama-3.1-8b-instant', groq_api_key=GROQ_API_KEY) MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 1024 # Initialize the schnell client client = InferenceClient("black-forest-labs/FLUX.1-schnell") # Few-shot examples with detailed prompts and image paths few_shot_examples = [ ("Create a birthday card for friend", "A vibrant birthday card with a colorful confetti background, featuring a large, playful 'Happy Birthday!' in the center. The card has a fun, festive theme with balloons, streamers, and a cupcake with a single lit candle. The message inside reads, 'Wishing you a day full of laughter and joy!'", "example1.webp"), ("An educational infographic showing the stages of the water cycle with bright, engaging visuals.", "An educational infographic illustrating the water cycle. The diagram shows labeled stages including evaporation, condensation, precipitation, and collection, with arrows guiding the flow. The colors are bright and engaging, with clouds, raindrops, and a sun. The design is simple and clear, suitable for a classroom setting.", "example2.webp"), ] def generate_detailed_prompt(user_input): system_message = SystemMessage(content=""" You are an AI assistant specialized in generating detailed image prompts. Given a simple description, create an elaborate and detailed prompt that can be used to generate high-quality images. Your response should be concise and no longer than 3 sentences. Use the following examples as a guide for the level of detail and creativity expected: """ + "\n\n".join([f"Input: {input}\nOutput: {output}" for input, _, _ in few_shot_examples])) human_message = HumanMessage(content=f"Generate a detailed image prompt based on this input, using no more than 3 sentences: {user_input}") response = llm([system_message, human_message]) return response.content def generate_image(prompt, width=1024, height=1024): try: result = client.text_to_image( prompt, width=width, height=height ) if isinstance(result, Image.Image): return result else: return Image.open(io.BytesIO(result)) except Exception as e: print(f"Error generating image: {str(e)}") return None def process_prompt(user_prompt): detailed_prompt = generate_detailed_prompt(user_prompt) return user_prompt, detailed_prompt def select_prompt(original_prompt, detailed_prompt, choice): return original_prompt if choice == "Original" else detailed_prompt def on_example_click(example): user_prompt, detailed_prompt, image_path = example image = Image.open(image_path) return user_prompt, detailed_prompt, image, gr.update(visible=True), gr.update(visible=True), gr.update(choices=["Original", "Detailed"], value="Original") css = """ #col-container {max-width: 800px; margin: 0 auto; padding: 20px; border-radius: 10px; box-shadow: 0px 0px 10px rgba(0, 0, 0, 0.1);} #title {text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;} #prompt, #result {margin-bottom: 20px;} """ with gr.Blocks(css=css, theme='gradio/soft') as demo: with gr.Column(elem_id="col-container"): gr.Markdown("# AI-Enhanced Image Generation", elem_id="title") with gr.Row(): prompt = gr.Textbox(label="Initial Prompt", placeholder="Enter your prompt", elem_id="prompt") generate_button = gr.Button("Generate Detailed Prompt") with gr.Row(visible=False) as prompt_selection_row: detailed_prompt = gr.Textbox(label="Detailed Prompt", elem_id="detailed_prompt") prompt_choice = gr.Radio(["Original", "Detailed"], label="Choose Prompt", value="Original") generate_image_button = gr.Button("Generate Image", visible=False) result = gr.Image(label="Generated Image", elem_id="result") examples = gr.Examples( examples=few_shot_examples, inputs=[prompt], outputs=[prompt, detailed_prompt, result, prompt_selection_row, generate_image_button, prompt_choice], fn=on_example_click, cache_examples=True ) generate_button.click( process_prompt, inputs=[prompt], outputs=[prompt, detailed_prompt, prompt_selection_row, generate_image_button], api_name="generate_detailed_prompt" ) generate_image_button.click( lambda p, d, c: generate_image(select_prompt(p, d, c)), inputs=[prompt, detailed_prompt, prompt_choice], outputs=[result], api_name="generate_image" ) demo.launch()